快速鍵

DTActor

class torchrl.modules.DTActor(state_dim: int, action_dim: int, transformer_config: Optional[Union[Dict, DTConfig]] = None, device: Optional[Union[device, str, int]] = None)[原始碼]

決策轉換器 Actor 類別。

決策轉換器的 Actor 類別,用於輸出確定性動作,如 “Decision Transformer” <https://arxiv.org/abs/2202.05607.pdf> 中所示。傳回確定性動作。

參數:
  • state_dim (int) – 狀態維度。

  • action_dim (int) – 動作維度。

  • transformer_config (字典或 DecisionTransformer.DTConfig, optional) – GPT2 transformer 的配置。預設為 default_config()

  • device (torch.device, optional) – 要使用的裝置。預設為 None。

範例

>>> model = DTActor(state_dim=4, action_dim=2,
...     transformer_config=DTActor.default_config())
>>> observation = torch.randn(32, 10, 4)
>>> action = torch.randn(32, 10, 2)
>>> return_to_go = torch.randn(32, 10, 1)
>>> output = model(observation, action, return_to_go)
>>> output.shape
torch.Size([32, 10, 2])
classmethod default_config()[來源]

DTActor 的預設配置。

forward(observation: Tensor, action: Tensor, return_to_go: Tensor) Tensor[來源]

定義每次呼叫時執行的計算。

應被所有子類別覆寫。

注意

雖然 forward pass 的方法需要在這個函數中定義,但應該在之後呼叫 Module 實例,而不是這個函數,因為前者會處理已註冊的 hooks 的執行,而後者會靜默地忽略它們。

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

取得針對初學者和進階開發者的深入教學

檢視教學

資源

尋找開發資源並取得您的問題解答

檢視資源