快速鍵

SafeProbabilisticTensorDictSequential

class torchrl.modules.tensordict_module.SafeProbabilisticTensorDictSequential(*args, **kwargs)[來源]

tensordict.nn.ProbabilisticTensorDictSequential 子類別,它接受 TensorSpec 作為控制輸出域的引數。

TensorDictSequential 類似,但強制序列中的最終模組是 ProbabilisticTensorDictModule,並且還公開 get_dist 方法,以從 ProbabilisticTensorDictModule 中恢復分佈物件

參數:
  • modules (TensorDictModules 的可迭代物件) – TensorDictModule 執行個體的排序序列,以 ProbabilisticTensorDictModule 終止,以依序執行。

  • partial_tolerant (bool, 可選) – 如果 True,則輸入 tensordict 可能會遺漏一些輸入鍵。 如果是這樣,則只會執行那些可以在給定存在的鍵的情況下執行的模組。 此外,如果輸入 tensordict 是 tensordict 的延遲堆疊,並且如果 partial_tolerant 為 True,並且如果堆疊沒有所需的鍵,則 TensorDictSequential 將掃描子 tensordict 以尋找具有所需鍵的 tensordict(如果有的話)。

文件

存取 PyTorch 的全面開發人員文件

檢視文件

教學

取得初學者和進階開發人員的深入教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源