BurnInTransform¶
- class torchrl.envs.transforms.BurnInTransform(modules: Sequence[TensorDictModuleBase], burn_in: int, out_keys: Optional[Sequence[NestedKey]] = None)[source]¶
轉換為部分燃燒資料序列。
當遞迴狀態不可用時,此轉換對於取得最新的遞迴狀態很有用。 它會沿著時間維度從採樣的循序資料片段中燃燒掉一些步驟,並將剩餘的資料序列傳回,並將燃燒的資料置於其初始時間步長中。 此轉換旨在用作重播緩衝區轉換,而不是環境轉換。
- 參數:
modules (TensorDictModule 的序列) – 用於燃燒資料序列的模組清單。
burn_in (int) – 要燃燒的時間步長數。
out_keys (NestedKey 的序列, 選擇性) – 目的地鍵。 預設為
` (指向下一個時間步長的所有模組 out_keys (例如,如果) –
("next" –
module). ("hidden")` 屬於 a 的 out_keys 的一部分) –
注意
此轉換預期輸入的 TensorDicts 的最後一個維度是時間維度。 它也假設所有提供的模組都可以處理循序資料。
範例
>>> import torch >>> from tensordict import TensorDict >>> from torchrl.envs.transforms import BurnInTransform >>> from torchrl.modules import GRUModule >>> gru_module = GRUModule( ... input_size=10, ... hidden_size=10, ... in_keys=["observation", "hidden"], ... out_keys=["intermediate", ("next", "hidden")], ... ).set_recurrent_mode(True) >>> burn_in_transform = BurnInTransform( ... modules=[gru_module], ... burn_in=5, ... ) >>> td = TensorDict({ ... "observation": torch.randn(2, 10, 10), ... "hidden": torch.randn(2, 10, gru_module.gru.num_layers, 10), ... "is_init": torch.zeros(2, 10, 1), ... }, batch_size=[2, 10]) >>> td = burn_in_transform(td) >>> td.shape torch.Size([2, 5]) >>> td.get("hidden").abs().sum() tensor(86.3008)
>>> from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer >>> buffer = TensorDictReplayBuffer( ... storage=LazyMemmapStorage(2), ... batch_size=1, ... ) >>> buffer.append_transform(burn_in_transform) >>> td = TensorDict({ ... "observation": torch.randn(2, 10, 10), ... "hidden": torch.randn(2, 10, gru_module.gru.num_layers, 10), ... "is_init": torch.zeros(2, 10, 1), ... }, batch_size=[2, 10]) >>> buffer.extend(td) >>> td = buffer.sample(1) >>> td.shape torch.Size([1, 5]) >>> td.get("hidden").abs().sum() tensor(37.0344)