快速鍵

BurnInTransform

class torchrl.envs.transforms.BurnInTransform(modules: Sequence[TensorDictModuleBase], burn_in: int, out_keys: Optional[Sequence[NestedKey]] = None)[source]

轉換為部分燃燒資料序列。

當遞迴狀態不可用時,此轉換對於取得最新的遞迴狀態很有用。 它會沿著時間維度從採樣的循序資料片段中燃燒掉一些步驟,並將剩餘的資料序列傳回,並將燃燒的資料置於其初始時間步長中。 此轉換旨在用作重播緩衝區轉換,而不是環境轉換。

參數:
  • modules (TensorDictModule 的序列) – 用於燃燒資料序列的模組清單。

  • burn_in (int) – 要燃燒的時間步長數。

  • out_keys (NestedKey 的序列, 選擇性) – 目的地鍵。 預設為

  • ` (指向下一個時間步長的所有模組 out_keys (例如,如果) –

  • ("next"

  • module). ("hidden")` 屬於 a 的 out_keys 的一部分) –

注意

此轉換預期輸入的 TensorDicts 的最後一個維度是時間維度。 它也假設所有提供的模組都可以處理循序資料。

範例

>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.envs.transforms import BurnInTransform
>>> from torchrl.modules import GRUModule
>>> gru_module = GRUModule(
...     input_size=10,
...     hidden_size=10,
...     in_keys=["observation", "hidden"],
...     out_keys=["intermediate", ("next", "hidden")],
... ).set_recurrent_mode(True)
>>> burn_in_transform = BurnInTransform(
...     modules=[gru_module],
...     burn_in=5,
... )
>>> td = TensorDict({
...     "observation": torch.randn(2, 10, 10),
...      "hidden": torch.randn(2, 10, gru_module.gru.num_layers, 10),
...      "is_init": torch.zeros(2, 10, 1),
... }, batch_size=[2, 10])
>>> td = burn_in_transform(td)
>>> td.shape
torch.Size([2, 5])
>>> td.get("hidden").abs().sum()
tensor(86.3008)
>>> from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
>>> buffer = TensorDictReplayBuffer(
...     storage=LazyMemmapStorage(2),
...     batch_size=1,
... )
>>> buffer.append_transform(burn_in_transform)
>>> td = TensorDict({
...     "observation": torch.randn(2, 10, 10),
...      "hidden": torch.randn(2, 10, gru_module.gru.num_layers, 10),
...      "is_init": torch.zeros(2, 10, 1),
... }, batch_size=[2, 10])
>>> buffer.extend(td)
>>> td = buffer.sample(1)
>>> td.shape
torch.Size([1, 5])
>>> td.get("hidden").abs().sum()
tensor(37.0344)
forward(tensordict: TensorDictBase) TensorDictBase[source]

讀取輸入的 tensordict,並針對選定的鍵,套用轉換。

文件

取得 PyTorch 的完整開發者文件

查看文件

教學

取得針對初學者和進階開發者的深度教學

查看教學

資源

尋找開發資源並取得您的問題解答

查看資源