捷徑

MultiStepActorWrapper

class torchrl.modules.tensordict_module.MultiStepActorWrapper(*args, **kwargs)[source]

多動作執行器的包裝器。

此類別可讓巨集在環境中執行。 執行器的 action(s) 項目必須具有額外的時間維度才能使用。 它必須放置在輸入 tensordict 的最後一個維度旁邊(即在 tensordict.ndim)。

如果未使用簡單的啟發式演算法提供動作項目鍵,則會自動從執行器中擷取(任何以 "action" 字串結尾的巢狀鍵)。

輸入 tensordict 中也必須存在 "is_init" 項目,以追蹤由於遇到「完成」狀態而應中斷當前收集的時間和原因。 與 action_keys 不同,此鍵必須是唯一的。

參數:
  • actor (TensorDictModuleBase) – 執行器。

  • n_steps (int) – 執行器一次輸出的動作數量(lookahead window)。

關鍵字引數:
  • action_keys (NestedKeys 的清單, optional) – 環境中的動作鍵。 可以從 env.action_keys 擷取。 預設為 actor 所有以 "action" 字串結尾的 out_keys

  • init_key (NestedKey, optional) – 指示環境何時重設的項目的鍵。 預設為 "is_init",這是來自 InitTracker 轉換的 out_key

範例

>>> import torch.nn
>>> from torchrl.modules.tensordict_module.actors import MultiStepActorWrapper, Actor
>>> from torchrl.envs import CatFrames, GymEnv, TransformedEnv, SerialEnv, InitTracker, Compose
>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
>>>
>>> time_steps = 6
>>> n_obs = 4
>>> n_action = 2
>>> batch = 5
>>>
>>> # Transforms a CatFrames in a stack of frames
>>> def reshape_cat(data: torch.Tensor):
...     return data.unflatten(-1, (time_steps, n_obs))
>>> # an actor that reads `time_steps` frames and outputs one action per frame
>>> # (actions are conditioned on the observation of `time_steps` in the past)
>>> actor_base = Seq(
...     Mod(reshape_cat, in_keys=["obs_cat"], out_keys=["obs_cat_reshape"]),
...     Mod(torch.nn.Linear(n_obs, n_action), in_keys=["obs_cat_reshape"], out_keys=["action"])
... )
>>> # Wrap the actor to dispatch the actions
>>> actor = MultiStepActorWrapper(actor_base, n_steps=time_steps)
>>>
>>> env = TransformedEnv(
...     SerialEnv(batch, lambda: GymEnv("CartPole-v1")),
...     Compose(
...         InitTracker(),
...         CatFrames(N=time_steps, in_keys=["observation"], out_keys=["obs_cat"], dim=-1)
...     )
... )
>>>
>>> print(env.rollout(100, policy=actor, break_when_any_done=False))
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([5, 100, 2]), device=cpu, dtype=torch.float32, is_shared=False),
        action_orig: Tensor(shape=torch.Size([5, 100, 6, 2]), device=cpu, dtype=torch.float32, is_shared=False),
        counter: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.int32, is_shared=False),
        done: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        is_init: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                is_init: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                obs_cat: Tensor(shape=torch.Size([5, 100, 24]), device=cpu, dtype=torch.float32, is_shared=False),
                observation: Tensor(shape=torch.Size([5, 100, 4]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([5, 100]),
            device=cpu,
            is_shared=False),
        obs_cat: Tensor(shape=torch.Size([5, 100, 24]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([5, 100, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([5, 100]),
    device=cpu,
    is_shared=False)
forward(tensordict: TensorDictBase) TensorDictBase[source]

定義每次呼叫時執行的計算。

應由所有子類別覆寫。

注意

雖然 forward pass 的配方需要在這個函式中定義,但應呼叫 Module 實例,而不是此函式,因為前者會處理執行已註冊的 hooks,而後者會靜默忽略它們。

property init_key: NestedKey

批次中給定元素的初始步驟的指標。

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

取得適合初學者和進階開發者的深入教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源