捷徑

GRUModule

class torchrl.modules.GRUModule(*args, **kwargs)[來源]

GRU 模組的嵌入器。

此類別將以下功能新增至 torch.nn.GRU

  • 與 TensorDict 的相容性:隱藏狀態會被重新塑形以匹配 tensordict 批次大小。

  • 可選的多步執行:使用 torch.nn,必須在 torch.nn.GRUCelltorch.nn.GRU 之間進行選擇,前者與單步輸入相容,而後者與多步輸入相容。 此類別可同時實現這兩種用法。

建構後,模組不會設定為遞迴模式,即它會預期單步輸入。

如果在遞迴模式下,預期 tensordict 的最後一個維度標記步數。 對 tensordict 的維度沒有限制(除非對於時間輸入,它必須大於 1)。

參數:
  • input_size – 輸入 x 中預期的特徵數量

  • hidden_size – 隱藏狀態 h 中的特徵數量

  • num_layers – 遞迴層數。 例如,設定 num_layers=2 表示將兩個 GRU 堆疊在一起以形成一個 堆疊 GRU,第二個 GRU 接收第一個 GRU 的輸出並計算最終結果。 預設值:1

  • bias – 如果 False,則該層不使用偏差權重。 預設值:True

  • dropout – 如果非零,則在每個 GRU 層的輸出(除了最後一層)上引入一個 Dropout 層,dropout 概率等於 dropout。 預設值:0

  • python_based – 如果 True,將使用 GRU cell 的完整 Python 實作。 預設值:False

關鍵字參數:
  • in_key (strstr 的 tuple) – 模組的輸入鍵。與 in_keys 互斥。如果提供,則遞迴鍵假定為 [“recurrent_state”],並且 in_key 將附加在此之前。

  • in_keys (str 的 list) – 一對字串,分別對應於輸入值和遞迴條目。與 in_key 互斥。

  • out_key (strstr 的 tuple) – 模組的輸出鍵。與 out_keys 互斥。如果提供,則遞迴鍵假定為 [(“recurrent_state”)],並且 out_key 將附加在這些之前。

  • out_keys (str 的 list) –

    一對字串,分別對應於輸出值、第一個和第二個隱藏鍵。 .. note

    For a better integration with TorchRL's environments, the best naming
    for the output hidden key is ``("next", <custom_key>)``, such
    that the hidden values are passed from step to step during a rollout.
    

  • device (torch.devicecompatible) – 模組的裝置。

  • gru (torch.nn.GRU, optional) – 要包裝的 GRU 實例。與其他 nn.GRU 參數互斥。

變數:

recurrent_mode – 返回模組的遞迴模式。

set_recurrent_mode()[source]

控制模組是否應在遞迴模式下執行。

make_tensordict_primer()[source]

為環境創建 TensorDictPrimer 轉換,以使環境能夠感知 RNN 的遞迴狀態。

注意

此模組依賴於輸入 TensorDicts 中存在的特定 recurrent_state 鍵。要生成一個 TensorDictPrimer 轉換,該轉換將自動將隱藏狀態添加到環境 TensorDicts,請使用方法 make_tensordict_primer()。如果此類是較大模組中的子模組,則可以對父模組調用方法 get_primers_from_module() 以自動生成所有子模組(包括此模組)所需的 primer 轉換。

範例

>>> from torchrl.envs import TransformedEnv, InitTracker
>>> from torchrl.envs import GymEnv
>>> from torchrl.modules import MLP
>>> from torch import nn
>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker())
>>> gru_module = GRUModule(
...     input_size=env.observation_spec["observation"].shape[-1],
...     hidden_size=64,
...     in_keys=["observation", "rs"],
...     out_keys=["intermediate", ("next", "rs")])
>>> mlp = MLP(num_cells=[64], out_features=1)
>>> policy = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
>>> policy(env.reset())
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        intermediate: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.float32, is_shared=False),
        is_init: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                rs: Tensor(shape=torch.Size([1, 64]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=cpu,
            is_shared=False),
        observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)
>>> gru_module_training = gru_module.set_recurrent_mode()
>>> policy_training = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
>>> traj_td = env.rollout(3) # some random temporal data
>>> traj_td = policy_training(traj_td)
>>> print(traj_td)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        intermediate: Tensor(shape=torch.Size([3, 64]), device=cpu, dtype=torch.float32, is_shared=False),
        is_init: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                is_init: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                rs: Tensor(shape=torch.Size([3, 1, 64]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([3]),
            device=cpu,
            is_shared=False),
        observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([3]),
    device=cpu,
    is_shared=False)
forward(tensordict: TensorDictBase)[source]

定義每次呼叫時執行的計算。

應由所有子類別覆寫。

注意

儘管 forward pass 的配方需要在這個函式中定義,但之後應該呼叫 Module 實例,而不是這個,因為前者負責運行已註冊的 hooks,而後者會默默地忽略它們。

make_tensordict_primer()[source]

為環境建立 tensordict primer。

一個 TensorDictPrimer 物件將確保策略在 rollout 執行期間了解補充輸入和輸出(遞迴狀態)。這樣,數據可以在各個進程之間共享並得到妥善處理。

在環境中不包含 TensorDictPrimer 可能會導致定義不佳的行為,例如在並行設置中,步驟涉及將新的遞迴狀態從 "next" 複製到根 tensordict,meth:~torchrl.EnvBase.step_mdp 方法將無法執行,因為遞迴狀態未在環境規格中註冊。

請參閱 torchrl.modules.utils.get_primers_from_module() 了解為給定模組生成所有 primers 的方法。

範例

>>> from torchrl.collectors import SyncDataCollector
>>> from torchrl.envs import TransformedEnv, InitTracker
>>> from torchrl.envs import GymEnv
>>> from torchrl.modules import MLP, LSTMModule
>>> from torch import nn
>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
>>>
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker())
>>> gru_module = GRUModule(
...     input_size=env.observation_spec["observation"].shape[-1],
...     hidden_size=64,
...     in_keys=["observation", "rs"],
...     out_keys=["intermediate", ("next", "rs")])
>>> mlp = MLP(num_cells=[64], out_features=1)
>>> policy = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
>>> policy(env.reset())
>>> env = env.append_transform(gru_module.make_tensordict_primer())
>>> data_collector = SyncDataCollector(
...     env,
...     policy,
...     frames_per_batch=10
... )
>>> for data in data_collector:
...     print(data)
...     break
set_recurrent_mode(mode: bool = True)[source]

返回模組的新副本,該副本共享相同的 gru 模型,但具有不同的 recurrent_mode 屬性(如果不同)。

創建一個副本,以便可以在程式碼的不同部分(推論與訓練)中使用具有不同行為的模組

範例

>>> from torchrl.envs import GymEnv, TransformedEnv, InitTracker, step_mdp
>>> from torchrl.modules import MLP
>>> from tensordict import TensorDict
>>> from torch import nn
>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker())
>>> gru = nn.GRU(input_size=env.observation_spec["observation"].shape[-1], hidden_size=64, batch_first=True)
>>> gru_module = GRUModule(gru=gru, in_keys=["observation", "hidden"], out_keys=["intermediate", ("next", "hidden")])
>>> mlp = MLP(num_cells=[64], out_features=1)
>>> # building two policies with different behaviors:
>>> policy_inference = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
>>> policy_training = Seq(gru_module.set_recurrent_mode(True), Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
>>> traj_td = env.rollout(3) # some random temporal data
>>> traj_td = policy_training(traj_td)
>>> # let's check that both return the same results
>>> td_inf = TensorDict({}, traj_td.shape[:-1])
>>> for td in traj_td.unbind(-1):
...     td_inf = td_inf.update(td.select("is_init", "observation", ("next", "observation")))
...     td_inf = policy_inference(td_inf)
...     td_inf = step_mdp(td_inf)
...
>>> torch.testing.assert_close(td_inf["hidden"], traj_td[..., -1]["next", "hidden"])

文件

取得 PyTorch 的完整開發者文件

檢視文件

教學課程

取得初學者和高級開發者的深入教學課程

檢視教學課程

資源

尋找開發資源並取得您問題的解答

檢視資源