捷徑

MaskedOneHotCategorical

class torchrl.modules.MaskedOneHotCategorical(logits: Optional[Tensor] = None, probs: Optional[Tensor] = None, mask: Optional[Tensor] = None, indices: Optional[Tensor] = None, neg_inf: float = - inf, padding_value: Optional[int] = None, grad_method: ReparamGradientStrategy = ReparamGradientStrategy.PassThrough)[source]

MaskedCategorical 分佈。

參考:https://tensorflow.dev.org.tw/agents/api_docs/python/tf_agents/distributions/masked/MaskedCategorical

參數:
  • logits (torch.Tensor) – 事件對數機率(未正規化)

  • probs (torch.Tensor) – 事件機率。如果提供,對應於被遮罩 (masked) 項目的機率將會歸零,並且機率會沿著其最後一個維度重新正規化。

關鍵字引數:
  • mask (torch.Tensor) – 一個布林遮罩,其形狀與 logits/probs 相同,其中 False 條目是要被遮罩的項目。或者,如果 sparse_mask 為 True,則它表示分佈中的有效索引列表。與 indices 互斥。

  • indices (torch.Tensor) – 一個密集的索引張量,表示必須考慮哪些動作。與 mask 互斥。

  • neg_inf (float, optional) – 分配給無效(超出遮罩範圍)索引的對數機率值。預設值為 -inf。

  • padding_value – 當 sparse_mask == True 時,遮罩張量中的填充值,padding_value 將被忽略。

  • grad_method (ReparamGradientStrategy, optional) –

    用於收集重新參數化樣本的策略。ReparamGradientStrategy.PassThrough 將通過使用 softmax 值的對數機率作為樣本梯度的代理來計算樣本梯度。

    通過使用 softmax 值的對數機率作為樣本梯度的代理來計算樣本梯度。

    ReparamGradientStrategy.RelaxedOneHot 將使用 torch.distributions.RelaxedOneHot 從分佈中取樣。

  • torch.manual_seed (>>>) –

  • torch.randn (>>> logits =) –

  • torch.tensor (>>> mask =) –

  • MaskedOneHotCategorical (>>> dist =) –

  • dist.sample (>>> sample =) –

  • print (>>>) –

  • 0], (tensor([[0, 0, 1,) – [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [1, 0, 0, 0], [0, 0, 1, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0]])

  • print

  • -1.0831, (tensor([-1.1203, -1.0928, -1.0831, -1.1203, -1.1203, -1.0831, -1.1203,) – -1.1203, -1.1203])

  • torch.zeros_like (>>> sample_non_valid =) –

  • 1 (>>> sample_non_valid[..., 1] =) –

  • print

  • tensor ([-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf]) –

  • probabilities (>>> # with) –

  • torch.ones (>>> prob =) –

  • prob.sum() (>>> prob = prob /) –

  • torch.tensor

  • MaskedOneHotCategorical

  • torch.arange (>>> s =) –

  • torch.nn.functional.one_hot (>>> s =) –

  • print

  • -2.1972, (tensor([ -inf, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972,) – -2.1972, -2.1972])

log_prob(value: Tensor) Tensor[source]

傳回在 value 處評估的機率密度/質量函數的對數。

參數:

value (Tensor) –

property mode: Tensor

傳回分佈的眾數 (mode)。

rsample(sample_shape: Optional[Union[Size, Sequence]] = None) Tensor[source]

如果分佈參數是批次的,則產生一個 sample_shape 形狀的重新參數化樣本或 sample_shape 形狀的一批重新參數化樣本。

sample(sample_shape: Optional[Union[Size, Sequence[int]]] = None) Tensor[source]

如果分佈參數是批次的,則產生一個 sample_shape 形狀的樣本或 sample_shape 形狀的一批樣本。

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

取得針對初學者與進階開發者的深入教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源