ConvNet¶
- class torchrl.modules.ConvNet(in_features: ~typing.Optional[int] = None, depth: ~typing.Optional[int] = None, num_cells: ~typing.Optional[~typing.Union[~typing.Sequence[int], int]] = None, kernel_sizes: ~typing.Union[~typing.Sequence[int], int] = 3, strides: ~typing.Union[~typing.Sequence[int], int] = 1, paddings: ~typing.Union[~typing.Sequence[int], int] = 0, activation_class: ~typing.Union[~typing.Type[~torch.nn.modules.module.Module], ~typing.Callable] = <class 'torch.nn.modules.activation.ELU'>, activation_kwargs: ~typing.Optional[~typing.Union[dict, ~typing.List[dict]]] = None, norm_class: ~typing.Optional[~typing.Union[~typing.Type[~torch.nn.modules.module.Module], ~typing.Callable]] = None, norm_kwargs: ~typing.Optional[~typing.Union[dict, ~typing.List[dict]]] = None, bias_last_layer: bool = True, aggregator_class: ~typing.Optional[~typing.Union[~typing.Type[~torch.nn.modules.module.Module], ~typing.Callable]] = <class 'torchrl.modules.models.utils.SquashDims'>, aggregator_kwargs: ~typing.Optional[dict] = None, squeeze_output: bool = False, device: ~typing.Optional[~typing.Union[~torch.device, str, int]] = None)[來源]¶
卷積神經網路。
- 參數:
in_features (int, optional) – 輸入特徵的數量。如果
None
,則使用LazyConv2d
模組作為第一層。;depth (int, optional) – 網路的深度。深度為 1 將產生一個具有所需輸入大小的單個線性層網路,其輸出大小等於 num_cells 參數的最後一個元素。如果沒有指定深度,則深度資訊應包含在
num_cells
參數中(見下文)。如果num_cells
是一個可迭代物件且指定了depth
,則兩者應匹配:len(num_cells)
必須等於depth
。num_cells (int or Sequence of int, optional) – 輸入和輸出之間每一層的 cell 數量。如果提供一個整數,則每一層都將具有相同數量的 cell。如果提供一個可迭代物件,則線性層的
out_features
將與 num_cells 的內容匹配。預設值為[32, 32, 32]
。kernel_sizes (int, sequence of int, optional) – 卷積網路的 Kernel 大小。如果是可迭代物件,則長度必須與深度匹配,深度由
num_cells
或 depth 參數定義。預設值為3
。strides (int or sequence of int, optional) – 卷積網路的步幅。如果是可迭代物件,則長度必須與深度匹配,深度由
num_cells
或 depth 參數定義。預設值為1
。activation_class (Type[nn.Module] or callable, optional) – 要使用的 activation 類別或建構子。預設值為
Tanh
。activation_kwargs (dict or list of dicts, optional) – 要與 activation 類別一起使用的 kwargs。也可以傳遞長度為
depth
的 kwargs 列表,每層一個元素。norm_class (Type or callable, optional) – 正規化類別或建構子(如果有的話)。
norm_kwargs (dict or list of dicts, optional) – 要與正規化層一起使用的 kwargs。也可以傳遞長度為
depth
的 kwargs 列表,每層一個元素。bias_last_layer (bool) – 如果
True
,則最後一個線性層將具有偏差參數。預設值為True
。aggregator_class (Type[nn.Module] or callable) – 要在鏈末端使用的 aggregator 類別或建構子。預設值為
torchrl.modules.utils.models.SquashDims
;aggregator_kwargs (dict, optional) –
aggregator_class
的 kwargs。squeeze_output (bool) – 是否應將輸出的單例維度擠壓掉。預設值為
False
。device (torch.device, optional) – 在其上建立模組的裝置。
範例
>>> # All of the following examples provide valid, working MLPs >>> cnet = ConvNet(in_features=3, depth=1, num_cells=[32,]) # MLP consisting of a single 3 x 6 linear layer >>> print(cnet) ConvNet( (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1)) (1): ELU(alpha=1.0) (2): SquashDims() ) >>> cnet = ConvNet(in_features=3, depth=4, num_cells=32) >>> print(cnet) ConvNet( (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1)) (1): ELU(alpha=1.0) (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1)) (3): ELU(alpha=1.0) (4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1)) (5): ELU(alpha=1.0) (6): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1)) (7): ELU(alpha=1.0) (8): SquashDims() ) >>> cnet = ConvNet(in_features=3, num_cells=[32, 33, 34, 35]) # defines the depth by the num_cells arg >>> print(cnet) ConvNet( (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1)) (1): ELU(alpha=1.0) (2): Conv2d(32, 33, kernel_size=(3, 3), stride=(1, 1)) (3): ELU(alpha=1.0) (4): Conv2d(33, 34, kernel_size=(3, 3), stride=(1, 1)) (5): ELU(alpha=1.0) (6): Conv2d(34, 35, kernel_size=(3, 3), stride=(1, 1)) (7): ELU(alpha=1.0) (8): SquashDims() ) >>> cnet = ConvNet(in_features=3, num_cells=[32, 33, 34, 35], kernel_sizes=[3, 4, 5, (2, 3)]) # defines kernels, possibly rectangular >>> print(cnet) ConvNet( (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1)) (1): ELU(alpha=1.0) (2): Conv2d(32, 33, kernel_size=(4, 4), stride=(1, 1)) (3): ELU(alpha=1.0) (4): Conv2d(33, 34, kernel_size=(5, 5), stride=(1, 1)) (5): ELU(alpha=1.0) (6): Conv2d(34, 35, kernel_size=(2, 3), stride=(1, 1)) (7): ELU(alpha=1.0) (8): SquashDims() )