快捷方式

torch.multinomial

torch.multinomial(input, num_samples, replacement=False, *, generator=None, out=None) LongTensor

傳回一個 tensor,其中每一列都包含從多項式分佈(更嚴格的定義將是多元的,請參考 torch.distributions.multinomial.Multinomial 以獲取更多詳細資訊)取樣的 num_samples 索引,該分佈位於 tensor input 的相應列中。

注意

input 的各行總和不一定要等於 1 (在這種情況下,我們會將這些值用作權重),但必須是非負數、有限且總和不為零。

索引會依照取樣順序由左至右排列(首先取樣的會放在第一欄)。

如果 input 是一個向量,則 out 是一個大小為 num_samples 的向量。

如果 input 是一個具有 m 列的矩陣,則 out 是一個形狀為 (m×num_samples)(m \times \text{num\_samples}) 的矩陣。

如果 replacement 為 True,則取樣時會放回。

否則,取樣時不會放回,這表示針對某一行取樣到一個索引後,該索引就不能再針對該行取樣。

注意

當不放回取樣時,num_samples 必須小於 input 中非零元素的數量(如果 input 是一個矩陣,則為每行中非零元素的最小數量)。

參數
  • input (Tensor) – 包含機率的輸入張量

  • num_samples (int) – 要取樣的數量

  • replacement (bool, optional) – 是否要放回取樣

關鍵字引數
  • generator (torch.Generator, optional) – 用於取樣的偽隨機數產生器

  • out (Tensor, optional) – 輸出張量。

範例

>>> weights = torch.tensor([0, 10, 3, 0], dtype=torch.float) # create a tensor of weights
>>> torch.multinomial(weights, 2)
tensor([1, 2])
>>> torch.multinomial(weights, 5) # ERROR!
RuntimeError: cannot sample n_sample > prob_dist.size(-1) samples without replacement
>>> torch.multinomial(weights, 4, replacement=True)
tensor([ 2,  1,  1,  1])

文件

取得 PyTorch 的全面開發人員文件

檢視文件

教學課程

取得初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得問題解答

檢視資源