• 文件 >
  • 匯出 tensordict 模組
捷徑

匯出 tensordict 模組

作者: Vincent Moens

先決條件

建議先閱讀 TensorDictModule 教學課程,以充分受益於本教學課程。

一旦使用 tensordict.nn 撰寫了模組,通常需要隔離計算圖並匯出該圖。 這樣做的目的可能是為了在硬體(例如,機器人、無人機、邊緣裝置)上執行模型,或完全消除對 tensordict 的依賴。

PyTorch 提供了多種匯出模組的方法,包括 onnxtorch.export,兩者都與 tensordict 相容。

在這個簡短的教學課程中,我們將了解如何使用 torch.export 來隔離模型的計算圖。torch.onnx 支援遵循相同的邏輯。

主要學習內容

  • 在沒有 TensorDict 輸入的情況下執行 tensordict.nn 模組;

  • 選擇模型的輸出;

  • 處理隨機模型;

  • 使用 torch.export 匯出此類模型;

  • 將模型儲存到檔案;

  • 隔離 pytorch 模型;

import time

import torch
from tensordict.nn import (
    InteractionType,
    NormalParamExtractor,
    ProbabilisticTensorDictModule as Prob,
    set_interaction_type,
    TensorDictModule as Mod,
    TensorDictSequential as Seq,
)
from torch import distributions as dists, nn

設計模型

在許多應用中,使用隨機模型非常有用,也就是說,模型輸出的變數不是確定性定義的,而是根據參數化分佈進行抽樣的。 例如,生成式 AI 模型在提供相同輸入時通常會產生不同的輸出,因為它們根據分佈對輸出進行抽樣,而分佈的參數由輸入定義。

tensordict 庫透過 ProbabilisticTensorDictModule 類別來處理這個問題。 這個基本元素是使用分佈類別(在我們的例子中是 Normal)和輸入鍵的指標來建構的,這些輸入鍵將在執行時用於建立該分佈。

因此,我們正在建構的網路將是三個主要元件的組合

  • 將輸入對應到潛在參數的網路;

  • 一個 tensordict.nn.NormalParamExtractor 模組,將輸入分割成位置 “loc”“scale” 參數,以傳遞到 Normal 分佈;

  • 一個分佈建構函式模組。

model = Seq(
    # 1. A small network for embedding
    Mod(nn.Linear(3, 4), in_keys=["x"], out_keys=["hidden"]),
    Mod(nn.ReLU(), in_keys=["hidden"], out_keys=["hidden"]),
    Mod(nn.Linear(4, 4), in_keys=["hidden"], out_keys=["latent"]),
    # 2. Extracting params
    Mod(NormalParamExtractor(), in_keys=["latent"], out_keys=["loc", "scale"]),
    # 3. Probabilistic module
    Prob(
        in_keys=["loc", "scale"],
        out_keys=["sample"],
        distribution_class=dists.Normal,
    ),
)

讓我們執行這個模型,看看輸出是什麼樣子

x = torch.randn(1, 3)
print(model(x=x))
(tensor([[0.0000, 0.2604, 0.0000, 0.0000]], grad_fn=<ReluBackward0>), tensor([[-0.1580, -0.5222, -0.3319,  0.5519]], grad_fn=<AddmmBackward0>), tensor([[-0.1580, -0.5222]], grad_fn=<SplitBackward0>), tensor([[0.8046, 1.3804]], grad_fn=<ClampMinBackward0>), tensor([[-0.1580, -0.5222]], grad_fn=<SplitBackward0>))

正如預期的那樣,使用張量輸入執行模型會傳回與模組輸出鍵一樣多的張量! 對於大型模型,這可能會非常煩人且浪費。 稍後,我們將了解如何限制模型的輸出數量來解決這個問題。

torch.exportTensorDictModule 搭配使用

現在我們已經成功建構了我們的模型,我們希望將其計算圖提取到一個獨立於 tensordict 的單一物件中。torch.export 是一個 PyTorch 模組,專門用於隔離模組的圖並以標準化的方式表示它。 它的主要進入點是 export(),它會傳回一個 ExportedProgram 物件。 反過來,這個物件有幾個我們將在下面探討的感興趣的屬性:一個 graph_module,表示由 export 捕獲的 FX 圖,一個具有圖的輸入、輸出等的 graph_signature,最後是一個 module(),它傳回一個可以代替原始模組使用的可呼叫物件。

雖然我們的模組接受 args 和 kwargs,但我們將重點放在使用 kwargs,因為這樣更清楚。

from torch.export import export

model_export = export(model, args=(), kwargs={"x": x})

讓我們看看這個模組

print("module:", model_export.module())
module: GraphModule(
  (module): Module(
    (0): Module(
      (module): Module()
    )
    (2): Module(
      (module): Module()
    )
  )
)



def forward(self, x):
    x, = fx_pytree.tree_flatten_spec(([], {'x':x}), self._in_spec)
    module_0_module_weight = getattr(self.module, "0").module.weight
    module_0_module_bias = getattr(self.module, "0").module.bias
    module_2_module_weight = getattr(self.module, "2").module.weight
    module_2_module_bias = getattr(self.module, "2").module.bias
    linear = torch.ops.aten.linear.default(x, module_0_module_weight, module_0_module_bias);  x = module_0_module_weight = module_0_module_bias = None
    relu = torch.ops.aten.relu.default(linear);  linear = None
    linear_1 = torch.ops.aten.linear.default(relu, module_2_module_weight, module_2_module_bias);  module_2_module_weight = module_2_module_bias = None
    split = torch.ops.aten.split.Tensor(linear_1, 2, -1)
    getitem = split[0]
    getitem_1 = split[1];  split = None
    add = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335);  getitem_1 = None
    softplus = torch.ops.aten.softplus.default(add);  add = None
    add_1 = torch.ops.aten.add.Tensor(softplus, 0.01);  softplus = None
    clamp_min = torch.ops.aten.clamp_min.default(add_1, 0.0001);  add_1 = None
    broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]);  getitem = clamp_min = None
    getitem_2 = broadcast_tensors[0]
    getitem_3 = broadcast_tensors[1];  broadcast_tensors = None
    return pytree.tree_unflatten((relu, linear_1, getitem_2, getitem_3, getitem_2), self._out_spec)

# To see more debug info, please use `graph_module.print_readable()`

這個模組的執行方式與我們原來的模組完全相同(但開銷更低)

t0 = time.time()
model(x=x)
print(f"Time for TDModule: {(time.time()-t0)*1e6: 4.2f} micro-seconds")
exported = model_export.module()

# Exported version
t0 = time.time()
exported(x=x)
print(f"Time for exported module: {(time.time()-t0)*1e6: 4.2f} micro-seconds")
Time for TDModule:  469.45 micro-seconds
Time for exported module:  340.70 micro-seconds

以及 FX 圖形

print("fx graph:", model_export.graph_module.print_readable())
class GraphModule(torch.nn.Module):
    def forward(self, p_l__args___0_module_0_module_weight: "f32[4, 3]", p_l__args___0_module_0_module_bias: "f32[4]", p_l__args___0_module_2_module_weight: "f32[4, 4]", p_l__args___0_module_2_module_bias: "f32[4]", x: "f32[1, 3]"):
         # File: /pytorch/tensordict/tensordict/nn/common.py:1010 in _call_module, code: out = self.module(*tensors, **kwargs)
        linear: "f32[1, 4]" = torch.ops.aten.linear.default(x, p_l__args___0_module_0_module_weight, p_l__args___0_module_0_module_bias);  x = p_l__args___0_module_0_module_weight = p_l__args___0_module_0_module_bias = None
        relu: "f32[1, 4]" = torch.ops.aten.relu.default(linear);  linear = None
        linear_1: "f32[1, 4]" = torch.ops.aten.linear.default(relu, p_l__args___0_module_2_module_weight, p_l__args___0_module_2_module_bias);  p_l__args___0_module_2_module_weight = p_l__args___0_module_2_module_bias = None

         # File: /pytorch/tensordict/tensordict/nn/distributions/continuous.py:129 in forward, code: loc, scale = tensor.chunk(2, -1)
        split = torch.ops.aten.split.Tensor(linear_1, 2, -1)
        getitem: "f32[1, 2]" = split[0]
        getitem_1: "f32[1, 2]" = split[1];  split = None

         # File: /pytorch/tensordict/tensordict/nn/utils.py:68 in forward, code: return torch.nn.functional.softplus(x + self.bias) + self.min_val
        add: "f32[1, 2]" = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335);  getitem_1 = None
        softplus: "f32[1, 2]" = torch.ops.aten.softplus.default(add);  add = None
        add_1: "f32[1, 2]" = torch.ops.aten.add.Tensor(softplus, 0.01);  softplus = None

         # File: /pytorch/tensordict/tensordict/nn/distributions/continuous.py:130 in forward, code: scale = self.scale_mapping(scale).clamp_min(self.scale_lb)
        clamp_min: "f32[1, 2]" = torch.ops.aten.clamp_min.default(add_1, 0.0001);  add_1 = None

         # File: /pytorch/tensordict/env/lib/python3.10/site-packages/torch/distributions/utils.py:55 in broadcast_all, code: return torch.broadcast_tensors(*values)
        broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]);  getitem = clamp_min = None
        getitem_2: "f32[1, 2]" = broadcast_tensors[0]
        getitem_3: "f32[1, 2]" = broadcast_tensors[1];  broadcast_tensors = None
        return (relu, linear_1, getitem_2, getitem_3, getitem_2)

fx graph: class GraphModule(torch.nn.Module):
    def forward(self, p_l__args___0_module_0_module_weight: "f32[4, 3]", p_l__args___0_module_0_module_bias: "f32[4]", p_l__args___0_module_2_module_weight: "f32[4, 4]", p_l__args___0_module_2_module_bias: "f32[4]", x: "f32[1, 3]"):
         # File: /pytorch/tensordict/tensordict/nn/common.py:1010 in _call_module, code: out = self.module(*tensors, **kwargs)
        linear: "f32[1, 4]" = torch.ops.aten.linear.default(x, p_l__args___0_module_0_module_weight, p_l__args___0_module_0_module_bias);  x = p_l__args___0_module_0_module_weight = p_l__args___0_module_0_module_bias = None
        relu: "f32[1, 4]" = torch.ops.aten.relu.default(linear);  linear = None
        linear_1: "f32[1, 4]" = torch.ops.aten.linear.default(relu, p_l__args___0_module_2_module_weight, p_l__args___0_module_2_module_bias);  p_l__args___0_module_2_module_weight = p_l__args___0_module_2_module_bias = None

         # File: /pytorch/tensordict/tensordict/nn/distributions/continuous.py:129 in forward, code: loc, scale = tensor.chunk(2, -1)
        split = torch.ops.aten.split.Tensor(linear_1, 2, -1)
        getitem: "f32[1, 2]" = split[0]
        getitem_1: "f32[1, 2]" = split[1];  split = None

         # File: /pytorch/tensordict/tensordict/nn/utils.py:68 in forward, code: return torch.nn.functional.softplus(x + self.bias) + self.min_val
        add: "f32[1, 2]" = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335);  getitem_1 = None
        softplus: "f32[1, 2]" = torch.ops.aten.softplus.default(add);  add = None
        add_1: "f32[1, 2]" = torch.ops.aten.add.Tensor(softplus, 0.01);  softplus = None

         # File: /pytorch/tensordict/tensordict/nn/distributions/continuous.py:130 in forward, code: scale = self.scale_mapping(scale).clamp_min(self.scale_lb)
        clamp_min: "f32[1, 2]" = torch.ops.aten.clamp_min.default(add_1, 0.0001);  add_1 = None

         # File: /pytorch/tensordict/env/lib/python3.10/site-packages/torch/distributions/utils.py:55 in broadcast_all, code: return torch.broadcast_tensors(*values)
        broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]);  getitem = clamp_min = None
        getitem_2: "f32[1, 2]" = broadcast_tensors[0]
        getitem_3: "f32[1, 2]" = broadcast_tensors[1];  broadcast_tensors = None
        return (relu, linear_1, getitem_2, getitem_3, getitem_2)

處理巢狀鍵

巢狀鍵是 tensordict 函式庫的核心功能,因此能夠匯出讀取和寫入巢狀條目的模組是一項重要的支援功能。由於關鍵字引數必須是常規字串,因此 dispatch 無法直接使用它們。相反,dispatch 將會解壓縮以常規底線(“_”)連接的巢狀鍵,如下面的範例所示。

model_nested = Seq(
    Mod(lambda x: x + 1, in_keys=[("some", "key")], out_keys=["hidden"]),
    Mod(lambda x: x - 1, in_keys=["hidden"], out_keys=[("some", "output")]),
).select_out_keys(("some", "output"))

model_nested_export = export(model_nested, args=(), kwargs={"some_key": x})
print("exported module with nested input:", model_nested_export.module())
exported module with nested input: GraphModule()



def forward(self, some_key):
    some_key, = fx_pytree.tree_flatten_spec(([], {'some_key':some_key}), self._in_spec)
    add = torch.ops.aten.add.Tensor(some_key, 1);  some_key = None
    sub = torch.ops.aten.sub.Tensor(add, 1);  add = None
    return pytree.tree_unflatten((sub,), self._out_spec)

# To see more debug info, please use `graph_module.print_readable()`

請注意,module() 傳回的可呼叫物件是一個純 Python 可呼叫物件,可以使用 compile() 進行編譯。

儲存匯出的模組

torch.export 有其自身的序列化協定,save()load()。按照慣例,應使用 “.pt2” 副檔名

>>> torch.export.save(model_export, "model.pt2")

選擇輸出

回想一下,tensordict.nn 會將每個中間值保留在輸出中,除非使用者特別要求僅保留特定值。在訓練期間,這非常有用:可以輕鬆記錄圖形的中間值,或將它們用於其他目的(例如,根據其儲存的參數重建分佈,而不是儲存 Distribution 物件本身)。還可以說,在訓練期間,註冊中間值對記憶體的影響可以忽略不計,因為它們是 torch.autograd 用於計算參數梯度之計算圖形的一部分。

但是,在推論期間,我們最有可能只對模型的最終樣本感興趣。因為我們想要提取模型以用於獨立於 tensordict 函式庫的用途,所以隔離我們想要的唯一輸出是有意義的。為此,我們有幾種選擇

  1. 使用 selected_out_keys 關鍵字引數建構 TensorDictSequential(),這將導致在呼叫模組期間選擇所需的條目;

  2. 使用 select_out_keys() 方法,該方法將就地修改 out_keys 屬性(可以透過 reset_out_keys() 恢復)。

  3. 將現有實例包裝在 TensorDictSequential() 中,該實例將過濾掉不需要的鍵

    >>> module_filtered = Seq(module, selected_out_keys=["sample"])
    

讓我們在選擇其輸出鍵後測試模型。當提供 x 輸入時,我們期望我們的模型輸出一個與分佈樣本對應的單一張量

model.select_out_keys("sample")
print(model(x=x))
tensor([[-0.1580, -0.5222]], grad_fn=<SplitBackward0>)

我們看到輸出現在是一個單一張量,對應於分佈的樣本。我們可以從此建立一個新的匯出圖形。它的計算圖形應該被簡化

model_export = export(model, args=(), kwargs={"x": x})
print("module:", model_export.module())
module: GraphModule(
  (module): Module(
    (0): Module(
      (module): Module()
    )
    (2): Module(
      (module): Module()
    )
  )
)



def forward(self, x):
    x, = fx_pytree.tree_flatten_spec(([], {'x':x}), self._in_spec)
    module_0_module_weight = getattr(self.module, "0").module.weight
    module_0_module_bias = getattr(self.module, "0").module.bias
    module_2_module_weight = getattr(self.module, "2").module.weight
    module_2_module_bias = getattr(self.module, "2").module.bias
    linear = torch.ops.aten.linear.default(x, module_0_module_weight, module_0_module_bias);  x = module_0_module_weight = module_0_module_bias = None
    relu = torch.ops.aten.relu.default(linear);  linear = None
    linear_1 = torch.ops.aten.linear.default(relu, module_2_module_weight, module_2_module_bias);  relu = module_2_module_weight = module_2_module_bias = None
    split = torch.ops.aten.split.Tensor(linear_1, 2, -1);  linear_1 = None
    getitem = split[0]
    getitem_1 = split[1];  split = None
    add = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335);  getitem_1 = None
    softplus = torch.ops.aten.softplus.default(add);  add = None
    add_1 = torch.ops.aten.add.Tensor(softplus, 0.01);  softplus = None
    clamp_min = torch.ops.aten.clamp_min.default(add_1, 0.0001);  add_1 = None
    broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]);  getitem = clamp_min = None
    getitem_2 = broadcast_tensors[0];  broadcast_tensors = None
    return pytree.tree_unflatten((getitem_2,), self._out_spec)

# To see more debug info, please use `graph_module.print_readable()`

控制取樣策略

我們尚未討論 ProbabilisticTensorDictModule 如何從分佈中取樣。透過取樣,我們指的是根據特定策略在分佈定義的空間內取得一個值。例如,可能希望在訓練期間獲得隨機樣本,但在推論期間獲得確定性樣本(例如,平均值或眾數)。為了應對這個問題,tensordict 使用 set_interaction_type 裝飾器和上下文管理器,它接受 InteractionType 列舉輸入

>>> with set_interaction_type(InteractionType.MEAN):
...     output = module(input)  # takes the input of the distribution, if ProbabilisticTensorDictModule is invoked

預設的 InteractionTypeInteractionType.DETERMINISTIC,如果未直接實現,則對於具有實數域的分佈是平均值,對於具有離散域的分佈是眾數。可以使用 ProbabilisticTensorDictModuledefault_interaction_type 關鍵字引數來變更此預設值。

讓我們回顧一下:為了控制我們網路的取樣策略,我們可以在建構函式中定義一個預設的取樣策略,或者在執行時透過 set_interaction_type 上下文管理器覆蓋它。

正如我們從以下範例中看到的那樣,torch.export 正確回應了裝飾器的使用:如果我們要求一個隨機樣本,則輸出與我們要求平均值時不同

with set_interaction_type(InteractionType.RANDOM):
    model_export = export(model, args=(), kwargs={"x": x})
    print(model_export.module())

with set_interaction_type(InteractionType.MEAN):
    model_export = export(model, args=(), kwargs={"x": x})
    print(model_export.module())
GraphModule(
  (module): Module(
    (0): Module(
      (module): Module()
    )
    (2): Module(
      (module): Module()
    )
  )
)



def forward(self, x):
    x, = fx_pytree.tree_flatten_spec(([], {'x':x}), self._in_spec)
    module_0_module_weight = getattr(self.module, "0").module.weight
    module_0_module_bias = getattr(self.module, "0").module.bias
    module_2_module_weight = getattr(self.module, "2").module.weight
    module_2_module_bias = getattr(self.module, "2").module.bias
    linear = torch.ops.aten.linear.default(x, module_0_module_weight, module_0_module_bias);  x = module_0_module_weight = module_0_module_bias = None
    relu = torch.ops.aten.relu.default(linear);  linear = None
    linear_1 = torch.ops.aten.linear.default(relu, module_2_module_weight, module_2_module_bias);  relu = module_2_module_weight = module_2_module_bias = None
    split = torch.ops.aten.split.Tensor(linear_1, 2, -1);  linear_1 = None
    getitem = split[0]
    getitem_1 = split[1];  split = None
    add = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335);  getitem_1 = None
    softplus = torch.ops.aten.softplus.default(add);  add = None
    add_1 = torch.ops.aten.add.Tensor(softplus, 0.01);  softplus = None
    clamp_min = torch.ops.aten.clamp_min.default(add_1, 0.0001);  add_1 = None
    broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]);  getitem = clamp_min = None
    getitem_2 = broadcast_tensors[0]
    getitem_3 = broadcast_tensors[1];  broadcast_tensors = None
    empty = torch.ops.aten.empty.memory_format([1, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
    normal_functional = torch.ops.aten.normal_functional.default(empty);  empty = None
    mul = torch.ops.aten.mul.Tensor(normal_functional, getitem_3);  normal_functional = getitem_3 = None
    add_2 = torch.ops.aten.add.Tensor(getitem_2, mul);  getitem_2 = mul = None
    return pytree.tree_unflatten((add_2,), self._out_spec)

# To see more debug info, please use `graph_module.print_readable()`
GraphModule(
  (module): Module(
    (0): Module(
      (module): Module()
    )
    (2): Module(
      (module): Module()
    )
  )
)



def forward(self, x):
    x, = fx_pytree.tree_flatten_spec(([], {'x':x}), self._in_spec)
    module_0_module_weight = getattr(self.module, "0").module.weight
    module_0_module_bias = getattr(self.module, "0").module.bias
    module_2_module_weight = getattr(self.module, "2").module.weight
    module_2_module_bias = getattr(self.module, "2").module.bias
    linear = torch.ops.aten.linear.default(x, module_0_module_weight, module_0_module_bias);  x = module_0_module_weight = module_0_module_bias = None
    relu = torch.ops.aten.relu.default(linear);  linear = None
    linear_1 = torch.ops.aten.linear.default(relu, module_2_module_weight, module_2_module_bias);  relu = module_2_module_weight = module_2_module_bias = None
    split = torch.ops.aten.split.Tensor(linear_1, 2, -1);  linear_1 = None
    getitem = split[0]
    getitem_1 = split[1];  split = None
    add = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335);  getitem_1 = None
    softplus = torch.ops.aten.softplus.default(add);  add = None
    add_1 = torch.ops.aten.add.Tensor(softplus, 0.01);  softplus = None
    clamp_min = torch.ops.aten.clamp_min.default(add_1, 0.0001);  add_1 = None
    broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]);  getitem = clamp_min = None
    getitem_2 = broadcast_tensors[0];  broadcast_tensors = None
    return pytree.tree_unflatten((getitem_2,), self._out_spec)

# To see more debug info, please use `graph_module.print_readable()`

這就是您使用 torch.export 所需了解的全部內容。有關更多資訊,請參閱 官方文件

後續步驟和延伸閱讀

  • 查看 torch.export 教程,網址為 這裡

  • ONNX 支援:查看 ONNX 教程,以了解有關此功能的更多資訊。匯出到 ONNX 與此處解釋的 torch.export 非常相似。

  • 若要在沒有 Python 環境的伺服器上部署 PyTorch 程式碼,請參閱 AOTInductor 文件。

腳本總執行時間: (0 分鐘 1.695 秒)

Gallery 由 Sphinx-Gallery 產生

文件

取得 PyTorch 完整的開發者文件

查看文件

教學

取得適合初學者和進階開發人員的深入教學課程

查看教學課程

資源

尋找開發資源並獲得問題解答

查看資源