ProbabilisticTensorDictModule¶
- class tensordict.nn.ProbabilisticTensorDictModule(*args, **kwargs)¶
一個機率 TD 模組。
ProbabilisticTensorDictModule 是一個非參數模組,代表一個機率分佈。它使用指定的 in_keys 從輸入 TensorDict 讀取分佈參數。輸出會根據某些規則進行採樣,這些規則由輸入的
default_interaction_type
參數和interaction_type()
全域函式指定。ProbabilisticTensorDictModule
可用於建構分佈(透過get_dist()
方法)和/或從此分佈取樣(透過對模組進行常規__call__()
)。一個
ProbabilisticTensorDictModule
實例有兩個主要特點: - 它讀取和寫入 TensorDict 物件 - 它使用真實映射 R^n -> R^m 來建立 R^d 中的分佈,可以從中取樣或計算值。當
__call__
/forward
方法被呼叫時,會建立一個分佈,並計算一個值(使用 'mean'、'mode'、'median' 屬性或 'rsample'、'sample' 方法)。如果提供的 TensorDict 已經擁有所有需要的鍵值對,則會跳過取樣步驟。預設情況下,ProbabilisticTensorDictModule 分佈類別是一個 Delta 分佈,這使得 ProbabilisticTensorDictModule 成為一個簡單的確定性映射函式包裝器。
- 參數:
in_keys (NestedKey 或 NestedKey 列表 或 dict) – 將從輸入 TensorDict 讀取並用於建構分佈的鍵。重要的是,如果它是一個 NestedKey 列表或 NestedKey,則這些鍵的葉節點(最後一個元素)必須與感興趣的分佈類別使用的關鍵字匹配,例如,Normal 分佈的
"loc"
和"scale"
等。如果 in_keys 是一個字典,則鍵是分佈的鍵,值是在 tensordict 中與相應分佈鍵匹配的鍵。out_keys (NestedKey 或 NestedKey 列表) – 樣本值將被寫入的鍵。重要的是,如果在輸入 TensorDict 中找到這些鍵,則會跳過取樣步驟。
default_interaction_mode (str, 可選) – 已棄用 僅限關鍵字的參數。請改用 default_interaction_type。
default_interaction_type (InteractionType, 可選) –
僅限關鍵字的參數。用於檢索輸出值的預設方法。應該是 InteractionType 之一:MODE、MEDIAN、MEAN 或 RANDOM(在這種情況下,值是從分佈中隨機取樣)。預設值為 MODE。
注意
當繪製樣本時,
ProbabilisticTensorDictModule
實例將首先尋找由interaction_type()
全域函式指定的交互模式。如果這返回 None(其預設值),則將使用 ProbabilisticTDModule 實例的 default_interaction_type。請注意,DataCollectorBase
實例將預設使用 set_interaction_type 設為tensordict.nn.InteractionType.RANDOM
。注意
在某些情況下,可能無法透過對應的屬性直接取得眾數、中位數或平均數值。為了改善這種情況,
ProbabilisticTensorDictModule
會先嘗試呼叫get_mode()
、get_median()
或get_mean()
來取得數值(如果該方法存在)。distribution_class (Type, optional) –
僅限關鍵字參數。用於抽樣的
torch.distributions.Distribution
類別。預設值為Delta
。注意
如果 distribution 類別屬於
CompositeDistribution
類型,則可以從此類別的distribution_kwargs
關鍵字參數提供的"distribution_map"
或"name_map"
關鍵字參數中直接推斷出out_keys
,在這種情況下,out_keys
是可選的。distribution_kwargs (dict, optional) – 僅限關鍵字參數。要傳遞給 distribution 的關鍵字參數對。
return_log_prob (bool, optional) – 僅限關鍵字參數。如果
True
,則 distribution 樣本的 log-probability 將會使用索引鍵 log_prob_key 寫入 tensordict 中。預設值為False
。log_prob_key (NestedKey, optional) – 如果 return_log_prob = True,則 log_prob 的寫入索引鍵。預設值為 ‘sample_log_prob’。
cache_dist (bool, optional) – 僅限關鍵字參數。實驗性功能:如果
True
,distribution 的參數(即 module 的輸出)將會與樣本一起寫入 tensordict 中。這些參數可用於稍後重新計算原始 distribution(例如,計算用於抽樣動作的 distribution 與 PPO 中更新的 distribution 之間的 divergence)。預設值為False
。n_empirical_estimate (int, optional) – 僅限關鍵字參數。計算經驗平均值時使用的樣本數,當無法直接取得平均值時使用。預設值為 1000。
範例
>>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn import ( ... ProbabilisticTensorDictModule, ... ProbabilisticTensorDictSequential, ... TensorDictModule, ... ) >>> from tensordict.nn.distributions import NormalParamExtractor >>> from tensordict.nn.functional_modules import make_functional >>> from torch.distributions import Normal, Independent >>> td = TensorDict( ... {"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3] ... ) >>> net = torch.nn.GRUCell(4, 8) >>> module = TensorDictModule( ... net, in_keys=["input", "hidden"], out_keys=["params"] ... ) >>> normal_params = TensorDictModule( ... NormalParamExtractor(), in_keys=["params"], out_keys=["loc", "scale"] ... ) >>> def IndepNormal(**kwargs): ... return Independent(Normal(**kwargs), 1) >>> prob_module = ProbabilisticTensorDictModule( ... in_keys=["loc", "scale"], ... out_keys=["action"], ... distribution_class=IndepNormal, ... return_log_prob=True, ... ) >>> td_module = ProbabilisticTensorDictSequential( ... module, normal_params, prob_module ... ) >>> params = TensorDict.from_module(td_module) >>> with params.to_module(td_module): ... _ = td_module(td) >>> print(td) TensorDict( fields={ action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False), input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), params: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False), sample_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False) >>> with params.to_module(td_module): ... dist = td_module.get_dist(td) >>> print(dist) Independent(Normal(loc: torch.Size([3, 4]), scale: torch.Size([3, 4])), 1) >>> # we can also apply the module to the TensorDict with vmap >>> from torch import vmap >>> params = params.expand(4) >>> def func(td, params): ... with params.to_module(td_module): ... return td_module(td) >>> td_vmap = vmap(func, (None, 0))(td, params) >>> print(td_vmap) TensorDict( fields={ action: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), hidden: Tensor(shape=torch.Size([4, 3, 8]), device=cpu, dtype=torch.float32, is_shared=False), input: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), loc: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), params: Tensor(shape=torch.Size([4, 3, 8]), device=cpu, dtype=torch.float32, is_shared=False), sample_log_prob: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False), scale: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([4, 3]), device=None, is_shared=False)
- forward(tensordict: TensorDictBase = None, tensordict_out: tensordict.base.TensorDictBase | None = None, _requires_sample: bool = True) TensorDictBase ¶
定義每次呼叫時執行的計算。
應該由所有子類別覆寫。
注意
雖然正向傳遞的配方需要在這個函式中定義,但應該在這個函式之後呼叫
Module
實例,而不是這個函式,因為前者會處理執行註冊的 hooks,而後者會靜默地忽略它們。
- get_dist(tensordict: TensorDictBase) Distribution ¶
使用輸入 tensordict 中提供的參數建立
torch.distribution.Distribution
實例。
- log_prob(tensordict)¶
寫入 distribution 樣本的 log-probability。