tensordict.nn.distributions.CompositeDistribution¶
- class tensordict.nn.distributions.CompositeDistribution(params: TensorDictBase, distribution_map: dict, *, name_map: Optional[dict] = None, extra_kwargs=None, aggregate_probabilities: Optional[bool] = None, log_prob_key: NestedKey = 'sample_log_prob', entropy_key: NestedKey = 'entropy')¶
分佈的組合。
將分佈與 TensorDict 介面組合在一起。方法 (
log_prob_composite
,entropy_composite
,cdf
,icdf
,rsample
,sample
等) 將傳回一個 tensordict,如果輸入是 tensordict,則可能會就地修改。- 參數:
params (TensorDictBase) – 一個巢狀的鍵-張量映射,其中根條目指向樣本名稱,而葉子是分佈參數。 條目名稱必須與
distribution_map
的名稱相符。distribution_map (Dict[NestedKey, Type[torch.distribution.Distribution]]) – 指示要使用的分佈類型。 分佈的名稱將與 tensordict 中樣本的名稱相符。
- 關鍵字參數:
name_map (Dict[NestedKey, NestedKey]]) – 一個字典,表示應在哪裡寫入每個樣本。 如果未提供,將使用
distribution_map
中的鍵名稱。extra_kwargs (Dict[NestedKey, Dict]) – 可能不完整的字典,包含要建構的分佈的額外關鍵字引數。
aggregate_probabilities ( bool) – 如果為
True
,則log_prob()
和entropy()
方法會將各別分佈的機率和熵加總,並回傳單一的 tensor。如果為False
,則單一的 log-probabilities 會註冊到輸入的 tensordict 中(針對log_prob()
),或作為輸出 tensordict 的 leaves 回傳(針對entropy()
)。這個參數可以在執行時被覆寫,只要將aggregate_probabilities
引數傳給log_prob
和entropy
即可。預設值為False
。log_prob_key (NestedKey, optional) – 寫入 log_prob 的鍵。預設值為 ‘sample_log_prob’。
entropy_key (NestedKey, optional) – 寫入 entropy 的鍵。預設值為 ‘entropy’。
注意
在這個分佈類別中,包含參數的輸入 tensordict (
params
) 的 batch-size 表示分佈的 batch_shape。 例如,呼叫log_prob
所產生的"sample_log_prob"
條目的形狀將會是 params 的形狀 (+ 任何額外的 batch 維度)。範例
>>> params = TensorDict({ ... "cont": {"loc": torch.randn(3, 4), "scale": torch.rand(3, 4)}, ... ("nested", "disc"): {"logits": torch.randn(3, 10)} ... }, [3]) >>> dist = CompositeDistribution(params, ... distribution_map={"cont": d.Normal, ("nested", "disc"): d.Categorical}) >>> sample = dist.sample((4,)) >>> sample = dist.log_prob(sample) >>> print(sample) TensorDict( fields={ cont: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), cont_log_prob: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), nested: TensorDict( fields={ disc: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.int64, is_shared=False), disc_log_prob: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([4]), device=None, is_shared=False)}, batch_size=torch.Size([4]), device=None, is_shared=False)