捷徑

terminated_or_truncated

torchrl.envs.utils.terminated_or_truncated(data: TensorDictBase, full_done_spec: Optional[TensorSpec] = None, key: str = '_reset', write_full_false: bool = False) bool[source]

讀取 tensordict 內的 done / terminated / truncated 鍵,並寫入一個新的 tensor,其中匯總了兩個訊號的值。

修改會在提供的 TensorDict 實例中就地發生。此函式可用於計算批次或多代理設定中的 “_reset” 訊號,因此輸出鍵的預設名稱。

參數:
  • data (TensorDictBase) – 輸入資料,通常是呼叫 step() 的結果。

  • full_done_spec (TensorSpec, optional) – 來自 env 的 done_spec,指示必須在哪裡找到 done leaves。如果未提供,則會在資料中搜尋預設的 "done""terminated""truncated" 條目。

  • key (NestedKey, optional) –

    應寫入彙總結果的位置。如果 None,則此函式不會寫入任何鍵,而只會輸出是否有任何完成值為 true。.. note:: 如果 key 條目已經存在值,

    則會沿用先前的值,並且不會進行更新。

  • write_full_false (bool, optional) – 如果 True,即使輸出為 False(即,在提供的資料結構中沒有 done 為 True),重置鍵也會被寫入。預設值為 False

傳回值:一個布林值,指示在資料中找到的任何完成狀態是否

包含 True

範例

>>> from torchrl.data.tensor_specs import Categorical
>>> from tensordict import TensorDict
>>> spec = Composite(
...     done=Categorical(2, dtype=torch.bool),
...     truncated=Categorical(2, dtype=torch.bool),
...     nested=Composite(
...         done=Categorical(2, dtype=torch.bool),
...         truncated=Categorical(2, dtype=torch.bool),
...     )
... )
>>> data = TensorDict({
...     "done": True, "truncated": False,
...     "nested": {"done": False, "truncated": True}},
...     batch_size=[]
... )
>>> data = _terminated_or_truncated(data, spec)
>>> print(data["_reset"])
tensor(True)
>>> print(data["nested", "_reset"])
tensor(True)

文件

存取 PyTorch 的完整開發人員文件

檢視文件

教學課程

取得適合初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得您的問題解答

檢視資源