Reward2GoTransform¶
- class torchrl.envs.transforms.Reward2GoTransform(gamma: Optional[Union[float, Tensor]] = 1.0, in_keys: Optional[Sequence[NestedKey]] = None, out_keys: Optional[Sequence[NestedKey]] = None, done_key: Optional[NestedKey] = 'done')[來源]¶
根據 episode reward 和折扣因子計算 reward to go。
由於
Reward2GoTransform
僅是一個反向轉換,in_keys
將直接用於in_keys_inv
。只有在 episode 結束後才能計算 reward-to-go。因此,這個轉換應該應用於 replay buffer,而不是應用於 collector 或環境中。- 參數:
gamma (float or torch.Tensor) – 折扣因子。預設值為 1.0。
in_keys (sequence of NestedKey) – 要重新命名的條目序列。如果未提供,預設值為
("next", "reward")
。out_keys (sequence of NestedKey) – 要重新命名的條目序列。如果未提供,預設值為
in_keys
的值。done_key (NestedKey) – 完成 (done) 條目。預設值為
"done"
。truncated_key (NestedKey) – 截斷 (truncated) 條目。預設值為
"truncated"
。如果未找到截斷條目,則只會使用"done"
。
範例
>>> # Using this transform as part of a replay buffer >>> from torchrl.data import ReplayBuffer, LazyTensorStorage >>> torch.manual_seed(0) >>> r2g = Reward2GoTransform(gamma=0.99, out_keys=["reward_to_go"]) >>> rb = ReplayBuffer(storage=LazyTensorStorage(100), transform=r2g) >>> batch, timesteps = 4, 5 >>> done = torch.zeros(batch, timesteps, 1, dtype=torch.bool) >>> for i in range(batch): ... while not done[i].any(): ... done[i] = done[i].bernoulli_(0.1) >>> reward = torch.ones(batch, timesteps, 1) >>> td = TensorDict( ... {"next": {"done": done, "reward": reward}}, ... [batch, timesteps], ... ) >>> rb.extend(td) >>> sample = rb.sample(1) >>> print(sample["next", "reward"]) tensor([[[1.], [1.], [1.], [1.], [1.]]]) >>> print(sample["reward_to_go"]) tensor([[[4.9010], [3.9404], [2.9701], [1.9900], [1.0000]]])
也可以直接將這個轉換用於 collector:請確保附加 transform 的 inv 方法。
範例
>>> from torchrl.envs.utils import RandomPolicy >>> from torchrl.collectors import SyncDataCollector >>> from torchrl.envs.libs.gym import GymEnv >>> t = Reward2GoTransform(gamma=0.99, out_keys=["reward_to_go"]) >>> env = GymEnv("Pendulum-v1") >>> collector = SyncDataCollector( ... env, ... RandomPolicy(env.action_spec), ... frames_per_batch=200, ... total_frames=-1, ... postproc=t.inv ... ) >>> for data in collector: ... break >>> print(data) TensorDict( fields={ action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), collector: TensorDict( fields={ traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([200]), device=cpu, is_shared=False), done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([200]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), reward_to_go: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([200]), device=cpu, is_shared=False)
將這個轉換作為環境的一部分會引發例外
範例
>>> t = Reward2GoTransform(gamma=0.99) >>> TransformedEnv(GymEnv("Pendulum-v1"), t) # crashes
注意
在存在多個完成條目的設定中,應該為每個 done-reward 對建立一個單獨的
Reward2GoTransform
。