MaskedOneHotCategorical¶
- class torchrl.modules.MaskedOneHotCategorical(logits: Optional[Tensor] = None, probs: Optional[Tensor] = None, mask: Optional[Tensor] = None, indices: Optional[Tensor] = None, neg_inf: float = - inf, padding_value: Optional[int] = None, grad_method: ReparamGradientStrategy = ReparamGradientStrategy.PassThrough)[source]¶
MaskedCategorical 分佈。
- 參數:
logits (torch.Tensor) – 事件對數機率(未正規化)
probs (torch.Tensor) – 事件機率。如果提供,對應於被遮罩 (masked) 項目的機率將會歸零,並且機率會沿著其最後一個維度重新正規化。
- 關鍵字引數:
mask (torch.Tensor) – 一個布林遮罩,其形狀與
logits
/probs
相同,其中False
條目是要被遮罩的項目。或者,如果sparse_mask
為 True,則它表示分佈中的有效索引列表。與indices
互斥。indices (torch.Tensor) – 一個密集的索引張量,表示必須考慮哪些動作。與
mask
互斥。neg_inf (float, optional) – 分配給無效(超出遮罩範圍)索引的對數機率值。預設值為 -inf。
padding_value – 當 sparse_mask == True 時,遮罩張量中的填充值,padding_value 將被忽略。
grad_method (ReparamGradientStrategy, optional) –
用於收集重新參數化樣本的策略。
ReparamGradientStrategy.PassThrough
將通過使用 softmax 值的對數機率作為樣本梯度的代理來計算樣本梯度。通過使用 softmax 值的對數機率作為樣本梯度的代理來計算樣本梯度。
ReparamGradientStrategy.RelaxedOneHot
將使用torch.distributions.RelaxedOneHot
從分佈中取樣。torch.manual_seed (>>>) –
torch.randn (>>> logits =) –
torch.tensor (>>> mask =) –
MaskedOneHotCategorical (>>> dist =) –
dist.sample (>>> sample =) –
print (>>>) –
0], (tensor([[0, 0, 1,) – [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [1, 0, 0, 0], [0, 0, 1, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0]])
print –
-1.0831, (tensor([-1.1203, -1.0928, -1.0831, -1.1203, -1.1203, -1.0831, -1.1203,) – -1.1203, -1.1203])
torch.zeros_like (>>> sample_non_valid =) –
1 (>>> sample_non_valid[..., 1] =) –
print –
tensor ([-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf]) –
probabilities (>>> # with) –
torch.ones (>>> prob =) –
prob.sum() (>>> prob = prob /) –
torch.tensor –
MaskedOneHotCategorical –
torch.arange (>>> s =) –
torch.nn.functional.one_hot (>>> s =) –
print –
-2.1972, (tensor([ -inf, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972,) – -2.1972, -2.1972])