OnlineDTActor¶
- class torchrl.modules.OnlineDTActor(state_dim: int, action_dim: int, transformer_config: Optional[Union[Dict, DTConfig]] = None, device: Optional[Union[device, str, int]] = None)[source]¶
線上決策轉換器 Actor 類別。
用於線上決策轉換器的 Actor 類別,用於從高斯分佈中取樣動作,如 “Online Decision Transformer” 中所呈現。
傳回用於從中取樣動作的高斯分佈的平均值和標準差。
- 參數:
state_dim (int) – 狀態維度。
action_dim (int) – 動作維度。
transformer_config (Dict 或
DecisionTransformer.DTConfig
) – GPT2 transformer 的設定。預設為default_config()
。device (torch.device, optional) – 要使用的裝置。預設為 None。
範例
>>> model = OnlineDTActor(state_dim=4, action_dim=2, ... transformer_config=OnlineDTActor.default_config()) >>> observation = torch.randn(32, 10, 4) >>> action = torch.randn(32, 10, 2) >>> return_to_go = torch.randn(32, 10, 1) >>> mu, std = model(observation, action, return_to_go) >>> mu.shape torch.Size([32, 10, 2]) >>> std.shape torch.Size([32, 10, 2])
- classmethod default_config()[原始碼]¶
OnlineDTActor
的預設設定。