torch.multinomial¶
- torch.multinomial(input, num_samples, replacement=False, *, generator=None, out=None) LongTensor ¶
傳回一個 tensor,其中每一列都包含從多項式分佈(更嚴格的定義將是多元的,請參考
torch.distributions.multinomial.Multinomial
以獲取更多詳細資訊)取樣的num_samples
索引,該分佈位於 tensorinput
的相應列中。注意
input
的各行總和不一定要等於 1 (在這種情況下,我們會將這些值用作權重),但必須是非負數、有限且總和不為零。索引會依照取樣順序由左至右排列(首先取樣的會放在第一欄)。
如果
input
是一個向量,則out
是一個大小為num_samples
的向量。如果
input
是一個具有 m 列的矩陣,則out
是一個形狀為 的矩陣。如果 replacement 為
True
,則取樣時會放回。否則,取樣時不會放回,這表示針對某一行取樣到一個索引後,該索引就不能再針對該行取樣。
注意
當不放回取樣時,
num_samples
必須小於input
中非零元素的數量(如果input
是一個矩陣,則為每行中非零元素的最小數量)。- 參數
- 關鍵字引數
generator (
torch.Generator
, optional) – 用於取樣的偽隨機數產生器out (Tensor, optional) – 輸出張量。
範例
>>> weights = torch.tensor([0, 10, 3, 0], dtype=torch.float) # create a tensor of weights >>> torch.multinomial(weights, 2) tensor([1, 2]) >>> torch.multinomial(weights, 5) # ERROR! RuntimeError: cannot sample n_sample > prob_dist.size(-1) samples without replacement >>> torch.multinomial(weights, 4, replacement=True) tensor([ 2, 1, 1, 1])