快捷鍵

DdpgCnnQNet

class torchrl.modules.DdpgCnnQNet(conv_net_kwargs: Optional[dict] = None, mlp_net_kwargs: Optional[dict] = None, use_avg_pooling: bool = True, device: Optional[Union[device, str, int]] = None)[原始碼]

DDPG 卷積 Q 值類別。

在「使用深度強化學習的連續控制」中介紹,https://arxiv.org/pdf/1509.02971.pdf

DDPG Q 值網路將觀察和動作作為輸入,並從中傳回一個純量。

參數:
  • conv_net_kwargs (dict, optional) –

    卷積網路的 kwargs。預設為

    >>> {
    ...     'in_features': None,
    ...     "num_cells": [32, 64, 128],
    ...     "kernel_sizes": [8, 4, 3],
    ...     "strides": [4, 2, 1],
    ...     "paddings": [0, 0, 1],
    ...     'activation_class': nn.ELU,
    ...     'norm_class': None,
    ...     'aggregator_class': nn.AdaptiveAvgPool2d,
    ...     'aggregator_kwargs': {},
    ...     'squeeze_output': True,
    ... }
    

  • mlp_net_kwargs (dict, optional) –

    MLP 的 kwargs。預設為

    >>> {
    ...     'in_features': None,
    ...     'out_features': 1,
    ...     'depth': 2,
    ...     'num_cells': 200,
    ...     'activation_class': nn.ELU,
    ...     'bias_last_layer': True,
    ... }
    

  • use_avg_pooling (bool, optional) – 如果 True,則會使用 AvgPooling 層來聚合輸出。預設值為 True

  • device (torch.device, optional) – 在其上建立模組的裝置。

範例

>>> from torchrl.modules import DdpgCnnQNet
>>> import torch
>>> net = DdpgCnnQNet()
>>> print(net)
DdpgCnnQNet(
  (convnet): ConvNet(
    (0): LazyConv2d(0, 32, kernel_size=(8, 8), stride=(4, 4))
    (1): ELU(alpha=1.0)
    (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
    (3): ELU(alpha=1.0)
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): ELU(alpha=1.0)
    (6): AdaptiveAvgPool2d(output_size=(1, 1))
    (7): Squeeze2dLayer()
  )
  (mlp): MLP(
    (0): LazyLinear(in_features=0, out_features=200, bias=True)
    (1): ELU(alpha=1.0)
    (2): Linear(in_features=200, out_features=200, bias=True)
    (3): ELU(alpha=1.0)
    (4): Linear(in_features=200, out_features=1, bias=True)
  )
)
>>> obs = torch.zeros(1, 3, 64, 64)
>>> action = torch.zeros(1, 4)
>>> value = net(obs, action)
>>> print(value.shape)
torch.Size([1, 1])
forward(observation: Tensor, action: Tensor) Tensor[source]

定義每次呼叫時執行的計算。

應該被所有子類別覆寫。

注意

雖然 forward pass 的方法需要在此函數中定義,但之後應該呼叫 Module 實例,而不是這個函數,因為前者會處理已註冊的 hooks 的執行,而後者會默默地忽略它們。

文件

存取 PyTorch 的完整開發人員文件

檢視文件

教學課程

取得針對初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並取得您的問題解答

檢視資源