捷徑

SafeProbabilisticModule

class torchrl.modules.tensordict_module.SafeProbabilisticModule(*args, **kwargs)[來源]

tensordict.nn.ProbabilisticTensorDictModule 子類別,它接受 TensorSpec 作為參數來控制輸出域。

SafeProbabilisticModule 是一個非參數模組,表示機率分佈。它使用指定的 in_keys 從輸入 TensorDict 中讀取分佈參數。輸出根據某些規則進行採樣,這些規則由輸入 default_interaction_type 參數和 interaction_type() 全域函數指定。

SafeProbabilisticModule 可用於建構分佈(透過 get_dist() 方法),和/或從這個分佈中採樣(透過對模組的常規 __call__())。

SafeProbabilisticModule 實例有兩個主要特點:- 它讀取和寫入 TensorDict 物件 - 它使用真實映射 R^n -> R^m 來建立 R^d 中的分佈,可以從中採樣或計算值。

當呼叫 __call__ / forward 方法時,會建立一個分佈,並計算一個值(使用 'mean'、'mode'、'median' 屬性或 'rsample'、'sample' 方法)。如果提供的 TensorDict 已經具有所有期望的鍵值對,則會跳過採樣步驟。

預設情況下,SafeProbabilisticModule 分佈類別是 Delta 分佈,這使得 SafeProbabilisticModule 成為確定性映射函數周圍的簡單包裝器。

參數:
  • in_keys (NestedKeyNestedKey 的 listdict) – 將從輸入 TensorDict 讀取並用於建構分佈的鍵。重要的是,如果它是 NestedKey 的 list 或 NestedKey,這些鍵的葉節點(最後一個元素)必須與感興趣的分佈類別使用的關鍵字相符,例如 Normal 分佈的 "loc""scale" 等等。如果 in_keys 是一個字典,則鍵是分佈的鍵,而值是 TensorDict 中將與相應分佈鍵匹配的鍵。

  • out_keys (NestedKeyNestedKey 的 list) – 將寫入採樣值的鍵。重要的是,如果在輸入 TensorDict 中找到這些鍵,則將跳過採樣步驟。

  • spec (TensorSpec) – 第一個輸出張量的規格。在呼叫 td_module.random() 以在目標空間中產生隨機值時使用。

  • safe (bool, optional) – 如果 True,則會根據輸入規格檢查樣本的值。由於探索策略或數值下溢/溢位問題,可能會發生超出範圍的採樣。對於 spec 參數,此檢查僅會針對分佈樣本發生,而不會針對輸入模組傳回的其他張量發生。如果樣本超出範圍,則會使用 TensorSpec.project 方法將其投影回所需的空間。預設值為 False

  • default_interaction_type (str, optional) – 用於檢索輸出值的預設方法。應為以下其中之一:'mode'、'median'、'mean' 或 'random' (在這種情況下,該值會從分佈中隨機採樣)。預設值為 'mode'。注意:當提取樣本時,ProbabilisticTDModule 實例將首先尋找由 interaction_typ() 全域函數決定的互動模式。如果這傳回 None(其預設值),則將使用 ProbabilisticTDModule 實例的 default_interaction_type。請注意,DataCollector 實例預設會使用 tensordict.nn.set_interaction_type() 設定為 tensordict.nn.InteractionType.RANDOM

  • distribution_class (Type, optional) – 用於採樣的 torch.distributions.Distribution 類別。預設值為 Delta。

  • distribution_kwargs (dict, optional) – 要傳遞給分佈的 kwargs。

  • return_log_prob (bool, optional) – 如果 True,則分佈樣本的對數機率將寫入具有鍵 ‘sample_log_prob’ 的 tensordict 中。預設值為 False

  • log_prob_key (NestedKey, optional) – 如果 return_log_prob = True,則寫入 log_prob 的鍵。預設為 ‘sample_log_prob’

  • cache_dist (bool, optional) – 實驗性:如果 True,則分佈的參數(即模組的輸出)將與樣本一起寫入 tensordict。這些參數可用於稍後重新計算原始分佈(例如,計算用於對動作進行採樣的分佈與 PPO 中更新的分佈之間的分歧)。預設值為 False

  • n_empirical_estimate (int, optional) – 在經驗平均值不可用時,用於計算經驗平均值的樣本數。預設值為 1000

random(tensordict: TensorDictBase) TensorDictBase[source]

在目標空間中採樣一個隨機元素,而不考慮任何輸入。

如果存在多個輸出鍵,則只會將第一個鍵寫入輸入 tensordict

參數:

tensordict (TensorDictBase) – 應該寫入輸出值的 tensordict。

傳回值:

具有新的/更新的輸出鍵值的原始 tensordict。

random_sample(tensordict: TensorDictBase) TensorDictBase[source]

請參閱 SafeModule.random(...)

文件

取得 PyTorch 的完整開發人員文件

檢視文件

教學課程

取得適合初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並取得您問題的解答

檢視資源