SafeProbabilisticTensorDictSequential¶
- class torchrl.modules.tensordict_module.SafeProbabilisticTensorDictSequential(*args, **kwargs)[來源]¶
tensordict.nn.ProbabilisticTensorDictSequential
子類別,它接受TensorSpec
作為控制輸出域的引數。與
TensorDictSequential
類似,但強制序列中的最終模組是ProbabilisticTensorDictModule
,並且還公開get_dist
方法,以從ProbabilisticTensorDictModule
中恢復分佈物件- 參數:
modules (TensorDictModules 的可迭代物件) – TensorDictModule 執行個體的排序序列,以 ProbabilisticTensorDictModule 終止,以依序執行。
partial_tolerant (bool, 可選) – 如果
True
,則輸入 tensordict 可能會遺漏一些輸入鍵。 如果是這樣,則只會執行那些可以在給定存在的鍵的情況下執行的模組。 此外,如果輸入 tensordict 是 tensordict 的延遲堆疊,並且如果 partial_tolerant 為True
,並且如果堆疊沒有所需的鍵,則 TensorDictSequential 將掃描子 tensordict 以尋找具有所需鍵的 tensordict(如果有的話)。