terminated_or_truncated¶
- torchrl.envs.utils.terminated_or_truncated(data: TensorDictBase, full_done_spec: Optional[TensorSpec] = None, key: str = '_reset', write_full_false: bool = False) bool [source]¶
讀取 tensordict 內的 done / terminated / truncated 鍵,並寫入一個新的 tensor,其中匯總了兩個訊號的值。
修改會在提供的 TensorDict 實例中就地發生。此函式可用於計算批次或多代理設定中的 “_reset” 訊號,因此輸出鍵的預設名稱。
- 參數:
data (TensorDictBase) – 輸入資料,通常是呼叫
step()
的結果。full_done_spec (TensorSpec, optional) – 來自 env 的 done_spec,指示必須在哪裡找到 done leaves。如果未提供,則會在資料中搜尋預設的
"done"
、"terminated"
和"truncated"
條目。key (NestedKey, optional) –
應寫入彙總結果的位置。如果
None
,則此函式不會寫入任何鍵,而只會輸出是否有任何完成值為 true。.. note:: 如果key
條目已經存在值,則會沿用先前的值,並且不會進行更新。
write_full_false (bool, optional) – 如果
True
,即使輸出為False
(即,在提供的資料結構中沒有 done 為True
),重置鍵也會被寫入。預設值為False
。
- 傳回值:一個布林值,指示在資料中找到的任何完成狀態是否
包含
True
。
範例
>>> from torchrl.data.tensor_specs import Categorical >>> from tensordict import TensorDict >>> spec = Composite( ... done=Categorical(2, dtype=torch.bool), ... truncated=Categorical(2, dtype=torch.bool), ... nested=Composite( ... done=Categorical(2, dtype=torch.bool), ... truncated=Categorical(2, dtype=torch.bool), ... ) ... ) >>> data = TensorDict({ ... "done": True, "truncated": False, ... "nested": {"done": False, "truncated": True}}, ... batch_size=[] ... ) >>> data = _terminated_or_truncated(data, spec) >>> print(data["_reset"]) tensor(True) >>> print(data["nested", "_reset"]) tensor(True)