捷徑

MultiAgentNetBase

class torchrl.modules.MultiAgentNetBase(*, n_agents: int, centralized: Optional[bool] = None, share_params: Optional[bool] = None, agent_dim: Optional[int] = None, vmap_randomness: str = 'different', use_td_params: bool = True, **kwargs)[原始碼]

多代理網路的基底類別。

注意

要使用 torch.nn.init 模組初始化 MARL 模組參數,請參考 get_stateful_net()from_stateful_net() 方法。

forward(*inputs: Tuple[Tensor]) Tensor[原始碼]

定義每次呼叫時執行的計算。

應該由所有子類別覆寫。

注意

雖然 forward pass 的方法需要在這個函數中定義,但應該呼叫 Module 實例,而不是呼叫這個函數,因為前者會處理已註冊的 hooks,而後者會靜默地忽略它們。

from_stateful_net(stateful_net: Module)[原始碼]

根據網路的有狀態版本填充參數。

有關如何收集網路的有狀態版本的詳細資訊,請參閱 get_stateful_net()

參數:

stateful_net (nn.Module) – 應從中收集參數的有狀態網路。

get_stateful_net(copy: bool = True)[原始碼]

傳回網路的有狀態版本。

這可用於初始化參數。

這樣的網路通常無法直接呼叫,需要呼叫 vmap 才能執行。

參數:

copy (bool, optional) – 如果 True,則會建立網路的深層副本。預設值為 True

如果在原地修改參數(建議),則無需將參數複製回 MARL 模組。有關如何使用已重新初始化到異地的參數重新填充 MARL 模型的詳細資訊,請參閱 from_stateful_net()

範例

>>> from torchrl.modules import MultiAgentMLP
>>> import torch
>>> n_agents = 6
>>> n_agent_inputs=3
>>> n_agent_outputs=2
>>> batch = 64
>>> obs = torch.zeros(batch, n_agents, n_agent_inputs)
>>> mlp = MultiAgentMLP(
...     n_agent_inputs=n_agent_inputs,
...     n_agent_outputs=n_agent_outputs,
...     n_agents=n_agents,
...     centralized=False,
...     share_params=False,
...     depth=2,
... )
>>> snet = mlp.get_stateful_net()
>>> def init(module):
...     if hasattr(module, "weight"):
...         torch.nn.init.kaiming_normal_(module.weight)
>>> snet.apply(init)
>>> # If the module has been updated out-of-place (not the case here) we can reset the params
>>> mlp.from_stateful_net(snet)
reset_parameters()[原始碼]

重設模型的參數。

文件

存取 PyTorch 的完整開發人員文件

檢視文件

教學

取得適合初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得您問題的解答

檢視資源