快捷方式

MultiAgentMLP

class torchrl.modules.MultiAgentMLP(n_agent_inputs: int | None, n_agent_outputs: int, n_agents: int, *, centralized: ~typing.Optional[bool] = None, share_params: ~typing.Optional[bool] = None, device: ~typing.Optional[~typing.Union[~torch.device, str, int]] = None, depth: ~typing.Optional[int] = None, num_cells: ~typing.Optional[~typing.Union[~typing.Sequence, int]] = None, activation_class: ~typing.Optional[~typing.Type[~torch.nn.modules.module.Module]] = <class 'torch.nn.modules.activation.Tanh'>, use_td_params: bool = True, **kwargs)[source]

多代理 MLP。

這是一個可用於多代理情境中的 MLP。 例如,作為策略或作為價值函數。 有關範例,請參閱 examples/multiagent

它期望形狀為 (*B, n_agents, n_agent_inputs) 的輸入。它傳回形狀為 (*B, n_agents, n_agent_outputs) 的輸出

如果 share_params 為 True,則將使用相同的 MLP 為所有代理執行前向傳遞(同質策略)。 否則,每個代理將使用不同的 MLP 來處理其輸入(異質策略)。

如果 centralized 為 True,則每個代理將使用所有代理的輸入來計算其輸出(n_agent_inputs * n_agents 將是一個代理的輸入數量)。 否則,每個代理只會使用其資料作為輸入。

參數:
  • n_agent_inputs (intNone) – 每個代理的輸入數量。 如果 None,則輸入的數量會在第一次呼叫時延遲實例化。

  • n_agent_outputs (int) – 每個代理的輸出數量。

  • n_agents (int) – 代理的數量。

關鍵字引數:
  • centralized (bool) – 如果 centralized 為 True,則每個代理將使用所有代理的輸入來計算其輸出(n_agent_inputs * n_agents 將是一個代理的輸入數量)。 否則,每個代理只會使用其資料作為輸入。

  • share_params (bool) – 如果 share_params 為 True,則將使用相同的 MLP 為所有代理執行前向傳遞(同質策略)。 否則,每個代理將使用不同的 MLP 來處理其輸入(異質策略)。

  • device (strtoech.device, optional) – 用於建立模組的裝置。

  • depth (int, optional) – 網路的深度。深度為 0 將產生一個具有所需輸入和輸出大小的單個線性層網路。長度為 1 將建立 2 個線性層,依此類推。如果未指示深度,則深度資訊應包含在 num_cells 參數中 (見下文)。如果 num_cells 是一個可迭代物件且指示了深度,則兩者應匹配:len(num_cells) 必須等於 depth。預設值:3。

  • num_cells (intSequence[int], optional) – 輸入和輸出之間每一層的單元數量。如果提供一個整數,則每一層將具有相同數量的單元。如果提供一個可迭代物件,則線性層的 out_features 將與 num_cells 的內容匹配。預設值:32。

  • activation_class (Type[nn.Module]) – 要使用的激活類別。預設值:nn.Tanh。

  • use_td_params (bool, optional) – 如果 True,則可以在 self.params 中找到參數,它是一個 TensorDictParams 物件(它同時繼承自 TensorDictnn.Module)。如果 False,則參數包含在 self._empty_net 中。總體而言,這兩種方法應該大致相同但不可互換:例如,使用 use_td_params=True 建立的 state_dictuse_td_params=False 時無法使用。

  • **kwargs – 用於 torchrl.modules.models.MLP 的可以傳遞以自訂 MLPs。

注意

要使用 torch.nn.init 模組初始化 MARL 模組參數,請參閱 get_stateful_net()from_stateful_net() 方法。

範例

>>> from torchrl.modules import MultiAgentMLP
>>> import torch
>>> n_agents = 6
>>> n_agent_inputs=3
>>> n_agent_outputs=2
>>> batch = 64
>>> obs = torch.zeros(batch, n_agents, n_agent_inputs)
>>> # instantiate a local network shared by all agents (e.g. a parameter-shared policy)
>>> mlp = MultiAgentMLP(
...     n_agent_inputs=n_agent_inputs,
...     n_agent_outputs=n_agent_outputs,
...     n_agents=n_agents,
...     centralized=False,
...     share_params=True,
...     depth=2,
... )
>>> print(mlp)
MultiAgentMLP(
  (agent_networks): ModuleList(
    (0): MLP(
      (0): Linear(in_features=3, out_features=32, bias=True)
      (1): Tanh()
      (2): Linear(in_features=32, out_features=32, bias=True)
      (3): Tanh()
      (4): Linear(in_features=32, out_features=2, bias=True)
    )
  )
)
>>> assert mlp(obs).shape == (batch, n_agents, n_agent_outputs)
Now let's instantiate a centralized network shared by all agents (e.g. a centalised value function)
>>> mlp = MultiAgentMLP(
...     n_agent_inputs=n_agent_inputs,
...     n_agent_outputs=n_agent_outputs,
...     n_agents=n_agents,
...     centralized=True,
...     share_params=True,
...     depth=2,
... )
>>> print(mlp)
MultiAgentMLP(
  (agent_networks): ModuleList(
    (0): MLP(
      (0): Linear(in_features=18, out_features=32, bias=True)
      (1): Tanh()
      (2): Linear(in_features=32, out_features=32, bias=True)
      (3): Tanh()
      (4): Linear(in_features=32, out_features=2, bias=True)
    )
  )
)
We can see that the input to the first layer is n_agents * n_agent_inputs,
this is because in the case the net acts as a centralized mlp (like a single huge agent)
>>> assert mlp(obs).shape == (batch, n_agents, n_agent_outputs)
Outputs will be identical for all agents.
Now we can do both examples just shown but with an independent set of parameters for each agent
Let's show the centralized=False case.
>>> mlp = MultiAgentMLP(
...     n_agent_inputs=n_agent_inputs,
...     n_agent_outputs=n_agent_outputs,
...     n_agents=n_agents,
...     centralized=False,
...     share_params=False,
...     depth=2,
... )
>>> print(mlp)
MultiAgentMLP(
  (agent_networks): ModuleList(
    (0-5): 6 x MLP(
      (0): Linear(in_features=3, out_features=32, bias=True)
      (1): Tanh()
      (2): Linear(in_features=32, out_features=32, bias=True)
      (3): Tanh()
      (4): Linear(in_features=32, out_features=2, bias=True)
    )
  )
)
We can see that this is the same as in the first example, but now we have 6 MLPs, one per agent!
>>> assert mlp(obs).shape == (batch, n_agents, n_agent_outputs)

文件

取得 PyTorch 的完整開發者文件

檢視文件

教學

取得初學者和進階開發人員的深入教學

檢視教學

資源

尋找開發資源並獲得您問題的解答

檢視資源