快捷鍵

MultiAgentConvNet

class torchrl.modules.MultiAgentConvNet(n_agents: int, centralized: ~typing.Optional[bool] = None, share_params: ~typing.Optional[bool] = None, *, in_features: ~typing.Optional[int] = None, device: ~typing.Optional[~typing.Union[~torch.device, str, int]] = None, num_cells: ~typing.Optional[~typing.Sequence[int]] = None, kernel_sizes: ~typing.Union[~typing.Sequence[~typing.Union[int, ~typing.Sequence[int]]], int] = 5, strides: ~typing.Union[~typing.Sequence, int] = 2, paddings: ~typing.Union[~typing.Sequence, int] = 0, activation_class: ~typing.Type[~torch.nn.modules.module.Module] = <class 'torch.nn.modules.activation.ELU'>, use_td_params: bool = True, **kwargs)[原始碼]

多代理 CNN。

在 MARL 設定中,代理人可能會或可能不會共享相同的動作策略:我們說參數可以共享或不共享。類似地,網路可以採用整個觀察空間(跨代理人),也可以基於每個代理人來計算其輸出,我們分別將其稱為「集中式」和「非集中式」。

它期望形狀為 (*B, n_agents, channels, x, y) 的輸入。

注意

若要使用 torch.nn.init 模組初始化 MARL 模組參數,請參閱 get_stateful_net()from_stateful_net() 方法。

參數:
  • n_agents (int) – 代理人數量。

  • centralized (bool) – 若為 True,則每個 agent 將使用所有 agent 的輸入來計算其輸出,導致輸入的形狀為 (*B, n_agents * channels, x, y)。 否則,每個 agent 將僅使用其自身的資料作為輸入。

  • share_params (bool) – 若為 True,則將使用相同的 ConvNet 來為所有 agent 進行前向傳遞(同質策略)。 否則,每個 agent 將使用不同的 ConvNet 來處理其輸入(異質策略)。

關鍵字參數:
  • in_features (int, optional) – 輸入特徵維度。 如果保留為 None,則使用 lazy module。

  • device (str or torch.device, optional) – 在其上建立 module 的裝置。

  • num_cells (int or Sequence[int], optional) – 輸入和輸出之間每層的 cell 數量。 如果提供整數,則每一層將具有相同數量的 cell。 如果提供可迭代物件,則線性層的 out_features 將與 num_cells 的內容相符。

  • kernel_sizes (int, Sequence[Union[int, Sequence[int]]]) – 卷積網路的 Kernel size。 預設為 5

  • strides (int or Sequence[int]) – 卷積網路的步幅。 如果是可迭代物件,則長度必須與深度匹配,該深度由 num_cells 或 depth 參數定義。 預設為 2

  • activation_class (Type[nn.Module]) – 要使用的 activation 類別。 預設為 torch.nn.ELU

  • use_td_params (bool, optional) – 如果 True,則可以在 self.params 中找到參數,它是 TensorDictParams 物件(繼承自 TensorDictnn.Module)。 如果 False,則參數包含在 self._empty_net 中。 總而言之,這兩種方法應該大致相同但不可以互換:例如,使用 use_td_params=True 建立的 state_dictuse_td_params=False 時無法使用。

  • **kwargs – 用於 ConvNet,可用於自訂 ConvNet。

範例

>>> import torch
>>> from torchrl.modules import MultiAgentConvNet
>>> batch = (3,2)
>>> n_agents = 7
>>> channels, x, y = 3, 100, 100
>>> obs = torch.randn(*batch, n_agents, channels, x, y)
>>> # Let's consider a centralized network with shared parameters.
>>> cnn = MultiAgentConvNet(
...     n_agents,
...     centralized = True,
...     share_params = True
... )
>>> print(cnn)
MultiAgentConvNet(
    (agent_networks): ModuleList(
        (0): ConvNet(
        (0): LazyConv2d(0, 32, kernel_size=(5, 5), stride=(2, 2))
        (1): ELU(alpha=1.0)
        (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
        (3): ELU(alpha=1.0)
        (4): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
        (5): ELU(alpha=1.0)
        (6): SquashDims()
        )
    )
)
>>> result = cnn(obs)
>>> # The final dimension of the resulting tensor would be determined based on the layer definition arguments and the shape of input 'obs'.
>>> print(result.shape)
torch.Size([3, 2, 7, 2592])
>>> # Since both observations and parameters are shared, we expect all agents to have identical outputs (eg. for a value function)
>>> print(all(result[0,0,0] == result[0,0,1]))
True
>>> # Alternatively, a local network with parameter sharing (eg. decentralized weight sharing policy)
>>> cnn = MultiAgentConvNet(
...     n_agents,
...     centralized = False,
...     share_params = True
... )
>>> print(cnn)
MultiAgentConvNet(
    (agent_networks): ModuleList(
        (0): ConvNet(
        (0): Conv2d(4, 32, kernel_size=(5, 5), stride=(2, 2))
        (1): ELU(alpha=1.0)
        (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
        (3): ELU(alpha=1.0)
        (4): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
        (5): ELU(alpha=1.0)
        (6): SquashDims()
        )
    )
)
>>> print(result.shape)
torch.Size([3, 2, 7, 2592])
>>> # Parameters are shared but not observations, hence each agent has a different output.
>>> print(all(result[0,0,0] == result[0,0,1]))
False
>>> # Or multiple local networks identical in structure but with differing weights.
>>> cnn = MultiAgentConvNet(
...     n_agents,
...     centralized = False,
...     share_params = False
... )
>>> print(cnn)
MultiAgentConvNet(
    (agent_networks): ModuleList(
        (0-6): 7 x ConvNet(
        (0): Conv2d(4, 32, kernel_size=(5, 5), stride=(2, 2))
        (1): ELU(alpha=1.0)
        (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
        (3): ELU(alpha=1.0)
        (4): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
        (5): ELU(alpha=1.0)
        (6): SquashDims()
        )
    )
)
>>> print(result.shape)
torch.Size([3, 2, 7, 2592])
>>> print(all(result[0,0,0] == result[0,0,1]))
False
>>> # Or where inputs are shared but not parameters.
>>> cnn = MultiAgentConvNet(
...     n_agents,
...     centralized = True,
...     share_params = False
... )
>>> print(cnn)
MultiAgentConvNet(
    (agent_networks): ModuleList(
        (0-6): 7 x ConvNet(
        (0): Conv2d(28, 32, kernel_size=(5, 5), stride=(2, 2))
        (1): ELU(alpha=1.0)
        (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
        (3): ELU(alpha=1.0)
        (4): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
        (5): ELU(alpha=1.0)
        (6): SquashDims()
        )
    )
)
>>> print(result.shape)
torch.Size([3, 2, 7, 2592])
>>> print(all(result[0,0,0] == result[0,0,1]))
False

文件

取得 PyTorch 的完整開發人員文件

檢視文件

教學

取得初學者和進階開發人員的深入教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源