MultiAgentNetBase¶
- class torchrl.modules.MultiAgentNetBase(*, n_agents: int, centralized: Optional[bool] = None, share_params: Optional[bool] = None, agent_dim: Optional[int] = None, vmap_randomness: str = 'different', use_td_params: bool = True, **kwargs)[原始碼]¶
多代理網路的基底類別。
注意
要使用 torch.nn.init 模組初始化 MARL 模組參數,請參考
get_stateful_net()
和from_stateful_net()
方法。- forward(*inputs: Tuple[Tensor]) Tensor [原始碼]¶
定義每次呼叫時執行的計算。
應該由所有子類別覆寫。
注意
雖然 forward pass 的方法需要在這個函數中定義,但應該呼叫
Module
實例,而不是呼叫這個函數,因為前者會處理已註冊的 hooks,而後者會靜默地忽略它們。
- from_stateful_net(stateful_net: Module)[原始碼]¶
根據網路的有狀態版本填充參數。
有關如何收集網路的有狀態版本的詳細資訊,請參閱
get_stateful_net()
。- 參數:
stateful_net (nn.Module) – 應從中收集參數的有狀態網路。
- get_stateful_net(copy: bool = True)[原始碼]¶
傳回網路的有狀態版本。
這可用於初始化參數。
這樣的網路通常無法直接呼叫,需要呼叫 vmap 才能執行。
- 參數:
copy (bool, optional) – 如果
True
,則會建立網路的深層副本。預設值為True
。
如果在原地修改參數(建議),則無需將參數複製回 MARL 模組。有關如何使用已重新初始化到異地的參數重新填充 MARL 模型的詳細資訊,請參閱
from_stateful_net()
。範例
>>> from torchrl.modules import MultiAgentMLP >>> import torch >>> n_agents = 6 >>> n_agent_inputs=3 >>> n_agent_outputs=2 >>> batch = 64 >>> obs = torch.zeros(batch, n_agents, n_agent_inputs) >>> mlp = MultiAgentMLP( ... n_agent_inputs=n_agent_inputs, ... n_agent_outputs=n_agent_outputs, ... n_agents=n_agents, ... centralized=False, ... share_params=False, ... depth=2, ... ) >>> snet = mlp.get_stateful_net() >>> def init(module): ... if hasattr(module, "weight"): ... torch.nn.init.kaiming_normal_(module.weight) >>> snet.apply(init) >>> # If the module has been updated out-of-place (not the case here) we can reset the params >>> mlp.from_stateful_net(snet)