MultiStepActorWrapper¶
- class torchrl.modules.tensordict_module.MultiStepActorWrapper(*args, **kwargs)[source]¶
多動作執行器的包裝器。
此類別可讓巨集在環境中執行。 執行器的 action(s) 項目必須具有額外的時間維度才能使用。 它必須放置在輸入 tensordict 的最後一個維度旁邊(即在
tensordict.ndim
)。如果未使用簡單的啟發式演算法提供動作項目鍵,則會自動從執行器中擷取(任何以
"action"
字串結尾的巢狀鍵)。輸入 tensordict 中也必須存在
"is_init"
項目,以追蹤由於遇到「完成」狀態而應中斷當前收集的時間和原因。 與action_keys
不同,此鍵必須是唯一的。- 參數:
actor (TensorDictModuleBase) – 執行器。
n_steps (int) – 執行器一次輸出的動作數量(lookahead window)。
- 關鍵字引數:
action_keys (NestedKeys 的清單, optional) – 環境中的動作鍵。 可以從
env.action_keys
擷取。 預設為actor
所有以"action"
字串結尾的out_keys
。init_key (NestedKey, optional) – 指示環境何時重設的項目的鍵。 預設為
"is_init"
,這是來自InitTracker
轉換的out_key
。
範例
>>> import torch.nn >>> from torchrl.modules.tensordict_module.actors import MultiStepActorWrapper, Actor >>> from torchrl.envs import CatFrames, GymEnv, TransformedEnv, SerialEnv, InitTracker, Compose >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod >>> >>> time_steps = 6 >>> n_obs = 4 >>> n_action = 2 >>> batch = 5 >>> >>> # Transforms a CatFrames in a stack of frames >>> def reshape_cat(data: torch.Tensor): ... return data.unflatten(-1, (time_steps, n_obs)) >>> # an actor that reads `time_steps` frames and outputs one action per frame >>> # (actions are conditioned on the observation of `time_steps` in the past) >>> actor_base = Seq( ... Mod(reshape_cat, in_keys=["obs_cat"], out_keys=["obs_cat_reshape"]), ... Mod(torch.nn.Linear(n_obs, n_action), in_keys=["obs_cat_reshape"], out_keys=["action"]) ... ) >>> # Wrap the actor to dispatch the actions >>> actor = MultiStepActorWrapper(actor_base, n_steps=time_steps) >>> >>> env = TransformedEnv( ... SerialEnv(batch, lambda: GymEnv("CartPole-v1")), ... Compose( ... InitTracker(), ... CatFrames(N=time_steps, in_keys=["observation"], out_keys=["obs_cat"], dim=-1) ... ) ... ) >>> >>> print(env.rollout(100, policy=actor, break_when_any_done=False)) TensorDict( fields={ action: Tensor(shape=torch.Size([5, 100, 2]), device=cpu, dtype=torch.float32, is_shared=False), action_orig: Tensor(shape=torch.Size([5, 100, 6, 2]), device=cpu, dtype=torch.float32, is_shared=False), counter: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.int32, is_shared=False), done: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False), is_init: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False), is_init: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False), obs_cat: Tensor(shape=torch.Size([5, 100, 24]), device=cpu, dtype=torch.float32, is_shared=False), observation: Tensor(shape=torch.Size([5, 100, 4]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([5, 100]), device=cpu, is_shared=False), obs_cat: Tensor(shape=torch.Size([5, 100, 24]), device=cpu, dtype=torch.float32, is_shared=False), observation: Tensor(shape=torch.Size([5, 100, 4]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([5, 100]), device=cpu, is_shared=False)
- forward(tensordict: TensorDictBase) TensorDictBase [source]¶
定義每次呼叫時執行的計算。
應由所有子類別覆寫。
注意
雖然 forward pass 的配方需要在這個函式中定義,但應呼叫
Module
實例,而不是此函式,因為前者會處理執行已註冊的 hooks,而後者會靜默忽略它們。
- property init_key: NestedKey¶
批次中給定元素的初始步驟的指標。