SafeProbabilisticModule¶
- class torchrl.modules.tensordict_module.SafeProbabilisticModule(*args, **kwargs)[來源]¶
tensordict.nn.ProbabilisticTensorDictModule
子類別,它接受TensorSpec
作為參數來控制輸出域。SafeProbabilisticModule 是一個非參數模組,表示機率分佈。它使用指定的 in_keys 從輸入 TensorDict 中讀取分佈參數。輸出根據某些規則進行採樣,這些規則由輸入
default_interaction_type
參數和interaction_type()
全域函數指定。SafeProbabilisticModule
可用於建構分佈(透過get_dist()
方法),和/或從這個分佈中採樣(透過對模組的常規__call__()
)。SafeProbabilisticModule
實例有兩個主要特點:- 它讀取和寫入 TensorDict 物件 - 它使用真實映射 R^n -> R^m 來建立 R^d 中的分佈,可以從中採樣或計算值。當呼叫
__call__
/forward
方法時,會建立一個分佈,並計算一個值(使用 'mean'、'mode'、'median' 屬性或 'rsample'、'sample' 方法)。如果提供的 TensorDict 已經具有所有期望的鍵值對,則會跳過採樣步驟。預設情況下,SafeProbabilisticModule 分佈類別是 Delta 分佈,這使得 SafeProbabilisticModule 成為確定性映射函數周圍的簡單包裝器。
- 參數:
in_keys (NestedKey 或 NestedKey 的 list 或 dict) – 將從輸入 TensorDict 讀取並用於建構分佈的鍵。重要的是,如果它是 NestedKey 的 list 或 NestedKey,這些鍵的葉節點(最後一個元素)必須與感興趣的分佈類別使用的關鍵字相符,例如 Normal 分佈的
"loc"
和"scale"
等等。如果 in_keys 是一個字典,則鍵是分佈的鍵,而值是 TensorDict 中將與相應分佈鍵匹配的鍵。out_keys (NestedKey 或 NestedKey 的 list) – 將寫入採樣值的鍵。重要的是,如果在輸入 TensorDict 中找到這些鍵,則將跳過採樣步驟。
spec (TensorSpec) – 第一個輸出張量的規格。在呼叫 td_module.random() 以在目標空間中產生隨機值時使用。
safe (bool, optional) – 如果
True
,則會根據輸入規格檢查樣本的值。由於探索策略或數值下溢/溢位問題,可能會發生超出範圍的採樣。對於spec
參數,此檢查僅會針對分佈樣本發生,而不會針對輸入模組傳回的其他張量發生。如果樣本超出範圍,則會使用 TensorSpec.project 方法將其投影回所需的空間。預設值為False
。default_interaction_type (str, optional) – 用於檢索輸出值的預設方法。應為以下其中之一:'mode'、'median'、'mean' 或 'random' (在這種情況下,該值會從分佈中隨機採樣)。預設值為 'mode'。注意:當提取樣本時,
ProbabilisticTDModule
實例將首先尋找由 interaction_typ() 全域函數決定的互動模式。如果這傳回 None(其預設值),則將使用ProbabilisticTDModule
實例的 default_interaction_type。請注意,DataCollector 實例預設會使用tensordict.nn.set_interaction_type()
設定為tensordict.nn.InteractionType.RANDOM
。distribution_class (Type, optional) – 用於採樣的 torch.distributions.Distribution 類別。預設值為 Delta。
distribution_kwargs (dict, optional) – 要傳遞給分佈的 kwargs。
return_log_prob (bool, optional) – 如果
True
,則分佈樣本的對數機率將寫入具有鍵 ‘sample_log_prob’ 的 tensordict 中。預設值為False
。log_prob_key (NestedKey, optional) – 如果 return_log_prob = True,則寫入 log_prob 的鍵。預設為 ‘sample_log_prob’。
cache_dist (bool, optional) – 實驗性:如果
True
,則分佈的參數(即模組的輸出)將與樣本一起寫入 tensordict。這些參數可用於稍後重新計算原始分佈(例如,計算用於對動作進行採樣的分佈與 PPO 中更新的分佈之間的分歧)。預設值為False
。n_empirical_estimate (int, optional) – 在經驗平均值不可用時,用於計算經驗平均值的樣本數。預設值為 1000