快捷方式

tensordict.nn.distributions.CompositeDistribution

class tensordict.nn.distributions.CompositeDistribution(params: TensorDictBase, distribution_map: dict, *, name_map: Optional[dict] = None, extra_kwargs=None, aggregate_probabilities: Optional[bool] = None, log_prob_key: NestedKey = 'sample_log_prob', entropy_key: NestedKey = 'entropy')

分佈的組合。

將分佈與 TensorDict 介面組合在一起。方法 (log_prob_composite, entropy_composite, cdf, icdf, rsample, sample 等) 將傳回一個 tensordict,如果輸入是 tensordict,則可能會就地修改。

參數:
  • params (TensorDictBase) – 一個巢狀的鍵-張量映射,其中根條目指向樣本名稱,而葉子是分佈參數。 條目名稱必須與 distribution_map 的名稱相符。

  • distribution_map (Dict[NestedKey, Type[torch.distribution.Distribution]]) – 指示要使用的分佈類型。 分佈的名稱將與 tensordict 中樣本的名稱相符。

關鍵字參數:
  • name_map (Dict[NestedKey, NestedKey]]) – 一個字典,表示應在哪裡寫入每個樣本。 如果未提供,將使用 distribution_map 中的鍵名稱。

  • extra_kwargs (Dict[NestedKey, Dict]) – 可能不完整的字典,包含要建構的分佈的額外關鍵字引數。

  • aggregate_probabilities ( bool) – 如果為 True,則 log_prob()entropy() 方法會將各別分佈的機率和熵加總,並回傳單一的 tensor。如果為 False,則單一的 log-probabilities 會註冊到輸入的 tensordict 中(針對 log_prob()),或作為輸出 tensordict 的 leaves 回傳(針對 entropy())。這個參數可以在執行時被覆寫,只要將 aggregate_probabilities 引數傳給 log_probentropy 即可。預設值為 False

  • log_prob_key (NestedKey, optional) – 寫入 log_prob 的鍵。預設值為 ‘sample_log_prob’

  • entropy_key (NestedKey, optional) – 寫入 entropy 的鍵。預設值為 ‘entropy’

注意

在這個分佈類別中,包含參數的輸入 tensordict ( params) 的 batch-size 表示分佈的 batch_shape。 例如,呼叫 log_prob 所產生的 "sample_log_prob" 條目的形狀將會是 params 的形狀 (+ 任何額外的 batch 維度)。

範例

>>> params = TensorDict({
...     "cont": {"loc": torch.randn(3, 4), "scale": torch.rand(3, 4)},
...     ("nested", "disc"): {"logits": torch.randn(3, 10)}
... }, [3])
>>> dist = CompositeDistribution(params,
...     distribution_map={"cont": d.Normal, ("nested", "disc"): d.Categorical})
>>> sample = dist.sample((4,))
>>> sample = dist.log_prob(sample)
>>> print(sample)
TensorDict(
    fields={
        cont: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        cont_log_prob: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        nested: TensorDict(
            fields={
                disc: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.int64, is_shared=False),
                disc_log_prob: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([4]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([4]),
    device=None,
    is_shared=False)

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

取得初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得問題解答

檢視資源