DiscreteActionProjection¶
- class torchrl.envs.transforms.DiscreteActionProjection(num_actions_effective: int, max_actions: int, action_key: NestedKey = 'action', include_forward: bool = True)[原始碼]¶
將離散動作從高維空間投影到低維空間。
給定一個編碼為 one-hot 向量的離散動作(從 1 到 N)和一個最大動作索引 num_actions(其中 num_actions < N),轉換動作,使得 action_out 最多為 num_actions。
如果輸入動作 > num_actions,它將被替換為 0 到 num_actions-1 之間的隨機值。 否則,保留相同的動作。 這旨在與應用於具有不同動作空間的多個離散控制環境的策略一起使用。
呼叫 DiscreteActionProjection.forward(例如從重播緩衝區或在 nn.Modules 序列中)將在
"in_keys"
上呼叫轉換 num_actions_effective -> max_actions,而對 _call 的呼叫將被忽略。 實際上,轉換後的 env 被指示僅更新內部 base_env 的輸入鍵,但原始輸入鍵將保持不變。- 參數:
num_actions_effective (int) – 考慮的最大動作數。
max_actions (int) – 此模組可以讀取的最大動作數。
action_key (NestedKey, optional) – 動作的鍵名。 預設為 “action”。
include_forward (bool, optional) – 如果
True
,當模組被重播緩衝區或 nn.Module 鏈呼叫時,對 forward 的呼叫也會將動作從一個域映射到另一個域。 預設為 True。
範例
>>> torch.manual_seed(0) >>> N = 3 >>> M = 2 >>> action = torch.zeros(N, dtype=torch.long) >>> action[-1] = 1 >>> td = TensorDict({"action": action}, []) >>> transform = DiscreteActionProjection(num_actions_effective=M, max_actions=N) >>> _ = transform.inv(td) >>> print(td.get("action")) tensor([1])
- transform_input_spec(input_spec: Composite)[原始碼]¶
轉換輸入規範,使得結果規範符合轉換映射。
- 參數:
input_spec (TensorSpec) – 轉換前的規範
- 傳回:
轉換後預期的規範