快捷方式

TensorDictModule

class tensordict.nn.TensorDictModule(*args, **kwargs)

TensorDictModule 是 nn.Module 的 Python 包裝器,用於讀取和寫入 TensorDict。

參數:
  • module (Callable) – 可呼叫物件,通常是 torch.nn.Module,用於將輸入映射到輸出參數空間。其 forward 方法可以傳回單個張量、張量元組甚至字典。在後一種情況下,TensorDictModule 的輸出鍵將用於填充輸出 tensordict (即 out_keys 中存在的鍵應該存在於 module forward 方法傳回的字典中)。

  • in_keys (NestedKeys 的可迭代物件, Dict[NestedStr, str]) – 要從輸入 tensordict 讀取並傳遞給模組的鍵。如果它包含多個元素,則這些值將以 in_keys 可迭代物件給定的順序傳遞。如果 in_keys 是一個字典,則其鍵必須對應於要在 tensordict 中讀取的鍵,並且其值必須與函數簽名中的關鍵字引數的名稱相符。

  • out_keys (str 的可迭代物件) – 要寫入輸入 tensordict 的鍵。out_keys 的長度必須與嵌入式模組傳回的張量數量相符。使用 "_" 作為鍵可以避免將張量寫入輸出。

關鍵字引數:

inplace (boolstring, optional) –

如果 True (預設),則模組的輸出會寫入提供給 forward() 方法的 tensordict 中。如果 False,則會建立一個新的 TensorDict,其批次大小為空,且沒有裝置。如果 "empty",則會使用 empty() 來建立輸出 tensordict。

注意

如果 inplace=False 且傳遞給模組的 tensordict 是另一個 TensorDictBase 子類別 (而不是 TensorDict),則輸出仍然會是 TensorDict 實例。其批次大小將為空,且沒有裝置。設定為 "empty" 以取得相同的 TensorDictBase 子類型、相同的批次大小和裝置。在執行階段使用 tensordict_out (請參閱下方) 以更精細地控制輸出。

注意

如果 inplace=False 且將 tensordict_out 傳遞給 forward() 方法,則 tensordict_out 將會優先。 這是取得一個 tensordict_out 的方式,即使傳遞給模組的 tensordict 是另一個 TensorDictBase 子類別,而非 TensorDict,輸出結果仍然會是 TensorDict 實例。

將神經網路嵌入 TensorDictModule 僅需要指定輸入和輸出鍵。 TensorDictModule 支援函數式和常規的 nn.Module 物件。 在函數式的情況下,必須指定 'params' (和 'buffers') 關鍵字參數。

範例

>>> from tensordict import TensorDict
>>> # one can wrap regular nn.Module
>>> module = TensorDictModule(nn.Transformer(128), in_keys=["input", "tgt"], out_keys=["out"])
>>> input = torch.ones(2, 3, 128)
>>> tgt = torch.zeros(2, 3, 128)
>>> data = TensorDict({"input": input, "tgt": tgt}, batch_size=[2, 3])
>>> data = module(data)
>>> print(data)
TensorDict(
    fields={
        input: Tensor(shape=torch.Size([2, 3, 128]), device=cpu, dtype=torch.float32, is_shared=False),
        out: Tensor(shape=torch.Size([2, 3, 128]), device=cpu, dtype=torch.float32, is_shared=False),
        tgt: Tensor(shape=torch.Size([2, 3, 128]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([2, 3]),
    device=None,
    is_shared=False)

我們也可以直接傳遞 tensors

範例

>>> out = module(input, tgt)
>>> assert out.shape == input.shape
>>> # we can also wrap regular functions
>>> module = TensorDictModule(lambda x: (x-1, x+1), in_keys=[("input", "x")], out_keys=[("output", "x-1"), ("output", "x+1")])
>>> module(TensorDict({("input", "x"): torch.zeros(())}, batch_size=[]))
TensorDict(
    fields={
        input: TensorDict(
            fields={
                x: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False),
        output: TensorDict(
            fields={
                x+1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                x-1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

我們可以使用 TensorDictModule 來填充 tensordict

範例

>>> module = TensorDictModule(lambda: torch.randn(3), in_keys=[], out_keys=["x"])
>>> print(module(TensorDict({}, batch_size=[])))
TensorDict(
    fields={
        x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

另一個特性是將字典作為輸入鍵傳遞,以控制值到特定關鍵字參數的分配。

範例

>>> module = TensorDictModule(lambda x, *, y: x+y,
...     in_keys={'1': 'x', '2': 'y'}, out_keys=['z'],
...     )
>>> td = module(TensorDict({'1': torch.ones(()), '2': torch.ones(())*2}, []))
>>> td['z']
tensor(3.)

對 tensordict 模組進行函數式呼叫很簡單

範例

>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule
>>> td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3,])
>>> module = torch.nn.GRUCell(4, 8)
>>> td_module = TensorDictModule(
...    module=module, in_keys=["input", "hidden"], out_keys=["output"]
... )
>>> params = TensorDict.from_module(td_module)
>>> # functional API
>>> with params.to_module(td_module):
...     td_functional = td_module(td.clone())
>>> print(td_functional)
TensorDict(
    fields={
        hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        output: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
在有狀態的情況下
>>> module = torch.nn.GRUCell(4, 8)
>>> td_module = TensorDictModule(
...    module=module, in_keys=["input", "hidden"], out_keys=["output"]
... )
>>> td_stateful = td_module(td.clone())
>>> print(td_stateful)
TensorDict(
    fields={
        hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        output: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
forward(tensordict: TensorDictBase = None, args=None, *, tensordict_out: tensordict.base.TensorDictBase | None = None, **kwargs: Any) TensorDictBase

當 tensordict 參數未設定時,kwargs 用於建立 TensorDict 的實例。

文件

取得 PyTorch 的完整開發者文件

檢視文件

教學

取得初學者和進階開發者的深度教學課程

檢視教學課程

資源

尋找開發資源並取得您問題的解答

檢視資源