快捷方式

step_mdp

torchrl.envs.utils.step_mdp(tensordict: TensorDictBase, next_tensordict: Optional[TensorDictBase] = None, keep_other: bool = True, exclude_reward: bool = True, exclude_done: bool = False, exclude_action: bool = True, reward_keys: Union[NestedKey, List[NestedKey]] = 'reward', done_keys: Union[NestedKey, List[NestedKey]] = 'done', action_keys: Union[NestedKey, List[NestedKey]] = 'action') TensorDictBase[source]

建立一個新的 tensordict,反映輸入 tensordict 的時間步進。

給定一個步驟後檢索到的 tensordict,傳回 "next" 索引的 tensordict。 這些參數允許精確控制應保留哪些內容,以及應從 "next" 條目複製哪些內容。 預設行為是:將觀測條目、獎勵和完成狀態移動到根目錄,排除當前動作並保留所有額外的鍵(非動作、非完成、非獎勵)。

參數:
  • tensordict (TensorDictBase) – 要重新命名的 tensordict

  • next_tensordict (TensorDictBase, optional) – 目標 tensordict

  • keep_other (bool, optional) – 如果 True,將保留所有不以 'next_' 開頭的鍵。 預設值為 True

  • exclude_reward (bool, optional) – 如果 True,則將從產生的 tensordict 中丟棄 "reward" 鍵。 如果 False,它將從 "next" 條目複製(並替換)(如果存在)。 預設值為 True

  • exclude_done (bool, optional) – 如果 True,則將從產生的 tensordict 中丟棄 "done" 鍵。 如果 False,它將從 "next" 條目複製(並替換)(如果存在)。 預設值為 False

  • exclude_action (bool, optional) – 如果 True,則將從產生的 tensordict 中丟棄 "action" 鍵。 如果 False,它將保留在根 tensordict 中(因為它不應存在於 "next" 條目中)。 預設值為 True

  • reward_keys (NestedKeyNestedKey 的列表, optional) – 獎勵寫入的鍵。 預設為 “reward”。

  • done_keys (NestedKeyNestedKey 的列表, optional) – 完成狀態寫入的鍵。 預設為 “done”。

  • action_keys (NestedKeyNestedKey 的列表, optional) – 動作寫入的鍵。 預設為 “action”。

傳回:

一個新的 tensordict (或 next_tensordict) 包含 t+1 步驟的張量。

範例:此函數允許使用這種迴圈

>>> from tensordict import TensorDict
>>> import torch
>>> td = TensorDict({
...     "done": torch.zeros((), dtype=torch.bool),
...     "reward": torch.zeros(()),
...     "extra": torch.zeros(()),
...     "next": TensorDict({
...         "done": torch.zeros((), dtype=torch.bool),
...         "reward": torch.zeros(()),
...         "obs": torch.zeros(()),
...     }, []),
...     "obs": torch.zeros(()),
...     "action": torch.zeros(()),
... }, [])
>>> print(step_mdp(td))
TensorDict(
    fields={
        done: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False),
        extra: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        obs: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> print(step_mdp(td, exclude_done=True))  # "done" is dropped
TensorDict(
    fields={
        extra: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        obs: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> print(step_mdp(td, exclude_reward=False))  # "reward" is kept
TensorDict(
    fields={
        done: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False),
        extra: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        obs: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        reward: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> print(step_mdp(td, exclude_action=False))  # "action" persists at the root
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False),
        extra: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        obs: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> print(step_mdp(td, keep_other=False))  # "extra" is missing
TensorDict(
    fields={
        done: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False),
        obs: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

警告

如果當排除獎勵鍵時,獎勵鍵也是輸入鍵的一部分,此函數將無法正常運作。 這就是為什麼 RewardSum 轉換預設會在觀測中註冊 episode 獎勵,而不是獎勵規格。 當使用此函數的快速、快取版本 (_StepMDP) 時,不應觀察到此問題。

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學課程

取得針對初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並取得問題解答

檢視資源