torch.bernoulli¶
- torch.bernoulli(input: Tensor, *, generator: Optional[Generator], out: Optional[Tensor]) Tensor ¶
從白努利分佈中抽取二元隨機數(0 或 1)。
input
張量應該是一個包含機率值的張量,用於繪製二元隨機數。 因此,input
中的所有值都必須在以下範圍內: 。輸出張量的第 個元素將根據
input
中給定的第 個機率值繪製一個值 。返回的
out
張量只有 0 或 1 的值,並且形狀與input
相同。out
可以具有整數dtype
,但input
必須具有浮點數dtype
。- 參數
input (Tensor) – Bernoulli 分佈的機率值的輸入張量
- 關鍵字參數
generator (
torch.Generator
, optional) – 用於採樣的偽隨機數生成器(可選)out (Tensor, optional) – 輸出張量(可選)。
範例
>>> a = torch.empty(3, 3).uniform_(0, 1) # generate a uniform random matrix with range [0, 1] >>> a tensor([[ 0.1737, 0.0950, 0.3609], [ 0.7148, 0.0289, 0.2676], [ 0.9456, 0.8937, 0.7202]]) >>> torch.bernoulli(a) tensor([[ 1., 0., 0.], [ 0., 0., 0.], [ 1., 1., 1.]]) >>> a = torch.ones(3, 3) # probability of drawing "1" is 1 >>> torch.bernoulli(a) tensor([[ 1., 1., 1.], [ 1., 1., 1.], [ 1., 1., 1.]]) >>> a = torch.zeros(3, 3) # probability of drawing "1" is 0 >>> torch.bernoulli(a) tensor([[ 0., 0., 0.], [ 0., 0., 0.], [ 0., 0., 0.]])