捷徑

TensorDictModule

作者: Nicolas Dufour, Vincent Moens

在本教學中,您將學習如何使用 TensorDictModuleTensorDictSequential 建立通用且可重複使用的模組,這些模組可以接受 TensorDict 作為輸入。

為了方便將 TensorDict 類別與 Module 一起使用,tensordict 提供了兩者之間的一個介面,命名為 TensorDictModule

TensorDictModule 類別是一個 Module,它在被呼叫時會接收一個 TensorDict 作為輸入。它將讀取一系列輸入鍵,將它們作為輸入傳遞給包裝的模組或函式,並在執行完成後將輸出寫入同一個 tensordict 中。

使用者有責任定義要讀取作為輸入和輸出的鍵。

import torch
import torch.nn as nn
from tensordict import TensorDict
from tensordict.nn import TensorDictModule, TensorDictSequential

簡單範例:編碼循環層

下面舉例說明了 TensorDictModule 的最簡單用法。 如果起初看起來使用這個類別會引入不必要的複雜性,我們稍後會看到,這個 API 允許使用者以程式設計的方式將模組連接在一起,在模組之間緩存值或以程式設計的方式構建一個模組。其中一個最簡單的例子是 ResNet 等架構中的循環模組,其中模組的輸入被緩存並添加到一個小型多層感知器 (MLP) 的輸出中。

首先,讓我們考慮如何對 MLP 進行分塊,並使用 tensordict.nn 進行編碼。堆疊的第一層可能是 Linear 層,它接收一個條目作為輸入(讓我們稱之為 x)並輸出另一個條目(我們將稱之為 y)。

為了饋送到我們的模組,我們有一個 TensorDict 實例,其中只有一個條目 "x"

tensordict = TensorDict(
    x=torch.randn(5, 3),
    batch_size=[5],
)

現在,我們使用 tensordict.nn.TensorDictModule 構建我們的簡單模組。預設情況下,這個類別會將輸入寫入到輸入 tensordict 中(意味著條目會被寫入到與輸入相同的 tensordict 中,而不是條目被覆寫!),因此我們不需要明確指出輸出是什麼

linear0 = TensorDictModule(nn.Linear(3, 128), in_keys=["x"], out_keys=["linear0"])
linear0(tensordict)

assert "linear0" in tensordict

如果模組輸出多個張量(或 tensordict!),則它們的條目必須以正確的順序傳遞給 TensorDictModule

支援可呼叫物件

在設計模型時,經常會遇到希望將任意的非參數函數整合到網路中的情況。 例如,您可能希望在將影像傳遞到卷積網路或視覺轉換器時,對影像的維度進行排列,或將數值除以 255。 有幾種方法可以做到這一點:例如,您可以使用 forward_hook,或設計一個新的 Module 來執行此操作。

TensorDictModule 可以與任何可呼叫物件(而不僅僅是模組)一起使用,這使得將任意函數整合到模組中變得容易。 例如,讓我們看看如何整合 relu 激活函數,而無需使用 ReLU 模組

relu0 = TensorDictModule(torch.relu, in_keys=["linear0"], out_keys=["relu0"])

堆疊模組

我們的 MLP 並非由單個層組成,因此我們現在需要向其中添加另一個層。 這一層將是一個激活函數,例如 ReLU。 我們可以使用 TensorDictSequential 來堆疊這個模組和前一個模組。

注意

這就是 tensordict.nn 的真正威力所在:與 Sequential 不同,TensorDictSequential 會將所有先前的輸入和輸出保存在記憶體中(並且可以在之後將它們過濾掉),這使得可以輕鬆地即時且以程式化的方式建立複雜的網路結構。

block0 = TensorDictSequential(linear0, relu0)

block0(tensordict)
assert "linear0" in tensordict
assert "relu0" in tensordict

我們可以重複這個邏輯來獲得一個完整的 MLP

linear1 = TensorDictModule(nn.Linear(128, 128), in_keys=["relu0"], out_keys=["linear1"])
relu1 = TensorDictModule(nn.ReLU(), in_keys=["linear1"], out_keys=["relu1"])
linear2 = TensorDictModule(nn.Linear(128, 3), in_keys=["relu1"], out_keys=["linear2"])
block1 = TensorDictSequential(linear1, relu1, linear2)

多個輸入鍵

殘差網路的最後一步是將輸入添加到最後一個線性層的輸出。 無需為此編寫特殊的 Module 子類! TensorDictModule 也可以用於包裝簡單的函數

residual = TensorDictModule(
    lambda x, y: x + y, in_keys=["x", "linear2"], out_keys=["y"]
)

現在我們可以將 block0block1residual 組合在一起,以構成一個完整的殘差區塊

block = TensorDictSequential(block0, block1, residual)
block(tensordict)
assert "y" in tensordict

一個真正的問題可能是用作輸入的 tensordict 中條目的累積:在某些情況下(例如,當需要梯度時),無論如何都會緩存中間值,但情況並非總是如此,讓垃圾收集器知道可以丟棄某些條目會很有用。 tensordict.nn.TensorDictModuleBase 及其子類(包括 tensordict.nn.TensorDictModuletensordict.nn.TensorDictSequential)可以選擇在執行後過濾其輸出鍵。 為此,只需呼叫 tensordict.nn.TensorDictModuleBase.select_out_keys 方法即可。 這將就地更新模組,並且所有不需要的條目都將被丟棄

block.select_out_keys("y")

tensordict = TensorDict(x=torch.randn(1, 3), batch_size=[1])
block(tensordict)
assert "y" in tensordict

assert "linear1" not in tensordict

但是,輸入鍵會被保留

assert "x" in tensordict

附帶一提,也可以將 selected_out_keys 傳遞給 tensordict.nn.TensorDictSequential,以避免單獨呼叫此方法。

在沒有 tensordict 的情況下使用 TensorDictModule

tensordict.nn.TensorDictSequential 提供的隨時構建複雜架構的機會並不意味著必須切換到 tensordict 來表示資料。 由於 dispatch,來自 tensordict.nn 的模組也支援與條目名稱匹配的引數和關鍵字引數

x = torch.randn(1, 3)
y = block(x=x)
assert isinstance(y, torch.Tensor)

在底層,dispatch 會重建 tensordict,執行模組,然後解構它。 這可能會導致一些額外負擔,但正如我們稍後將看到的,有一種解決方案可以擺脫這種情況。

執行時間

在執行時,tensordict.nn.TensorDictModuletensordict.nn.TensorDictSequential 的確會產生一些額外負擔,因為它們需要從 tensordict 讀取和寫入。 但是,我們可以通過使用 compile() 來大大減少這種額外負擔。 為此,讓我們比較一下此程式碼的三個版本,無論是否進行編譯

class ResidualBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear0 = nn.Linear(3, 128)
        self.relu0 = nn.ReLU()
        self.linear1 = nn.Linear(128, 128)
        self.relu1 = nn.ReLU()
        self.linear2 = nn.Linear(128, 3)

    def forward(self, x):
        y = self.linear0(x)
        y = self.relu0(y)
        y = self.linear1(y)
        y = self.relu1(y)
        return self.linear2(y) + x


print("Without compile")
x = torch.randn(256, 3)
block_notd = ResidualBlock()
block_tdm = TensorDictModule(block_notd, in_keys=["x"], out_keys=["y"])
block_tds = block

from torch.utils.benchmark import Timer

print(
    f"Regular: {Timer('block_notd(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)
print(
    f"TDM: {Timer('block_tdm(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)
print(
    f"Sequential: {Timer('block_tds(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)

print("Compiled versions")
block_notd_c = torch.compile(block_notd, mode="reduce-overhead")
for _ in range(5):  # warmup
    block_notd_c(x)
print(
    f"Compiled regular: {Timer('block_notd_c(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)
block_tdm_c = torch.compile(block_tdm, mode="reduce-overhead")
for _ in range(5):  # warmup
    block_tdm_c(x=x)
print(
    f"Compiled TDM: {Timer('block_tdm_c(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)
block_tds_c = torch.compile(block_tds, mode="reduce-overhead")
for _ in range(5):  # warmup
    block_tds_c(x=x)
print(
    f"Compiled sequential: {Timer('block_tds_c(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)
Without compile
Regular:  219.5165 us
TDM:  260.3091 us
Sequential:  375.0590 us
Compiled versions
Compiled regular:  326.0555 us
Compiled TDM:  333.1850 us
Compiled sequential:  342.4750 us

正如人們所看到的,TensorDictSequential 引入的額外負擔已完全消除。

TensorDictModule 的使用須知

  • 請勿在來自 tensordict.nn 的模組周圍使用 Sequence。這會破壞輸入/輸出鍵的結構。請始終嘗試依賴 nn:TensorDictSequential

  • 請勿將輸出 tensordict 指定給新變數,因為輸出 tensordict 只是就地修改的輸入。嚴格來說,不禁止指定新變數名稱,但這表示您可能希望在刪除其中一個變數時,兩個變數都消失,但實際上垃圾收集器仍然會看到工作區中的張量,並且不會釋放任何記憶體。

    >>> tensordict = module(tensordict)  # ok!
    >>> tensordict_out = module(tensordict)  # don't!
    

處理分佈:ProbabilisticTensorDictModule

ProbabilisticTensorDictModule 是一個非參數模組,表示機率分佈。 分佈參數從 tensordict 輸入讀取,輸出寫入輸出 tensordict。 輸出根據某些規則進行採樣,由輸入 default_interaction_type 參數和 interaction_type() 全域函數指定。 如果它們衝突,則上下文管理器優先。

它可以與 TensorDictModule 結合使用,該模組使用 ProbabilisticTensorDictSequential 返回一個用分佈參數更新的 tensordict。 這是 TensorDictSequential 的一個特例,其最後一層是一個 ProbabilisticTensorDictModule 實例。

ProbabilisticTensorDictModule 負責建構分佈(透過 get_dist() 方法)和/或從這個分佈中採樣(透過對模組的常規 forward 呼叫)。相同的 get_dist() 方法在 ProbabilisticTensorDictSequential 中公開。

如果需要,可以在輸出 tensordict 中找到參數以及對數機率。

from tensordict.nn import (
    ProbabilisticTensorDictModule,
    ProbabilisticTensorDictSequential,
)
from tensordict.nn.distributions import NormalParamExtractor
from torch import distributions as dist

td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3])
net = torch.nn.GRUCell(4, 8)
net = TensorDictModule(net, in_keys=["input", "hidden"], out_keys=["hidden"])
extractor = NormalParamExtractor()
extractor = TensorDictModule(extractor, in_keys=["hidden"], out_keys=["loc", "scale"])
td_module = ProbabilisticTensorDictSequential(
    net,
    extractor,
    ProbabilisticTensorDictModule(
        in_keys=["loc", "scale"],
        out_keys=["action"],
        distribution_class=dist.Normal,
        return_log_prob=True,
    ),
)
print(f"TensorDict before going through module: {td}")
td_module(td)
print(f"TensorDict after going through module now as keys action, loc and scale: {td}")
TensorDict before going through module: TensorDict(
    fields={
        hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
TensorDict after going through module now as keys action, loc and scale: TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        sample_log_prob: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)

結論

我們已經了解了如何使用 tensordict.nn 來動態建構複雜的隨選神經網路架構。 這開啟了建構與模型簽名無關的管道的可能性,即編寫以靈活方式使用具有任意數量輸入或輸出的網路的通用程式碼。

我們還看到了 dispatch 如何能夠使用 tensordict.nn 來建構此類網路並使用它們,而無需直接求助於 TensorDict。 感謝 compile()tensordict.nn.TensorDictSequential 引入的額外負荷可以完全消除,為使用者留下一個簡潔、無 tensordict 版本的模組。

在下一個教學課程中,我們將了解如何使用 torch.export 來隔離模組並匯出它。

腳本的總運行時間: (0 分鐘 18.375 秒)

由 Sphinx-Gallery 產生的圖片集

文件

取得 PyTorch 的全面開發人員文件

查看文件

教學

取得初學者和高級開發人員的深入教學課程

查看教學課程

資源

尋找開發資源並獲得問題解答

查看資源