快捷方式

tensordict.nn.EnsembleModule

class tensordict.nn.EnsembleModule(*args, **kwargs)

該模組封裝一個模組並重複它以形成一個集成。

參數:
  • module (nn.Module) – 要複製和封裝的 nn.module。

  • num_copies (int) – 要建立的模組副本數量。

  • parameter_init_function (Callable) – 一個函式,它接受一個模組副本並初始化其參數。

  • expand_input (bool) – 是否擴展輸入 TensorDict 以匹配副本數量。 除非您將集成模組鏈接在一起,否則應設為 True,例如 EnsembleModule(cnn) -> EnsembleModule(mlp)。 如果為 False,EnsembleModule(mlp) 將預期先前的模組已經擴展了輸入。

範例

>>> import torch
>>> from torch import nn
>>> from tensordict.nn import TensorDictModule, EnsembleModule
>>> from tensordict import TensorDict
>>> net = nn.Sequential(nn.Linear(4, 32), nn.ReLU(), nn.Linear(32, 2))
>>> mod = TensorDictModule(net, in_keys=['a'], out_keys=['b'])
>>> ensemble = EnsembleModule(mod, num_copies=3)
>>> data = TensorDict({'a': torch.randn(10, 4)}, batch_size=[10])
>>> ensemble(data)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3, 10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([3, 10, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3, 10]),
    device=None,
    is_shared=False)

要堆疊 EnsembleModules 在一起,我們應該注意關閉第二個模組的 expand_input

範例

>>> import torch
>>> from tensordict.nn import TensorDictModule, TensorDictSequential, EnsembleModule
>>> from tensordict import TensorDict
>>> module = TensorDictModule(torch.nn.Linear(2,3), in_keys=['bork'], out_keys=['dork'])
>>> next_module = TensorDictModule(torch.nn.Linear(3,1), in_keys=['dork'], out_keys=['spork'])
>>> e0 = EnsembleModule(module, num_copies=4, expand_input=True)
>>> e1 = EnsembleModule(next_module, num_copies=4, expand_input=False)
>>> seq = TensorDictSequential(e0, e1)
>>> data = TensorDict({'bork': torch.randn(5,2)}, batch_size=[5])
>>> seq(data)
TensorDict(
    fields={
        bork: Tensor(shape=torch.Size([4, 5, 2]), device=cpu, dtype=torch.float32, is_shared=False),
        dork: Tensor(shape=torch.Size([4, 5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        spork: Tensor(shape=torch.Size([4, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([4, 5]),
    device=None,
    is_shared=False)

文件

存取 PyTorch 的全面開發人員文件

檢視文件

教學

獲取適合初學者和高級開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得解答

檢視資源