捷徑

OneHotCategorical

class torchrl.modules.OneHotCategorical(logits: Optional[Tensor] = None, probs: Optional[Tensor] = None, grad_method: ReparamGradientStrategy = ReparamGradientStrategy.PassThrough, **kwargs)[原始碼]

One-hot 分類分佈。

此類別的行為與 torch.distributions.Categorical 完全相同,除了它讀取並產生離散張量的 one-hot 編碼。

參數:
  • logits (torch.Tensor) – 事件的 log 機率 (未正規化)

  • probs (torch.Tensor) – 事件機率

  • grad_method (ReparamGradientStrategy, optional) –

    收集重新參數化樣本的策略。ReparamGradientStrategy.PassThrough 將透過使用 softmax 值 log 機率作為樣本梯度的代理來計算樣本梯度。

    ReparamGradientStrategy.RelaxedOneHot 將使用 torch.distributions.RelaxedOneHot 從分佈中取樣。

範例

>>> torch.manual_seed(0)
>>> logits = torch.randn(4)
>>> dist = OneHotCategorical(logits=logits)
>>> print(dist.rsample((3,)))
tensor([[1., 0., 0., 0.],
        [0., 0., 0., 1.],
        [1., 0., 0., 0.]])
log_prob(value: Tensor) Tensor[原始碼]

傳回在 value 評估的機率密度/質量函數的對數。

參數:

value (Tensor) –

property mode: Tensor

傳回分配的眾數 (mode)。

rsample(sample_shape: Optional[Union[Size, Sequence]] = None) Tensor[原始碼]

產生一個 sample_shape 形狀的重新參數化樣本,或者如果分配參數是批次的,則產生一個 sample_shape 形狀的重新參數化樣本批次。

sample(sample_shape: Optional[Union[Size, Sequence]] = None) Tensor[原始碼]

產生一個 sample_shape 形狀的樣本,或者如果分配參數是批次的,則產生一個 sample_shape 形狀的樣本批次。

文件

取得 PyTorch 的完整開發人員文件

檢視文件

教學課程

取得初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得問題解答

檢視資源