OneHotCategorical¶
- class torchrl.modules.OneHotCategorical(logits: Optional[Tensor] = None, probs: Optional[Tensor] = None, grad_method: ReparamGradientStrategy = ReparamGradientStrategy.PassThrough, **kwargs)[原始碼]¶
One-hot 分類分佈。
此類別的行為與 torch.distributions.Categorical 完全相同,除了它讀取並產生離散張量的 one-hot 編碼。
- 參數:
logits (torch.Tensor) – 事件的 log 機率 (未正規化)
probs (torch.Tensor) – 事件機率
grad_method (ReparamGradientStrategy, optional) –
收集重新參數化樣本的策略。
ReparamGradientStrategy.PassThrough
將透過使用 softmax 值 log 機率作為樣本梯度的代理來計算樣本梯度。ReparamGradientStrategy.RelaxedOneHot
將使用torch.distributions.RelaxedOneHot
從分佈中取樣。
範例
>>> torch.manual_seed(0) >>> logits = torch.randn(4) >>> dist = OneHotCategorical(logits=logits) >>> print(dist.rsample((3,))) tensor([[1., 0., 0., 0.], [0., 0., 0., 1.], [1., 0., 0., 0.]])