tensordict.nn 套件¶
tensordict.nn 套件使得在 ML 管道中靈活使用 TensorDict 成為可能。
由於 TensorDict 將程式碼的部分轉換為基於鍵的結構,因此現在可以使用這些鍵作為掛鉤來建構複雜的圖形結構。基本構建模組是 TensorDictModule
,它使用輸入和輸出鍵的列表封裝 torch.nn.Module
實例
>>> from torch.nn import Transformer
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule
>>> import torch
>>> module = TensorDictModule(Transformer(), in_keys=["feature", "target"], out_keys=["prediction"])
>>> data = TensorDict({"feature": torch.randn(10, 11, 512), "target": torch.randn(10, 11, 512)}, [10, 11])
>>> data = module(data)
>>> print(data)
TensorDict(
fields={
feature: Tensor(torch.Size([10, 11, 512]), dtype=torch.float32),
prediction: Tensor(torch.Size([10, 11, 512]), dtype=torch.float32),
target: Tensor(torch.Size([10, 11, 512]), dtype=torch.float32)},
batch_size=torch.Size([10, 11]),
device=None,
is_shared=False)
不一定需要使用 TensorDictModule
,具有輸入和輸出鍵(分別命名為 module.in_keys
和 module.out_keys
)的自訂 torch.nn.Module
即可。
多個 PyTorch 使用者的一個主要痛點是 nn.Sequential 無法處理具有多個輸入的模組。使用基於鍵的圖可以輕鬆解決該問題,因為序列中的每個節點都知道需要讀取哪些資料以及將資料寫入何處。
為此,我們提供了 TensorDictSequential 類別,它透過一系列 TensorDictModules 傳遞資料。序列中的每個模組從原始 TensorDict 獲取其輸入並將其輸出寫入其中,這意味著序列中的模組可以忽略其前任的輸出,或根據需要從 tensordict 獲取額外輸入。這是一個範例
>>> from tensordict.nn import TensorDictSequential
>>> class Net(nn.Module):
... def __init__(self, input_size=100, hidden_size=50, output_size=10):
... super().__init__()
... self.fc1 = nn.Linear(input_size, hidden_size)
... self.fc2 = nn.Linear(hidden_size, output_size)
...
... def forward(self, x):
... x = torch.relu(self.fc1(x))
... return self.fc2(x)
...
>>> class Masker(nn.Module):
... def forward(self, x, mask):
... return torch.softmax(x * mask, dim=1)
...
>>> net = TensorDictModule(
... Net(), in_keys=[("input", "x")], out_keys=[("intermediate", "x")]
... )
>>> masker = TensorDictModule(
... Masker(),
... in_keys=[("intermediate", "x"), ("input", "mask")],
... out_keys=[("output", "probabilities")],
... )
>>> module = TensorDictSequential(net, masker)
>>>
>>> td = TensorDict(
... {
... "input": TensorDict(
... {"x": torch.rand(32, 100), "mask": torch.randint(2, size=(32, 10))},
... batch_size=[32],
... )
... },
... batch_size=[32],
... )
>>> td = module(td)
>>> print(td)
TensorDict(
fields={
input: TensorDict(
fields={
mask: Tensor(torch.Size([32, 10]), dtype=torch.int64),
x: Tensor(torch.Size([32, 100]), dtype=torch.float32)},
batch_size=torch.Size([32]),
device=None,
is_shared=False),
intermediate: TensorDict(
fields={
x: Tensor(torch.Size([32, 10]), dtype=torch.float32)},
batch_size=torch.Size([32]),
device=None,
is_shared=False),
output: TensorDict(
fields={
probabilities: Tensor(torch.Size([32, 10]), dtype=torch.float32)},
batch_size=torch.Size([32]),
device=None,
is_shared=False)},
batch_size=torch.Size([32]),
device=None,
is_shared=False)
我們也可以透過 select_subsequence()
方法輕鬆選擇子圖
>>> sub_module = module.select_subsequence(out_keys=[("intermediate", "x")])
>>> td = TensorDict(
... {
... "input": TensorDict(
... {"x": torch.rand(32, 100), "mask": torch.randint(2, size=(32, 10))},
... batch_size=[32],
... )
... },
... batch_size=[32],
... )
>>> sub_module(td)
>>> print(td) # the "output" has not been computed
TensorDict(
fields={
input: TensorDict(
fields={
mask: Tensor(torch.Size([32, 10]), dtype=torch.int64),
x: Tensor(torch.Size([32, 100]), dtype=torch.float32)},
batch_size=torch.Size([32]),
device=None,
is_shared=False),
intermediate: TensorDict(
fields={
x: Tensor(torch.Size([32, 10]), dtype=torch.float32)},
batch_size=torch.Size([32]),
device=None,
is_shared=False)},
batch_size=torch.Size([32]),
device=None,
is_shared=False)
最後,tensordict.nn
帶有一個 ProbabilisticTensorDictModule
,它允許從網路輸出建構分佈,並從中取得摘要統計資訊或樣本(以及分佈參數)
>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule
>>> from tensordict.nn.distributions import NormalParamWrapper
>>> from tensordict.nn.prototype import (
... ProbabilisticTensorDictModule,
... ProbabilisticTensorDictSequential,
... )
>>> from torch.distributions import Normal
>>> td = TensorDict(
... {"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3]
... )
>>> net = torch.nn.GRUCell(4, 8)
>>> module = TensorDictModule(
... NormalParamWrapper(net), in_keys=["input", "hidden"], out_keys=["loc", "scale"]
... )
>>> prob_module = ProbabilisticTensorDictModule(
... in_keys=["loc", "scale"],
... out_keys=["sample"],
... distribution_class=Normal,
... return_log_prob=True,
... )
>>> td_module = ProbabilisticTensorDictSequential(module, prob_module)
>>> td_module(td)
>>> print(td)
TensorDict(
fields={
action: Tensor(torch.Size([3, 4]), dtype=torch.float32),
hidden: Tensor(torch.Size([3, 8]), dtype=torch.float32),
input: Tensor(torch.Size([3, 4]), dtype=torch.float32),
loc: Tensor(torch.Size([3, 4]), dtype=torch.float32),
sample_log_prob: Tensor(torch.Size([3, 4]), dtype=torch.float32),
scale: Tensor(torch.Size([3, 4]), dtype=torch.float32)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)
|
TensorDict 模組的基底類別。 |
|
TensorDictModule 是一個 Python 封裝器,用於包裝讀取和寫入 TensorDict 的 |
|
概率 TD 模組。 |
|
TensorDictModules 的序列。 |
|
TensorDictModule 物件的包裝類別。 |
|
用於 PyTorch 可呼叫物件的 cudagraph 包裝器。 |
Ensembles¶
函數式方法能夠實現簡單明瞭的集成實例。 我們可以使用 tensordict.nn.EnsembleModule
複製和重新初始化模型副本。
>>> import torch
>>> from torch import nn
>>> from tensordict.nn import TensorDictModule
>>> from torchrl.modules import EnsembleModule
>>> from tensordict import TensorDict
>>> net = nn.Sequential(nn.Linear(4, 32), nn.ReLU(), nn.Linear(32, 2))
>>> mod = TensorDictModule(net, in_keys=['a'], out_keys=['b'])
>>> ensemble = EnsembleModule(mod, num_copies=3)
>>> data = TensorDict({'a': torch.randn(10, 4)}, batch_size=[10])
>>> ensemble(data)
TensorDict(
fields={
a: Tensor(shape=torch.Size([3, 10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([3, 10, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3, 10]),
device=None,
is_shared=False)
|
模組封裝一個模組並重複它以形成一個集成模型。 |
編譯 TensorDictModules¶
從 v0.5 開始,TensorDict 組件與 compile()
兼容。 例如,TensorDictSequential
模組可以使用 torch.compile
進行編譯,並達到與封裝在 TensorDictModule
中的常規 PyTorch 模組相似的運行時。
分布¶
|
一個非參數化的 nn.Module,將其輸入分割為 loc 和 scale 參數。 |
一個 nn.Module,它添加可訓練的與狀態無關的比例參數。 |
|
|
分布的組合。 |
|
Delta 分布。 |
|
One-hot 分類分布。 |
|
截斷常態分佈。 |
工具¶
|
返回從關鍵字參數或輸入字典創建的 TensorDict。 |
|
允許使用 kwargs 調用期望 TensorDict 的函數。 |
|
將所有 ProbabilisticTDModules 抽樣設置為所需的類型。 |
|
反向 softplus 函數。 |
|
一個有偏差的 softplus 模組。 |
|
一個用於跳過 TensorDict 圖中現有節點的上下文管理器。 |
返回是否應由模組重新計算 tensordict 中的現有條目。 |