快捷方式

EGreedyModule

class torchrl.modules.EGreedyModule(*args, **kwargs)[source]

Epsilon-Greedy 探索模組。

此模組使用 epsilon greedy 探索策略隨機更新 tensordict 中的動作。 每次呼叫時,都會根據某個機率閾值執行隨機抽取(每個動作一個)。 如果成功,則相應的動作將被從提供的動作規格中抽取的隨機樣本替換。 其他保持不變。

參數:
  • spec (TensorSpec) – 用於抽樣動作的規格。

  • eps_init (scalar, optional) – 初始 epsilon 值。預設值:1.0

  • eps_end (scalar, optional) – 最終 epsilon 值。預設值:0.1

  • annealing_num_steps (int, optional) – epsilon 達到 eps_end 值所需的步數。預設為 1000

關鍵字參數:
  • action_key (NestedKey, optional) – 可以在輸入 tensordict 中找到動作的鍵。預設值為 "action"

  • action_mask_key (NestedKey, optional) – 可以在輸入 tensordict 中找到動作遮罩的鍵。預設值為 None (對應於沒有遮罩)。

注意

在訓練迴圈中包含對 step() 的呼叫以更新探索因子至關重要。由於不容易捕捉到這種遺漏,如果省略了它,將不會引發任何警告或異常!

範例

>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictSequential
>>> from torchrl.modules import EGreedyModule, Actor
>>> from torchrl.data import Bounded
>>> torch.manual_seed(0)
>>> spec = Bounded(-1, 1, torch.Size([4]))
>>> module = torch.nn.Linear(4, 4, bias=False)
>>> policy = Actor(spec=spec, module=module)
>>> explorative_policy = TensorDictSequential(policy,  EGreedyModule(eps_init=0.2))
>>> td = TensorDict({"observation": torch.zeros(10, 4)}, batch_size=[10])
>>> print(explorative_policy(td).get("action"))
tensor([[ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.9055, -0.9277, -0.6295, -0.2532],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000]], grad_fn=<AddBackward0>)
forward(tensordict: TensorDictBase) TensorDictBase[source]

定義每次呼叫時執行的計算。

應由所有子類別覆寫。

注意

雖然 forward pass 的配方需要在這個函數中定義,但應該在之後呼叫 Module 實例,而不是這個函數,因為前者會處理執行註冊的 hooks,而後者會靜默地忽略它們。

step(frames: int = 1) None[source]

epsilon 衰減的一個步驟。

在呼叫此方法 self.annealing_num_steps 次之後,後續的呼叫將不會執行任何操作。

參數:

frames (int, optional) – 自上次步驟以來的影格數。預設值為 1

文件

取得 PyTorch 的完整開發者文件

檢視文件

教學

取得適合初學者和進階開發者的深入教學

檢視教學

資源

尋找開發資源並獲得您的問題解答

檢視資源