有关于torch multinomial采样
创始人
2024-03-28 05:32:19
0

解释

多项式采样,核心思想就是从一个概率分布中,采样n个样本的index,概率高的优先会被采样到。

以官方例子作为解释:

>>> weights = torch.tensor([0, 10, 3, 0], dtype=torch.float) # create a tensor of weights
>>> torch.multinomial(weights, 2)
tensor([1, 2])

如上述,对于一个概率分布[0, 10, 3, 0],从中采样两个样本。由于10最大,3其次,所以采样出来的结果是1,2。这个很好理解。

>>> torch.multinomial(weights, 4) # ERROR!
RuntimeError: invalid argument 2: invalid multinomial distribution (with replacement=False,
not enough non-negative category to sample) at ../aten/src/TH/generic/THTensorRandom.cpp:320
>>> torch.multinomial(weights, 4, replacement=True)
tensor([ 2,  1,  1,  1])

但是由于剩下两个样本的概率为0,因此不可能被采样到。即,对于[0, 10, 3, 0]这个概率分布,只有10和3两个样本能够被采样到。所以当采样数量设置为4的时候,就会报错。

这个时候需要设置参数replacement=True,也就是放回采样,可以重复采样到同一个样本。因此输出结果为[ 2, 1, 1, 1],可以看到第一个采样到的是3,随后采样到的一直都是10,因为10最大。


一些补充

  1. multinomial是一种非常常见的采样策略,在NLP生成任务中进行decoding的时候会经常用到。相似的采样方法还有greedy search、beam search。具体可以参见huggingface:text generation
  2. multinomial的输入不仅可以是一个一维的tensor,笔者尝试二维的输入也照样可以。采样的维度为1。比方说:
>>> softmax(i) ## [2,20]
tensor([[0.0478, 0.0138, 0.0332, 0.0305, 0.0653, 0.0146, 0.1430, 0.1126, 0.0552,0.0183, 0.0737, 0.0337, 0.0433, 0.0323, 0.0106, 0.0604, 0.0992, 0.0377,0.0722, 0.0027],[0.0183, 0.0085, 0.0519, 0.0160, 0.0411, 0.0356, 0.0217, 0.1620, 0.0172,0.0469, 0.0071, 0.0161, 0.3427, 0.0042, 0.0125, 0.0379, 0.0437, 0.0064,0.0298, 0.0804]])  
>>> torch.multinomial(softmax(i),1)  ## [2,1] 
tensor([[ 6],[19]])
>>> torch.multinomial(softmax(i),2)  ## [2,2]
tensor([[ 4,  7],[15,  0]])

参考:

torch multinomial官方文档

相关内容

热门资讯

股价走高触发强赎 7月将有两只... 近日,银行正股股价表现强势,多只银行转债触发强制赎回条款。7月1日是杭银转债最后一个交易日,其最后转...
300548,“改名”,股价历... 科技股和顺周期板块再现“跷跷板”走势。今天上午,顺周期板块走强,银行、有色金属、白酒、新能源等板块上...
明阳电路在昆山投资成立集成电路... 人民财讯7月2日电,企查查APP显示,近日,昆山华芯微测集成电路有限公司成立,法定代表人为窦旭才,经...
民生银行“易创E贷”获“202... 在近日举行的第七届数字普惠金融大会上,民生银行“易创E贷”产品凭借在服务科技型中小微企业的创新与实践...
机器人大军逼近,很快,亚马逊的...   炒股就看金麒麟分析师研报,权威,专业,及时,全面,助您挖掘潜力主题机会! 亚马逊正迅速接近仓库...
广东省教育考试院通报:不存在中... 7月2日,广东省教育考试院发布通报称,7月1日下午广东省初中学业水平考试(以下简称“中考”)数学科目...
国际油价承压 中东油企拟放缓全... 财联社7月2日讯(编辑 秦嘉禾)在国际油价下行压力加剧背景下,中东两大国有能源巨头——沙特阿美(Sa...
光伏50ETF、光伏龙头ETF... 光伏设备板块走强,亿晶光电、欧晶科技涨停。银华光伏50ETF、汇添富光伏龙头ETF、浦银安盛光伏龙头...
世界银行发布重磅预测:黄金今年... 错过上半年,别再错过下半年!世界银行继续看好贵金属前景,黄金、白银、铂金或将继续延续强势……世界银行...
洪灝称高分红投资策略应继续有所... 【#洪灝称高分红投资策略应继续有所表现#】中国股市近期一个显著趋势是分红和回购活动显著增加。其中,许...