torch.export¶
警告
此功能是一個正在積極開發中的原型,未來將會出現重大變更。
概述¶
torch.export.export()
接受任意 Python 可呼叫物件 (一個 torch.nn.Module
、一個函式或一個方法),並產生一個追蹤圖,以預先編譯 (Ahead-of-Time, AOT) 的方式呈現該函式的 Tensor 計算,隨後可以使用不同的輸出執行或序列化。
import torch
from torch.export import export
class Mod(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
a = torch.sin(x)
b = torch.cos(y)
return a + b
example_args = (torch.randn(10, 10), torch.randn(10, 10))
exported_program: torch.export.ExportedProgram = export(
Mod(), args=example_args
)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[10, 10]", y: "f32[10, 10]"):
# code: a = torch.sin(x)
sin: "f32[10, 10]" = torch.ops.aten.sin.default(x)
# code: b = torch.cos(y)
cos: "f32[10, 10]" = torch.ops.aten.cos.default(y)
# code: return a + b
add: f32[10, 10] = torch.ops.aten.add.Tensor(sin, cos)
return (add,)
Graph signature:
ExportGraphSignature(
input_specs=[
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='x'),
target=None,
persistent=None
),
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='y'),
target=None,
persistent=None
)
],
output_specs=[
OutputSpec(
kind=<OutputKind.USER_OUTPUT: 1>,
arg=TensorArgument(name='add'),
target=None
)
]
)
Range constraints: {}
torch.export
產生一個乾淨的中間表示法 (Intermediate Representation, IR),具有以下不變性。關於 IR 的更多規範可以在這裡找到。
健全性 (Soundness):保證它是原始程式的一個健全表示法,並保持原始程式相同的呼叫慣例。
標準化 (Normalized):圖中沒有 Python 語義。原始程式中的子模組會被內聯,形成一個完全扁平化的計算圖。
圖的屬性 (Graph properties):該圖是純函數式的,意味著它不包含具有副作用的操作,例如變異 (mutations) 或別名 (aliasing)。它不會變異任何中間值、參數或緩衝區。
元數據 (Metadata):該圖包含在追蹤期間捕獲的元數據,例如來自使用者程式碼的堆疊追蹤 (stacktrace)。
在底層,torch.export
利用了以下最新技術
TorchDynamo (torch._dynamo) 是一個內部 API,它使用一個名為 Frame Evaluation API 的 CPython 功能來安全地追蹤 PyTorch 圖。這提供了大幅改進的圖捕獲體驗,只需更少的重寫即可完全追蹤 PyTorch 程式碼。
AOT Autograd 提供了一個函數化的 PyTorch 圖,並確保該圖被分解/降低到 ATen 運算子集合。
Torch FX (torch.fx) 是該圖的底層表示法,允許基於 Python 的靈活轉換。
現有框架¶
torch.compile()
也使用了與 torch.export
相同的 PT2 堆疊,但略有不同
JIT vs. AOT:
torch.compile()
是一個 JIT 編譯器,而torch.export
不旨在用於在部署之外產生編譯的成品。部分 vs. 完整圖捕獲 (Partial vs. Full Graph Capture):當
torch.compile()
遇到模型中無法追蹤的部分時,它會「中斷圖 (graph break)」並退回到在 eager Python 運行時中執行程式。相比之下,torch.export
旨在獲得 PyTorch 模型的完整圖表示,因此當達到無法追蹤的內容時,它會出錯。由於torch.export
產生一個與任何 Python 功能或運行時分離的完整圖,因此該圖可以被保存、載入並在不同的環境和語言中執行。可用性權衡 (Usability tradeoff):由於
torch.compile()
能夠在達到無法追蹤的內容時回退到 Python 運行時,因此它更加靈活。torch.export
則需要使用者提供更多資訊或重寫他們的程式碼以使其可追蹤。
與 torch.fx.symbolic_trace()
相比,torch.export
使用在 Python 字節碼層級運作的 TorchDynamo 進行追蹤,使其能夠追蹤任意 Python 結構,而不受 Python 運算子重載支援的限制。此外,torch.export
精細地追蹤 Tensor 元數據,因此對 Tensor 形狀等事物的條件判斷不會導致追蹤失敗。總體而言,預期 torch.export
可以在更多的使用者程式上運作,並產生更低層級的圖 (在 torch.ops.aten
運算子層級)。請注意,使用者仍然可以使用 torch.fx.symbolic_trace()
作為 torch.export
之前的一個預處理步驟。
與 torch.jit.script()
相比,torch.export
不捕獲 Python 控制流或數據結構,但它比 TorchScript 支援更多的 Python 語言特性 (因為更容易全面覆蓋 Python 字節碼)。生成的圖更簡單,並且只有直線控制流 (除了顯式的控制流運算子)。
與 torch.jit.trace()
相比,torch.export
是健全的:它能夠追蹤對大小執行整數計算的程式碼,並記錄所有必要的副作用條件,以表明特定追蹤對於其他輸入有效。
導出 PyTorch 模型¶
一個例子¶
主要的進入點是透過 torch.export.export()
,它接受一個可呼叫物件 (torch.nn.Module
、函式或方法) 和範例輸入,並將計算圖捕獲到一個 torch.export.ExportedProgram
中。一個例子
import torch
from torch.export import export
# Simple module for demonstration
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels=3, out_channels=16, kernel_size=3, padding=1
)
self.relu = torch.nn.ReLU()
self.maxpool = torch.nn.MaxPool2d(kernel_size=3)
def forward(self, x: torch.Tensor, *, constant=None) -> torch.Tensor:
a = self.conv(x)
a.add_(constant)
return self.maxpool(self.relu(a))
example_args = (torch.randn(1, 3, 256, 256),)
example_kwargs = {"constant": torch.ones(1, 16, 256, 256)}
exported_program: torch.export.ExportedProgram = export(
M(), args=example_args, kwargs=example_kwargs
)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_conv_weight: "f32[16, 3, 3, 3]", p_conv_bias: "f32[16]", x: "f32[1, 3, 256, 256]", constant: "f32[1, 16, 256, 256]"):
# code: a = self.conv(x)
conv2d: "f32[1, 16, 256, 256]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias, [1, 1], [1, 1])
# code: a.add_(constant)
add_: "f32[1, 16, 256, 256]" = torch.ops.aten.add_.Tensor(conv2d, constant)
# code: return self.maxpool(self.relu(a))
relu: "f32[1, 16, 256, 256]" = torch.ops.aten.relu.default(add_)
max_pool2d: "f32[1, 16, 85, 85]" = torch.ops.aten.max_pool2d.default(relu, [3, 3], [3, 3])
return (max_pool2d,)
Graph signature:
ExportGraphSignature(
input_specs=[
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_conv_weight'),
target='conv.weight',
persistent=None
),
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_conv_bias'),
target='conv.bias',
persistent=None
),
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='x'),
target=None,
persistent=None
),
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='constant'),
target=None,
persistent=None
)
],
output_specs=[
OutputSpec(
kind=<OutputKind.USER_OUTPUT: 1>,
arg=TensorArgument(name='max_pool2d'),
target=None
)
]
)
Range constraints: {}
檢查 ExportedProgram
,我們可以注意到以下幾點
torch.fx.Graph
包含原始程式的計算圖,以及原始程式碼的記錄,以便於除錯。該圖僅包含
torch.ops.aten
運算子 (可在這裡找到) 和自定義運算子,並且是完全函數式的,沒有任何 inplace 運算子,例如torch.add_
。參數 (卷積的權重和偏差) 作為圖的輸入被提升,導致圖中沒有
get_attr
節點,而這些節點之前存在於torch.fx.symbolic_trace()
的結果中。torch.export.ExportGraphSignature
對輸入和輸出簽名進行建模,並指定哪些輸入是參數。圖中每個節點產生的張量的結果形狀 (shape) 和資料類型 (dtype) 都會被記錄。 例如,
convolution
節點將產生一個資料類型為torch.float32
,形狀為 (1, 16, 256, 256) 的張量。
非嚴格匯出¶
在 PyTorch 2.3 中,我們引入了一種新的追蹤模式,稱為非嚴格模式。它仍在強化階段,因此如果您遇到任何問題,請在 Github 上提交 issue,並加上 "oncall: export" 標籤。
在非嚴格模式中,我們使用 Python 直譯器追蹤程式。您的程式碼將完全按照 eager 模式中的方式執行;唯一的區別是,所有 Tensor 物件都將被 ProxyTensors 替換,ProxyTensors 會將它們的所有操作記錄到一個圖中。
在嚴格模式下 (目前是預設模式),我們首先使用 TorchDynamo (一個 bytecode 分析引擎) 追蹤程式。 TorchDynamo 實際上不會執行您的 Python 程式碼。 相反,它會以符號方式分析它,並根據結果建立一個圖。 這種分析使 torch.export 能夠提供更強的安全保證,但並非所有 Python 程式碼都受支援。
一個可能需要使用非嚴格模式的例子是,當您遇到一個不受支援的 TorchDynamo 功能,而該功能可能不容易解決,並且您知道 Python 程式碼對於計算而言並非絕對必要時。例如:
import contextlib
import torch
class ContextManager():
def __init__(self):
self.count = 0
def __enter__(self):
self.count += 1
def __exit__(self, exc_type, exc_value, traceback):
self.count -= 1
class M(torch.nn.Module):
def forward(self, x):
with ContextManager():
return x.sin() + x.cos()
export(M(), (torch.ones(3, 3),), strict=False) # Non-strict traces successfully
export(M(), (torch.ones(3, 3),)) # Strict mode fails with torch._dynamo.exc.Unsupported: ContextManager
在這個例子中,第一次使用非嚴格模式的呼叫 (透過 strict=False
標誌) 成功追蹤,而第二次使用嚴格模式的呼叫 (預設) 則失敗,因為 TorchDynamo 無法支援 context managers。 一個選項是重寫程式碼 (請參閱torch.export 的限制),但由於 context manager 不會影響模型中的張量計算,因此我們可以採用非嚴格模式的結果。
用於訓練和推論的匯出¶
在 PyTorch 2.5 中,我們引入了一個新的 API,稱為 export_for_training()
。它仍在強化階段,因此如果您遇到任何問題,請在 Github 上提交 issue,並加上 "oncall: export" 標籤。
在這個 API 中,我們產生最通用的 IR,其中包含所有 ATen 運算符 (包括 functional 和 non-functional),可用於在 eager PyTorch Autograd 中進行訓練。 這個 API 適用於 eager training 的使用案例,例如 PT2 量化,並且很快將成為 torch.export.export 的預設 IR。 要進一步了解此變更背後的動機,請參閱 https://dev-discuss.pytorch.org/t/why-pytorch-does-not-need-a-new-standardized-operator-set/2206
當此 API 與 run_decompositions()
結合使用時,您應該能夠獲得具有任何所需分解行為的推論 IR。
為了展示一些範例:
class ConvBatchnorm(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(1, 3, 1, 1)
self.bn = torch.nn.BatchNorm2d(3)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return (x,)
mod = ConvBatchnorm()
inp = torch.randn(1, 1, 3, 3)
ep_for_training = torch.export.export_for_training(mod, (inp,))
print(ep_for_training)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias)
add_: "i64[]" = torch.ops.aten.add_.Tensor(b_bn_num_batches_tracked, 1)
batch_norm: "f32[1, 3, 3, 3]" = torch.ops.aten.batch_norm.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05, True)
return (batch_norm,)
Graph signature:
ExportGraphSignature(
input_specs=[
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_conv_weight'),
target='conv.weight',
persistent=None
),
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_conv_bias'),
target='conv.bias',
persistent=None
),
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_bn_weight'),
target='bn.weight',
persistent=None
),
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_bn_bias'),
target='bn.bias',
persistent=None
),
InputSpec(
kind=<InputKind.BUFFER: 3>,
arg=TensorArgument(name='b_bn_running_mean'),
target='bn.running_mean',
persistent=True
),
InputSpec(
kind=<InputKind.BUFFER: 3>,
arg=TensorArgument(name='b_bn_running_var'),
target='bn.running_var',
persistent=True
),
InputSpec(
kind=<InputKind.BUFFER: 3>,
arg=TensorArgument(name='b_bn_num_batches_tracked'),
target='bn.num_batches_tracked',
persistent=True
),
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='x'),
target=None,
persistent=None
)
],
output_specs=[
OutputSpec(
kind=<OutputKind.USER_OUTPUT: 1>,
arg=TensorArgument(name='batch_norm'),
target=None
)
]
)
Range constraints: {}
從上面的輸出中,您可以看到 export_for_training()
產生的 ExportedProgram 幾乎與 export()
相同,只是圖中的運算符不同。 您可以看到我們以最通用的形式捕獲了 batch_norm。 這個 op 是 non-functional 的,並且在運行推論時會被降級為不同的 ops。
您也可以透過具有任意自訂的 run_decompositions()
從此 IR 轉換為推論 IR。
# Lower to core aten inference IR, but keep conv2d
decomp_table = torch.export.default_decompositions()
del decomp_table[torch.ops.aten.conv2d.default]
ep_for_inference = ep_for_training.run_decompositions(decomp_table)
print(ep_for_inference)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias)
add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1)
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05)
getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]
return (getitem_3, getitem_4, add, getitem)
Graph signature:
ExportGraphSignature(
input_specs=[
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_conv_weight'),
target='conv.weight',
persistent=None
),
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_conv_bias'),
target='conv.bias',
persistent=None
),
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_bn_weight'),
target='bn.weight',
persistent=None
),
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_bn_bias'),
target='bn.bias',
persistent=None
),
InputSpec(
kind=<InputKind.BUFFER: 3>,
arg=TensorArgument(name='b_bn_running_mean'),
target='bn.running_mean',
persistent=True
),
InputSpec(
kind=<InputKind.BUFFER: 3>,
arg=TensorArgument(name='b_bn_running_var'),
target='bn.running_var',
persistent=True
),
InputSpec(
kind=<InputKind.BUFFER: 3>,
arg=TensorArgument(name='b_bn_num_batches_tracked'),
target='bn.num_batches_tracked',
persistent=True
),
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='x'),
target=None,
persistent=None
)
],
output_specs=[
OutputSpec(
kind=<OutputKind.BUFFER_MUTATION: 3>,
arg=TensorArgument(name='getitem_3'),
target='bn.running_mean'
),
OutputSpec(
kind=<OutputKind.BUFFER_MUTATION: 3>,
arg=TensorArgument(name='getitem_4'),
target='bn.running_var'
),
OutputSpec(
kind=<OutputKind.BUFFER_MUTATION: 3>,
arg=TensorArgument(name='add'),
target='bn.num_batches_tracked'
),
OutputSpec(
kind=<OutputKind.USER_OUTPUT: 1>,
arg=TensorArgument(name='getitem'),
target=None
)
]
)
Range constraints: {}
在這裡您可以看到我們保留了 IR 中的 conv2d
op,同時分解了其餘部分。 現在 IR 是一個 functional IR,包含核心 aten 運算符,除了 conv2d
。
您可以透過直接註冊您選擇的分解行為來進行更多自訂。
您可以透過直接註冊自訂的分解行為來進行更多自訂
# Lower to core aten inference IR, but customize conv2d
decomp_table = torch.export.default_decompositions()
def my_awesome_custom_conv2d_function(x, weight, bias, stride=[1, 1], padding=[0, 0], dilation=[1, 1], groups=1):
return 2 * torch.ops.aten.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0], groups)
decomp_table[torch.ops.aten.conv2d.default] = my_awesome_conv2d_function
ep_for_inference = ep_for_training.run_decompositions(decomp_table)
print(ep_for_inference)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(x, p_conv_weight, p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1)
mul: "f32[1, 3, 3, 3]" = torch.ops.aten.mul.Tensor(convolution, 2)
add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1)
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(mul, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05)
getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4];
return (getitem_3, getitem_4, add, getitem)
Graph signature:
ExportGraphSignature(
input_specs=[
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_conv_weight'),
target='conv.weight',
persistent=None
),
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_conv_bias'),
target='conv.bias',
persistent=None
),
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_bn_weight'),
target='bn.weight',
persistent=None
),
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_bn_bias'),
target='bn.bias',
persistent=None
),
InputSpec(
kind=<InputKind.BUFFER: 3>,
arg=TensorArgument(name='b_bn_running_mean'),
target='bn.running_mean',
persistent=True
),
InputSpec(
kind=<InputKind.BUFFER: 3>,
arg=TensorArgument(name='b_bn_running_var'),
target='bn.running_var',
persistent=True
),
InputSpec(
kind=<InputKind.BUFFER: 3>,
arg=TensorArgument(name='b_bn_num_batches_tracked'),
target='bn.num_batches_tracked',
persistent=True
),
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='x'),
target=None,
persistent=None
)
],
output_specs=[
OutputSpec(
kind=<OutputKind.BUFFER_MUTATION: 3>,
arg=TensorArgument(name='getitem_3'),
target='bn.running_mean'
),
OutputSpec(
kind=<OutputKind.BUFFER_MUTATION: 3>,
arg=TensorArgument(name='getitem_4'),
target='bn.running_var'
),
OutputSpec(
kind=<OutputKind.BUFFER_MUTATION: 3>,
arg=TensorArgument(name='add'),
target='bn.num_batches_tracked'
),
OutputSpec(
kind=<OutputKind.USER_OUTPUT: 1>,
arg=TensorArgument(name='getitem'),
target=None
)
]
)
Range constraints: {}
表達動態性¶
預設情況下,torch.export
將追蹤程式,假設所有輸入形狀都是靜態的,並且將匯出的程式專門化為這些維度。 然而,某些維度 (例如批次維度) 可能是動態的,並且每次運行都會有所不同。 這些維度必須使用 torch.export.Dim()
API 來建立,並透過 dynamic_shapes
參數傳遞到 torch.export.export()
中。 例如:
import torch
from torch.export import Dim, export
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.branch1 = torch.nn.Sequential(
torch.nn.Linear(64, 32), torch.nn.ReLU()
)
self.branch2 = torch.nn.Sequential(
torch.nn.Linear(128, 64), torch.nn.ReLU()
)
self.buffer = torch.ones(32)
def forward(self, x1, x2):
out1 = self.branch1(x1)
out2 = self.branch2(x2)
return (out1 + self.buffer, out2)
example_args = (torch.randn(32, 64), torch.randn(32, 128))
# Create a dynamic batch size
batch = Dim("batch")
# Specify that the first dimension of each input is that batch size
dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}
exported_program: torch.export.ExportedProgram = export(
M(), args=example_args, dynamic_shapes=dynamic_shapes
)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_branch1_0_weight: "f32[32, 64]", p_branch1_0_bias: "f32[32]", p_branch2_0_weight: "f32[64, 128]", p_branch2_0_bias: "f32[64]", c_buffer: "f32[32]", x1: "f32[s0, 64]", x2: "f32[s0, 128]"):
# code: out1 = self.branch1(x1)
linear: "f32[s0, 32]" = torch.ops.aten.linear.default(x1, p_branch1_0_weight, p_branch1_0_bias)
relu: "f32[s0, 32]" = torch.ops.aten.relu.default(linear)
# code: out2 = self.branch2(x2)
linear_1: "f32[s0, 64]" = torch.ops.aten.linear.default(x2, p_branch2_0_weight, p_branch2_0_bias)
relu_1: "f32[s0, 64]" = torch.ops.aten.relu.default(linear_1)
# code: return (out1 + self.buffer, out2)
add: "f32[s0, 32]" = torch.ops.aten.add.Tensor(relu, c_buffer)
return (add, relu_1)
Graph signature:
ExportGraphSignature(
input_specs=[
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_branch1_0_weight'),
target='branch1.0.weight',
persistent=None
),
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_branch1_0_bias'),
target='branch1.0.bias',
persistent=None
),
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_branch2_0_weight'),
target='branch2.0.weight',
persistent=None
),
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_branch2_0_bias'),
target='branch2.0.bias',
persistent=None
),
InputSpec(
kind=<InputKind.CONSTANT_TENSOR: 4>,
arg=TensorArgument(name='c_buffer'),
target='buffer',
persistent=True
),
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='x1'),
target=None,
persistent=None
),
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='x2'),
target=None,
persistent=None
)
],
output_specs=[
OutputSpec(
kind=<OutputKind.USER_OUTPUT: 1>,
arg=TensorArgument(name='add'),
target=None
),
OutputSpec(
kind=<OutputKind.USER_OUTPUT: 1>,
arg=TensorArgument(name='relu_1'),
target=None
)
]
)
Range constraints: {s0: VR[0, int_oo]}
一些需要注意的額外事項:
透過
torch.export.Dim()
API 和dynamic_shapes
參數,我們指定了每個輸入的第一個維度是動態的。 查看輸入x1
和x2
,它們具有 (s0, 64) 和 (s0, 128) 的符號形狀,而不是我們作為範例輸入傳入的 (32, 64) 和 (32, 128) 形狀的張量。s0
是一個符號,表示此維度可以是一系列的值。exported_program.range_constraints
描述了圖中出現的每個符號的範圍。 在這種情況下,我們看到s0
的範圍是 [0, int_oo]。 由於技術原因,很難在此處解釋,它們被假定為非 0 或 1。 這不是錯誤,並且不一定表示匯出的程式不適用於維度 0 或 1。 有關此主題的深入討論,請參閱 0/1 特殊化問題。
我們還可以指定輸入形狀之間更具表達性的關係,例如一對形狀可能相差 1、一個形狀可能是另一個形狀的兩倍,或者一個形狀是偶數。 例如:
class M(torch.nn.Module):
def forward(self, x, y):
return x + y[1:]
x, y = torch.randn(5), torch.randn(6)
dimx = torch.export.Dim("dimx", min=3, max=6)
dimy = dimx + 1
exported_program = torch.export.export(
M(), (x, y), dynamic_shapes=({0: dimx}, {0: dimy}),
)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[s0]", y: "f32[s0 + 1]"):
# code: return x + y[1:]
slice_1: "f32[s0]" = torch.ops.aten.slice.Tensor(y, 0, 1, 9223372036854775807)
add: "f32[s0]" = torch.ops.aten.add.Tensor(x, slice_1)
return (add,)
Graph signature:
ExportGraphSignature(
input_specs=[
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='x'),
target=None,
persistent=None
),
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='y'),
target=None,
persistent=None
)
],
output_specs=[
OutputSpec(
kind=<OutputKind.USER_OUTPUT: 1>,
arg=TensorArgument(name='add'),
target=None
)
]
)
Range constraints: {s0: VR[3, 6], s0 + 1: VR[4, 7]}
一些需要注意的事項:
透過為第一個輸入指定
{0: dimx}
,我們可以看到第一個輸入的結果形狀現在是動態的,變為[s0]
。現在,透過為第二個輸入指定{0: dimy}
,我們可以看到第二個輸入的結果形狀也是動態的。然而,因為我們表示dimy = dimx + 1
,所以y
的形狀並沒有包含新的符號,而是用與x
中相同的符號s0
來表示。我們可以觀察到dimy = dimx + 1
的關係透過s0 + 1
來呈現。查看範圍約束,我們可以看到
s0
的範圍是 [3, 6],這是最初指定的,並且我們可以看到s0 + 1
的已解範圍是 [4, 7]。
序列化¶
為了保存 ExportedProgram
,使用者可以使用 torch.export.save()
和 torch.export.load()
API。一個慣例是使用 .pt2
副檔名來保存 ExportedProgram
。
範例
import torch
import io
class MyModule(torch.nn.Module):
def forward(self, x):
return x + 10
exported_program = torch.export.export(MyModule(), torch.randn(5))
torch.export.save(exported_program, 'exported_program.pt2')
saved_exported_program = torch.export.load('exported_program.pt2')
特化 (Specializations)¶
理解 torch.export
行為的一個關鍵概念是 *靜態* 和 *動態* 值之間的差異。
動態 值是指在每次執行中都可能變化的值。它們的行為就像 Python 函數的普通參數一樣——您可以為參數傳遞不同的值,並期望您的函數能做正確的事情。Tensor *資料* 被視為動態的。
靜態 值是指在匯出時固定的值,並且在匯出程式的執行之間不能更改。當在追蹤期間遇到該值時,匯出器會將其視為常數並將其硬編碼到圖形中。
當執行操作(例如 x + y
)並且所有輸入都是靜態的時,則操作的輸出將直接硬編碼到圖形中,並且操作不會顯示(即,它將被常數摺疊)。
當一個值被硬編碼到圖形中時,我們說該圖形已被特化為該值。
以下值是靜態的
輸入 Tensor 形狀¶
預設情況下,torch.export
將追蹤程式,並根據輸入 tensors 的形狀進行特化,除非透過 torch.export
的 dynamic_shapes
引數將維度指定為動態。這意味著如果存在與形狀相關的控制流程,torch.export
將特化於使用給定範例輸入所採取的哪個分支。例如
import torch
from torch.export import export
class Mod(torch.nn.Module):
def forward(self, x):
if x.shape[0] > 5:
return x + 1
else:
return x - 1
example_inputs = (torch.rand(10, 2),)
exported_program = export(Mod(), example_inputs)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[10, 2]"):
# code: return x + 1
add: "f32[10, 2]" = torch.ops.aten.add.Tensor(x, 1)
return (add,)
(x.shape[0] > 5
) 的條件不會出現在 ExportedProgram
中,因為範例輸入具有 (10, 2) 的靜態形狀。由於 torch.export
根據輸入的靜態形狀進行特化,因此永遠不會到達 else 分支 (x - 1
)。為了保留基於 tensors 形狀的追蹤圖形中的動態分支行為,需要使用 torch.export.Dim()
來指定輸入 tensor (x.shape[0]
) 的維度為動態的,並且需要 重寫 原始碼。
請注意,作為模組狀態一部分的 tensors(例如,參數和緩衝區)始終具有靜態形狀。
Python 原始型別 (Primitives)¶
torch.export
還會特化 Python 原始型別,例如 int
、float
、bool
和 str
。但是,它們確實具有動態變體,例如 SymInt
、SymFloat
和 SymBool
。
例如
import torch
from torch.export import export
class Mod(torch.nn.Module):
def forward(self, x: torch.Tensor, const: int, times: int):
for i in range(times):
x = x + const
return x
example_inputs = (torch.rand(2, 2), 1, 3)
exported_program = export(Mod(), example_inputs)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[2, 2]", const, times):
# code: x = x + const
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(x, 1)
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 1)
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 1)
return (add_2,)
由於整數被特化,因此 torch.ops.aten.add.Tensor
操作都使用硬編碼的常數 1
而不是 const
來計算。如果使用者在執行時為 const
傳遞了不同的值(例如 2),而不是匯出時使用的值 1,則會導致錯誤。此外,for
迴圈中使用的 times
迭代器也透過 3 個重複的 torch.ops.aten.add.Tensor
呼叫“內聯”到圖形中,並且永遠不會使用輸入 times
。
Python 容器¶
Python 容器(List
、Dict
、NamedTuple
等)被認為具有靜態結構。
torch.export
的限制¶
圖形中斷 (Graph Breaks)¶
由於 torch.export
是一個一次性的程序,用於從 PyTorch 程式中擷取計算圖,因此它最終可能會遇到無法追蹤的程式碼部分,因為幾乎不可能支援追蹤所有 PyTorch 和 Python 功能。在 torch.compile
的情況下,不支援的操作將會導致「圖形中斷」(graph break),並且不支援的操作將會使用預設的 Python 評估來執行。相反地,torch.export
將會要求使用者提供額外的資訊或重寫部分程式碼,使其可追蹤。由於追蹤是基於 TorchDynamo,而 TorchDynamo 是在 Python 位元組碼層級進行評估,因此與先前的追蹤框架相比,所需的重寫次數將會顯著減少。
當遇到圖形中斷時,ExportDB 是一個很棒的資源,可以了解支援和不支援哪些類型的程式,以及重寫程式以使其可追蹤的方法。
一個可以避免處理圖形中斷的選項是使用非嚴格匯出 (non-strict export)
資料/形狀相依的控制流程¶
當形狀未被特化時,在資料相依的控制流程 (例如 if x.shape[0] > 2
) 上也可能遇到圖形中斷,因為追蹤編譯器不可能處理,除非為組合式爆炸數量的路徑產生程式碼。在這種情況下,使用者需要使用特殊的控制流程運算子來重寫其程式碼。目前,我們支援 torch.cond 來表達類似 if-else 的控制流程 (更多功能即將推出!)。
運算子缺少 Fake/Meta/Abstract 核心¶
在追蹤時,所有運算子都需要一個 FakeTensor 核心(又稱 meta 核心、抽象實作)。這用於推斷此運算子的輸入/輸出形狀。
請參閱 torch.library.register_fake()
以取得更多詳細資訊。
在不幸的情況下,如果您的模型使用了一個尚未實作 FakeTensor 核心的 ATen 運算子,請提交 issue。
API 參考¶
- torch.export.export(mod, args, kwargs=None, *, dynamic_shapes=None, strict=True, preserve_module_call_signature=())[source][source]¶
export()
接受任何 nn.Module 以及範例輸入,並產生一個追蹤的圖形,該圖形僅代表函數的 Tensor 計算,並以 Ahead-of-Time (AOT) 方式進行,隨後可以使用不同的輸入執行或序列化。追蹤的圖形 (1) 在 functional ATen 運算子集中產生標準化的運算子(以及任何使用者指定的自訂運算子),(2) 消除了所有 Python 控制流程和資料結構(除了一些例外),並且 (3) 記錄了顯示這種標準化和控制流程消除對於未來輸入是健全的所需的一組形狀約束。健全性保證
在追蹤時,
export()
會記錄使用者程式和底層 PyTorch 運算子核心所做的形狀相關假設。只有當這些假設成立時,輸出ExportedProgram
才被認為是有效的。追蹤會對輸入張量的形狀(而不是值)做出假設。對於
export()
成功,必須在圖形擷取時驗證這些假設。具體來說對輸入張量的靜態形狀的假設會自動驗證,無需額外工作。
對輸入張量的動態形狀的假設需要明確指定,方法是使用
Dim()
API 來建構動態維度,並透過dynamic_shapes
引數將其與範例輸入相關聯。
如果任何假設無法驗證,將會引發嚴重錯誤。發生這種情況時,錯誤訊息將會包含建議的規格修正,這些修正需要驗證假設。例如,
export()
可能會建議以下修正動態維度dim0_x
的定義,例如出現在與輸入x
關聯的形狀中,先前定義為Dim("dim0_x")
dim = Dim("dim0_x", max=5)
這個範例表示產生的程式碼需要輸入
x
的維度 0 小於或等於 5 才有效。您可以檢查針對動態維度定義提供的建議修正,然後將它們逐字複製到您的程式碼中,而無需更改對export()
呼叫的dynamic_shapes
參數。- 參數
mod (Module) – 我們將追蹤此模組的 forward 方法。
dynamic_shapes (Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]]) –
一個選擇性的引數,其類型應為: 1) 從
f
的引數名稱到其動態形狀規格的字典;2) 一個元組,指定每個輸入依原始順序的動態形狀規格。如果您要指定關鍵字引數的動態性,您需要依照原始函數簽章中定義的順序傳遞它們。張量引數的動態形狀可以指定為 (1) 從動態維度索引到
Dim()
類型的字典,其中不需要在此字典中包含靜態維度索引,但當它們存在時,應該映射到 None;或 (2)Dim()
類型或 None 的元組 / 列表,其中Dim()
類型對應於動態維度,而靜態維度以 None 表示。 作為張量的字典或元組 / 列表的引數,通過使用包含的規格的映射或序列來遞迴指定。strict (bool) – 啟用時 (預設),匯出函數將透過 TorchDynamo 追蹤程式,這將確保結果圖的健全性。 否則,匯出的程式將不會驗證烘焙到圖中的隱含假設,並且可能導致原始模型和匯出模型之間的行為差異。 當使用者需要解決追蹤器中的錯誤,或者只是想在其模型中逐步啟用安全性時,這很有用。 請注意,這不會影響結果 IR 規格的不同,並且無論傳遞的值是什麼,模型都將以相同的方式序列化。 警告:此選項是實驗性的,使用風險自負。
- 返回值
包含追蹤可調用物件的
ExportedProgram
。- 返回類型
可接受的輸入/輸出類型
可接受的輸入類型 (用於
args
和kwargs
) 和輸出包括基本類型,即
torch.Tensor
、int
、float
、bool
和str
。資料類別,但必須先呼叫
register_dataclass()
來註冊它們。包含所有上述類型的
dict
、list
、tuple
、namedtuple
和OrderedDict
組成的(巢狀)資料結構。
- torch.export.save(ep, f, *, extra_files=None, opset_version=None)[source][source]¶
警告
正在積極開發中,保存的檔案可能無法在較新版本的 PyTorch 中使用。
將
ExportedProgram
保存到類似檔案的物件。 然後可以使用 Python APItorch.export.load
載入它。- 參數
ep (ExportedProgram) – 要儲存的已匯出程式 (exported program)。
f (Union[str, os.PathLike, io.BytesIO) – 檔案類物件 (file-like object) (必須實作 write 和 flush) 或包含檔案名稱的字串。
extra_files (Optional[Dict[str, Any]]) – 從檔名到內容的映射,將會儲存為 f 的一部分。
opset_version (Optional[Dict[str, int]]) – opset 名稱到此 opset 版本的映射
範例
import torch import io class MyModule(torch.nn.Module): def forward(self, x): return x + 10 ep = torch.export.export(MyModule(), (torch.randn(5),)) # Save to file torch.export.save(ep, 'exported_program.pt2') # Save to io.BytesIO buffer buffer = io.BytesIO() torch.export.save(ep, buffer) # Save with extra files extra_files = {'foo.txt': b'bar'.decode('utf-8')} torch.export.save(ep, 'exported_program.pt2', extra_files=extra_files)
- torch.export.load(f, *, extra_files=None, expected_opset_version=None)[source][source]¶
警告
正在積極開發中,保存的檔案可能無法在較新版本的 PyTorch 中使用。
載入先前使用
torch.export.save
儲存的ExportedProgram
。- 參數
ep (ExportedProgram) – 要儲存的已匯出程式 (exported program)。
f (Union[str, os.PathLike, io.BytesIO) – 檔案類物件 (file-like object) (必須實作 write 和 flush) 或包含檔案名稱的字串。
extra_files (Optional[Dict[str, Any]]) – 在此映射中給定的額外檔案名稱將被載入,並且其內容將被儲存在提供的映射中。
expected_opset_version (Optional[Dict[str, int]]) – opset 名稱到預期 opset 版本的映射
- 返回值
一個
ExportedProgram
物件- 返回類型
範例
import torch import io # Load ExportedProgram from file ep = torch.export.load('exported_program.pt2') # Load ExportedProgram from io.BytesIO object with open('exported_program.pt2', 'rb') as f: buffer = io.BytesIO(f.read()) buffer.seek(0) ep = torch.export.load(buffer) # Load with extra files. extra_files = {'foo.txt': ''} # values will be replaced with data ep = torch.export.load('exported_program.pt2', extra_files=extra_files) print(extra_files['foo.txt']) print(ep(torch.randn(5)))
- torch.export.register_dataclass(cls, *, serialized_type_name=None)[source][source]¶
將資料類別 (dataclass) 註冊為
torch.export.export()
的有效輸入/輸出類型。- 參數
範例
import torch from dataclasses import dataclass @dataclass class InputDataClass: feature: torch.Tensor bias: int @dataclass class OutputDataClass: res: torch.Tensor torch.export.register_dataclass(InputDataClass) torch.export.register_dataclass(OutputDataClass) class Mod(torch.nn.Module): def forward(self, x: InputDataClass) -> OutputDataClass: res = x.feature + x.bias return OutputDataClass(res=res) ep = torch.export.export(Mod(), (InputDataClass(torch.ones(2, 2), 1), )) print(ep)
- torch.export.dynamic_shapes.Dim(name, *, min=None, max=None)[source][source]¶
Dim()
構造一個類似於具有範圍的具名符號整數的類型。它可用於描述動態張量維度的多個可能值。 請注意,相同張量或不同張量的不同動態維度可以用同一種類型來描述。
- torch.export.exported_program.default_decompositions()[source][source]¶
這是預設的分解表,其中包含將所有 ATEN 運算符分解為核心 aten opset。 將此 API 與
run_decompositions()
一起使用- 返回類型
- class torch.export.dynamic_shapes.ShapesCollection[原始碼][原始碼]¶
dynamic_shapes 的建構器。用於將動態形狀規範指派給輸入中出現的張量。
- 範例:
args = ({“x”: tensor_x, “others”: [tensor_y, tensor_z]})
dim = torch.export.Dim(…) dynamic_shapes = torch.export.ShapesCollection() dynamic_shapes[tensor_x] = (dim, dim + 1, 8) dynamic_shapes[tensor_y] = {0: dim * 2} # 這等同於以下(現在自動產生): # dynamic_shapes = {“x”: (dim, dim + 1, 8), “others”: [{0: dim * 2}, None]}
torch.export(…, args, dynamic_shapes=dynamic_shapes)
- torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes(msg, dynamic_shapes)[原始碼][原始碼]¶
用於處理 export 的動態形狀建議修正,和/或自動動態形狀。根據 ConstraintViolation 錯誤訊息和原始動態形狀,精煉給定的動態形狀規範。
在大多數情況下,行為很簡單 - 例如,對於專門化或精煉 Dim 範圍的建議修正,或建議導出關係的修正,新的動態形狀規範將如此更新。
例如,建議的修正
dim = Dim(‘dim’, min=3, max=6) -> 這只精煉了 dim 的範圍 dim = 4 -> 這專門化為常數 dy = dx + 1 -> dy 被指定為獨立的 dim,但實際上透過此關係與 dx 綁定
但是,與導出 dim 相關的建議修正可能更複雜。例如,如果為根 dim 提供了建議修正,則會根據根來評估新的導出 dim 值。
例如 dx = Dim(‘dx’) dy = dx + 2 dynamic_shapes = {“x”: (dx,), “y”: (dy,)}
建議的修正
dx = 4 # 專門化將導致 dy 也專門化 = 6 dx = Dim(‘dx’, max=6) # dy 現在具有 max = 8
導出 dim 的建議修正也可用於表達可除性約束。這涉及創建未與特定輸入形狀綁定的新根 dim。在這種情況下,根 dim 不會直接出現在新規範中,而是作為其中一個 dim 的根。
例如,建議的修正
_dx = Dim(‘_dx’, max=1024) # 這不會出現在回傳結果中,但 dx 會 dx = 4*_dx # dx 現在可以被 4 整除,最大值為 4096
- class torch.export.ExportedProgram(root, graph, graph_signature, state_dict, range_constraints, module_call_graph, example_inputs=None, constants=None, *, verifiers=None)[原始碼][原始碼]¶
來自
export()
的程式包。它包含一個torch.fx.Graph
,代表張量計算、一個 state_dict,包含所有 lifted 參數和緩衝區的張量值,以及各種元資料。您可以像原始的可呼叫物件一樣呼叫 ExportedProgram,該物件由具有相同呼叫慣例的
export()
追蹤。要對圖形執行轉換,請使用
.module
屬性來存取torch.fx.GraphModule
。然後您可以使用 FX 轉換來重寫圖形。之後,您可以再次簡單地使用export()
來建構正確的 ExportedProgram。- run_decompositions(decomp_table=None)[原始碼][原始碼]¶
在導出的程式上執行一組分解,並傳回一個新的導出的程式。 預設情況下,我們將執行 Core ATen 分解,以取得 Core ATen Operator Set 中的運算子。
目前,我們不分解聯合圖形。
- 參數
decomp_table (Optional[Dict[OperatorBase, Callable]]) – 一個可選參數,用於指定 Aten 運算子的分解行為 (1) 如果為 None,我們會分解為核心 aten 分解 (2) 如果為空,我們不會分解任何運算子
- 返回類型
一些範例
如果您不想分解任何內容
ep = torch.export.export(model, ...) ep = ep.run_decompositions(decomp_table={})
如果您想要取得核心 aten 運算子集,但排除某些運算子,您可以執行以下操作
ep = torch.export.export(model, ...) decomp_table = torch.export.default_decompositions() decomp_table[your_op] = your_custom_decomp ep = ep.run_decompositions(decomp_table=decomp_table)
- class torch.export.ExportBackwardSignature(gradients_to_parameters: Dict[str, str], gradients_to_user_inputs: Dict[str, str], loss_output: str)[原始碼][原始碼]¶
- class torch.export.ExportGraphSignature(input_specs, output_specs)[source][source]¶
ExportGraphSignature
為 Export Graph 的輸入/輸出簽章建模,Export Graph 是一個 fx.Graph,具有更強的不變性保證。Export Graph 是函數式的,並且不透過
getattr
節點存取圖中的「狀態」,例如參數或緩衝區。取而代之的是,export()
保證參數、緩衝區和常數張量會被提升為圖的輸入。同樣地,對緩衝區的任何修改也不會包含在圖中,而是將修改後的緩衝區值建模為 Export Graph 的額外輸出。所有輸入和輸出的順序為
Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs] Outputs = [*mutated_inputs, *flattened_user_outputs]
例如,如果匯出以下模組
class CustomModule(nn.Module): def __init__(self) -> None: super(CustomModule, self).__init__() # Define a parameter self.my_parameter = nn.Parameter(torch.tensor(2.0)) # Define two buffers self.register_buffer('my_buffer1', torch.tensor(3.0)) self.register_buffer('my_buffer2', torch.tensor(4.0)) def forward(self, x1, x2): # Use the parameter, buffers, and both inputs in the forward method output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2 # Mutate one of the buffers (e.g., increment it by 1) self.my_buffer2.add_(1.0) # In-place addition return output
結果圖將會是
graph(): %arg0_1 := placeholder[target=arg0_1] %arg1_1 := placeholder[target=arg1_1] %arg2_1 := placeholder[target=arg2_1] %arg3_1 := placeholder[target=arg3_1] %arg4_1 := placeholder[target=arg4_1] %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {}) %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {}) %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {}) %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {}) %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {}) return (add_tensor_2, add_tensor_1)
結果 ExportGraphSignature 將會是
ExportGraphSignature( input_specs=[ InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='my_parameter'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg1_1'), target='my_buffer1'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg2_1'), target='my_buffer2'), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg4_1'), target=None) ], output_specs=[ OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add_2'), target='my_buffer2'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None) ] )
- class torch.export.ModuleCallSignature(inputs: List[Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument]], outputs: List[Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument]], in_spec: torch.utils._pytree.TreeSpec, out_spec: torch.utils._pytree.TreeSpec, forward_arg_names: Optional[List[str]] = None)[source][source]¶
- class torch.export.ModuleCallEntry(fqn: str, signature: Optional[torch.export.exported_program.ModuleCallSignature] = None)[source][source]¶
- class torch.export.decomp_utils.CustomDecompTable[source][source]¶
這是一個自定義的字典,專門用於處理匯出時的 decomp_table。我們需要它的原因是,在新的架構中,你只能從 decomp table 中刪除一個 op 來保留它。這對於自定義的 op 來說是有問題的,因為我們不知道自定義的 op 何時會真正被載入到 dispatcher。因此,我們需要記錄自定義 op 的操作,直到我們真正需要將其具體化(也就是當我們執行分解 pass 的時候)。
- 我們維持的不變性如下:
所有的 aten decomp 都在初始化時載入
當使用者從表格中讀取任何內容時,我們會具體化所有 (ALL) 的 op,以提高 dispatcher 選擇自定義 op 的可能性。
如果是寫入操作,我們不一定會具體化
我們在匯出期間的最後一次載入,就在調用 run_decompositions() 之前
- class torch.export.graph_signature.InputSpec(kind: torch.export.graph_signature.InputKind, arg: Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument], target: Optional[str], persistent: Optional[bool] = None)[原始碼][原始碼]¶
- class torch.export.graph_signature.OutputSpec(kind: torch.export.graph_signature.OutputKind, arg: Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument], target: Optional[str])[source][source]¶
- class torch.export.graph_signature.ExportGraphSignature(input_specs, output_specs)[source][source]¶
ExportGraphSignature
模擬 Export Graph 的輸入/輸出簽章,Export Graph 是一個具有更強不變性保證的 fx.Graph。Export Graph 是函數式的,並且不會透過
getattr
節點存取圖形內的「狀態」,例如參數或緩衝區。相反地,export()
保證參數、緩衝區和常數張量會被提升為圖形的輸入。同樣地,對緩衝區的任何變更也不會包含在圖形中,而是將變更後的緩衝區值建模為 Export Graph 的額外輸出。所有輸入和輸出的順序為
Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs] Outputs = [*mutated_inputs, *flattened_user_outputs]
例如,如果匯出以下模組
class CustomModule(nn.Module): def __init__(self) -> None: super(CustomModule, self).__init__() # Define a parameter self.my_parameter = nn.Parameter(torch.tensor(2.0)) # Define two buffers self.register_buffer('my_buffer1', torch.tensor(3.0)) self.register_buffer('my_buffer2', torch.tensor(4.0)) def forward(self, x1, x2): # Use the parameter, buffers, and both inputs in the forward method output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2 # Mutate one of the buffers (e.g., increment it by 1) self.my_buffer2.add_(1.0) # In-place addition return output
結果圖將會是
graph(): %arg0_1 := placeholder[target=arg0_1] %arg1_1 := placeholder[target=arg1_1] %arg2_1 := placeholder[target=arg2_1] %arg3_1 := placeholder[target=arg3_1] %arg4_1 := placeholder[target=arg4_1] %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {}) %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {}) %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {}) %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {}) %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {}) return (add_tensor_2, add_tensor_1)
結果 ExportGraphSignature 將會是
ExportGraphSignature( input_specs=[ InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='my_parameter'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg1_1'), target='my_buffer1'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg2_1'), target='my_buffer2'), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg4_1'), target=None) ], output_specs=[ OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add_2'), target='my_buffer2'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None) ] )
- class torch.export.graph_signature.CustomObjArgument(name: str, class_fqn: str, fake_val: Optional[torch._library.fake_class_registry.FakeScriptObject] = None)[source][source]¶
- class torch.export.unflatten.InterpreterModule(graph)[原始碼][原始碼]¶
一個使用 torch.fx.Interpreter 執行,而不是使用 GraphModule 慣用的程式碼產生的模組。這提供了更好的堆疊追蹤資訊,並更容易偵錯執行。
- class torch.export.unflatten.InterpreterModuleDispatcher(attrs, call_modules)[原始碼][原始碼]¶
一個攜帶一系列 InterpreterModules 的模組,對應於該模組的一系列呼叫。每次呼叫該模組都會分派到下一個 InterpreterModule,並在最後一個之後迴繞。
- torch.export.unflatten.unflatten(module, flat_args_adapter=None)[原始碼][原始碼]¶
展開一個 ExportedProgram,產生一個與原始 eager 模組具有相同模組層級結構的模組。如果您嘗試將
torch.export
與另一個期望模組層級結構而不是torch.export
通常產生的平面圖的系統一起使用,這可能會很有用。注意
展開模組的 args/kwargs 不一定與 eager 模組匹配,因此執行模組交換 (例如
self.submod = new_mod
) 不一定有效。如果您需要換出一個模組,您需要設定torch.export.export()
的preserve_module_call_signature
參數。- 參數
module (ExportedProgram) – 要展開的 ExportedProgram。
flat_args_adapter (Optional[FlatArgsAdapter]) – 如果輸入 TreeSpec 與匯出的模組不匹配,則調整扁平參數。
- 返回值
UnflattenedModule
的一個實例,它具有與匯出前的原始 eager 模組相同的模組層級結構。- 返回類型
UnflattenedModule
- torch.export.passes.move_to_device_pass(ep, location)[原始碼][原始碼]¶
將匯出的程式移動到給定的裝置。
- 參數
ep (ExportedProgram) – 要移動的匯出程式。
location (Union[torch.device, str, Dict[str, str]]) – 要將匯出的程式移動到的裝置。 如果是字串,則將其解釋為裝置名稱。 如果是字典,則將其解釋為從現有裝置到預期裝置的映射
- 返回值
移動後的匯出程式。
- 返回類型