概觀¶
TensorDict 使組織資料和編寫可重複使用的通用 PyTorch 程式碼變得容易。最初是為 TorchRL 開發的,我們已將其分離到一個單獨的函式庫中。
TensorDict 主要是一個字典,但也是一個類似張量的類別:它支援多個主要與形狀和儲存相關的張量操作。它被設計成可以有效地序列化或從節點到節點或進程到進程傳輸。最後,它附帶了自己的 tensordict.nn
模組,該模組與 functorch
相容,旨在使模型集成和參數操作更容易。
在本頁中,我們將說明 TensorDict
的動機,並提供一些它可以做的事情的範例。
動機¶
TensorDict 允許您編寫可在不同範例中重複使用的通用程式碼模組。例如,以下迴圈可以在大多數 SL、SSL、UL 和 RL 任務中重複使用。
>>> for i, tensordict in enumerate(dataset):
... # the model reads and writes tensordicts
... tensordict = model(tensordict)
... loss = loss_module(tensordict)
... loss.backward()
... optimizer.step()
... optimizer.zero_grad()
憑藉其 tensordict.nn
模組,該套件提供了許多工具,可以在程式碼庫中使用 TensorDict
,而幾乎不需要任何工作。
在多處理或分散式設定中,tensordict
允許您無縫地將資料分派給每個工作人員
>>> # creates batches of 10 datapoints
>>> splits = torch.arange(tensordict.shape[0]).split(10)
>>> for worker in range(workers):
... idx = splits[worker]
... pipe[worker].send(tensordict[idx])
TensorDict 提供的一些操作也可以通過 tree_map 完成,但複雜度更高
>>> td = TensorDict(
... {"a": torch.randn(3, 11), "b": torch.randn(3, 3)}, batch_size=3
... )
>>> regular_dict = {"a": td["a"], "b": td["b"]}
>>> td0, td1, td2 = td.unbind(0)
>>> # similar structure with pytree
>>> regular_dicts = tree_map(lambda x: x.unbind(0))
>>> regular_dict1, regular_dict2, regular_dict3 = [
... {"a": regular_dicts["a"][i], "b": regular_dicts["b"][i]}
... for i in range(3)]
嵌套案例更引人注目
>>> td = TensorDict(
... {"a": {"c": torch.randn(3, 11)}, "b": torch.randn(3, 3)}, batch_size=3
... )
>>> regular_dict = {"a": {"c": td["a", "c"]}, "b": td["b"]}
>>> td0, td1, td2 = td.unbind(0)
>>> # similar structure with pytree
>>> regular_dicts = tree_map(lambda x: x.unbind(0))
>>> regular_dict1, regular_dict2, regular_dict3 = [
... {"a": {"c": regular_dicts["a"]["c"][i]}, "b": regular_dicts["b"][i]}
... for i in range(3)
在應用 unbind 操作後,將輸出字典分解為三個結構相似的字典,當天真地使用 pytree 時,很快就會變得非常繁瑣。使用 tensordict,我們為想要 unbind 或分割嵌套結構的使用者提供了一個簡單的 API,而不是計算嵌套分割/ unbound 嵌套結構。
特性¶
TensorDict
是一個類似 dict 的張量容器。要實例化 TensorDict
,您必須指定鍵值對以及批次大小。TensorDict
中任何值的前導維度必須與批次大小相容。
>>> import torch
>>> from tensordict import TensorDict
>>> tensordict = TensorDict(
... {"zeros": torch.zeros(2, 3, 4), "ones": torch.ones(2, 3, 4, 5)},
... batch_size=[2, 3],
... )
設定或檢索值的語法與常規字典非常相似。
>>> zeros = tensordict["zeros"]
>>> tensordict["twos"] = 2 * torch.ones(2, 3)
也可以沿著它的 batch_size 索引一個 tensordict,這使得只需幾個字元即可獲得資料的同餘切片(請注意,使用 tree_map 和省略符號索引第 n 個前導維度需要更多程式碼)
>>> sub_tensordict = tensordict[..., :2]
也可以使用 inplace=True
的 set 方法或 set_
方法來就地更新內容。前者是後者的容錯版本:如果找不到匹配的鍵,它將寫入一個新的鍵。
現在可以集體操作 TensorDict 的內容。例如,要將所有內容放置到特定裝置上,只需執行
>>> tensordict = tensordict.to("cuda:0")
要重塑批次維度,可以執行
>>> tensordict = tensordict.reshape(6)
該類別支援許多其他操作,包括 squeeze、unsqueeze、view、permute、unbind、stack、cat 等等。如果沒有某個操作,TensorDict.apply 方法通常會提供所需的解決方案。
命名維度¶
TensorDict 和相關類別也支援維度名稱。名稱可以在建構時給出,也可以稍後細化。語義與 torch.Tensor 維度名稱特性類似
>>> tensordict = TensorDict({}, batch_size=[3, 4], names=["a", None])
>>> tensordict.refine_names(..., "b")
>>> tensordict.names = ["z", "y"]
>>> tensordict.rename("m", "n")
>>> tensordict.rename(m="h")
嵌套 TensorDicts¶
TensorDict
中的值本身可以是 TensorDicts(以下範例中的嵌套字典將轉換為嵌套 TensorDicts)。
>>> tensordict = TensorDict(
... {
... "inputs": {
... "image": torch.rand(100, 28, 28),
... "mask": torch.randint(2, (100, 28, 28), dtype=torch.uint8)
... },
... "outputs": {"logits": torch.randn(100, 10)},
... },
... batch_size=[100],
... )
可以使用字串元組來存取或設定嵌套鍵
>>> image = tensordict["inputs", "image"]
>>> logits = tensordict.get(("outputs", "logits")) # alternative way to access
>>> tensordict["outputs", "probabilities"] = torch.sigmoid(logits)
延遲評估¶
對 TensorDict
的一些操作會延遲執行,直到存取項目為止。例如,堆疊、擠壓、取消擠壓、置換批次維度和建立視圖不會立即在 TensorDict
的所有內容上執行。相反,它們會在存取 TensorDict
中的值時延遲執行。如果 TensorDict
包含許多值,這可以節省許多不必要的計算。
>>> tensordicts = [TensorDict({
... "a": torch.rand(10),
... "b": torch.rand(10, 1000, 1000)}, [10])
... for _ in range(3)]
>>> stacked = torch.stack(tensordicts, 0) # no stacking happens here
>>> stacked_a = stacked["a"] # we stack the a values, b values are not stacked
它還具有以下優點:我們可以操作堆疊中的原始 tensordicts
>>> stacked["a"] = torch.zeros_like(stacked["a"])
>>> assert (tensordicts[0]["a"] == 0).all()
需要注意的是,get 方法現在變成了一個昂貴的操作,如果重複多次,可能會導致一些額外負擔。可以透過在執行 stack 之後簡單地呼叫 tensordict.contiguous() 來避免這種情況。為了進一步緩解這個問題,TensorDict 配備了自己的元數據類別 (MetaTensor),它可以追蹤字典中每個條目的類型、形狀、dtype 和 device,而無需執行昂貴的操作。
延遲預分配¶
假設我們有一個函數 foo() -> TensorDict,並且我們執行類似以下的操作
>>> tensordict = TensorDict({}, batch_size=[N])
>>> for i in range(N):
... tensordict[i] = foo()
當 i == 0
時,空的 TensorDict
將會自動填充具有批次大小 N 的空 tensors。在迴圈的後續迭代中,更新將會全部就地寫入。
TensorDictModule¶
為了方便將 TensorDict
整合到自己的程式碼庫中,我們提供了一個 tensordict.nn 套件,允許使用者將 TensorDict
實例傳遞給 nn.Module
物件。
TensorDictModule
包裝了 nn.Module
並接受單個 TensorDict
作為輸入。您可以指定底層模組應該從哪裡獲取其輸入,以及應該將其輸出寫入到哪裡。這是我們能夠編寫可重複使用、通用的高階程式碼(例如動機章節中的訓練迴圈)的關鍵原因。
>>> from tensordict.nn import TensorDictModule
>>> class Net(nn.Module):
... def __init__(self):
... super().__init__()
... self.linear = nn.LazyLinear(1)
...
... def forward(self, x):
... logits = self.linear(x)
... return logits, torch.sigmoid(logits)
>>> module = TensorDictModule(
... Net(),
... in_keys=["input"],
... out_keys=[("outputs", "logits"), ("outputs", "probabilities")],
... )
>>> tensordict = TensorDict({"input": torch.randn(32, 100)}, [32])
>>> tensordict = module(tensordict)
>>> # outputs can now be retrieved from the tensordict
>>> logits = tensordict["outputs", "logits"]
>>> probabilities = tensordict.get(("outputs", "probabilities"))
為了促進這個類別的使用,也可以將 tensors 作為 kwargs 傳遞
>>> tensordict = module(input=torch.randn(32, 100))
這將返回一個與前一個程式碼框中相同的 TensorDict
。
多個 PyTorch 使用者的主要痛點是 nn.Sequential 無法處理具有多個輸入的模組。使用基於 key 的圖可以輕鬆解決這個問題,因為序列中的每個節點都知道需要讀取哪些資料以及將其寫入到哪裡。
為此,我們提供了 TensorDictSequential
類別,它通過一系列 TensorDictModules
傳遞資料。序列中的每個模組都從原始 TensorDict
中獲取其輸入並將其輸出寫入到原始 TensorDict
,這意味著序列中的模組可以忽略其前任的輸出,或者根據需要從 tensordict 中獲取額外的輸入。這是一個例子。
>>> class Net(nn.Module):
... def __init__(self, input_size=100, hidden_size=50, output_size=10):
... super().__init__()
... self.fc1 = nn.Linear(input_size, hidden_size)
... self.fc2 = nn.Linear(hidden_size, output_size)
...
... def forward(self, x):
... x = torch.relu(self.fc1(x))
... return self.fc2(x)
...
... class Masker(nn.Module):
... def forward(self, x, mask):
... return torch.softmax(x * mask, dim=1)
>>> net = TensorDictModule(
... Net(), in_keys=[("input", "x")], out_keys=[("intermediate", "x")]
... )
>>> masker = TensorDictModule(
... Masker(),
... in_keys=[("intermediate", "x"), ("input", "mask")],
... out_keys=[("output", "probabilities")],
... )
>>> module = TensorDictSequential(net, masker)
>>> tensordict = TensorDict(
... {
... "input": TensorDict(
... {"x": torch.rand(32, 100), "mask": torch.randint(2, size=(32, 10))},
... batch_size=[32],
... )
... },
... batch_size=[32],
... )
>>> tensordict = module(tensordict)
>>> intermediate_x = tensordict["intermediate", "x"]
>>> probabilities = tensordict["output", "probabilities"]
在這個範例中,第二個模組將第一個模組的輸出與儲存在 TensorDict
中 (“inputs”, “mask”) 下的遮罩結合在一起。
TensorDictSequential
提供了許多其他功能:可以通過查詢 in_keys 和 out_keys 屬性來訪問輸入和輸出鍵的列表。也可以透過使用所需的輸入和輸出鍵集合查詢 select_subsequence()
來請求子圖。這將返回另一個 TensorDictSequential
,其中僅包含滿足這些要求所必需的模組。TensorDictModule
也與 vmap
和其他 functorch
功能相容。
函數式編程¶
我們提供了一個 API,用於將 TensorDict
與 functorch
結合使用。 例如,TensorDict
使連接模型權重以進行模型集成變得容易
>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(separator=".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(separator=".")
>>> params = make_functional(model)
>>> # params provided by make_functional match state_dict:
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params) # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])
如果沒有更快,則函數式 API 與 functorch
中實作的當前 FunctionalModule
相當。