快捷方式

torch.bernoulli

torch.bernoulli(input: Tensor, *, generator: Optional[Generator], out: Optional[Tensor]) Tensor

從白努利分佈中抽取二元隨機數(0 或 1)。

input 張量應該是一個包含機率值的張量,用於繪製二元隨機數。 因此,input 中的所有值都必須在以下範圍內: 0inputi10 \leq \text{input}_i \leq 1

輸出張量的第 ith\text{i}^{th} 個元素將根據 input 中給定的第 ith\text{i}^{th} 個機率值繪製一個值 11

outiBernoulli(p=inputi)\text{out}_{i} \sim \mathrm{Bernoulli}(p = \text{input}_{i})

返回的 out 張量只有 0 或 1 的值,並且形狀與 input 相同。

out 可以具有整數 dtype,但 input 必須具有浮點數 dtype

參數

input (Tensor) – Bernoulli 分佈的機率值的輸入張量

關鍵字參數
  • generator (torch.Generator, optional) – 用於採樣的偽隨機數生成器(可選)

  • out (Tensor, optional) – 輸出張量(可選)。

範例

>>> a = torch.empty(3, 3).uniform_(0, 1)  # generate a uniform random matrix with range [0, 1]
>>> a
tensor([[ 0.1737,  0.0950,  0.3609],
        [ 0.7148,  0.0289,  0.2676],
        [ 0.9456,  0.8937,  0.7202]])
>>> torch.bernoulli(a)
tensor([[ 1.,  0.,  0.],
        [ 0.,  0.,  0.],
        [ 1.,  1.,  1.]])

>>> a = torch.ones(3, 3) # probability of drawing "1" is 1
>>> torch.bernoulli(a)
tensor([[ 1.,  1.,  1.],
        [ 1.,  1.,  1.],
        [ 1.,  1.,  1.]])
>>> a = torch.zeros(3, 3) # probability of drawing "1" is 0
>>> torch.bernoulli(a)
tensor([[ 0.,  0.,  0.],
        [ 0.,  0.,  0.],
        [ 0.,  0.,  0.]])

文件

獲取 PyTorch 的全面開發者文檔

查看文檔

教程

獲取針對初學者和高級開發者的深入教程

查看教程

資源

尋找開發資源並獲得問題解答

查看資源