DuelingCnnDQNet¶
- class torchrl.modules.DuelingCnnDQNet(out_features: int, out_features_value: int = 1, cnn_kwargs: Optional[dict] = None, mlp_kwargs: Optional[dict] = None, device: Optional[Union[device, str, int]] = None)[原始碼]¶
Dueling CNN Q-網路。
發表於 https://arxiv.org/abs/1511.06581
- 參數:
out_features (int) – 優勢網路的特徵數量。
out_features_value (int) – 價值網路的特徵數量。
cnn_kwargs (dict 或 dicts 列表, 選用) –
特徵網路的 kwargs。預設為
>>> cnn_kwargs = { ... 'num_cells': [32, 64, 64], ... 'strides': [4, 2, 1], ... 'kernels': [8, 4, 3], ... }
mlp_kwargs (dict 或 dicts 列表, 選用) –
優勢和價值網路的 kwargs。預設為
>>> mlp_kwargs = { ... "depth": 1, ... "activation_class": nn.ELU, ... "num_cells": 512, ... "bias_last_layer": True, ... }
device (torch.device, 選用) – 在其上建立模組的裝置。
範例
>>> import torch >>> from torchrl.modules import DuelingCnnDQNet >>> net = DuelingCnnDQNet(out_features=20) >>> print(net) DuelingCnnDQNet( (features): ConvNet( (0): LazyConv2d(0, 32, kernel_size=(8, 8), stride=(4, 4)) (1): ELU(alpha=1.0) (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2)) (3): ELU(alpha=1.0) (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1)) (5): ELU(alpha=1.0) (6): SquashDims() ) (advantage): MLP( (0): LazyLinear(in_features=0, out_features=512, bias=True) (1): ELU(alpha=1.0) (2): Linear(in_features=512, out_features=20, bias=True) ) (value): MLP( (0): LazyLinear(in_features=0, out_features=512, bias=True) (1): ELU(alpha=1.0) (2): Linear(in_features=512, out_features=1, bias=True) ) ) >>> x = torch.zeros(1, 3, 64, 64) >>> y = net(x) >>> print(y.shape) torch.Size([1, 20])