注意
前往結尾下載完整的範例程式碼。
匯出 tensordict 模組¶
作者: Vincent Moens
先決條件¶
建議先閱讀 TensorDictModule 教學課程,以充分受益於本教學課程。
一旦使用 tensordict.nn
撰寫了模組,通常需要隔離計算圖並匯出該圖。 這樣做的目的可能是為了在硬體(例如,機器人、無人機、邊緣裝置)上執行模型,或完全消除對 tensordict 的依賴。
PyTorch 提供了多種匯出模組的方法,包括 onnx
和 torch.export
,兩者都與 tensordict
相容。
在這個簡短的教學課程中,我們將了解如何使用 torch.export
來隔離模型的計算圖。torch.onnx
支援遵循相同的邏輯。
主要學習內容¶
在沒有
TensorDict
輸入的情況下執行tensordict.nn
模組;選擇模型的輸出;
處理隨機模型;
使用 torch.export 匯出此類模型;
將模型儲存到檔案;
隔離 pytorch 模型;
import time
import torch
from tensordict.nn import (
InteractionType,
NormalParamExtractor,
ProbabilisticTensorDictModule as Prob,
set_interaction_type,
TensorDictModule as Mod,
TensorDictSequential as Seq,
)
from torch import distributions as dists, nn
設計模型¶
在許多應用中,使用隨機模型非常有用,也就是說,模型輸出的變數不是確定性定義的,而是根據參數化分佈進行抽樣的。 例如,生成式 AI 模型在提供相同輸入時通常會產生不同的輸出,因為它們根據分佈對輸出進行抽樣,而分佈的參數由輸入定義。
tensordict
庫透過 ProbabilisticTensorDictModule
類別來處理這個問題。 這個基本元素是使用分佈類別(在我們的例子中是 Normal
)和輸入鍵的指標來建構的,這些輸入鍵將在執行時用於建立該分佈。
因此,我們正在建構的網路將是三個主要元件的組合
將輸入對應到潛在參數的網路;
一個
tensordict.nn.NormalParamExtractor
模組,將輸入分割成位置 “loc” 和 “scale” 參數,以傳遞到Normal
分佈;一個分佈建構函式模組。
model = Seq(
# 1. A small network for embedding
Mod(nn.Linear(3, 4), in_keys=["x"], out_keys=["hidden"]),
Mod(nn.ReLU(), in_keys=["hidden"], out_keys=["hidden"]),
Mod(nn.Linear(4, 4), in_keys=["hidden"], out_keys=["latent"]),
# 2. Extracting params
Mod(NormalParamExtractor(), in_keys=["latent"], out_keys=["loc", "scale"]),
# 3. Probabilistic module
Prob(
in_keys=["loc", "scale"],
out_keys=["sample"],
distribution_class=dists.Normal,
),
)
讓我們執行這個模型,看看輸出是什麼樣子
x = torch.randn(1, 3)
print(model(x=x))
(tensor([[0.0000, 0.2604, 0.0000, 0.0000]], grad_fn=<ReluBackward0>), tensor([[-0.1580, -0.5222, -0.3319, 0.5519]], grad_fn=<AddmmBackward0>), tensor([[-0.1580, -0.5222]], grad_fn=<SplitBackward0>), tensor([[0.8046, 1.3804]], grad_fn=<ClampMinBackward0>), tensor([[-0.1580, -0.5222]], grad_fn=<SplitBackward0>))
正如預期的那樣,使用張量輸入執行模型會傳回與模組輸出鍵一樣多的張量! 對於大型模型,這可能會非常煩人且浪費。 稍後,我們將了解如何限制模型的輸出數量來解決這個問題。
將 torch.export
與 TensorDictModule
搭配使用¶
現在我們已經成功建構了我們的模型,我們希望將其計算圖提取到一個獨立於 tensordict
的單一物件中。torch.export
是一個 PyTorch 模組,專門用於隔離模組的圖並以標準化的方式表示它。 它的主要進入點是 export()
,它會傳回一個 ExportedProgram
物件。 反過來,這個物件有幾個我們將在下面探討的感興趣的屬性:一個 graph_module
,表示由 export
捕獲的 FX 圖,一個具有圖的輸入、輸出等的 graph_signature
,最後是一個 module()
,它傳回一個可以代替原始模組使用的可呼叫物件。
雖然我們的模組接受 args 和 kwargs,但我們將重點放在使用 kwargs,因為這樣更清楚。
from torch.export import export
model_export = export(model, args=(), kwargs={"x": x})
讓我們看看這個模組
print("module:", model_export.module())
module: GraphModule(
(module): Module(
(0): Module(
(module): Module()
)
(2): Module(
(module): Module()
)
)
)
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([], {'x':x}), self._in_spec)
module_0_module_weight = getattr(self.module, "0").module.weight
module_0_module_bias = getattr(self.module, "0").module.bias
module_2_module_weight = getattr(self.module, "2").module.weight
module_2_module_bias = getattr(self.module, "2").module.bias
linear = torch.ops.aten.linear.default(x, module_0_module_weight, module_0_module_bias); x = module_0_module_weight = module_0_module_bias = None
relu = torch.ops.aten.relu.default(linear); linear = None
linear_1 = torch.ops.aten.linear.default(relu, module_2_module_weight, module_2_module_bias); module_2_module_weight = module_2_module_bias = None
split = torch.ops.aten.split.Tensor(linear_1, 2, -1)
getitem = split[0]
getitem_1 = split[1]; split = None
add = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335); getitem_1 = None
softplus = torch.ops.aten.softplus.default(add); add = None
add_1 = torch.ops.aten.add.Tensor(softplus, 0.01); softplus = None
clamp_min = torch.ops.aten.clamp_min.default(add_1, 0.0001); add_1 = None
broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]); getitem = clamp_min = None
getitem_2 = broadcast_tensors[0]
getitem_3 = broadcast_tensors[1]; broadcast_tensors = None
return pytree.tree_unflatten((relu, linear_1, getitem_2, getitem_3, getitem_2), self._out_spec)
# To see more debug info, please use `graph_module.print_readable()`
這個模組的執行方式與我們原來的模組完全相同(但開銷更低)
Time for TDModule: 469.45 micro-seconds
Time for exported module: 340.70 micro-seconds
以及 FX 圖形
print("fx graph:", model_export.graph_module.print_readable())
class GraphModule(torch.nn.Module):
def forward(self, p_l__args___0_module_0_module_weight: "f32[4, 3]", p_l__args___0_module_0_module_bias: "f32[4]", p_l__args___0_module_2_module_weight: "f32[4, 4]", p_l__args___0_module_2_module_bias: "f32[4]", x: "f32[1, 3]"):
# File: /pytorch/tensordict/tensordict/nn/common.py:1010 in _call_module, code: out = self.module(*tensors, **kwargs)
linear: "f32[1, 4]" = torch.ops.aten.linear.default(x, p_l__args___0_module_0_module_weight, p_l__args___0_module_0_module_bias); x = p_l__args___0_module_0_module_weight = p_l__args___0_module_0_module_bias = None
relu: "f32[1, 4]" = torch.ops.aten.relu.default(linear); linear = None
linear_1: "f32[1, 4]" = torch.ops.aten.linear.default(relu, p_l__args___0_module_2_module_weight, p_l__args___0_module_2_module_bias); p_l__args___0_module_2_module_weight = p_l__args___0_module_2_module_bias = None
# File: /pytorch/tensordict/tensordict/nn/distributions/continuous.py:129 in forward, code: loc, scale = tensor.chunk(2, -1)
split = torch.ops.aten.split.Tensor(linear_1, 2, -1)
getitem: "f32[1, 2]" = split[0]
getitem_1: "f32[1, 2]" = split[1]; split = None
# File: /pytorch/tensordict/tensordict/nn/utils.py:68 in forward, code: return torch.nn.functional.softplus(x + self.bias) + self.min_val
add: "f32[1, 2]" = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335); getitem_1 = None
softplus: "f32[1, 2]" = torch.ops.aten.softplus.default(add); add = None
add_1: "f32[1, 2]" = torch.ops.aten.add.Tensor(softplus, 0.01); softplus = None
# File: /pytorch/tensordict/tensordict/nn/distributions/continuous.py:130 in forward, code: scale = self.scale_mapping(scale).clamp_min(self.scale_lb)
clamp_min: "f32[1, 2]" = torch.ops.aten.clamp_min.default(add_1, 0.0001); add_1 = None
# File: /pytorch/tensordict/env/lib/python3.10/site-packages/torch/distributions/utils.py:55 in broadcast_all, code: return torch.broadcast_tensors(*values)
broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]); getitem = clamp_min = None
getitem_2: "f32[1, 2]" = broadcast_tensors[0]
getitem_3: "f32[1, 2]" = broadcast_tensors[1]; broadcast_tensors = None
return (relu, linear_1, getitem_2, getitem_3, getitem_2)
fx graph: class GraphModule(torch.nn.Module):
def forward(self, p_l__args___0_module_0_module_weight: "f32[4, 3]", p_l__args___0_module_0_module_bias: "f32[4]", p_l__args___0_module_2_module_weight: "f32[4, 4]", p_l__args___0_module_2_module_bias: "f32[4]", x: "f32[1, 3]"):
# File: /pytorch/tensordict/tensordict/nn/common.py:1010 in _call_module, code: out = self.module(*tensors, **kwargs)
linear: "f32[1, 4]" = torch.ops.aten.linear.default(x, p_l__args___0_module_0_module_weight, p_l__args___0_module_0_module_bias); x = p_l__args___0_module_0_module_weight = p_l__args___0_module_0_module_bias = None
relu: "f32[1, 4]" = torch.ops.aten.relu.default(linear); linear = None
linear_1: "f32[1, 4]" = torch.ops.aten.linear.default(relu, p_l__args___0_module_2_module_weight, p_l__args___0_module_2_module_bias); p_l__args___0_module_2_module_weight = p_l__args___0_module_2_module_bias = None
# File: /pytorch/tensordict/tensordict/nn/distributions/continuous.py:129 in forward, code: loc, scale = tensor.chunk(2, -1)
split = torch.ops.aten.split.Tensor(linear_1, 2, -1)
getitem: "f32[1, 2]" = split[0]
getitem_1: "f32[1, 2]" = split[1]; split = None
# File: /pytorch/tensordict/tensordict/nn/utils.py:68 in forward, code: return torch.nn.functional.softplus(x + self.bias) + self.min_val
add: "f32[1, 2]" = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335); getitem_1 = None
softplus: "f32[1, 2]" = torch.ops.aten.softplus.default(add); add = None
add_1: "f32[1, 2]" = torch.ops.aten.add.Tensor(softplus, 0.01); softplus = None
# File: /pytorch/tensordict/tensordict/nn/distributions/continuous.py:130 in forward, code: scale = self.scale_mapping(scale).clamp_min(self.scale_lb)
clamp_min: "f32[1, 2]" = torch.ops.aten.clamp_min.default(add_1, 0.0001); add_1 = None
# File: /pytorch/tensordict/env/lib/python3.10/site-packages/torch/distributions/utils.py:55 in broadcast_all, code: return torch.broadcast_tensors(*values)
broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]); getitem = clamp_min = None
getitem_2: "f32[1, 2]" = broadcast_tensors[0]
getitem_3: "f32[1, 2]" = broadcast_tensors[1]; broadcast_tensors = None
return (relu, linear_1, getitem_2, getitem_3, getitem_2)
處理巢狀鍵¶
巢狀鍵是 tensordict 函式庫的核心功能,因此能夠匯出讀取和寫入巢狀條目的模組是一項重要的支援功能。由於關鍵字引數必須是常規字串,因此 dispatch
無法直接使用它們。相反,dispatch
將會解壓縮以常規底線(“_”)連接的巢狀鍵,如下面的範例所示。
model_nested = Seq(
Mod(lambda x: x + 1, in_keys=[("some", "key")], out_keys=["hidden"]),
Mod(lambda x: x - 1, in_keys=["hidden"], out_keys=[("some", "output")]),
).select_out_keys(("some", "output"))
model_nested_export = export(model_nested, args=(), kwargs={"some_key": x})
print("exported module with nested input:", model_nested_export.module())
exported module with nested input: GraphModule()
def forward(self, some_key):
some_key, = fx_pytree.tree_flatten_spec(([], {'some_key':some_key}), self._in_spec)
add = torch.ops.aten.add.Tensor(some_key, 1); some_key = None
sub = torch.ops.aten.sub.Tensor(add, 1); add = None
return pytree.tree_unflatten((sub,), self._out_spec)
# To see more debug info, please use `graph_module.print_readable()`
請注意,module() 傳回的可呼叫物件是一個純 Python 可呼叫物件,可以使用 compile()
進行編譯。
儲存匯出的模組¶
torch.export
有其自身的序列化協定,save()
和 load()
。按照慣例,應使用 “.pt2” 副檔名
>>> torch.export.save(model_export, "model.pt2")
選擇輸出¶
回想一下,tensordict.nn
會將每個中間值保留在輸出中,除非使用者特別要求僅保留特定值。在訓練期間,這非常有用:可以輕鬆記錄圖形的中間值,或將它們用於其他目的(例如,根據其儲存的參數重建分佈,而不是儲存 Distribution
物件本身)。還可以說,在訓練期間,註冊中間值對記憶體的影響可以忽略不計,因為它們是 torch.autograd
用於計算參數梯度之計算圖形的一部分。
但是,在推論期間,我們最有可能只對模型的最終樣本感興趣。因為我們想要提取模型以用於獨立於 tensordict
函式庫的用途,所以隔離我們想要的唯一輸出是有意義的。為此,我們有幾種選擇
使用
selected_out_keys
關鍵字引數建構TensorDictSequential()
,這將導致在呼叫模組期間選擇所需的條目;使用
select_out_keys()
方法,該方法將就地修改out_keys
屬性(可以透過reset_out_keys()
恢復)。將現有實例包裝在
TensorDictSequential()
中,該實例將過濾掉不需要的鍵>>> module_filtered = Seq(module, selected_out_keys=["sample"])
讓我們在選擇其輸出鍵後測試模型。當提供 x 輸入時,我們期望我們的模型輸出一個與分佈樣本對應的單一張量
tensor([[-0.1580, -0.5222]], grad_fn=<SplitBackward0>)
我們看到輸出現在是一個單一張量,對應於分佈的樣本。我們可以從此建立一個新的匯出圖形。它的計算圖形應該被簡化
model_export = export(model, args=(), kwargs={"x": x})
print("module:", model_export.module())
module: GraphModule(
(module): Module(
(0): Module(
(module): Module()
)
(2): Module(
(module): Module()
)
)
)
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([], {'x':x}), self._in_spec)
module_0_module_weight = getattr(self.module, "0").module.weight
module_0_module_bias = getattr(self.module, "0").module.bias
module_2_module_weight = getattr(self.module, "2").module.weight
module_2_module_bias = getattr(self.module, "2").module.bias
linear = torch.ops.aten.linear.default(x, module_0_module_weight, module_0_module_bias); x = module_0_module_weight = module_0_module_bias = None
relu = torch.ops.aten.relu.default(linear); linear = None
linear_1 = torch.ops.aten.linear.default(relu, module_2_module_weight, module_2_module_bias); relu = module_2_module_weight = module_2_module_bias = None
split = torch.ops.aten.split.Tensor(linear_1, 2, -1); linear_1 = None
getitem = split[0]
getitem_1 = split[1]; split = None
add = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335); getitem_1 = None
softplus = torch.ops.aten.softplus.default(add); add = None
add_1 = torch.ops.aten.add.Tensor(softplus, 0.01); softplus = None
clamp_min = torch.ops.aten.clamp_min.default(add_1, 0.0001); add_1 = None
broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]); getitem = clamp_min = None
getitem_2 = broadcast_tensors[0]; broadcast_tensors = None
return pytree.tree_unflatten((getitem_2,), self._out_spec)
# To see more debug info, please use `graph_module.print_readable()`
控制取樣策略¶
我們尚未討論 ProbabilisticTensorDictModule
如何從分佈中取樣。透過取樣,我們指的是根據特定策略在分佈定義的空間內取得一個值。例如,可能希望在訓練期間獲得隨機樣本,但在推論期間獲得確定性樣本(例如,平均值或眾數)。為了應對這個問題,tensordict
使用 set_interaction_type
裝飾器和上下文管理器,它接受 InteractionType
列舉輸入
>>> with set_interaction_type(InteractionType.MEAN):
... output = module(input) # takes the input of the distribution, if ProbabilisticTensorDictModule is invoked
預設的 InteractionType
是 InteractionType.DETERMINISTIC
,如果未直接實現,則對於具有實數域的分佈是平均值,對於具有離散域的分佈是眾數。可以使用 ProbabilisticTensorDictModule
的 default_interaction_type
關鍵字引數來變更此預設值。
讓我們回顧一下:為了控制我們網路的取樣策略,我們可以在建構函式中定義一個預設的取樣策略,或者在執行時透過 set_interaction_type
上下文管理器覆蓋它。
正如我們從以下範例中看到的那樣,torch.export
正確回應了裝飾器的使用:如果我們要求一個隨機樣本,則輸出與我們要求平均值時不同
with set_interaction_type(InteractionType.RANDOM):
model_export = export(model, args=(), kwargs={"x": x})
print(model_export.module())
with set_interaction_type(InteractionType.MEAN):
model_export = export(model, args=(), kwargs={"x": x})
print(model_export.module())
GraphModule(
(module): Module(
(0): Module(
(module): Module()
)
(2): Module(
(module): Module()
)
)
)
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([], {'x':x}), self._in_spec)
module_0_module_weight = getattr(self.module, "0").module.weight
module_0_module_bias = getattr(self.module, "0").module.bias
module_2_module_weight = getattr(self.module, "2").module.weight
module_2_module_bias = getattr(self.module, "2").module.bias
linear = torch.ops.aten.linear.default(x, module_0_module_weight, module_0_module_bias); x = module_0_module_weight = module_0_module_bias = None
relu = torch.ops.aten.relu.default(linear); linear = None
linear_1 = torch.ops.aten.linear.default(relu, module_2_module_weight, module_2_module_bias); relu = module_2_module_weight = module_2_module_bias = None
split = torch.ops.aten.split.Tensor(linear_1, 2, -1); linear_1 = None
getitem = split[0]
getitem_1 = split[1]; split = None
add = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335); getitem_1 = None
softplus = torch.ops.aten.softplus.default(add); add = None
add_1 = torch.ops.aten.add.Tensor(softplus, 0.01); softplus = None
clamp_min = torch.ops.aten.clamp_min.default(add_1, 0.0001); add_1 = None
broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]); getitem = clamp_min = None
getitem_2 = broadcast_tensors[0]
getitem_3 = broadcast_tensors[1]; broadcast_tensors = None
empty = torch.ops.aten.empty.memory_format([1, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
normal_functional = torch.ops.aten.normal_functional.default(empty); empty = None
mul = torch.ops.aten.mul.Tensor(normal_functional, getitem_3); normal_functional = getitem_3 = None
add_2 = torch.ops.aten.add.Tensor(getitem_2, mul); getitem_2 = mul = None
return pytree.tree_unflatten((add_2,), self._out_spec)
# To see more debug info, please use `graph_module.print_readable()`
GraphModule(
(module): Module(
(0): Module(
(module): Module()
)
(2): Module(
(module): Module()
)
)
)
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([], {'x':x}), self._in_spec)
module_0_module_weight = getattr(self.module, "0").module.weight
module_0_module_bias = getattr(self.module, "0").module.bias
module_2_module_weight = getattr(self.module, "2").module.weight
module_2_module_bias = getattr(self.module, "2").module.bias
linear = torch.ops.aten.linear.default(x, module_0_module_weight, module_0_module_bias); x = module_0_module_weight = module_0_module_bias = None
relu = torch.ops.aten.relu.default(linear); linear = None
linear_1 = torch.ops.aten.linear.default(relu, module_2_module_weight, module_2_module_bias); relu = module_2_module_weight = module_2_module_bias = None
split = torch.ops.aten.split.Tensor(linear_1, 2, -1); linear_1 = None
getitem = split[0]
getitem_1 = split[1]; split = None
add = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335); getitem_1 = None
softplus = torch.ops.aten.softplus.default(add); add = None
add_1 = torch.ops.aten.add.Tensor(softplus, 0.01); softplus = None
clamp_min = torch.ops.aten.clamp_min.default(add_1, 0.0001); add_1 = None
broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]); getitem = clamp_min = None
getitem_2 = broadcast_tensors[0]; broadcast_tensors = None
return pytree.tree_unflatten((getitem_2,), self._out_spec)
# To see more debug info, please use `graph_module.print_readable()`
這就是您使用 torch.export
所需了解的全部內容。有關更多資訊,請參閱 官方文件。
後續步驟和延伸閱讀¶
查看
torch.export
教程,網址為 這裡;ONNX 支援:查看 ONNX 教程,以了解有關此功能的更多資訊。匯出到 ONNX 與此處解釋的 torch.export 非常相似。
若要在沒有 Python 環境的伺服器上部署 PyTorch 程式碼,請參閱 AOTInductor 文件。
腳本總執行時間: (0 分鐘 1.695 秒)