LSTMModule¶
- class torchrl.modules.LSTMModule(*args, **kwargs)[來源]¶
LSTM 模組的嵌入器。
此類別為
torch.nn.LSTM
新增以下功能與 TensorDict 的相容性:隱藏狀態會被重新塑形以符合 tensordict 批次大小。
可選的多步執行:使用 torch.nn,必須在
torch.nn.LSTMCell
和torch.nn.LSTM
之間選擇,前者與單步輸入相容,後者與多步相容。此類別同時支援這兩種用法。
建構後,模組未設定為遞迴模式,即它將預期單步輸入。
如果在遞迴模式下,則預期 tensordict 的最後一個維度標記步數。對 tensordict 的維度沒有限制(除非對於時間輸入,它必須大於 1)。
注意
此類別可以處理時間維度上的多個連續軌跡,但在這些情況下,不應信任最終的隱藏值(即,不應將它們重新用於連續軌跡)。原因是 LSTM 僅傳回最後一個隱藏值,對於我們提供的填充輸入,該值可能對應於 0 填充的輸入。
- 參數:
input_size – 輸入 x 中預期的特徵數量
hidden_size – 隱藏狀態 h 中的特徵數量
num_layers – 遞迴層的數量。例如,設定
num_layers=2
表示將兩個 LSTM 堆疊在一起以形成一個 堆疊 LSTM,第二個 LSTM 接收第一個 LSTM 的輸出並計算最終結果。預設值:1bias – 如果
False
,則該層不使用偏差權重 b_ih 和 b_hh。預設值:True
dropout – 若不為零,則在除了最後一層以外的每個 LSTM 層的輸出上引入一個 Dropout 層,其 dropout 機率等於
dropout
。預設值:0python_based – 若為
True
,則會使用 LSTM cell 的完整 Python 實作。預設值:False
- 關鍵字參數:
in_key (str 或 str 的 tuple) – 模組的輸入鍵。與
in_keys
互斥使用。如果提供,則遞迴鍵將被假定為 [“recurrent_state_h”, “recurrent_state_c”],且in_key
會被附加在這些鍵的前面。in_keys (str 的 list) – 對應於輸入值、第一個和第二個隱藏鍵的三個字串。與
in_key
互斥使用。out_key (str 或 str 的 tuple) – 模組的輸出鍵。與
out_keys
互斥使用。如果提供,則遞迴鍵將被假定為 [(“next”, “recurrent_state_h”), (“next”, “recurrent_state_c”)],且out_key
會被附加在這些鍵的前面。out_keys (str 的 list) –
對應於輸出值、第一個和第二個隱藏鍵的三個字串。.. note
For a better integration with TorchRL's environments, the best naming for the output hidden key is ``("next", <custom_key>)``, such that the hidden values are passed from step to step during a rollout.
device (torch.device 或 compatible) – 模組的裝置。
lstm (torch.nn.LSTM, optional) – 要包裝的 LSTM 實例。與其他 nn.LSTM 參數互斥使用。
- 變數:
recurrent_mode – 返回模組的遞迴模式。
注意
此模組依賴於輸入 TensorDicts 中存在特定的
recurrent_state
鍵。要產生一個TensorDictPrimer
轉換,該轉換將自動將隱藏狀態添加到環境 TensorDicts 中,請使用方法make_tensordict_primer()
。如果此類別是較大模組中的子模組,則可以在父模組上呼叫方法get_primers_from_module()
,以自動產生所有子模組(包括此模組)所需的 primer 轉換。範例
>>> from torchrl.envs import TransformedEnv, InitTracker >>> from torchrl.envs import GymEnv >>> from torchrl.modules import MLP >>> from torch import nn >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod >>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker()) >>> lstm_module = LSTMModule( ... input_size=env.observation_spec["observation"].shape[-1], ... hidden_size=64, ... in_keys=["observation", "rs_h", "rs_c"], ... out_keys=["intermediate", ("next", "rs_h"), ("next", "rs_c")]) >>> mlp = MLP(num_cells=[64], out_features=1) >>> policy = Seq(lstm_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) >>> policy(env.reset()) TensorDict( fields={ action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), intermediate: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.float32, is_shared=False), is_init: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ rs_c: Tensor(shape=torch.Size([1, 64]), device=cpu, dtype=torch.float32, is_shared=False), rs_h: Tensor(shape=torch.Size([1, 64]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False)
- forward(tensordict: TensorDictBase)[source]¶
定義每次呼叫時執行的計算。
應由所有子類別覆寫。
注意
雖然 forward pass 的配方需要在這個函式中定義,但應該在之後呼叫
Module
實例,而不是呼叫此函式,因為前者會處理已註冊的 hooks,而後者會默默地忽略它們。
- make_tensordict_primer()[source]¶
為環境建立 tensordict primer。
一個
TensorDictPrimer
物件將確保 policy 在 rollout 執行期間能感知到補充輸入和輸出 (遞迴狀態)。這樣,資料就可以在不同程序之間共享並妥善處理。在環境中不包含
TensorDictPrimer
可能會導致行為定義不佳,例如在平行設定中,一個步驟涉及將新的遞迴狀態從"next"
複製到根 tensordict,而 meth:~torchrl.EnvBase.step_mdp 方法將無法做到這一點,因為遞迴狀態未在環境規格中註冊。請參閱
torchrl.modules.utils.get_primers_from_module()
以了解產生給定模組的所有 primers 的方法。範例
>>> from torchrl.collectors import SyncDataCollector >>> from torchrl.envs import TransformedEnv, InitTracker >>> from torchrl.envs import GymEnv >>> from torchrl.modules import MLP, LSTMModule >>> from torch import nn >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod >>> >>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker()) >>> lstm_module = LSTMModule( ... input_size=env.observation_spec["observation"].shape[-1], ... hidden_size=64, ... in_keys=["observation", "rs_h", "rs_c"], ... out_keys=["intermediate", ("next", "rs_h"), ("next", "rs_c")]) >>> mlp = MLP(num_cells=[64], out_features=1) >>> policy = Seq(lstm_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) >>> policy(env.reset()) >>> env = env.append_transform(lstm_module.make_tensordict_primer()) >>> data_collector = SyncDataCollector( ... env, ... policy, ... frames_per_batch=10 ... ) >>> for data in data_collector: ... print(data) ... break
- set_recurrent_mode(mode: bool = True)[source]¶
返回模組的新副本,該副本共享相同的 lstm 模型,但具有不同的
recurrent_mode
屬性(如果它不同)。創建副本是為了使模組可以在代碼的不同部分(推論與訓練)中具有不同的行為。
範例
>>> from torchrl.envs import TransformedEnv, InitTracker, step_mdp >>> from torchrl.envs import GymEnv >>> from torchrl.modules import MLP >>> from tensordict import TensorDict >>> from torch import nn >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod >>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker()) >>> lstm = nn.LSTM(input_size=env.observation_spec["observation"].shape[-1], hidden_size=64, batch_first=True) >>> lstm_module = LSTMModule(lstm=lstm, in_keys=["observation", "hidden0", "hidden1"], out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")]) >>> mlp = MLP(num_cells=[64], out_features=1) >>> # building two policies with different behaviors: >>> policy_inference = Seq(lstm_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) >>> policy_training = Seq(lstm_module.set_recurrent_mode(True), Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) >>> traj_td = env.rollout(3) # some random temporal data >>> traj_td = policy_training(traj_td) >>> # let's check that both return the same results >>> td_inf = TensorDict({}, traj_td.shape[:-1]) >>> for td in traj_td.unbind(-1): ... td_inf = td_inf.update(td.select("is_init", "observation", ("next", "observation"))) ... td_inf = policy_inference(td_inf) ... td_inf = step_mdp(td_inf) ... >>> torch.testing.assert_close(td_inf["hidden0"], traj_td[..., -1]["next", "hidden0"])