• 教學 >
  • torch.export 教學
捷徑

torch.export 教學

建立於:2023 年 10 月 02 日 | 最後更新:2025 年 1 月 27 日 | 最後驗證:2024 年 11 月 05 日

作者:William Wen, Zhengxu Chen, Angela Yi, Pian Pawakapan

警告

torch.export 及其相關功能處於原型狀態,並且可能會發生向後不相容的變更。 本教學提供了 PyTorch 2.5 時 torch.export 使用方式的快照。

torch.export() 是 PyTorch 2.X 將 PyTorch 模型匯出為標準化模型表示的方式,旨在於不同的(即無 Python 的)環境中執行。 官方文件可在此處找到:這裡

在本教學中,您將學習如何使用 torch.export() 從 PyTorch 程式中提取 ExportedProgram(即單圖表示)。 我們還詳細說明了您可能需要進行的一些考量/修改,以使您的模型與 torch.export 相容。

目錄

基本用法

torch.export 透過追蹤目標函式並給定範例輸入,從 PyTorch 程式中提取單圖表示。 torch.export.export()torch.export 的主要進入點。

在本教學中,torch.exporttorch.export.export() 實際上是同義詞,儘管 torch.export 通常指的是 PyTorch 2.X 匯出過程,而 torch.export.export() 通常指的是實際的函式呼叫。

torch.export.export() 的簽章是

export(
    mod: torch.nn.Module,
    args: Tuple[Any, ...],
    kwargs: Optional[Dict[str, Any]] = None,
    *,
    dynamic_shapes: Optional[Dict[str, Dict[int, Dim]]] = None
) -> ExportedProgram

torch.export.export() 透過呼叫 mod(*args, **kwargs) 來追蹤張量計算圖,並將其包裝在 ExportedProgram 中,該程式可以序列化或稍後使用不同的輸入執行。 為了執行 ExportedProgram,我們可以對其呼叫 .module() 以傳回 torch.nn.Module,它就像原始程式一樣可呼叫。 我們將在本教學稍後詳細說明 dynamic_shapes 引數。

import torch
from torch.export import export

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = torch.nn.Linear(100, 10)

    def forward(self, x, y):
        return torch.nn.functional.relu(self.lin(x + y), inplace=True)

mod = MyModule()
exported_mod = export(mod, (torch.randn(8, 100), torch.randn(8, 100)))
print(type(exported_mod))
print(exported_mod.module()(torch.randn(8, 100), torch.randn(8, 100)))
<class 'torch.export.exported_program.ExportedProgram'>
tensor([[0.8632, 0.8407, 0.0407, 0.0000, 0.4132, 0.0000, 0.0000, 0.1538, 0.6111,
         0.0000],
        [0.0000, 0.0000, 0.0273, 0.8057, 0.0000, 1.0162, 0.8042, 0.0000, 0.2660,
         0.0000],
        [0.9481, 0.1396, 1.0225, 0.9563, 0.5832, 0.2546, 0.4095, 0.4591, 0.0000,
         2.0053],
        [1.1300, 0.4873, 0.0000, 0.9663, 1.2275, 1.4015, 0.0000, 0.9444, 0.0000,
         0.0000],
        [0.0000, 0.8724, 1.1648, 0.6867, 0.0000, 0.2833, 0.3202, 0.5848, 0.0000,
         0.0833],
        [1.1311, 0.1324, 0.0000, 1.7842, 0.0000, 0.3474, 0.9916, 0.3571, 0.0000,
         0.0000],
        [1.4348, 1.0570, 0.1771, 0.0000, 0.9510, 0.0000, 0.0000, 0.0000, 0.2618,
         0.0000],
        [0.8853, 0.0000, 0.0000, 0.4486, 0.0000, 0.0000, 0.5841, 0.7604, 0.0000,
         0.0000]], grad_fn=<ReluBackward0>)

讓我們回顧一些令人感興趣的 ExportedProgram 屬性。

graph 屬性是從我們匯出的函式追蹤的 FX 圖,也就是所有 PyTorch 運算的計算圖。 FX 圖處於「ATen IR」中,這表示它僅包含「ATen 等級」的運算。

graph_signature 屬性提供了匯出圖中輸入和輸出節點的更詳細描述,描述了哪些是參數、緩衝區、使用者輸入或使用者輸出。

range_constraints 屬性將在稍後介紹。

print(exported_mod)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_lin_weight: "f32[10, 100]", p_lin_bias: "f32[10]", x: "f32[8, 100]", y: "f32[8, 100]"):
             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:71 in forward, code: return torch.nn.functional.relu(self.lin(x + y), inplace=True)
            add: "f32[8, 100]" = torch.ops.aten.add.Tensor(x, y);  x = y = None
            linear: "f32[8, 10]" = torch.ops.aten.linear.default(add, p_lin_weight, p_lin_bias);  add = p_lin_weight = p_lin_bias = None
            relu_: "f32[8, 10]" = torch.ops.aten.relu_.default(linear);  linear = None
            return (relu_,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_lin_weight'), target='lin.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_lin_bias'), target='lin.bias', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='relu_'), target=None)])
Range constraints: {}

有關更多詳細資訊,請參閱 torch.export 文件

圖形中斷

儘管 torch.exporttorch.compile 共享元件,但 torch.export 的主要限制,尤其是在與 torch.compile 相比時,是它不支援圖形中斷。 這是因為處理圖形中斷涉及使用預設 Python 評估來解譯不支援的運算,這與匯出用例不相容。 因此,為了使您的模型程式碼與 torch.export 相容,您需要修改您的程式碼以移除圖形中斷。

在以下情況下,圖形中斷是必要的

  • 資料相關的控制流程

class Bad1(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return torch.sin(x)
        return torch.cos(x)

import traceback as tb
try:
    export(Bad1(), (torch.randn(3, 3),))
except Exception:
    tb.print_exc()
Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 122, in <module>
    export(Bad1(), (torch.randn(3, 3),))
  File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 368, in export
    return _export(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1970, in _export
    return _export_for_training(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1834, in _export_for_training
    export_artifact = export_func(  # type: ignore[operator]
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1283, in _strict_export_lower_to_aten_ir
    gm_torch_level = _export_to_torch_ir(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 662, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1569, in inner
    result_traced = opt_f(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
    return self._torchdynamo_orig_callable(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
    return _compile(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 95, in wrapper_function
    return function(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
    out_code = transform_code_object(code, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 231, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 662, in transform
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2868, in run
    super().run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 640, in inner
    raise exc.UserError(
torch._dynamo.exc.UserError: Dynamic control flow is not supported at the moment. Please use functorch.experimental.control_flow.cond to explicitly capture the control flow. For more information about this error, see: https://pytorch.dev.org.tw/docs/main/generated/exportdb/index.html#cond-operands

from user code:
   File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 116, in forward
    if x.sum() > 0:

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
  • 使用 .data 存取張量資料

class Bad2(torch.nn.Module):
    def forward(self, x):
        x.data[0, 0] = 3
        return x

try:
    export(Bad2(), (torch.randn(3, 3),))
except Exception:
    tb.print_exc()
  • 呼叫不支援的函式(例如許多內建函式)

class Bad3(torch.nn.Module):
    def forward(self, x):
        x = x + 1
        return x + id(x)

try:
    export(Bad3(), (torch.randn(3, 3),))
except Exception:
    tb.print_exc()
Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 148, in <module>
    export(Bad3(), (torch.randn(3, 3),))
  File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 368, in export
    return _export(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1970, in _export
    return _export_for_training(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1834, in _export_for_training
    export_artifact = export_func(  # type: ignore[operator]
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1283, in _strict_export_lower_to_aten_ir
    gm_torch_level = _export_to_torch_ir(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 662, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1569, in inner
    result_traced = opt_f(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
    return self._torchdynamo_orig_callable(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
    return _compile(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 95, in wrapper_function
    return function(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
    out_code = transform_code_object(code, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 231, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 662, in transform
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2868, in run
    super().run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1658, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 1004, in call_function
    return handler(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 843, in builtin_dispatch
    rv = handler(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 772, in call_self_handler
    result = self_handler(tx, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 1936, in call_id
    return tensor_variable.call_id(tx)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/tensor.py", line 469, in call_id
    unimplemented("call_id not supported for sourceless TensorVariable")
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/exc.py", line 317, in unimplemented
    raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: call_id not supported for sourceless TensorVariable

from user code:
   File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 145, in forward
    return x + id(x)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

非嚴格匯出

為了追蹤程式,torch.export 預設使用 TorchDynamo,一種位元組碼分析引擎,來符號分析 Python 程式碼並根據結果建立圖形。 這種分析使 torch.export 能夠提供更強的安全保證,但並非所有 Python 程式碼都受到支援,從而導致這些圖形中斷。

為了處理這個問題,在 PyTorch 2.3 中,我們引入了一種新的匯出模式,稱為非嚴格模式 (non-strict mode)。在此模式下,我們會使用 Python 直譯器追蹤程式,就像在 eager 模式下執行一樣,這允許我們跳過不支援的 Python 功能。這是透過新增 strict=False 標記來完成的。

看看之前導致圖形中斷 (graph breaks) 的一些範例

  • 呼叫不支援的函數(例如許多內建函數)會進行追蹤

但是,在這種情況下,id(x) 會在圖形中被特殊化為一個常數整數。這是因為 id(x) 不是張量運算,因此該運算不會記錄在圖形中。

class Bad3(torch.nn.Module):
    def forward(self, x):
        x = x + 1
        return x + id(x)

bad3_nonstrict = export(Bad3(), (torch.randn(3, 3),), strict=False)
print(bad3_nonstrict)
print(bad3_nonstrict.module()(torch.ones(3, 3)))
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 3]"):
             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:179 in forward, code: x = x + 1
            add: "f32[3, 3]" = torch.ops.aten.add.Tensor(x, 1);  x = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:180 in forward, code: return x + id(x)
            add_1: "f32[3, 3]" = torch.ops.aten.add.Tensor(add, 139872432249072);  add = None
            return (add_1,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None)])
Range constraints: {}

tensor([[1.3987e+14, 1.3987e+14, 1.3987e+14],
        [1.3987e+14, 1.3987e+14, 1.3987e+14],
        [1.3987e+14, 1.3987e+14, 1.3987e+14]])

然而,仍然有一些功能需要重寫原始模組

控制流程運算 (Control Flow Ops)

torch.export 實際上支援資料相依的控制流程。但這些需要使用控制流程運算來表達。例如,我們可以使用 cond 運算來修復上面的控制流程範例,如下所示

class Bad1Fixed(torch.nn.Module):
    def forward(self, x):
        def true_fn(x):
            return torch.sin(x)
        def false_fn(x):
            return torch.cos(x)
        return torch.cond(x.sum() > 0, true_fn, false_fn, [x])

exported_bad1_fixed = export(Bad1Fixed(), (torch.randn(3, 3),))
print(exported_bad1_fixed)
print(exported_bad1_fixed.module()(torch.ones(3, 3)))
print(exported_bad1_fixed.module()(-torch.ones(3, 3)))
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 3]"):
             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:205 in forward, code: return torch.cond(x.sum() > 0, true_fn, false_fn, [x])
            sum_1: "f32[]" = torch.ops.aten.sum.default(x)
            gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 0);  sum_1 = None

             # File: /usr/local/lib/python3.10/dist-packages/torch/_higher_order_ops/cond.py:144 in cond, code: return cond_op(pred, true_fn, false_fn, operands)
            true_graph_0 = self.true_graph_0
            false_graph_0 = self.false_graph_0
            cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [x]);  gt = true_graph_0 = false_graph_0 = x = None
            getitem: "f32[3, 3]" = cond[0];  cond = None
            return (getitem,)

        class true_graph_0(torch.nn.Module):
            def forward(self, x: "f32[3, 3]"):
                 # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:202 in true_fn, code: return torch.sin(x)
                sin: "f32[3, 3]" = torch.ops.aten.sin.default(x);  x = None
                return (sin,)

        class false_graph_0(torch.nn.Module):
            def forward(self, x: "f32[3, 3]"):
                 # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:204 in false_fn, code: return torch.cos(x)
                cos: "f32[3, 3]" = torch.ops.aten.cos.default(x);  x = None
                return (cos,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])
Range constraints: {}

tensor([[0.8415, 0.8415, 0.8415],
        [0.8415, 0.8415, 0.8415],
        [0.8415, 0.8415, 0.8415]])
tensor([[0.5403, 0.5403, 0.5403],
        [0.5403, 0.5403, 0.5403],
        [0.5403, 0.5403, 0.5403]])

對於 cond,應該注意一些限制

  • Predicate(即 x.sum() > 0)必須產生布林值或單元素張量。

  • 運算元(即 [x])必須是張量。

  • 分支函數(即 true_fnfalse_fn)簽名必須與運算元匹配,並且它們都必須返回具有相同 metadata 的單個張量(例如,dtypeshape 等)。

  • 分支函數不能改變輸入或全域變數。

  • 分支函數不能存取 closure 變數,除非函數是在方法範圍內定義的 self

有關 cond 的更多詳細資訊,請查看 cond 文件

我們也可以使用 map,它將函數應用於第一個張量參數的第一個維度。

from torch._higher_order_ops.map import map as torch_map

class MapModule(torch.nn.Module):
    def forward(self, xs, y, z):
        def body(x, y, z):
            return x + y + z

        return torch_map(body, xs, y, z)

inps = (torch.ones(6, 4), torch.tensor(5), torch.tensor(4))
exported_map_example = export(MapModule(), inps)
print(exported_map_example)
print(exported_map_example.module()(*inps))
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, xs: "f32[6, 4]", y: "i64[]", z: "i64[]"):
             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:236 in forward, code: return torch_map(body, xs, y, z)
            body_graph_0 = self.body_graph_0
            map_impl = torch.ops.higher_order.map_impl(body_graph_0, [xs], [y, z]);  body_graph_0 = xs = y = z = None
            getitem: "f32[6, 4]" = map_impl[0];  map_impl = None
            return (getitem,)

        class body_graph_0(torch.nn.Module):
            def forward(self, xs: "f32[4]", y: "i64[]", z: "i64[]"):
                 # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:234 in body, code: return x + y + z
                add: "f32[4]" = torch.ops.aten.add.Tensor(xs, y);  xs = y = None
                add_1: "f32[4]" = torch.ops.aten.add.Tensor(add, z);  add = z = None
                return (add_1,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='xs'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='z'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])
Range constraints: {}

tensor([[10., 10., 10., 10.],
        [10., 10., 10., 10.],
        [10., 10., 10., 10.],
        [10., 10., 10., 10.],
        [10., 10., 10., 10.],
        [10., 10., 10., 10.]])

其他控制流程運算包括 while_loopassociative_scanscan。有關每個運算子的更多文件,請參閱此頁面

約束/動態形狀 (Constraints/Dynamic Shapes)

本節涵蓋了匯出程式的動態行為和表示。動態行為取決於正在匯出的特定模型,因此在本教學的大部分時間裡,我們將專注於這個特定的玩具模型(並註釋產生的張量形狀)

class DynamicModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.l = torch.nn.Linear(5, 3)

    def forward(
        self,
        w: torch.Tensor,  # [6, 5]
        x: torch.Tensor,  # [4]
        y: torch.Tensor,  # [8, 4]
        z: torch.Tensor,  # [32]
    ):
        x0 = x + y  # [8, 4]
        x1 = self.l(w)  # [6, 3]
        x2 = x0.flatten()  # [32]
        x3 = x2 + z  # [32]
        return x1, x3

預設情況下,torch.export 會產生一個靜態程式。這樣做的結果是,在執行時,該程式無法處理具有不同形狀的輸入,即使它們在 eager 模式下是有效的。

w = torch.randn(6, 5)
x = torch.randn(4)
y = torch.randn(8, 4)
z = torch.randn(32)
model = DynamicModel()
ep = export(model, (w, x, y, z))
model(w, x, torch.randn(3, 4), torch.randn(12))
try:
    ep.module()(w, x, torch.randn(3, 4), torch.randn(12))
except Exception:
    tb.print_exc()
Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 286, in <module>
    ep.module()(w, x, torch.randn(3, 4), torch.randn(12))
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 822, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 400, in __call__
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 387, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1845, in _call_impl
    return inner()
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1772, in inner
    args_kwargs_result = hook(self, args, kwargs)  # type: ignore[misc]
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_unlift.py", line 49, in _check_input_constraints_pre_hook
    _check_input_constraints_for_graph(
  File "/usr/local/lib/python3.10/dist-packages/torch/_export/utils.py", line 360, in _check_input_constraints_for_graph
    raise RuntimeError(
RuntimeError: Expected input at *args[2].shape[0] to be equal to 8, but got 3

基本概念:符號和守衛 (symbols and guards)

為了啟用動態性,export() 提供了一個 dynamic_shapes 參數。使用動態形狀最簡單的方法是使用 Dim.AUTO 並查看返回的程式。動態行為是在輸入維度層級指定的;對於每個輸入,我們可以指定一個值元組

from torch.export.dynamic_shapes import Dim

dynamic_shapes = {
    "w": (Dim.AUTO, Dim.AUTO),
    "x": (Dim.AUTO,),
    "y": (Dim.AUTO, Dim.AUTO),
    "z": (Dim.AUTO,),
}
ep = export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)

在查看產生的程式之前,讓我們了解指定 dynamic_shapes 意味著什麼,以及它如何與匯出交互。對於指定了 Dim 物件的每個輸入維度,都會分配一個符號,取值範圍為 [2, inf](為什麼不是 [0, inf][1, inf]?我們將在 0/1 特殊化部分中解釋)。

然後,匯出會執行模型追蹤,查看模型執行的每個運算。每個單獨的運算都可以發出所謂的「守衛 (guards)」;基本上是程式有效的布林條件。當守衛涉及為輸入維度分配的符號時,程式會包含對有效輸入形狀的限制;即程式的動態行為。符號形狀子系統負責接收所有發出的守衛,並產生符合所有這些守衛的最終程式表示。在我們在 ExportedProgram 中看到這個「最終表示」之前,讓我們看看我們正在追蹤的玩具模型發出的守衛。

在這裡,每個 forward 輸入張量都使用在追蹤開始時分配的符號進行註釋

class DynamicModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.l = torch.nn.Linear(5, 3)

    def forward(
        self,
        w: torch.Tensor,  # [s0, s1]
        x: torch.Tensor,  # [s2]
        y: torch.Tensor,  # [s3, s4]
        z: torch.Tensor,  # [s5]
    ):
        x0 = x + y  # guard: s2 == s4
        x1 = self.l(w)  # guard: s1 == 5
        x2 = x0.flatten()  # no guard added here
        x3 = x2 + z  # guard: s3 * s4 == s5
        return x1, x3

讓我們了解每個運算和發出的守衛

  • x0 = x + y:這是一個具有廣播 (broadcasting) 的元素級加法,因為 x 是一個一維張量,而 y 是一個二維張量。x 沿著 y 的最後一個維度進行廣播,發出守衛 s2 == s4

  • x1 = self.l(w):呼叫 nn.Linear() 會執行與模型參數的矩陣乘法。在匯出中,參數、緩衝區和常數被視為程式狀態,程式狀態被認為是靜態的,因此這是動態輸入(w: [s0, s1])與靜態形狀張量之間的 matmul。這會發出守衛 s1 == 5

  • x2 = x0.flatten():此呼叫實際上沒有發出任何守衛!(至少沒有任何與輸入形狀相關的守衛)

  • x3 = x2 + zx2 在展平後具有形狀 [s3*s4],並且此元素級加法發出 s3 * s4 == s5

寫下所有這些防護並進行總結幾乎就像一個數學證明,而這正是符號形狀子系統試圖做的事情! 總而言之,我們可以得出結論,程式必須具有以下輸入形狀才是有效的:

  • w: [s0, 5]

  • x: [s2]

  • y: [s3, s2]

  • z: [s2*s3]

當我們最終印出導出的程式以查看結果時,我們會在相應的輸入上看到這些形狀的註釋

print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_l_weight: "f32[3, 5]", p_l_bias: "f32[3]", w: "f32[s0, 5]", x: "f32[s2]", y: "f32[s3, s2]", z: "f32[s2*s3]"):
             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:268 in forward, code: x0 = x + y  # [8, 4]
            add: "f32[s3, s2]" = torch.ops.aten.add.Tensor(x, y);  x = y = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:269 in forward, code: x1 = self.l(w)  # [6, 3]
            linear: "f32[s0, 3]" = torch.ops.aten.linear.default(w, p_l_weight, p_l_bias);  w = p_l_weight = p_l_bias = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:270 in forward, code: x2 = x0.flatten()  # [32]
            flatten: "f32[s2*s3]" = torch.ops.aten.flatten.using_ints(add);  add = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:271 in forward, code: x3 = x2 + z  # [32]
            add_1: "f32[s2*s3]" = torch.ops.aten.add.Tensor(flatten, z);  flatten = z = None
            return (linear, add_1)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_l_weight'), target='l.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_l_bias'), target='l.bias', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='w'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='z'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='linear'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None)])
Range constraints: {s0: VR[2, int_oo], s2: VR[2, int_oo], s3: VR[2, int_oo], s2*s3: VR[4, int_oo]}

另一個需要注意的特點是上面的 range_constraints 欄位,它包含每個符號的有效範圍。 目前這並不是那麼有趣,因為這個匯出呼叫沒有發出任何與符號邊界相關的防護,並且每個基本符號都有一個通用的邊界,但這稍後會出現。

到目前為止,由於我們一直在匯出這個玩具模型,因此這種經驗並不能代表除錯動態形狀防護和問題的典型難度。 在大多數情況下,哪些防護正在發出,以及哪些操作和使用者程式碼的部分負責,並不那麼明顯。 對於這個玩具模型,我們可以精確地找出確切的程式碼行,並且防護相當直觀。

在更複雜的情況下,一個有用的第一步始終是啟用詳細記錄。 這可以使用環境變數 TORCH_LOGS="+dynamic",或以互動方式使用 torch._logging.set_logs(dynamic=10) 來完成

torch._logging.set_logs(dynamic=10)
ep = export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
I0203 16:55:11.092000 634 torch/fx/experimental/symbolic_shapes.py:3192] [12/0] create_env
I0203 16:55:11.094000 634 torch/fx/experimental/symbolic_shapes.py:4423] [12/0] create_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0203 16:55:11.094000 634 torch/fx/experimental/symbolic_shapes.py:4423] [12/0] create_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s1" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0203 16:55:11.095000 634 torch/fx/experimental/symbolic_shapes.py:6614] [12/0] runtime_assert True == True [statically known]
I0203 16:55:11.098000 634 torch/fx/experimental/symbolic_shapes.py:4423] [12/0] create_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s2" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0203 16:55:11.100000 634 torch/fx/experimental/symbolic_shapes.py:4423] [12/0] create_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s3" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0203 16:55:11.100000 634 torch/fx/experimental/symbolic_shapes.py:4423] [12/0] create_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s4" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0203 16:55:11.103000 634 torch/fx/experimental/symbolic_shapes.py:4423] [12/0] create_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s5" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0203 16:55:11.105000 634 torch/fx/experimental/symbolic_shapes.py:6412] [12/0] eval Eq(s2, 1) == False [statically known]
V0203 16:55:11.106000 634 torch/fx/experimental/symbolic_shapes.py:6614] [12/0] runtime_assert True == True [statically known]
V0203 16:55:11.107000 634 torch/fx/experimental/symbolic_shapes.py:6412] [12/0] eval Eq(s4, 1) == False [statically known]
I0203 16:55:11.108000 634 torch/fx/experimental/symbolic_shapes.py:5963] [12/0] set_replacement s4 = s2 (solve) VR[2, int_oo]
I0203 16:55:11.109000 634 torch/fx/experimental/symbolic_shapes.py:6281] [12/0] runtime_assert Eq(s2, s4) [guard added] x0 = x + y  # [8, 4]  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:268 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2, s4)"
V0203 16:55:11.110000 634 torch/fx/experimental/symbolic_shapes.py:6412] [12/0] eval Ne(s2, 1) == True [statically known]
V0203 16:55:11.111000 634 torch/fx/experimental/symbolic_shapes.py:6412] [12/0] eval Ne(s3, 1) == True [statically known]
V0203 16:55:11.118000 634 torch/fx/experimental/symbolic_shapes.py:5802] [12/0] _update_var_to_range s1 = VR[5, 5] (update)
I0203 16:55:11.118000 634 torch/fx/experimental/symbolic_shapes.py:5963] [12/0] set_replacement s1 = 5 (range_refined_to_singleton) VR[5, 5]
I0203 16:55:11.119000 634 torch/fx/experimental/symbolic_shapes.py:6281] [12/0] runtime_assert Eq(s1, 5) [guard added] x1 = self.l(w)  # [6, 3]  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:269 in forward (_meta_registrations.py:2127 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)"
V0203 16:55:11.120000 634 torch/fx/experimental/symbolic_shapes.py:6412] [12/0] eval Eq(s0, 1) == False [statically known]
V0203 16:55:11.126000 634 torch/fx/experimental/symbolic_shapes.py:6412] [12/0] eval Eq(s2*s3, 1) == False [statically known]
V0203 16:55:11.127000 634 torch/fx/experimental/symbolic_shapes.py:6412] [12/0] eval Eq(s5, 1) == False [statically known]
V0203 16:55:11.128000 634 torch/fx/experimental/symbolic_shapes.py:5802] [12/0] _update_var_to_range s5 = VR[4, int_oo] (update)
I0203 16:55:11.129000 634 torch/fx/experimental/symbolic_shapes.py:5963] [12/0] set_replacement s5 = s2*s3 (solve) VR[4, int_oo]
I0203 16:55:11.130000 634 torch/fx/experimental/symbolic_shapes.py:6281] [12/0] runtime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z  # [32]  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:271 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)"
V0203 16:55:11.131000 634 torch/fx/experimental/symbolic_shapes.py:6412] [12/0] eval Ne(s2*s3, 1) == True [statically known]
I0203 16:55:11.138000 634 torch/fx/experimental/symbolic_shapes.py:4547] [12/0] produce_guards
V0203 16:55:11.138000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['w'].size()[0] s0 None
V0203 16:55:11.139000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['w'].size()[1] 5 None
V0203 16:55:11.139000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['w'].stride()[0] 5 None
V0203 16:55:11.139000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['w'].stride()[1] 1 None
V0203 16:55:11.140000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['w'].storage_offset() 0 None
V0203 16:55:11.140000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['x'].size()[0] s2 None
V0203 16:55:11.141000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['x'].stride()[0] 1 None
V0203 16:55:11.141000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['x'].storage_offset() 0 None
V0203 16:55:11.141000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['y'].size()[0] s3 None
V0203 16:55:11.142000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['y'].size()[1] s2 None
V0203 16:55:11.142000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['y'].stride()[0] s2 None
V0203 16:55:11.142000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['y'].stride()[1] 1 None
V0203 16:55:11.143000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['y'].storage_offset() 0 None
V0203 16:55:11.143000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['z'].size()[0] s2*s3 None
V0203 16:55:11.143000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['z'].stride()[0] 1 None
V0203 16:55:11.144000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['z'].storage_offset() 0 None
V0203 16:55:11.179000 634 torch/fx/experimental/symbolic_shapes.py:6412] eval Ne(s0, 1) == True [statically known]

即使使用這個簡單的玩具模型,這也會吐出相當多的資訊。 此處的日誌行已在前面和後面截斷,以忽略不必要的資訊,但瀏覽日誌我們可以查看與我們上面描述相關的行; 例如,符號的分配

"""
create_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
create_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
runtime_assert True == True [statically known]
create_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
create_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
create_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
create_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
"""
"\ncreate_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\ncreate_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\nruntime_assert True == True [statically known]\ncreate_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\ncreate_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\ncreate_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\ncreate_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\n"

帶有 create_symbol 的行顯示何時分配了新符號,並且日誌還識別已為其分配的張量變數名稱和維度。 在其他行中,我們也可以看到發出的防護

"""
runtime_assert Eq(s2, s4) [guard added] x0 = x + y  # output shape: [8, 4]  # dynamic_shapes_tutorial.py:16 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2, s4)"
runtime_assert Eq(s1, 5) [guard added] x1 = self.l(w)  # [6, 3]  # dynamic_shapes_tutorial.py:17 in forward (_meta_registrations.py:2127 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)"
runtime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z  # [32]  # dynamic_shapes_tutorial.py:19 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)"
"""
'\nruntime_assert Eq(s2, s4) [guard added] x0 = x + y  # output shape: [8, 4]  # dynamic_shapes_tutorial.py:16 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2, s4)"\nruntime_assert Eq(s1, 5) [guard added] x1 = self.l(w)  # [6, 3]  # dynamic_shapes_tutorial.py:17 in forward (_meta_registrations.py:2127 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)"\nruntime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z  # [32]  # dynamic_shapes_tutorial.py:19 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)"\n'

[guard added] 訊息旁邊,我們還看到了負責的使用者程式碼行 - 幸運的是,這裡的模型足夠簡單。 在許多現實世界的案例中,情況並非如此簡單:高階的 torch 操作可能具有複雜的偽核心實現或運算子分解,從而使發出防護的位置和內容變得複雜。 在這種情況下,深入挖掘和調查的最佳方法是遵循日誌的建議,並使用環境變數 TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="..." 重新執行,以進一步歸因感興趣的防護。

Dim.AUTO 只是與 dynamic_shapes 互動的可用選項之一; 在撰寫本文時,還有另外 2 個選項可用:Dim.DYNAMICDim.STATICDim.STATIC 只是將維度標記為靜態,而 Dim.DYNAMIC 在各方面都與 Dim.AUTO 相似,但有一點除外:當專業化為常數時,它會引發錯誤; 這是為了保持動態性。 例如,請參閱在動態標記的維度上發出靜態防護時會發生什麼情況

dynamic_shapes["w"] = (Dim.AUTO, Dim.DYNAMIC)
try:
    export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
except Exception:
    tb.print_exc()
I0203 16:55:11.200000 634 torch/fx/experimental/symbolic_shapes.py:3192] [13/0] create_env
I0203 16:55:11.202000 634 torch/fx/experimental/symbolic_shapes.py:4423] [13/0] create_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0203 16:55:11.203000 634 torch/fx/experimental/symbolic_shapes.py:4423] [13/0] create_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s1" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0203 16:55:11.204000 634 torch/fx/experimental/symbolic_shapes.py:6614] [13/0] runtime_assert True == True [statically known]
I0203 16:55:11.206000 634 torch/fx/experimental/symbolic_shapes.py:4423] [13/0] create_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s2" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0203 16:55:11.208000 634 torch/fx/experimental/symbolic_shapes.py:4423] [13/0] create_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s3" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0203 16:55:11.208000 634 torch/fx/experimental/symbolic_shapes.py:4423] [13/0] create_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s4" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0203 16:55:11.211000 634 torch/fx/experimental/symbolic_shapes.py:4423] [13/0] create_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s5" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0203 16:55:11.213000 634 torch/fx/experimental/symbolic_shapes.py:6412] [13/0] eval Eq(s2, 1) == False [statically known]
V0203 16:55:11.214000 634 torch/fx/experimental/symbolic_shapes.py:6614] [13/0] runtime_assert True == True [statically known]
V0203 16:55:11.214000 634 torch/fx/experimental/symbolic_shapes.py:6412] [13/0] eval Eq(s4, 1) == False [statically known]
I0203 16:55:11.216000 634 torch/fx/experimental/symbolic_shapes.py:5963] [13/0] set_replacement s4 = s2 (solve) VR[2, int_oo]
I0203 16:55:11.217000 634 torch/fx/experimental/symbolic_shapes.py:6281] [13/0] runtime_assert Eq(s2, s4) [guard added] x0 = x + y  # [8, 4]  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:268 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2, s4)"
V0203 16:55:11.218000 634 torch/fx/experimental/symbolic_shapes.py:6412] [13/0] eval Ne(s2, 1) == True [statically known]
V0203 16:55:11.219000 634 torch/fx/experimental/symbolic_shapes.py:6412] [13/0] eval Ne(s3, 1) == True [statically known]
V0203 16:55:11.226000 634 torch/fx/experimental/symbolic_shapes.py:5802] [13/0] _update_var_to_range s1 = VR[5, 5] (update)
I0203 16:55:11.226000 634 torch/fx/experimental/symbolic_shapes.py:5963] [13/0] set_replacement s1 = 5 (range_refined_to_singleton) VR[5, 5]
I0203 16:55:11.227000 634 torch/fx/experimental/symbolic_shapes.py:6281] [13/0] runtime_assert Eq(s1, 5) [guard added] x1 = self.l(w)  # [6, 3]  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:269 in forward (_meta_registrations.py:2127 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)"
V0203 16:55:11.228000 634 torch/fx/experimental/symbolic_shapes.py:6412] [13/0] eval Eq(s0, 1) == False [statically known]
V0203 16:55:11.234000 634 torch/fx/experimental/symbolic_shapes.py:6412] [13/0] eval Eq(s2*s3, 1) == False [statically known]
V0203 16:55:11.235000 634 torch/fx/experimental/symbolic_shapes.py:6412] [13/0] eval Eq(s5, 1) == False [statically known]
V0203 16:55:11.236000 634 torch/fx/experimental/symbolic_shapes.py:5802] [13/0] _update_var_to_range s5 = VR[4, int_oo] (update)
I0203 16:55:11.237000 634 torch/fx/experimental/symbolic_shapes.py:5963] [13/0] set_replacement s5 = s2*s3 (solve) VR[4, int_oo]
I0203 16:55:11.238000 634 torch/fx/experimental/symbolic_shapes.py:6281] [13/0] runtime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z  # [32]  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:271 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)"
V0203 16:55:11.239000 634 torch/fx/experimental/symbolic_shapes.py:6412] [13/0] eval Ne(s2*s3, 1) == True [statically known]
I0203 16:55:11.246000 634 torch/fx/experimental/symbolic_shapes.py:4547] [13/0] produce_guards
V0203 16:55:11.246000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['w'].size()[0] s0 None
V0203 16:55:11.246000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['w'].size()[1] 5 RelaxedUnspecConstraint(warn_only=False)
V0203 16:55:11.247000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['w'].stride()[0] 5 None
V0203 16:55:11.247000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['w'].stride()[1] 1 None
V0203 16:55:11.247000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['w'].storage_offset() 0 None
V0203 16:55:11.248000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['x'].size()[0] s2 None
V0203 16:55:11.248000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['x'].stride()[0] 1 None
V0203 16:55:11.248000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['x'].storage_offset() 0 None
V0203 16:55:11.249000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['y'].size()[0] s3 None
V0203 16:55:11.249000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['y'].size()[1] s2 None
V0203 16:55:11.249000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['y'].stride()[0] s2 None
V0203 16:55:11.250000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['y'].stride()[1] 1 None
V0203 16:55:11.250000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['y'].storage_offset() 0 None
V0203 16:55:11.250000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['z'].size()[0] s2*s3 None
V0203 16:55:11.251000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['z'].stride()[0] 1 None
V0203 16:55:11.251000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['z'].storage_offset() 0 None
E0203 16:55:11.253000 634 torch/_guards.py:295] [13/0] Error while creating guard:
E0203 16:55:11.253000 634 torch/_guards.py:295] [13/0] Name: ''
E0203 16:55:11.253000 634 torch/_guards.py:295] [13/0]     Source: shape_env
E0203 16:55:11.253000 634 torch/_guards.py:295] [13/0]     Create Function: SHAPE_ENV
E0203 16:55:11.253000 634 torch/_guards.py:295] [13/0]     Guard Types: None
E0203 16:55:11.253000 634 torch/_guards.py:295] [13/0]     Code List: None
E0203 16:55:11.253000 634 torch/_guards.py:295] [13/0]     Object Weakref: None
E0203 16:55:11.253000 634 torch/_guards.py:295] [13/0]     Guarded Class Weakref: None
E0203 16:55:11.253000 634 torch/_guards.py:295] [13/0] Traceback (most recent call last):
E0203 16:55:11.253000 634 torch/_guards.py:295] [13/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_guards.py", line 293, in create
E0203 16:55:11.253000 634 torch/_guards.py:295] [13/0]     return self.create_fn(builder, self)
E0203 16:55:11.253000 634 torch/_guards.py:295] [13/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 1868, in SHAPE_ENV
E0203 16:55:11.253000 634 torch/_guards.py:295] [13/0]     code_parts, verbose_code_parts = output_graph.shape_env.produce_guards_verbose(
E0203 16:55:11.253000 634 torch/_guards.py:295] [13/0]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5188, in produce_guards_verbose
E0203 16:55:11.253000 634 torch/_guards.py:295] [13/0]     raise ConstraintViolationError(
E0203 16:55:11.253000 634 torch/_guards.py:295] [13/0] torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['w'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
E0203 16:55:11.253000 634 torch/_guards.py:295] [13/0]   - Not all values of RelaxedUnspecConstraint(L['w'].size()[1]) are valid because L['w'].size()[1] was inferred to be a constant (5).
E0203 16:55:11.255000 634 torch/_guards.py:297] [13/0] Created at:
E0203 16:55:11.255000 634 torch/_guards.py:297] [13/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 642, in transform
E0203 16:55:11.255000 634 torch/_guards.py:297] [13/0]     tracer = InstructionTranslator(
E0203 16:55:11.255000 634 torch/_guards.py:297] [13/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2711, in __init__
E0203 16:55:11.255000 634 torch/_guards.py:297] [13/0]     output=OutputGraph(
E0203 16:55:11.255000 634 torch/_guards.py:297] [13/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 336, in __init__
E0203 16:55:11.255000 634 torch/_guards.py:297] [13/0]     self.init_ambient_guards()
E0203 16:55:11.255000 634 torch/_guards.py:297] [13/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 485, in init_ambient_guards
E0203 16:55:11.255000 634 torch/_guards.py:297] [13/0]     self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 662, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1614, in inner
    raise constraint_violation_error
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1569, in inner
    result_traced = opt_f(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
    return self._torchdynamo_orig_callable(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
    return _compile(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 95, in wrapper_function
    return function(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 852, in _compile_inner
    check_fn = CheckFunctionManager(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 2303, in __init__
    guard.create(builder)
  File "/usr/local/lib/python3.10/dist-packages/torch/_guards.py", line 293, in create
    return self.create_fn(builder, self)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 1868, in SHAPE_ENV
    code_parts, verbose_code_parts = output_graph.shape_env.produce_guards_verbose(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5188, in produce_guards_verbose
    raise ConstraintViolationError(
torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['w'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of RelaxedUnspecConstraint(L['w'].size()[1]) are valid because L['w'].size()[1] was inferred to be a constant (5).


During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 418, in <module>
    export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 368, in export
    return _export(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1970, in _export
    return _export_for_training(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1834, in _export_for_training
    export_artifact = export_func(  # type: ignore[operator]
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1283, in _strict_export_lower_to_aten_ir
    gm_torch_level = _export_to_torch_ir(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 679, in _export_to_torch_ir
    raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e))  # noqa: B904
torch._dynamo.exc.UserError: Constraints violated (L['w'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of RelaxedUnspecConstraint(L['w'].size()[1]) are valid because L['w'].size()[1] was inferred to be a constant (5).

靜態防護並不總是模型固有的; 它們也可能來自使用者規格。 實際上,導致形狀專業化的一個常見陷阱是,當使用者為等效維度指定衝突的標記時; 一個是動態的,另一個是靜態的。 當 x.shape[0]y.shape[1] 屬於這種情況時,也會引發相同的錯誤類型

dynamic_shapes["w"] = (Dim.AUTO, Dim.AUTO)
dynamic_shapes["x"] = (Dim.STATIC,)
dynamic_shapes["y"] = (Dim.AUTO, Dim.DYNAMIC)
try:
    export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
except Exception:
    tb.print_exc()
I0203 16:55:11.273000 634 torch/fx/experimental/symbolic_shapes.py:3192] [14/0] create_env
I0203 16:55:11.276000 634 torch/fx/experimental/symbolic_shapes.py:4423] [14/0] create_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0203 16:55:11.276000 634 torch/fx/experimental/symbolic_shapes.py:4423] [14/0] create_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s1" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0203 16:55:11.277000 634 torch/fx/experimental/symbolic_shapes.py:6614] [14/0] runtime_assert True == True [statically known]
I0203 16:55:11.280000 634 torch/fx/experimental/symbolic_shapes.py:4423] [14/0] create_symbol s2 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s2" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0203 16:55:11.280000 634 torch/fx/experimental/symbolic_shapes.py:4423] [14/0] create_symbol s3 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s3" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0203 16:55:11.283000 634 torch/fx/experimental/symbolic_shapes.py:4423] [14/0] create_symbol s4 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s4" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0203 16:55:11.286000 634 torch/fx/experimental/symbolic_shapes.py:6412] [14/0] eval Eq(s3, 1) == False [statically known]
V0203 16:55:11.290000 634 torch/fx/experimental/symbolic_shapes.py:5802] [14/0] _update_var_to_range s3 = VR[4, 4] (update)
I0203 16:55:11.291000 634 torch/fx/experimental/symbolic_shapes.py:5963] [14/0] set_replacement s3 = 4 (range_refined_to_singleton) VR[4, 4]
I0203 16:55:11.291000 634 torch/fx/experimental/symbolic_shapes.py:6281] [14/0] runtime_assert Eq(s3, 4) [guard added] x0 = x + y  # [8, 4]  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:268 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s3, 4)"
V0203 16:55:11.293000 634 torch/fx/experimental/symbolic_shapes.py:6412] [14/0] eval Ne(s2, 1) == True [statically known]
V0203 16:55:11.299000 634 torch/fx/experimental/symbolic_shapes.py:5802] [14/0] _update_var_to_range s1 = VR[5, 5] (update)
I0203 16:55:11.300000 634 torch/fx/experimental/symbolic_shapes.py:5963] [14/0] set_replacement s1 = 5 (range_refined_to_singleton) VR[5, 5]
I0203 16:55:11.300000 634 torch/fx/experimental/symbolic_shapes.py:6281] [14/0] runtime_assert Eq(s1, 5) [guard added] x1 = self.l(w)  # [6, 3]  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:269 in forward (_meta_registrations.py:2127 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)"
V0203 16:55:11.302000 634 torch/fx/experimental/symbolic_shapes.py:6412] [14/0] eval Eq(s0, 1) == False [statically known]
V0203 16:55:11.302000 634 torch/fx/experimental/symbolic_shapes.py:6614] [14/0] runtime_assert True == True [statically known]
V0203 16:55:11.309000 634 torch/fx/experimental/symbolic_shapes.py:6412] [14/0] eval Eq(s4, 1) == False [statically known]
V0203 16:55:11.314000 634 torch/fx/experimental/symbolic_shapes.py:5802] [14/0] _update_var_to_range s4 = VR[8, int_oo] (update)
I0203 16:55:11.317000 634 torch/fx/experimental/symbolic_shapes.py:5963] [14/0] set_replacement s4 = 4*s2 (solve) VR[8, int_oo]
I0203 16:55:11.317000 634 torch/fx/experimental/symbolic_shapes.py:6281] [14/0] runtime_assert Eq(4*s2, s4) [guard added] x3 = x2 + z  # [32]  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:271 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(4*s2, s4)"
I0203 16:55:11.324000 634 torch/fx/experimental/symbolic_shapes.py:4547] [14/0] produce_guards
V0203 16:55:11.324000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['w'].size()[0] s0 None
V0203 16:55:11.325000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['w'].size()[1] 5 None
V0203 16:55:11.325000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['w'].stride()[0] 5 None
V0203 16:55:11.325000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['w'].stride()[1] 1 None
V0203 16:55:11.326000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['w'].storage_offset() 0 None
V0203 16:55:11.326000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['x'].size()[0] 4 None
V0203 16:55:11.326000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['x'].stride()[0] 1 None
V0203 16:55:11.327000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['x'].storage_offset() 0 None
V0203 16:55:11.327000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['y'].size()[0] s2 None
V0203 16:55:11.327000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['y'].size()[1] 4 RelaxedUnspecConstraint(warn_only=False)
V0203 16:55:11.327000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['y'].stride()[0] 4 None
V0203 16:55:11.328000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['y'].stride()[1] 1 None
V0203 16:55:11.328000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['y'].storage_offset() 0 None
V0203 16:55:11.328000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['z'].size()[0] 4*s2 None
V0203 16:55:11.329000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['z'].stride()[0] 1 None
V0203 16:55:11.329000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['z'].storage_offset() 0 None
E0203 16:55:11.330000 634 torch/_guards.py:295] [14/0] Error while creating guard:
E0203 16:55:11.330000 634 torch/_guards.py:295] [14/0] Name: ''
E0203 16:55:11.330000 634 torch/_guards.py:295] [14/0]     Source: shape_env
E0203 16:55:11.330000 634 torch/_guards.py:295] [14/0]     Create Function: SHAPE_ENV
E0203 16:55:11.330000 634 torch/_guards.py:295] [14/0]     Guard Types: None
E0203 16:55:11.330000 634 torch/_guards.py:295] [14/0]     Code List: None
E0203 16:55:11.330000 634 torch/_guards.py:295] [14/0]     Object Weakref: None
E0203 16:55:11.330000 634 torch/_guards.py:295] [14/0]     Guarded Class Weakref: None
E0203 16:55:11.330000 634 torch/_guards.py:295] [14/0] Traceback (most recent call last):
E0203 16:55:11.330000 634 torch/_guards.py:295] [14/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_guards.py", line 293, in create
E0203 16:55:11.330000 634 torch/_guards.py:295] [14/0]     return self.create_fn(builder, self)
E0203 16:55:11.330000 634 torch/_guards.py:295] [14/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 1868, in SHAPE_ENV
E0203 16:55:11.330000 634 torch/_guards.py:295] [14/0]     code_parts, verbose_code_parts = output_graph.shape_env.produce_guards_verbose(
E0203 16:55:11.330000 634 torch/_guards.py:295] [14/0]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5188, in produce_guards_verbose
E0203 16:55:11.330000 634 torch/_guards.py:295] [14/0]     raise ConstraintViolationError(
E0203 16:55:11.330000 634 torch/_guards.py:295] [14/0] torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['y'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
E0203 16:55:11.330000 634 torch/_guards.py:295] [14/0]   - Not all values of RelaxedUnspecConstraint(L['y'].size()[1]) are valid because L['y'].size()[1] was inferred to be a constant (4).
E0203 16:55:11.332000 634 torch/_guards.py:297] [14/0] Created at:
E0203 16:55:11.332000 634 torch/_guards.py:297] [14/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 642, in transform
E0203 16:55:11.332000 634 torch/_guards.py:297] [14/0]     tracer = InstructionTranslator(
E0203 16:55:11.332000 634 torch/_guards.py:297] [14/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2711, in __init__
E0203 16:55:11.332000 634 torch/_guards.py:297] [14/0]     output=OutputGraph(
E0203 16:55:11.332000 634 torch/_guards.py:297] [14/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 336, in __init__
E0203 16:55:11.332000 634 torch/_guards.py:297] [14/0]     self.init_ambient_guards()
E0203 16:55:11.332000 634 torch/_guards.py:297] [14/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 485, in init_ambient_guards
E0203 16:55:11.332000 634 torch/_guards.py:297] [14/0]     self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 662, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1614, in inner
    raise constraint_violation_error
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1569, in inner
    result_traced = opt_f(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
    return self._torchdynamo_orig_callable(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
    return _compile(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 95, in wrapper_function
    return function(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 852, in _compile_inner
    check_fn = CheckFunctionManager(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 2303, in __init__
    guard.create(builder)
  File "/usr/local/lib/python3.10/dist-packages/torch/_guards.py", line 293, in create
    return self.create_fn(builder, self)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 1868, in SHAPE_ENV
    code_parts, verbose_code_parts = output_graph.shape_env.produce_guards_verbose(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5188, in produce_guards_verbose
    raise ConstraintViolationError(
torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['y'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of RelaxedUnspecConstraint(L['y'].size()[1]) are valid because L['y'].size()[1] was inferred to be a constant (4).


During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 431, in <module>
    export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 368, in export
    return _export(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1970, in _export
    return _export_for_training(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1834, in _export_for_training
    export_artifact = export_func(  # type: ignore[operator]
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1283, in _strict_export_lower_to_aten_ir
    gm_torch_level = _export_to_torch_ir(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 679, in _export_to_torch_ir
    raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e))  # noqa: B904
torch._dynamo.exc.UserError: Constraints violated (L['y'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of RelaxedUnspecConstraint(L['y'].size()[1]) are valid because L['y'].size()[1] was inferred to be a constant (4).

在這裡,您可能會問為什麼匯出會“專業化”,也就是說,為什麼我們通過採用靜態路線來解決這種靜態/動態衝突。 答案是由於上面描述的符號形狀系統,即符號和防護。 當 x.shape[0] 被標記為靜態時,我們不會分配符號,而是將此形狀視為具體的整數 4 來編譯。 為 y.shape[1] 分配了一個符號,因此我們最終發出防護 s3 == 4,從而導致專業化。

匯出的一個特點是在追蹤期間,像 asserts、torch._check()if/else 條件這樣的語句也會發出防護。 請參閱當我們使用這些語句來擴充現有模型時會發生什麼

class DynamicModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.l = torch.nn.Linear(5, 3)

    def forward(self, w, x, y, z):
        assert w.shape[0] <= 512
        torch._check(x.shape[0] >= 4)
        if w.shape[0] == x.shape[0] + 2:
            x0 = x + y
            x1 = self.l(w)
            x2 = x0.flatten()
            x3 = x2 + z
            return x1, x3
        else:
            return w

dynamic_shapes = {
    "w": (Dim.AUTO, Dim.AUTO),
    "x": (Dim.AUTO,),
    "y": (Dim.AUTO, Dim.AUTO),
    "z": (Dim.AUTO,),
}
try:
    ep = export(DynamicModel(), (w, x, y, z), dynamic_shapes=dynamic_shapes)
except Exception:
    tb.print_exc()
I0203 16:55:11.350000 634 torch/fx/experimental/symbolic_shapes.py:3192] [15/0] create_env
I0203 16:55:11.352000 634 torch/fx/experimental/symbolic_shapes.py:4423] [15/0] create_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0203 16:55:11.353000 634 torch/fx/experimental/symbolic_shapes.py:4423] [15/0] create_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s1" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0203 16:55:11.354000 634 torch/fx/experimental/symbolic_shapes.py:6614] [15/0] runtime_assert True == True [statically known]
I0203 16:55:11.356000 634 torch/fx/experimental/symbolic_shapes.py:4423] [15/0] create_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s2" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0203 16:55:11.358000 634 torch/fx/experimental/symbolic_shapes.py:4423] [15/0] create_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s3" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0203 16:55:11.358000 634 torch/fx/experimental/symbolic_shapes.py:4423] [15/0] create_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s4" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0203 16:55:11.361000 634 torch/fx/experimental/symbolic_shapes.py:4423] [15/0] create_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s5" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0203 16:55:11.367000 634 torch/fx/experimental/symbolic_shapes.py:5802] [15/0] _update_var_to_range s0 = VR[2, 512] (update)
I0203 16:55:11.368000 634 torch/fx/experimental/symbolic_shapes.py:6281] [15/0] runtime_assert s0 <= 512 [guard added] assert w.shape[0] <= 512  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:450 in forward (_dynamo/symbolic_convert.py:522 in inner), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="s0 <= 512"
V0203 16:55:11.372000 634 torch/fx/experimental/symbolic_shapes.py:5802] [15/0] _update_var_to_range s2 = VR[4, int_oo] (update)
I0203 16:55:11.373000 634 torch/fx/experimental/symbolic_shapes.py:6281] [15/0] runtime_assert s2 >= 4 [guard added] torch._check(x.shape[0] >= 4)  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:451 in forward (_dynamo/utils.py:2586 in run_node), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="s2 >= 4"
V0203 16:55:11.379000 634 torch/fx/experimental/symbolic_shapes.py:5802] [15/0] _update_var_to_range s0 = VR[6, 512] (update)
V0203 16:55:11.382000 634 torch/fx/experimental/symbolic_shapes.py:5802] [15/0] _update_var_to_range s2 = VR[4, 510] (update)
I0203 16:55:11.382000 634 torch/fx/experimental/symbolic_shapes.py:5963] [15/0] set_replacement s0 = s2 + 2 (solve) VR[6, 512]
I0203 16:55:11.383000 634 torch/fx/experimental/symbolic_shapes.py:6281] [15/0] eval Eq(s0, s2 + 2) [guard added] if w.shape[0] == x.shape[0] + 2:  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:452 in forward (_dynamo/variables/tensor.py:1201 in evaluate_expr), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s0, s2 + 2)"
V0203 16:55:11.384000 634 torch/fx/experimental/symbolic_shapes.py:6412] [15/0] eval Eq(s2, 1) == False [statically known]
V0203 16:55:11.385000 634 torch/fx/experimental/symbolic_shapes.py:6614] [15/0] runtime_assert True == True [statically known]
V0203 16:55:11.386000 634 torch/fx/experimental/symbolic_shapes.py:6412] [15/0] eval Eq(s4, 1) == False [statically known]
V0203 16:55:11.388000 634 torch/fx/experimental/symbolic_shapes.py:5802] [15/0] _update_var_to_range s4 = VR[4, 510] (update)
I0203 16:55:11.389000 634 torch/fx/experimental/symbolic_shapes.py:5963] [15/0] set_replacement s4 = s2 (solve) VR[4, 510]
I0203 16:55:11.390000 634 torch/fx/experimental/symbolic_shapes.py:6281] [15/0] runtime_assert Eq(s2, s4) [guard added] x0 = x + y  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:453 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2, s4)"
V0203 16:55:11.391000 634 torch/fx/experimental/symbolic_shapes.py:6412] [15/0] eval Ne(s2, 1) == True [statically known]
V0203 16:55:11.392000 634 torch/fx/experimental/symbolic_shapes.py:6412] [15/0] eval Ne(s3, 1) == True [statically known]
V0203 16:55:11.399000 634 torch/fx/experimental/symbolic_shapes.py:5802] [15/0] _update_var_to_range s1 = VR[5, 5] (update)
I0203 16:55:11.399000 634 torch/fx/experimental/symbolic_shapes.py:5963] [15/0] set_replacement s1 = 5 (range_refined_to_singleton) VR[5, 5]
I0203 16:55:11.400000 634 torch/fx/experimental/symbolic_shapes.py:6281] [15/0] runtime_assert Eq(s1, 5) [guard added] x1 = self.l(w)  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:454 in forward (_meta_registrations.py:2127 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)"
V0203 16:55:11.409000 634 torch/fx/experimental/symbolic_shapes.py:6412] [15/0] eval Eq(s2*s3, 1) == False [statically known]
V0203 16:55:11.410000 634 torch/fx/experimental/symbolic_shapes.py:6412] [15/0] eval Eq(s5, 1) == False [statically known]
V0203 16:55:11.419000 634 torch/fx/experimental/symbolic_shapes.py:5802] [15/0] _update_var_to_range s5 = VR[8, int_oo] (update)
I0203 16:55:11.420000 634 torch/fx/experimental/symbolic_shapes.py:5963] [15/0] set_replacement s5 = s2*s3 (solve) VR[8, int_oo]
I0203 16:55:11.421000 634 torch/fx/experimental/symbolic_shapes.py:6281] [15/0] runtime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:456 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)"
V0203 16:55:11.422000 634 torch/fx/experimental/symbolic_shapes.py:6412] [15/0] eval Ne(s2*s3, 1) == True [statically known]
V0203 16:55:11.426000 634 torch/fx/experimental/symbolic_shapes.py:6614] [15/0] runtime_assert s2 >= 4 == True [statically known]
I0203 16:55:11.432000 634 torch/fx/experimental/symbolic_shapes.py:4547] [15/0] produce_guards
V0203 16:55:11.432000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['w'].size()[0] s2 + 2 None
V0203 16:55:11.433000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['w'].size()[1] 5 None
V0203 16:55:11.433000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['w'].stride()[0] 5 None
V0203 16:55:11.433000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['w'].stride()[1] 1 None
V0203 16:55:11.434000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['w'].storage_offset() 0 None
V0203 16:55:11.434000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['x'].size()[0] s2 None
V0203 16:55:11.434000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['x'].stride()[0] 1 None
V0203 16:55:11.434000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['x'].storage_offset() 0 None
V0203 16:55:11.435000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['y'].size()[0] s3 None
V0203 16:55:11.435000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['y'].size()[1] s2 None
V0203 16:55:11.435000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['y'].stride()[0] s2 None
V0203 16:55:11.436000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['y'].stride()[1] 1 None
V0203 16:55:11.436000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['y'].storage_offset() 0 None
V0203 16:55:11.436000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['z'].size()[0] s2*s3 None
V0203 16:55:11.436000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['z'].stride()[0] 1 None
V0203 16:55:11.437000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['z'].storage_offset() 0 None

每個這些語句都會發出一個額外的防護,並且匯出的程式會顯示這些變更; s0 被消除,取而代之的是 s2 + 2,並且 s2 現在包含下限和上限,這反映在 range_constraints 中。

對於 if/else 條件,您可能會問為什麼採取了 True 分支,以及為什麼沒有發出從追蹤發出的 w.shape[0] != x.shape[0] + 2 防護。 答案是匯出是由追蹤提供的範例輸入引導的,並且專門針對所採取的程式碼分支。 如果提供了不同的範例輸入形狀,這些形狀未能滿足 if 條件,則匯出將追蹤並發出對應於 else 分支的防護。 此外,您可能會問為什麼我們僅追蹤了 if 分支,以及是否可以保持程式中的控制流程並保持兩個分支都有效。 為此,請參閱遵循上述 Control Flow Ops 區段來重寫模型程式碼。

0/1 專業化

由於我們正在討論防護和專業化,因此現在是討論我們之前提出的 0/1 專業化問題的好時機。 總而言之,匯出將專門針對值為 0 或 1 的範例輸入維度,因為這些形狀具有追蹤時屬性,這些屬性無法推廣到其他形狀。 例如,大小為 1 的張量可以廣播,而其他大小的張量則會失敗; 大小為 0 的張量...。 這僅意味著當您希望程式硬編碼這些張量時,應指定 0/1 範例輸入,而當需要動態行為時,應指定非 0/1 範例輸入。 請參閱當我們匯出此線性層時,執行階段會發生什麼情況

ep = export(
    torch.nn.Linear(4, 3),
    (torch.randn(1, 4),),
    dynamic_shapes={
        "input": (Dim.AUTO, Dim.STATIC),
    },
)
try:
    ep.module()(torch.randn(2, 4))
except Exception:
    tb.print_exc()
I0203 16:55:11.502000 634 torch/fx/experimental/symbolic_shapes.py:3192] [3/1] create_env
I0203 16:55:11.516000 634 torch/fx/experimental/symbolic_shapes.py:4547] [3/1] produce_guards
V0203 16:55:11.516000 634 torch/fx/experimental/symbolic_shapes.py:4755] [3/1] track_symint L['args'][0].size()[0] 1 None
V0203 16:55:11.517000 634 torch/fx/experimental/symbolic_shapes.py:4755] [3/1] track_symint L['args'][0].size()[1] 4 None
V0203 16:55:11.517000 634 torch/fx/experimental/symbolic_shapes.py:4755] [3/1] track_symint L['args'][0].stride()[0] 4 None
V0203 16:55:11.517000 634 torch/fx/experimental/symbolic_shapes.py:4755] [3/1] track_symint L['args'][0].stride()[1] 1 None
V0203 16:55:11.517000 634 torch/fx/experimental/symbolic_shapes.py:4755] [3/1] track_symint L['args'][0].storage_offset() 0 None
Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 500, in <module>
    ep.module()(torch.randn(2, 4))
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 822, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 400, in __call__
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 387, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1845, in _call_impl
    return inner()
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1772, in inner
    args_kwargs_result = hook(self, args, kwargs)  # type: ignore[misc]
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_unlift.py", line 49, in _check_input_constraints_pre_hook
    _check_input_constraints_for_graph(
  File "/usr/local/lib/python3.10/dist-packages/torch/_export/utils.py", line 360, in _check_input_constraints_for_graph
    raise RuntimeError(
RuntimeError: Expected input at *args[0].shape[0] to be equal to 1, but got 2

命名維度

到目前為止,我們只討論了三種指定動態形狀的方式:Dim.AUTODim.DYNAMICDim.STATIC。這些方式的優點在於低摩擦的使用者體驗;在模型追蹤期間發出的所有 guard 都會被遵守,並且在匯出過程中會自動計算出像是 min/max 範圍、關係和靜態/動態維度等動態行為。動態形狀子系統本質上充當一個「探索」過程,總結這些 guard 並呈現匯出認為的程式整體動態行為。這種設計的缺點在於,一旦使用者對這些模型的動態行為有更強烈的期望或信念時就會顯現出來 - 也許非常希望實現動態性,並且不惜一切代價避免對特定維度進行特化,或者我們只是想透過修改原始模型程式碼或底層的分解或 meta-kernel 來捕捉動態行為的變化。這些變化不會被偵測到,並且 export() 呼叫很可能成功,除非有測試來檢查產生的 ExportedProgram 表示。

對於這種情況,我們的立場是建議使用指定動態形狀的「傳統」方式,熟悉匯出的長期使用者可能對此很熟悉:具名的 Dims

dx = Dim("dx", min=4, max=256)
dh = Dim("dh", max=512)
dynamic_shapes = {
    "x": (dx, None),
    "y": (2 * dx, dh),
}

這種動態形狀的風格允許使用者指定哪些符號被分配給輸入維度、這些符號的 min/max 邊界,並對產生的 ExportedProgram 的動態行為施加限制;如果模型追蹤發出的 guard 與給定的關係或靜態/動態規範衝突,將會引發 ConstraintViolation 錯誤。例如,在上述規範中,會斷言以下內容

  • x.shape[0] 的範圍應為 [4, 256],並且透過 y.shape[0] == 2 * x.shape[0]y.shape[0] 相關。

  • x.shape[1] 是靜態的。

  • y.shape[1] 的範圍為 [2, 512],並且與任何其他維度無關。

在這個設計中,我們允許使用單變量線性表達式指定維度之間的關係:可以為任何維度指定 A * dim + B。這允許使用者為動態維度指定更複雜的約束,例如整數可除性

dx = Dim("dx", min=4, max=512)
dynamic_shapes = {
    "x": (4 * dx, None)  # x.shape[0] has range [16, 2048], and is divisible by 4.
}

約束違規,建議的修復

這種規範風格的一個常見問題(在引入 Dim.AUTO 之前)是,規範通常與模型追蹤產生的內容不匹配。這將導致 ConstraintViolation 錯誤和匯出建議的修復 - 例如,使用此模型和規範,其中模型本質上要求 xy 的維度 0 之間相等,並且要求維度 1 是靜態的。

class Foo(torch.nn.Module):
    def forward(self, x, y):
        w = x + y
        return w + torch.ones(4)

dx, dy, d1 = torch.export.dims("dx", "dy", "d1")
try:
    ep = export(
        Foo(),
        (torch.randn(6, 4), torch.randn(6, 4)),
        dynamic_shapes={
            "x": (dx, d1),
            "y": (dy, d1),
        },
    )
except Exception:
    tb.print_exc()
I0203 16:55:11.662000 634 torch/fx/experimental/symbolic_shapes.py:3192] [16/0] create_env
I0203 16:55:11.665000 634 torch/fx/experimental/symbolic_shapes.py:4423] [16/0] create_symbol s0 = 6 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0203 16:55:11.665000 634 torch/fx/experimental/symbolic_shapes.py:4423] [16/0] create_symbol s1 = 4 for L['x'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s1" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0203 16:55:11.666000 634 torch/fx/experimental/symbolic_shapes.py:6614] [16/0] runtime_assert True == True [statically known]
I0203 16:55:11.668000 634 torch/fx/experimental/symbolic_shapes.py:4423] [16/0] create_symbol s2 = 6 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s2" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0203 16:55:11.669000 634 torch/fx/experimental/symbolic_shapes.py:4423] [16/0] create_symbol s3 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s3" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0203 16:55:11.673000 634 torch/fx/experimental/symbolic_shapes.py:6412] [16/0] eval Eq(s1, 1) == False [statically known]
V0203 16:55:11.673000 634 torch/fx/experimental/symbolic_shapes.py:6614] [16/0] runtime_assert True == True [statically known]
V0203 16:55:11.674000 634 torch/fx/experimental/symbolic_shapes.py:6412] [16/0] eval Eq(s0, 1) == False [statically known]
V0203 16:55:11.675000 634 torch/fx/experimental/symbolic_shapes.py:6412] [16/0] eval Eq(s3, 1) == False [statically known]
I0203 16:55:11.677000 634 torch/fx/experimental/symbolic_shapes.py:5963] [16/0] set_replacement s3 = s1 (solve) VR[2, int_oo]
I0203 16:55:11.678000 634 torch/fx/experimental/symbolic_shapes.py:6281] [16/0] runtime_assert Eq(s1, s3) [guard added] w = x + y  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:552 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, s3)"
V0203 16:55:11.679000 634 torch/fx/experimental/symbolic_shapes.py:6412] [16/0] eval Eq(s2, 1) == False [statically known]
I0203 16:55:11.681000 634 torch/fx/experimental/symbolic_shapes.py:5963] [16/0] set_replacement s2 = s0 (solve) VR[2, int_oo]
I0203 16:55:11.681000 634 torch/fx/experimental/symbolic_shapes.py:6281] [16/0] runtime_assert Eq(s0, s2) [guard added] w = x + y  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:552 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s0, s2)"
V0203 16:55:11.683000 634 torch/fx/experimental/symbolic_shapes.py:6412] [16/0] eval Ne(s1, 1) == True [statically known]
V0203 16:55:11.683000 634 torch/fx/experimental/symbolic_shapes.py:6412] [16/0] eval Ne(s0, 1) == True [statically known]
V0203 16:55:11.690000 634 torch/fx/experimental/symbolic_shapes.py:5802] [16/0] _update_var_to_range s1 = VR[4, 4] (update)
I0203 16:55:11.691000 634 torch/fx/experimental/symbolic_shapes.py:5963] [16/0] set_replacement s1 = 4 (range_refined_to_singleton) VR[4, 4]
I0203 16:55:11.691000 634 torch/fx/experimental/symbolic_shapes.py:6281] [16/0] runtime_assert Eq(s1, 4) [guard added] return w + torch.ones(4)  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:553 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 4)"
V0203 16:55:11.694000 634 torch/fx/experimental/symbolic_shapes.py:5802] [16/0] _update_var_to_range s3 = VR[4, 4] (update)
I0203 16:55:11.695000 634 torch/fx/experimental/symbolic_shapes.py:5963] [16/0] set_replacement s3 = 4 (find) VR[4, 4]
I0203 16:55:11.698000 634 torch/fx/experimental/symbolic_shapes.py:4547] [16/0] produce_guards
V0203 16:55:11.698000 634 torch/fx/experimental/symbolic_shapes.py:4755] [16/0] track_symint L['x'].size()[0] s0 StrictMinMaxConstraint(warn_only=False, vr=VR[0, int_oo])
V0203 16:55:11.698000 634 torch/fx/experimental/symbolic_shapes.py:4755] [16/0] track_symint L['x'].size()[1] 4 StrictMinMaxConstraint(warn_only=False, vr=VR[0, int_oo])
V0203 16:55:11.699000 634 torch/fx/experimental/symbolic_shapes.py:4755] [16/0] track_symint L['x'].stride()[0] 4 None
V0203 16:55:11.699000 634 torch/fx/experimental/symbolic_shapes.py:4755] [16/0] track_symint L['x'].stride()[1] 1 None
V0203 16:55:11.699000 634 torch/fx/experimental/symbolic_shapes.py:4755] [16/0] track_symint L['x'].storage_offset() 0 None
V0203 16:55:11.699000 634 torch/fx/experimental/symbolic_shapes.py:4755] [16/0] track_symint L['y'].size()[0] s0 StrictMinMaxConstraint(warn_only=False, vr=VR[0, int_oo])
V0203 16:55:11.700000 634 torch/fx/experimental/symbolic_shapes.py:4755] [16/0] track_symint L['y'].size()[1] 4 StrictMinMaxConstraint(warn_only=False, vr=VR[0, int_oo])
V0203 16:55:11.700000 634 torch/fx/experimental/symbolic_shapes.py:4755] [16/0] track_symint L['y'].stride()[0] 4 None
V0203 16:55:11.700000 634 torch/fx/experimental/symbolic_shapes.py:4755] [16/0] track_symint L['y'].stride()[1] 1 None
V0203 16:55:11.700000 634 torch/fx/experimental/symbolic_shapes.py:4755] [16/0] track_symint L['y'].storage_offset() 0 None
E0203 16:55:11.702000 634 torch/_guards.py:295] [16/0] Error while creating guard:
E0203 16:55:11.702000 634 torch/_guards.py:295] [16/0] Name: ''
E0203 16:55:11.702000 634 torch/_guards.py:295] [16/0]     Source: shape_env
E0203 16:55:11.702000 634 torch/_guards.py:295] [16/0]     Create Function: SHAPE_ENV
E0203 16:55:11.702000 634 torch/_guards.py:295] [16/0]     Guard Types: None
E0203 16:55:11.702000 634 torch/_guards.py:295] [16/0]     Code List: None
E0203 16:55:11.702000 634 torch/_guards.py:295] [16/0]     Object Weakref: None
E0203 16:55:11.702000 634 torch/_guards.py:295] [16/0]     Guarded Class Weakref: None
E0203 16:55:11.702000 634 torch/_guards.py:295] [16/0] Traceback (most recent call last):
E0203 16:55:11.702000 634 torch/_guards.py:295] [16/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_guards.py", line 293, in create
E0203 16:55:11.702000 634 torch/_guards.py:295] [16/0]     return self.create_fn(builder, self)
E0203 16:55:11.702000 634 torch/_guards.py:295] [16/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 1868, in SHAPE_ENV
E0203 16:55:11.702000 634 torch/_guards.py:295] [16/0]     code_parts, verbose_code_parts = output_graph.shape_env.produce_guards_verbose(
E0203 16:55:11.702000 634 torch/_guards.py:295] [16/0]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5188, in produce_guards_verbose
E0203 16:55:11.702000 634 torch/_guards.py:295] [16/0]     raise ConstraintViolationError(
E0203 16:55:11.702000 634 torch/_guards.py:295] [16/0] torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (d1, dy)! For more information, run with TORCH_LOGS="+dynamic".
E0203 16:55:11.702000 634 torch/_guards.py:295] [16/0]   - Not all values of d1 = L['x'].size()[1] in the specified range are valid because d1 was inferred to be a constant (4).
E0203 16:55:11.702000 634 torch/_guards.py:295] [16/0]   - Not all values of d1 = L['y'].size()[1] in the specified range are valid because d1 was inferred to be a constant (4).
E0203 16:55:11.702000 634 torch/_guards.py:295] [16/0]   - The values of dy = L['y'].size()[0] and dx = L['x'].size()[0] must always be equal.
E0203 16:55:11.705000 634 torch/_guards.py:297] [16/0] Created at:
E0203 16:55:11.705000 634 torch/_guards.py:297] [16/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 642, in transform
E0203 16:55:11.705000 634 torch/_guards.py:297] [16/0]     tracer = InstructionTranslator(
E0203 16:55:11.705000 634 torch/_guards.py:297] [16/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2711, in __init__
E0203 16:55:11.705000 634 torch/_guards.py:297] [16/0]     output=OutputGraph(
E0203 16:55:11.705000 634 torch/_guards.py:297] [16/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 336, in __init__
E0203 16:55:11.705000 634 torch/_guards.py:297] [16/0]     self.init_ambient_guards()
E0203 16:55:11.705000 634 torch/_guards.py:297] [16/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 485, in init_ambient_guards
E0203 16:55:11.705000 634 torch/_guards.py:297] [16/0]     self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 662, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1614, in inner
    raise constraint_violation_error
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1569, in inner
    result_traced = opt_f(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
    return self._torchdynamo_orig_callable(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
    return _compile(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 95, in wrapper_function
    return function(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 852, in _compile_inner
    check_fn = CheckFunctionManager(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 2303, in __init__
    guard.create(builder)
  File "/usr/local/lib/python3.10/dist-packages/torch/_guards.py", line 293, in create
    return self.create_fn(builder, self)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 1868, in SHAPE_ENV
    code_parts, verbose_code_parts = output_graph.shape_env.produce_guards_verbose(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5188, in produce_guards_verbose
    raise ConstraintViolationError(
torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (d1, dy)! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of d1 = L['x'].size()[1] in the specified range are valid because d1 was inferred to be a constant (4).
  - Not all values of d1 = L['y'].size()[1] in the specified range are valid because d1 was inferred to be a constant (4).
  - The values of dy = L['y'].size()[0] and dx = L['x'].size()[0] must always be equal.

Suggested fixes:
  d1 = 4
  dy = dx

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 557, in <module>
    ep = export(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 368, in export
    return _export(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1970, in _export
    return _export_for_training(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1834, in _export_for_training
    export_artifact = export_func(  # type: ignore[operator]
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1283, in _strict_export_lower_to_aten_ir
    gm_torch_level = _export_to_torch_ir(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 679, in _export_to_torch_ir
    raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e))  # noqa: B904
torch._dynamo.exc.UserError: Constraints violated (d1, dy)! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of d1 = L['x'].size()[1] in the specified range are valid because d1 was inferred to be a constant (4).
  - Not all values of d1 = L['y'].size()[1] in the specified range are valid because d1 was inferred to be a constant (4).
  - The values of dy = L['y'].size()[0] and dx = L['x'].size()[0] must always be equal.

Suggested fixes:
  d1 = 4
  dy = dx

建議修復的期望是使用者可以交互式地將變更複製並貼到他們的動態形狀規範中,然後成功匯出。

最後,還有一些關於規範選項的注意事項

  • None 是靜態行為的一個好選擇: - dynamic_shapes=None(預設)以整個模型為靜態匯出。 - 在輸入層級指定 None 會以所有張量維度為靜態匯出,並且也是非張量輸入所必需的。 - 在維度層級指定 None 會特化該維度,但此做法已被棄用,建議使用 Dim.STATIC

  • 指定每個維度的整數值也會產生靜態行為,並且還會檢查提供的範例輸入是否與規範匹配。

這些選項在下面的輸入和動態形狀規範中組合在一起

inputs = (
    torch.randn(4, 4),
    torch.randn(3, 3),
    16,
    False,
)
dynamic_shapes = {
    "tensor_0": (Dim.AUTO, None),
    "tensor_1": None,
    "int_val": None,
    "bool_val": None,
}

資料相關錯誤

在嘗試匯出模型時,您可能遇到過類似「Could not guard on data-dependent expression」(無法保護資料相關的表達式)或「Could not extract specialized integer from data-dependent expression」(無法從資料相關的表達式中提取特化的整數)的錯誤。 這些錯誤的存在是因為 torch.export() 使用 FakeTensors 編譯程式,FakeTensors 符號化地表示它們真實的張量對應物。 雖然這些具有等效的符號屬性(例如大小、步幅、資料類型),但它們的不同之處在於 FakeTensors 不包含任何資料值。 雖然這避免了不必要的記憶體使用和昂貴的計算,但這也意味著匯出可能無法開箱即用地編譯使用者程式碼中依賴於資料值的部分。 簡而言之,如果編譯器需要具體的、資料相關的值才能繼續,它將會出錯,並抱怨該值不可用。

資料相關的值出現在很多地方,常見的來源是像是 item()tolist()torch.unbind() 這些從張量中提取純量值的呼叫。 這些值在匯出的程式中如何表示? 在 Constraints/Dynamic Shapes(約束/動態形狀)章節中,我們討論了分配符號來表示動態輸入維度。 這裡也發生了同樣的事情:我們為程式中出現的每個資料相關值分配符號。 重要的區別是,這些是「未備份」的符號,與為輸入維度分配的「已備份」符號形成對比。 “backed/unbacked”(已備份/未備份) 命名法指的是符號是否存在「提示」:一個支持符號的具體值,可以告知編譯器如何繼續。

在輸入形狀符號的情況下(已備份的符號),這些提示只是提供的範例輸入形狀,這解釋了為什麼控制流程分支是由範例輸入屬性決定的。 對於資料相關的值,這些符號在追蹤期間取自 FakeTensor「資料」,因此編譯器不知道這些符號將會採用的實際值(提示)。

讓我們看看這些如何在匯出的程式中顯示出來

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        b = y.tolist()
        return b + [a]

inps = (
    torch.tensor(1),
    torch.tensor([2, 3]),
)
ep = export(Foo(), inps)
print(ep)
I0203 16:55:11.720000 634 torch/fx/experimental/symbolic_shapes.py:3192] [17/0] create_env
I0203 16:55:11.724000 634 torch/fx/experimental/symbolic_shapes.py:4103] [17/0] create_unbacked_symint u0 [-int_oo, int_oo] a = x.item()  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:618 in forward (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0203 16:55:11.724000 634 torch/fx/experimental/symbolic_shapes.py:970] [17/0] compute_unbacked_bindings [u0]
I0203 16:55:11.727000 634 torch/fx/experimental/symbolic_shapes.py:4103] [17/0] create_unbacked_symint u1 [-int_oo, int_oo] b = y.tolist()  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:619 in forward (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0203 16:55:11.727000 634 torch/fx/experimental/symbolic_shapes.py:970] [17/0] compute_unbacked_bindings [u1]
I0203 16:55:11.729000 634 torch/fx/experimental/symbolic_shapes.py:4103] [17/0] create_unbacked_symint u2 [-int_oo, int_oo] b = y.tolist()  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:619 in forward (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0203 16:55:11.729000 634 torch/fx/experimental/symbolic_shapes.py:970] [17/0] compute_unbacked_bindings [u2]
I0203 16:55:11.733000 634 torch/fx/experimental/symbolic_shapes.py:4547] [17/0] produce_guards
V0203 16:55:11.734000 634 torch/fx/experimental/symbolic_shapes.py:4755] [17/0] track_symint L['x'].storage_offset() 0 None
V0203 16:55:11.734000 634 torch/fx/experimental/symbolic_shapes.py:4755] [17/0] track_symint L['y'].size()[0] 2 None
V0203 16:55:11.734000 634 torch/fx/experimental/symbolic_shapes.py:4755] [17/0] track_symint L['y'].stride()[0] 1 None
V0203 16:55:11.734000 634 torch/fx/experimental/symbolic_shapes.py:4755] [17/0] track_symint L['y'].storage_offset() 0 None
I0203 16:55:11.741000 634 torch/fx/experimental/symbolic_shapes.py:4103] create_unbacked_symint u3 [-int_oo, int_oo] (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0203 16:55:11.742000 634 torch/fx/experimental/symbolic_shapes.py:4103] create_unbacked_symint u4 [-int_oo, int_oo] (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0203 16:55:11.750000 634 torch/fx/experimental/symbolic_shapes.py:4103] create_unbacked_symint u5 [-int_oo, int_oo] (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0203 16:55:11.750000 634 torch/fx/experimental/symbolic_shapes.py:970] compute_unbacked_bindings [u5]
I0203 16:55:11.750000 634 torch/fx/experimental/symbolic_shapes.py:5963] set_replacement u5 = u0 (rename_unbacked_to) VR[-int_oo, int_oo]
I0203 16:55:11.752000 634 torch/fx/experimental/symbolic_shapes.py:4103] create_unbacked_symint u6 [-int_oo, int_oo] (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0203 16:55:11.753000 634 torch/fx/experimental/symbolic_shapes.py:970] compute_unbacked_bindings [u6]
I0203 16:55:11.753000 634 torch/fx/experimental/symbolic_shapes.py:5963] set_replacement u6 = u1 (rename_unbacked_to) VR[-int_oo, int_oo]
I0203 16:55:11.755000 634 torch/fx/experimental/symbolic_shapes.py:4103] create_unbacked_symint u7 [-int_oo, int_oo] (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0203 16:55:11.755000 634 torch/fx/experimental/symbolic_shapes.py:970] compute_unbacked_bindings [u7]
I0203 16:55:11.755000 634 torch/fx/experimental/symbolic_shapes.py:5963] set_replacement u7 = u2 (rename_unbacked_to) VR[-int_oo, int_oo]
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "i64[]", y: "i64[2]"):
             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:618 in forward, code: a = x.item()
            item: "Sym(u0)" = torch.ops.aten.item.default(x);  x = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:619 in forward, code: b = y.tolist()
            select: "i64[]" = torch.ops.aten.select.int(y, 0, 0)
            item_1: "Sym(u1)" = torch.ops.aten.item.default(select);  select = None
            select_1: "i64[]" = torch.ops.aten.select.int(y, 0, 1);  y = None
            item_2: "Sym(u2)" = torch.ops.aten.item.default(select_1);  select_1 = None
            return (item_1, item_2, item)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=SymIntArgument(name='item_1'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=SymIntArgument(name='item_2'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=SymIntArgument(name='item'), target=None)])
Range constraints: {u0: VR[-int_oo, int_oo], u1: VR[-int_oo, int_oo], u2: VR[-int_oo, int_oo], u3: VR[-int_oo, int_oo], u4: VR[-int_oo, int_oo], u5: VR[-int_oo, int_oo], u6: VR[-int_oo, int_oo], u7: VR[-int_oo, int_oo]}

結果是分配並返回了 3 個未備份的符號(請注意,它們的前綴是 "u",而不是輸入形狀/備份符號常用的 "s"):1 個用於 item() 呼叫,1 個用於 y 的每個元素與 tolist() 呼叫。請注意範圍限制欄位,這些符號的範圍是 [-int_oo, int_oo],而不是分配給輸入形狀符號的預設範圍 [0, int_oo],因為我們沒有關於這些值的任何資訊 - 它們不代表大小,因此不一定具有正值。

Guards, torch._check()

但上面的情況很容易匯出,因為這些符號的具體值未在任何編譯器決策中使用;重要的是回傳值是未備份的符號。本節中強調的依賴資料的錯誤是以下情況,其中遇到依賴資料的 guards

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        if a // 2 >= 5:
            return y + 2
        else:
            return y * 5

在這裡,我們實際上需要 "hint",或者 a 的具體值,以便編譯器決定是否追蹤 return y + 2return y * 5 作為輸出。因為我們使用 FakeTensors 進行追蹤,所以我們不知道 a // 2 >= 5 實際評估的結果,並且匯出錯誤並顯示 "Could not guard on data-dependent expression u0 // 2 >= 5 (unhinted)"。

那麼我們如何匯出這個玩具模型?與 torch.compile() 不同,匯出需要完整的圖編譯,而我們不能僅僅在此處進行圖分割。以下是一些基本選項

  1. 手動特例化:我們可以透過選擇要追蹤的分支來介入,可以透過刪除控制流程程式碼以僅包含特例化分支,或使用 torch.compiler.is_compiling() 來控制在編譯時追蹤的內容。

  2. torch.cond():我們可以重寫控制流程程式碼以使用 torch.cond(),這樣我們就不會在分支上進行特例化。

雖然這些選項有效,但它們有其缺點。選項 1 有時需要大幅、侵入性地重寫模型程式碼以進行特例化,並且 torch.cond() 不是處理依賴資料錯誤的全面系統。正如我們將看到的,存在不涉及控制流程的依賴資料錯誤。

通常建議的方法是從 torch._check() 呼叫開始。雖然這些呼叫給人一種純粹是 assert 陳述的印象,但它們實際上是一個告知編譯器符號屬性的系統。雖然 torch._check() 呼叫在執行時充當 assertion,但在編譯時追蹤時,檢查的表達式會被傳送到符號形狀子系統進行推理,並且從表達式為真推斷出的任何符號屬性都會儲存為符號屬性(前提是它夠聰明可以推斷出這些屬性)。因此,即使未備份的符號沒有 hint,如果我們能夠透過 torch._check() 呼叫傳達通常適用於這些符號的屬性,我們也可以在不重寫有問題的模型程式碼的情況下繞過依賴資料的 guards。

例如,在上面的模型中,插入 torch._check(a >= 10) 會告訴編譯器始終可以回傳 y + 2,而 torch._check(a == 4) 會告訴它回傳 y * 5。看看我們重新匯出這個模型時會發生什麼。

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        torch._check(a >= 10)
        torch._check(a <= 60)
        if a // 2 >= 5:
            return y + 2
        else:
            return y * 5

inps = (
    torch.tensor(32),
    torch.randn(4),
)
ep = export(Foo(), inps)
print(ep)
I0203 16:55:11.765000 634 torch/fx/experimental/symbolic_shapes.py:3192] [18/0] create_env
I0203 16:55:11.769000 634 torch/fx/experimental/symbolic_shapes.py:4103] [18/0] create_unbacked_symint u0 [-int_oo, int_oo] a = x.item()  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:672 in forward (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0203 16:55:11.769000 634 torch/fx/experimental/symbolic_shapes.py:970] [18/0] compute_unbacked_bindings [u0]
V0203 16:55:11.772000 634 torch/fx/experimental/symbolic_shapes.py:5802] [18/0] _update_var_to_range u0 = VR[10, int_oo] (update)
I0203 16:55:11.773000 634 torch/fx/experimental/symbolic_shapes.py:6281] [18/0] runtime_assert u0 >= 10 [guard added] torch._check(a >= 10)  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:673 in forward (_dynamo/utils.py:2586 in run_node), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 >= 10"
V0203 16:55:11.777000 634 torch/fx/experimental/symbolic_shapes.py:5802] [18/0] _update_var_to_range u0 = VR[10, 60] (update)
I0203 16:55:11.779000 634 torch/fx/experimental/symbolic_shapes.py:6281] [18/0] runtime_assert u0 <= 60 [guard added] torch._check(a <= 60)  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:674 in forward (_dynamo/utils.py:2586 in run_node), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 <= 60"
V0203 16:55:11.783000 634 torch/fx/experimental/symbolic_shapes.py:6412] [18/0] eval ((u0//2)) >= 5 == True [statically known]
V0203 16:55:11.786000 634 torch/fx/experimental/symbolic_shapes.py:6614] [18/0] runtime_assert u0 >= 10 == True [statically known]
V0203 16:55:11.787000 634 torch/fx/experimental/symbolic_shapes.py:6614] [18/0] runtime_assert u0 <= 60 == True [statically known]
I0203 16:55:11.790000 634 torch/fx/experimental/symbolic_shapes.py:4547] [18/0] produce_guards
V0203 16:55:11.790000 634 torch/fx/experimental/symbolic_shapes.py:4755] [18/0] track_symint L['x'].storage_offset() 0 None
V0203 16:55:11.791000 634 torch/fx/experimental/symbolic_shapes.py:4755] [18/0] track_symint L['y'].size()[0] 4 None
V0203 16:55:11.791000 634 torch/fx/experimental/symbolic_shapes.py:4755] [18/0] track_symint L['y'].stride()[0] 1 None
V0203 16:55:11.791000 634 torch/fx/experimental/symbolic_shapes.py:4755] [18/0] track_symint L['y'].storage_offset() 0 None
I0203 16:55:11.806000 634 torch/fx/experimental/symbolic_shapes.py:4103] create_unbacked_symint u1 [-int_oo, int_oo] (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0203 16:55:11.807000 634 torch/fx/experimental/symbolic_shapes.py:970] compute_unbacked_bindings [u1]
V0203 16:55:11.807000 634 torch/fx/experimental/symbolic_shapes.py:5802] _update_var_to_range u1 = VR[10, 60] (update)
I0203 16:55:11.807000 634 torch/fx/experimental/symbolic_shapes.py:5963] set_replacement u1 = u0 (rename_unbacked_to) VR[10, 60]
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "i64[]", y: "f32[4]"):
             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:672 in forward, code: a = x.item()
            item: "Sym(u0)" = torch.ops.aten.item.default(x);  x = None
            ge_1: "Sym(u0 >= 10)" = item >= 10
            _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 10 on node 'ge_1'");  ge_1 = _assert_scalar_default = None
            le_1: "Sym(u0 <= 60)" = item <= 60;  item = None
            _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u0 <= 60 on node 'le_1'");  le_1 = _assert_scalar_default_1 = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:676 in forward, code: return y + 2
            add: "f32[4]" = torch.ops.aten.add.Tensor(y, 2);  y = None
            return (add,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {u0: VR[10, 60], u1: VR[10, 60]}

匯出成功,並從範圍限制欄位中注意到 u0 的範圍是 [10, 60]

那麼 torch._check() 呼叫實際上傳達了什麼資訊?這會隨著符號形狀子系統變得更聰明而變化,但在根本層面上,這些通常都是真的

  1. 與非依賴資料的表達式相等:傳達等式的 torch._check() 呼叫,例如 u0 == s0 + 4u0 == 5

  2. 範圍細化:提供符號的下限或上限的呼叫,如上所述。

  3. 圍繞更複雜的表達式的一些基本推理:插入 torch._check(a < 4) 通常會告訴編譯器 a >= 4 為 false。檢查像 torch._check(a ** 2 - 3 * a <= 10) 這樣的複雜表達式通常會讓你通過相同的 guards。

如前所述,torch._check() 呼叫在依賴資料的控制流程之外具有適用性。例如,這是一個模型,其中插入 torch._check() 占上風,而手動特例化 & torch.cond() 則不然

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        return y[a]

inps = (
    torch.tensor(32),
    torch.randn(60),
)
try:
    export(Foo(), inps)
except Exception:
    tb.print_exc()
I0203 16:55:11.821000 634 torch/fx/experimental/symbolic_shapes.py:3192] [19/0] create_env
I0203 16:55:11.825000 634 torch/fx/experimental/symbolic_shapes.py:4103] [19/0] create_unbacked_symint u0 [-int_oo, int_oo] a = x.item()  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:701 in forward (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0203 16:55:11.826000 634 torch/fx/experimental/symbolic_shapes.py:970] [19/0] compute_unbacked_bindings [u0]
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] Data dependent variable 'u0' allocated at:
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/bin/sphinx-build", line 8, in <module>
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     sys.exit(main())
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 288, in main
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return make_main(argv)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 193, in make_main
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return make_mode.run_make_mode(argv[1:])
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/make_mode.py", line 160, in run_make_mode
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return make.run_generic_build(args[0])
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/make_mode.py", line 148, in run_generic_build
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return build_main(args + opts)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 272, in build_main
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     app = Sphinx(args.sourcedir, args.confdir, args.outputdir,
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/application.py", line 256, in __init__
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     self._init_builder()
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/application.py", line 314, in _init_builder
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     self.events.emit('builder-inited')
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/events.py", line 94, in emit
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     results.append(listener.handler(self.app, *args))
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_gallery.py", line 491, in generate_gallery_rst
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     ) = generate_dir_rst(
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 431, in generate_dir_rst
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     intro, title, cost = generate_file_rst(
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 1027, in generate_file_rst
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     output_blocks, time_elapsed = execute_script(script_blocks,
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 945, in execute_script
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     output_blocks.append(execute_code_block(
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 810, in execute_code_block
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     is_last_expr, mem_max = _exec_and_get_memory(
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 676, in _exec_and_get_memory
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     mem_max, _ = gallery_conf['call_memory'](
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_gallery.py", line 223, in call_memory
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return 0., func()
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 600, in __call__
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     exec(self.code, self.fake_main.__dict__)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 709, in <module>
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     export(Foo(), inps)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 368, in export
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return _export(
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     ep = fn(*args, **kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return fn(*args, **kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1970, in _export
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return _export_for_training(
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     ep = fn(*args, **kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return fn(*args, **kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1834, in _export_for_training
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     export_artifact = export_func(  # type: ignore[operator]
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1283, in _strict_export_lower_to_aten_ir
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     gm_torch_level = _export_to_torch_ir(
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 662, in _export_to_torch_ir
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     gm_torch_level, _ = torch._dynamo.export(
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1569, in inner
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     result_traced = opt_f(*args, **kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return self._call_impl(*args, **kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return forward_call(*args, **kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return fn(*args, **kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return self._call_impl(*args, **kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return forward_call(*args, **kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return self._torchdynamo_orig_callable(
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return _compile(
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     guarded_code = compile_inner(code, one_graph, hooks, transform)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return _compile_inner(code, one_graph, hooks, transform)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 95, in wrapper_function
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return function(*args, **kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     out_code = transform_code_object(code, transform)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     transformations(instructions, code_options)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 231, in _fn
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return fn(*args, **kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 662, in transform
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     tracer.run()
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2868, in run
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     super().run()
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     while self.step():
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     self.dispatch_table[inst.opcode](self, inst)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return inner_fn(self, inst)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1658, in CALL_FUNCTION
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     self.call_function(fn, args, {})
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/misc.py", line 1022, in call_function
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return self.obj.call_method(tx, self.name, args, kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/tensor.py", line 591, in call_method
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return wrap_fx_proxy(
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py", line 2153, in wrap_fx_proxy
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py", line 2219, in wrap_fx_proxy_cls
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return _wrap_fx_proxy(
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py", line 2315, in _wrap_fx_proxy
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 2471, in get_fake_value
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     ret_val = wrap_fake_exception(
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 2017, in wrap_fake_exception
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return fn()
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 2472, in <lambda>
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     lambda: run_node(tx.output, node, args, kwargs, nnmodule)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 2588, in run_node
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return getattr(args[0], node.target)(*args[1:], **kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_stats.py", line 21, in wrapper
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return fn(*args, **kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1276, in __torch_dispatch__
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return self.dispatch(func, types, args, kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1816, in dispatch
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return self._cached_dispatch_impl(func, types, args, kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1386, in _cached_dispatch_impl
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     output = self._dispatch_impl(func, types, args, kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 2354, in _dispatch_impl
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     op_impl_out = op_impl(self, func, *args, **kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_impls.py", line 160, in dispatch_to_op_implementations_dict
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return op_implementations_dict[func](fake_mode, func, *args, **kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_impls.py", line 403, in local_scalar_dense
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     r = fake_mode.shape_env.create_unbacked_symint()
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/recording.py", line 263, in wrapper
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return retlog(fn(*args, **kwargs))
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]
W0203 16:55:11.836000 634 torch/fx/experimental/symbolic_shapes.py:6307] [19/0] failed during evaluate_expr(-u0 > 60, hint=None, size_oblivious=True, forcing_spec=False
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0] failed while running evaluate_expr(*(-u0 > 60, None), **{'fx_node': False, 'size_oblivious': True})
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0] Traceback (most recent call last):
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/recording.py", line 263, in wrapper
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0]     return retlog(fn(*args, **kwargs))
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6303, in evaluate_expr
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0]     return self._evaluate_expr(
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6493, in _evaluate_expr
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0]     raise self._make_data_dependent_error(
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0] torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression -u0 > 60 (unhinted: -u0 > 60).  (Size-like symbols: none)
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0]
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0] Caused by: return y[a]  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:702 in forward (_meta_registrations.py:4874 in meta_select)
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0] For more information, run with TORCH_LOGS="dynamic"
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0] For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0] If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0] For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0]
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0] User Stack (most recent call last):
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0]   (snipped, see stack below for prefix)
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0]   File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 702, in forward
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0]     return y[a]
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0]
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] failed while attempting to run meta for aten.select.int
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] Traceback (most recent call last):
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 2384, in _dispatch_impl
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]     r = func(*args, **kwargs)
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 723, in __call__
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]     return self._op(*args, **kwargs)
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_meta_registrations.py", line 4874, in meta_select
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]     guard_size_oblivious(-index > size) or guard_size_oblivious(index >= size)
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 407, in guard_size_oblivious
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]     return expr.node.guard_size_oblivious("", 0)
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/sym_node.py", line 564, in guard_size_oblivious
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]     r = self.shape_env.evaluate_expr(
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/recording.py", line 263, in wrapper
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]     return retlog(fn(*args, **kwargs))
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6303, in evaluate_expr
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]     return self._evaluate_expr(
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6493, in _evaluate_expr
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]     raise self._make_data_dependent_error(
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression -u0 > 60 (unhinted: -u0 > 60).  (Size-like symbols: none)
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] Caused by: return y[a]  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:702 in forward (_meta_registrations.py:4874 in meta_select)
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] For more information, run with TORCH_LOGS="dynamic"
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] User Stack (most recent call last):
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]   (snipped, see stack below for prefix)
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]   File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 702, in forward
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]     return y[a]
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 2586, in run_node
    return node.target(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1276, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1816, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1377, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 2384, in _dispatch_impl
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 723, in __call__
    return self._op(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_meta_registrations.py", line 4874, in meta_select
    guard_size_oblivious(-index > size) or guard_size_oblivious(index >= size)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 407, in guard_size_oblivious
    return expr.node.guard_size_oblivious("", 0)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/sym_node.py", line 564, in guard_size_oblivious
    r = self.shape_env.evaluate_expr(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/recording.py", line 263, in wrapper
    return retlog(fn(*args, **kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6303, in evaluate_expr
    return self._evaluate_expr(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6493, in _evaluate_expr
    raise self._make_data_dependent_error(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression -u0 > 60 (unhinted: -u0 > 60).  (Size-like symbols: none)

Caused by: return y[a]  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:702 in forward (_meta_registrations.py:4874 in meta_select)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

User Stack (most recent call last):
  (snipped, see stack below for prefix)
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 702, in forward
    return y[a]

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 2471, in get_fake_value
    ret_val = wrap_fake_exception(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 2017, in wrap_fake_exception
    return fn()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 2472, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 2604, in run_node
    raise RuntimeError(make_error_message(e)).with_traceback(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 2586, in run_node
    return node.target(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1276, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1816, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1377, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 2384, in _dispatch_impl
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 723, in __call__
    return self._op(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_meta_registrations.py", line 4874, in meta_select
    guard_size_oblivious(-index > size) or guard_size_oblivious(index >= size)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 407, in guard_size_oblivious
    return expr.node.guard_size_oblivious("", 0)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/sym_node.py", line 564, in guard_size_oblivious
    r = self.shape_env.evaluate_expr(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/recording.py", line 263, in wrapper
    return retlog(fn(*args, **kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6303, in evaluate_expr
    return self._evaluate_expr(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6493, in _evaluate_expr
    raise self._make_data_dependent_error(
RuntimeError: Failed running call_function <built-in method select of type object at 0x7f3973a1fec0>(*(FakeTensor(..., size=(60,)), 0, u0), **{}):
Could not guard on data-dependent expression -u0 > 60 (unhinted: -u0 > 60).  (Size-like symbols: none)

Caused by: return y[a]  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:702 in forward (_meta_registrations.py:4874 in meta_select)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

User Stack (most recent call last):
  (snipped, see stack below for prefix)
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 702, in forward
    return y[a]

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 709, in <module>
    export(Foo(), inps)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 368, in export
    return _export(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1970, in _export
    return _export_for_training(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1834, in _export_for_training
    export_artifact = export_func(  # type: ignore[operator]
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1283, in _strict_export_lower_to_aten_ir
    gm_torch_level = _export_to_torch_ir(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 662, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1569, in inner
    result_traced = opt_f(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
    return self._torchdynamo_orig_callable(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
    return _compile(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 95, in wrapper_function
    return function(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
    out_code = transform_code_object(code, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 231, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 662, in transform
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2868, in run
    super().run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 314, in impl
    self.push(fn_var.call_function(self, self.popn(nargs), {}))
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 1004, in call_function
    return handler(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 980, in _handle_insert_op_in_graph
    return wrap_fx_proxy(tx, proxy)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py", line 2153, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py", line 2219, in wrap_fx_proxy_cls
    return _wrap_fx_proxy(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py", line 2315, in _wrap_fx_proxy
    example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 2526, in get_fake_value
    raise UserError(  # noqa: B904
torch._dynamo.exc.UserError: Could not guard on data-dependent expression -u0 > 60 (unhinted: -u0 > 60).  (Size-like symbols: none)

Caused by: return y[a]  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:702 in forward (_meta_registrations.py:4874 in meta_select)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

User Stack (most recent call last):
  (snipped, see stack below for prefix)
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 702, in forward
    return y[a]

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more information about this error, see: https://pytorch.dev.org.tw/docs/main/generated/exportdb/index.html#constrain-as-size-example

from user code:
   File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 702, in forward
    return y[a]

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

這裡有個情境,只需要插入 torch._check() 就能避免操作失敗。匯出呼叫會失敗,並顯示 "Could not guard on data-dependent expression -u0 > 60",這表示編譯器不知道這是否為有效的索引操作 - 如果 x 的值對 y 來說是否超出範圍。在這裡,手動特化 (manual specialization) 的成本太高,而且 torch.cond() 也沒有用武之地。相反的,告知編譯器 u0 的範圍就足夠了。

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        torch._check(a >= 0)
        torch._check(a < y.shape[0])
        return y[a]

inps = (
    torch.tensor(32),
    torch.randn(60),
)
ep = export(Foo(), inps)
print(ep)
I0203 16:55:11.859000 634 torch/fx/experimental/symbolic_shapes.py:3192] [20/0] create_env
I0203 16:55:11.863000 634 torch/fx/experimental/symbolic_shapes.py:4103] [20/0] create_unbacked_symint u0 [-int_oo, int_oo] a = x.item()  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:721 in forward (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0203 16:55:11.864000 634 torch/fx/experimental/symbolic_shapes.py:970] [20/0] compute_unbacked_bindings [u0]
V0203 16:55:11.866000 634 torch/fx/experimental/symbolic_shapes.py:5802] [20/0] _update_var_to_range u0 = VR[0, int_oo] (update)
I0203 16:55:11.866000 634 torch/fx/experimental/symbolic_shapes.py:6281] [20/0] runtime_assert u0 >= 0 [guard added] torch._check(a >= 0)  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:722 in forward (_dynamo/utils.py:2586 in run_node), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 >= 0"
V0203 16:55:11.870000 634 torch/fx/experimental/symbolic_shapes.py:5802] [20/0] _update_var_to_range u0 = VR[0, 59] (update)
I0203 16:55:11.871000 634 torch/fx/experimental/symbolic_shapes.py:6281] [20/0] runtime_assert u0 < 60 [guard added] torch._check(a < y.shape[0])  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:723 in forward (_dynamo/utils.py:2586 in run_node), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 < 60"
V0203 16:55:11.873000 634 torch/fx/experimental/symbolic_shapes.py:6412] [20/0] eval -u0 > 60 == False [statically known]
V0203 16:55:11.873000 634 torch/fx/experimental/symbolic_shapes.py:6412] [20/0] eval u0 >= 60 == False [statically known]
V0203 16:55:11.874000 634 torch/fx/experimental/symbolic_shapes.py:6412] [20/0] eval u0 >= 0 == True [statically known]
V0203 16:55:11.877000 634 torch/fx/experimental/symbolic_shapes.py:6614] [20/0] runtime_assert u0 >= 0 == True [statically known]
V0203 16:55:11.878000 634 torch/fx/experimental/symbolic_shapes.py:6614] [20/0] runtime_assert u0 <= 59 == True [statically known]
V0203 16:55:11.879000 634 torch/fx/experimental/symbolic_shapes.py:6614] [20/0] runtime_assert u0 < 60 == True [statically known]
I0203 16:55:11.882000 634 torch/fx/experimental/symbolic_shapes.py:4547] [20/0] produce_guards
V0203 16:55:11.883000 634 torch/fx/experimental/symbolic_shapes.py:4755] [20/0] track_symint L['x'].storage_offset() 0 None
V0203 16:55:11.883000 634 torch/fx/experimental/symbolic_shapes.py:4755] [20/0] track_symint L['y'].size()[0] 60 None
V0203 16:55:11.883000 634 torch/fx/experimental/symbolic_shapes.py:4755] [20/0] track_symint L['y'].stride()[0] 1 None
V0203 16:55:11.883000 634 torch/fx/experimental/symbolic_shapes.py:4755] [20/0] track_symint L['y'].storage_offset() 0 None
I0203 16:55:11.901000 634 torch/fx/experimental/symbolic_shapes.py:4103] create_unbacked_symint u1 [-int_oo, int_oo] (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0203 16:55:11.902000 634 torch/fx/experimental/symbolic_shapes.py:970] compute_unbacked_bindings [u1]
V0203 16:55:11.902000 634 torch/fx/experimental/symbolic_shapes.py:5802] _update_var_to_range u1 = VR[0, 59] (update)
I0203 16:55:11.903000 634 torch/fx/experimental/symbolic_shapes.py:5963] set_replacement u1 = u0 (rename_unbacked_to) VR[0, 59]
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "i64[]", y: "f32[60]"):
             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:721 in forward, code: a = x.item()
            item: "Sym(u0)" = torch.ops.aten.item.default(x);  x = None
            ge_1: "Sym(u0 >= 0)" = item >= 0
            _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'");  ge_1 = _assert_scalar_default = None
            le_1: "Sym(u0 <= 59)" = item <= 59
            _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u0 <= 59 on node 'le_1'");  le_1 = _assert_scalar_default_1 = None

             #
            lt_1: "Sym(u0 < 60)" = item < 60
            _assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(lt_1, "Runtime assertion failed for expression u0 < 60 on node 'lt_1'");  lt_1 = _assert_scalar_default_2 = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:724 in forward, code: return y[a]
            select: "f32[]" = torch.ops.aten.select.int(y, 0, item);  y = item = None
            return (select,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='select'), target=None)])
Range constraints: {u0: VR[0, 59], u1: VR[0, 59]}

特化的值 (Specialized values)

另一類型的資料相依錯誤發生在程式嘗試在追蹤時提取具體的資料相依整數/浮點數值時。這看起來像 "Could not extract specialized integer from data-dependent expression",並且與前一類型的錯誤類似 - 如果在嘗試評估具體整數/浮點數值時發生這些錯誤,則在評估具體布林值時會出現資料相依防護錯誤 (data-dependent guard errors)。

此錯誤通常發生在資料相依的表達式上存在顯式或隱式的 int() 轉換時。 例如,這個列表推導式 (list comprehension) 有一個 range() 呼叫,它隱式地對列表的大小進行 int() 轉換

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        b = torch.cat([y for y in range(a)], dim=0)
        return b + int(a)

inps = (
    torch.tensor(32),
    torch.randn(60),
)
try:
    export(Foo(), inps, strict=False)
except Exception:
    tb.print_exc()
I0203 16:55:11.920000 634 torch/fx/experimental/symbolic_shapes.py:3192] create_env
I0203 16:55:11.926000 634 torch/fx/experimental/symbolic_shapes.py:4103] create_unbacked_symint u0 [-int_oo, int_oo] (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0203 16:55:11.927000 634 torch/fx/experimental/symbolic_shapes.py:970] compute_unbacked_bindings [u0]
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] Data dependent variable 'u0' allocated at:
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/bin/sphinx-build", line 8, in <module>
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     sys.exit(main())
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 288, in main
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return make_main(argv)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 193, in make_main
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return make_mode.run_make_mode(argv[1:])
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/make_mode.py", line 160, in run_make_mode
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return make.run_generic_build(args[0])
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/make_mode.py", line 148, in run_generic_build
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return build_main(args + opts)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 272, in build_main
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     app = Sphinx(args.sourcedir, args.confdir, args.outputdir,
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/sphinx/application.py", line 256, in __init__
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     self._init_builder()
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/sphinx/application.py", line 314, in _init_builder
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     self.events.emit('builder-inited')
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/sphinx/events.py", line 94, in emit
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     results.append(listener.handler(self.app, *args))
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_gallery.py", line 491, in generate_gallery_rst
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     ) = generate_dir_rst(
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 431, in generate_dir_rst
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     intro, title, cost = generate_file_rst(
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 1027, in generate_file_rst
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     output_blocks, time_elapsed = execute_script(script_blocks,
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 945, in execute_script
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     output_blocks.append(execute_code_block(
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 810, in execute_code_block
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     is_last_expr, mem_max = _exec_and_get_memory(
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 676, in _exec_and_get_memory
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     mem_max, _ = gallery_conf['call_memory'](
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_gallery.py", line 223, in call_memory
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return 0., func()
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 600, in __call__
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     exec(self.code, self.fake_main.__dict__)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 756, in <module>
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     export(Foo(), inps, strict=False)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 368, in export
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return _export(
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     ep = fn(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return fn(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1970, in _export
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return _export_for_training(
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     ep = fn(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return fn(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1834, in _export_for_training
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     export_artifact = export_func(  # type: ignore[operator]
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1772, in _non_strict_export
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     aten_export_artifact = _to_aten_func(  # type: ignore[operator]
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1564, in _export_to_aten_ir_make_fx
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     gm, graph_signature = transform(_make_fx_helper)(
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1702, in _aot_export_non_strict
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1485, in _make_fx_helper
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     gm = make_fx(
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2196, in wrapped
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return make_fx_tracer.trace(f, *args)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2134, in trace
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return self._trace_inner(f, *args)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2105, in _trace_inner
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     t = dispatch_trace(
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/_compile.py", line 32, in inner
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return disable_fn(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return fn(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1138, in dispatch_trace
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1694, in trace
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     res = super().trace(root, concrete_args)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return fn(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 843, in trace
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     (self.create_arg(fn(*args)),),
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1193, in wrapped
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     out = f(*tensors)  # type:ignore[call-arg]
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "<string>", line 1, in <lambda>
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1469, in wrapped_fn
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return tuple(flat_fn(*args))
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     tree_out = fn(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 879, in functional_call
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     out = mod(*args[params_len:], **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 821, in module_call_wrapper
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return self.call_module(mod, forward, args, kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1764, in call_module
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return Tracer.call_module(self, m, forward, args, kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 539, in call_module
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     ret_val = forward(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 814, in forward
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return _orig_module_call(mod, *args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return self._call_impl(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return forward_call(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1689, in forward
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     tree_out = mod(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 821, in module_call_wrapper
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return self.call_module(mod, forward, args, kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1764, in call_module
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return Tracer.call_module(self, m, forward, args, kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 539, in call_module
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     ret_val = forward(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 814, in forward
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return _orig_module_call(mod, *args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return self._call_impl(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return forward_call(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 747, in forward
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     a = x.item()
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1241, in __torch_function__
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return func(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1288, in __torch_function__
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return func(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/_export/non_strict_utils.py", line 557, in __torch_function__
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return func(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 840, in handler
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return torch._library.utils.handle_dispatch_mode(
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/_library/utils.py", line 295, in handle_dispatch_mode
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_stats.py", line 21, in wrapper
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return fn(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1343, in __torch_dispatch__
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return proxy_call(self, func, self.pre_dispatch, args, kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 912, in proxy_call
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     out = func(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 723, in __call__
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return self._op(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_stats.py", line 21, in wrapper
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return fn(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1276, in __torch_dispatch__
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return self.dispatch(func, types, args, kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1816, in dispatch
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return self._cached_dispatch_impl(func, types, args, kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1386, in _cached_dispatch_impl
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     output = self._dispatch_impl(func, types, args, kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 2354, in _dispatch_impl
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     op_impl_out = op_impl(self, func, *args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_impls.py", line 160, in dispatch_to_op_implementations_dict
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return op_implementations_dict[func](fake_mode, func, *args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_impls.py", line 403, in local_scalar_dense
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     r = fake_mode.shape_env.create_unbacked_symint()
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/recording.py", line 263, in wrapper
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return retlog(fn(*args, **kwargs))
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]
W0203 16:55:11.936000 634 torch/fx/experimental/symbolic_shapes.py:6307] failed during evaluate_expr(u0, hint=None, size_oblivious=False, forcing_spec=False
E0203 16:55:11.936000 634 torch/fx/experimental/recording.py:299] failed while running evaluate_expr(*(u0, None), **{'fx_node': False})
E0203 16:55:11.936000 634 torch/fx/experimental/recording.py:299] Traceback (most recent call last):
E0203 16:55:11.936000 634 torch/fx/experimental/recording.py:299]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/recording.py", line 263, in wrapper
E0203 16:55:11.936000 634 torch/fx/experimental/recording.py:299]     return retlog(fn(*args, **kwargs))
E0203 16:55:11.936000 634 torch/fx/experimental/recording.py:299]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6303, in evaluate_expr
E0203 16:55:11.936000 634 torch/fx/experimental/recording.py:299]     return self._evaluate_expr(
E0203 16:55:11.936000 634 torch/fx/experimental/recording.py:299]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6493, in _evaluate_expr
E0203 16:55:11.936000 634 torch/fx/experimental/recording.py:299]     raise self._make_data_dependent_error(
E0203 16:55:11.936000 634 torch/fx/experimental/recording.py:299] torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not extract specialized integer from data-dependent expression u0 (unhinted: u0).  (Size-like symbols: none)
E0203 16:55:11.936000 634 torch/fx/experimental/recording.py:299]
E0203 16:55:11.936000 634 torch/fx/experimental/recording.py:299] Caused by: (ar/lib/workspace/intermediate_source/torch_export_tutorial.py:748 in forward)
E0203 16:55:11.936000 634 torch/fx/experimental/recording.py:299] For more information, run with TORCH_LOGS="dynamic"
E0203 16:55:11.936000 634 torch/fx/experimental/recording.py:299] For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
E0203 16:55:11.936000 634 torch/fx/experimental/recording.py:299] If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
E0203 16:55:11.936000 634 torch/fx/experimental/recording.py:299] For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing
E0203 16:55:11.936000 634 torch/fx/experimental/recording.py:299]
E0203 16:55:11.936000 634 torch/fx/experimental/recording.py:299] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 756, in <module>
    export(Foo(), inps, strict=False)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 368, in export
    return _export(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1970, in _export
    return _export_for_training(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1834, in _export_for_training
    export_artifact = export_func(  # type: ignore[operator]
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1772, in _non_strict_export
    aten_export_artifact = _to_aten_func(  # type: ignore[operator]
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1564, in _export_to_aten_ir_make_fx
    gm, graph_signature = transform(_make_fx_helper)(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1702, in _aot_export_non_strict
    gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1485, in _make_fx_helper
    gm = make_fx(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2196, in wrapped
    return make_fx_tracer.trace(f, *args)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2134, in trace
    return self._trace_inner(f, *args)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2105, in _trace_inner
    t = dispatch_trace(
  File "/usr/local/lib/python3.10/dist-packages/torch/_compile.py", line 32, in inner
    return disable_fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1138, in dispatch_trace
    graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1694, in trace
    res = super().trace(root, concrete_args)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 843, in trace
    (self.create_arg(fn(*args)),),
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1193, in wrapped
    out = f(*tensors)  # type:ignore[call-arg]
  File "<string>", line 1, in <lambda>
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1469, in wrapped_fn
    return tuple(flat_fn(*args))
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn
    tree_out = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 879, in functional_call
    out = mod(*args[params_len:], **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 821, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1764, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 539, in call_module
    ret_val = forward(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 814, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1689, in forward
    tree_out = mod(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 821, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1764, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 539, in call_module
    ret_val = forward(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 814, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 748, in forward
    b = torch.cat([y for y in range(a)], dim=0)
  File "/usr/local/lib/python3.10/dist-packages/torch/__init__.py", line 427, in __index__
    return self.node.int_()
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/sym_node.py", line 445, in int_
    return self.guard_int("", 0)  # NB: uses Python backtrace
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/sym_node.py", line 492, in guard_int
    r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/recording.py", line 263, in wrapper
    return retlog(fn(*args, **kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6303, in evaluate_expr
    return self._evaluate_expr(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6493, in _evaluate_expr
    raise self._make_data_dependent_error(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not extract specialized integer from data-dependent expression u0 (unhinted: u0).  (Size-like symbols: none)

Caused by: (ar/lib/workspace/intermediate_source/torch_export_tutorial.py:748 in forward)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

對於這些錯誤,您可以選擇的一些基本選項是

  1. 避免不必要的 int() 呼叫,在本例中指的是 return 語句中的 int(a)

  2. 使用 torch._check() 呼叫; 不幸的是,在這種情況下,您可能只能進行特化 (使用 torch._check(a == 60))。

  3. 在更高的層級重新編寫有問題的程式碼。 例如,列表推導式在語義上是一個 repeat() 操作,它不涉及 int() 轉換。 以下重寫避免了資料相依的錯誤

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        b = y.unsqueeze(0).repeat(a, 1)
        return b + a

inps = (
    torch.tensor(32),
    torch.randn(60),
)
ep = export(Foo(), inps, strict=False)
print(ep)
I0203 16:55:11.946000 634 torch/fx/experimental/symbolic_shapes.py:3192] create_env
I0203 16:55:11.951000 634 torch/fx/experimental/symbolic_shapes.py:4103] create_unbacked_symint u0 [-int_oo, int_oo] (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0203 16:55:11.952000 634 torch/fx/experimental/symbolic_shapes.py:970] compute_unbacked_bindings [u0]
V0203 16:55:11.956000 634 torch/fx/experimental/symbolic_shapes.py:5802] _update_var_to_range u0 = VR[0, int_oo] (update)
I0203 16:55:11.957000 634 torch/fx/experimental/symbolic_shapes.py:6281] runtime_assert u0 >= 0 [guard added] (_refs/__init__.py:4800 in new_empty), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 >= 0"
V0203 16:55:11.959000 634 torch/fx/experimental/symbolic_shapes.py:6412] eval Eq(u0, 0) == False [statically known]
V0203 16:55:11.962000 634 torch/fx/experimental/symbolic_shapes.py:6412] eval Eq(u0, 1) == False [statically known]
V0203 16:55:11.962000 634 torch/fx/experimental/symbolic_shapes.py:6614] runtime_assert True == True [statically known]
I0203 16:55:11.966000 634 torch/fx/experimental/symbolic_shapes.py:4547] produce_guards
V0203 16:55:11.966000 634 torch/fx/experimental/symbolic_shapes.py:4755] track_symint L['args'][0][0].storage_offset() 0 None
V0203 16:55:11.967000 634 torch/fx/experimental/symbolic_shapes.py:4755] track_symint L['args'][0][1].size()[0] 60 None
V0203 16:55:11.967000 634 torch/fx/experimental/symbolic_shapes.py:4755] track_symint L['args'][0][1].stride()[0] 1 None
V0203 16:55:11.967000 634 torch/fx/experimental/symbolic_shapes.py:4755] track_symint L['args'][0][1].storage_offset() 0 None
V0203 16:55:11.969000 634 torch/fx/experimental/symbolic_shapes.py:6614] runtime_assert u0 >= 0 == True [statically known]
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "i64[]", y: "f32[60]"):
             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:769 in forward, code: a = x.item()
            item: "Sym(u0)" = torch.ops.aten.item.default(x);  x = None

             #
            sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(item);  sym_constrain_range_for_size_default = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:769 in forward, code: a = x.item()
            ge: "Sym(u0 >= 0)" = item >= 0
            _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'");  ge = _assert_scalar_default = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:770 in forward, code: b = y.unsqueeze(0).repeat(a, 1)
            unsqueeze: "f32[1, 60]" = torch.ops.aten.unsqueeze.default(y, 0);  y = None
            repeat: "f32[u0, 60]" = torch.ops.aten.repeat.default(unsqueeze, [item, 1]);  unsqueeze = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:771 in forward, code: return b + a
            add: "f32[u0, 60]" = torch.ops.aten.add.Tensor(repeat, item);  repeat = item = None
            return (add,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {u0: VR[0, int_oo]}

資料相依的錯誤可能會更加複雜,並且您的工具箱中有更多選項可以處理它們: torch._check_is_size()guard_size_oblivious() 或實數張量追蹤 (real-tensor tracing),做為起點。 如需更深入的指南,請參閱 匯出程式設計模型 (Export Programming Model)處理 GuardOnDataDependentSymNode 錯誤 (Dealing with GuardOnDataDependentSymNode errors)

自定義運算子 (Custom Ops)

torch.export 可以匯出帶有自定義運算子的 PyTorch 程式。 請參考 此頁面,了解如何使用 C++ 或 Python 編寫自定義運算子。

以下是在 Python 中註冊一個自定義運算子以供 torch.export 使用的範例。 重要的是要注意,自定義運算子必須具有 FakeTensor 核心

@torch.library.custom_op("my_custom_library::custom_op", mutates_args={})
def custom_op(x: torch.Tensor) -> torch.Tensor:
    print("custom_op called!")
    return torch.relu(x)

@custom_op.register_fake
def custom_op_meta(x):
    # Returns an empty tensor with the same shape as the expected output
    return torch.empty_like(x)

這是一個匯出帶有自定義運算子的程式的範例。

class CustomOpExample(torch.nn.Module):
    def forward(self, x):
        x = torch.sin(x)
        x = torch.ops.my_custom_library.custom_op(x)
        x = torch.cos(x)
        return x

exported_custom_op_example = export(CustomOpExample(), (torch.randn(3, 3),))
print(exported_custom_op_example)
print(exported_custom_op_example.module()(torch.randn(3, 3)))
I0203 16:55:11.985000 634 torch/fx/experimental/symbolic_shapes.py:3192] [21/0] create_env
I0203 16:55:11.995000 634 torch/fx/experimental/symbolic_shapes.py:4547] [21/0] produce_guards
V0203 16:55:11.996000 634 torch/fx/experimental/symbolic_shapes.py:4755] [21/0] track_symint L['x'].size()[0] 3 None
V0203 16:55:11.996000 634 torch/fx/experimental/symbolic_shapes.py:4755] [21/0] track_symint L['x'].size()[1] 3 None
V0203 16:55:11.996000 634 torch/fx/experimental/symbolic_shapes.py:4755] [21/0] track_symint L['x'].stride()[0] 3 None
V0203 16:55:11.997000 634 torch/fx/experimental/symbolic_shapes.py:4755] [21/0] track_symint L['x'].stride()[1] 1 None
V0203 16:55:11.997000 634 torch/fx/experimental/symbolic_shapes.py:4755] [21/0] track_symint L['x'].storage_offset() 0 None
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 3]"):
             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:812 in forward, code: x = torch.sin(x)
            sin: "f32[3, 3]" = torch.ops.aten.sin.default(x);  x = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:813 in forward, code: x = torch.ops.my_custom_library.custom_op(x)
            custom_op: "f32[3, 3]" = torch.ops.my_custom_library.custom_op.default(sin);  sin = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:814 in forward, code: x = torch.cos(x)
            cos: "f32[3, 3]" = torch.ops.aten.cos.default(custom_op);  custom_op = None
            return (cos,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cos'), target=None)])
Range constraints: {}

custom_op called!
tensor([[0.5499, 0.6889, 0.7180],
        [0.5413, 1.0000, 1.0000],
        [0.8332, 0.5524, 1.0000]])

請注意,在 ExportedProgram 中,自定義運算子包含在圖中。

IR/分解 (Decompositions)

torch.export 產生的圖包含僅 ATen 運算子的圖,這些運算子是 PyTorch 中計算的基本單元。由於有超過 3000 個 ATen 運算子,匯出提供了一種基於某些特徵來縮小圖中使用的運算子集合的方法,從而建立不同的 IR。

預設情況下,匯出會產生最通用的 IR,其中包含所有 ATen 運算子,包括函數式 (functional) 和非函數式 (non-functional) 運算子。函數式運算子是不包含輸入的任何修改或別名 (aliasing) 的運算子。您可以在 這裡找到所有 ATen 運算子的列表,並且可以透過檢查 op._schema.is_mutable 來檢查運算子是否為函數式,例如

print(torch.ops.aten.add.Tensor._schema.is_mutable)
print(torch.ops.aten.add_.Tensor._schema.is_mutable)
False
True

這個通用的 IR 可以用於在 eager PyTorch Autograd 中進行訓練。可以透過 API torch.export.export_for_training 更明確地使用這個 IR,該 API 在 PyTorch 2.5 中引入,但是從 PyTorch 2.6 開始,呼叫 torch.export.export 應該會產生相同的圖。

class DecompExample(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.Conv2d(1, 3, 1, 1)
        self.bn = torch.nn.BatchNorm2d(3)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return (x,)

ep_for_training = torch.export.export_for_training(DecompExample(), (torch.randn(1, 1, 3, 3),))
print(ep_for_training.graph)
I0203 16:55:12.023000 634 torch/fx/experimental/symbolic_shapes.py:3192] [22/0] create_env
I0203 16:55:12.054000 634 torch/fx/experimental/symbolic_shapes.py:4547] [22/0] produce_guards
V0203 16:55:12.054000 634 torch/fx/experimental/symbolic_shapes.py:4755] [22/0] track_symint L['x'].size()[0] 1 None
V0203 16:55:12.054000 634 torch/fx/experimental/symbolic_shapes.py:4755] [22/0] track_symint L['x'].size()[1] 1 None
V0203 16:55:12.055000 634 torch/fx/experimental/symbolic_shapes.py:4755] [22/0] track_symint L['x'].size()[2] 3 None
V0203 16:55:12.055000 634 torch/fx/experimental/symbolic_shapes.py:4755] [22/0] track_symint L['x'].size()[3] 3 None
V0203 16:55:12.055000 634 torch/fx/experimental/symbolic_shapes.py:4755] [22/0] track_symint L['x'].stride()[0] 9 None
V0203 16:55:12.056000 634 torch/fx/experimental/symbolic_shapes.py:4755] [22/0] track_symint L['x'].stride()[1] 9 None
V0203 16:55:12.056000 634 torch/fx/experimental/symbolic_shapes.py:4755] [22/0] track_symint L['x'].stride()[2] 3 None
V0203 16:55:12.056000 634 torch/fx/experimental/symbolic_shapes.py:4755] [22/0] track_symint L['x'].stride()[3] 1 None
V0203 16:55:12.056000 634 torch/fx/experimental/symbolic_shapes.py:4755] [22/0] track_symint L['x'].storage_offset() 0 None
graph():
    %p_conv_weight : [num_users=1] = placeholder[target=p_conv_weight]
    %p_conv_bias : [num_users=1] = placeholder[target=p_conv_bias]
    %p_bn_weight : [num_users=1] = placeholder[target=p_bn_weight]
    %p_bn_bias : [num_users=1] = placeholder[target=p_bn_bias]
    %b_bn_running_mean : [num_users=1] = placeholder[target=b_bn_running_mean]
    %b_bn_running_var : [num_users=1] = placeholder[target=b_bn_running_var]
    %b_bn_num_batches_tracked : [num_users=1] = placeholder[target=b_bn_num_batches_tracked]
    %x : [num_users=1] = placeholder[target=x]
    %conv2d : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%x, %p_conv_weight, %p_conv_bias), kwargs = {})
    %add_ : [num_users=0] = call_function[target=torch.ops.aten.add_.Tensor](args = (%b_bn_num_batches_tracked, 1), kwargs = {})
    %batch_norm : [num_users=1] = call_function[target=torch.ops.aten.batch_norm.default](args = (%conv2d, %p_bn_weight, %p_bn_bias, %b_bn_running_mean, %b_bn_running_var, True, 0.1, 1e-05, True), kwargs = {})
    return (batch_norm,)

然後,我們可以透過 API run_decompositions 將這個匯出的程式降低到僅包含函數式 ATen 運算子的運算子集合,該 API 將 ATen 運算子分解為分解表中指定的運算子,並將圖函數化 (functionalizes)。透過指定一個空集合,我們只執行函數化,而不執行任何額外的分解。這會產生一個包含約 2000 個運算子(而不是上面的 3000 個運算子)的 IR,並且非常適合推論 (inference) 的情況。

graph():
    %p_conv_weight : [num_users=1] = placeholder[target=p_conv_weight]
    %p_conv_bias : [num_users=1] = placeholder[target=p_conv_bias]
    %p_bn_weight : [num_users=1] = placeholder[target=p_bn_weight]
    %p_bn_bias : [num_users=1] = placeholder[target=p_bn_bias]
    %b_bn_running_mean : [num_users=1] = placeholder[target=b_bn_running_mean]
    %b_bn_running_var : [num_users=1] = placeholder[target=b_bn_running_var]
    %b_bn_num_batches_tracked : [num_users=1] = placeholder[target=b_bn_num_batches_tracked]
    %x : [num_users=1] = placeholder[target=x]
    %conv2d : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%x, %p_conv_weight, %p_conv_bias), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%b_bn_num_batches_tracked, 1), kwargs = {})
    %_native_batch_norm_legit_functional : [num_users=3] = call_function[target=torch.ops.aten._native_batch_norm_legit_functional.default](args = (%conv2d, %p_bn_weight, %p_bn_bias, %b_bn_running_mean, %b_bn_running_var, True, 0.1, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 0), kwargs = {})
    %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 3), kwargs = {})
    %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 4), kwargs = {})
    return (getitem_3, getitem_4, add, getitem)

正如我們所看到的,先前可變的運算子 torch.ops.aten.add_.default 現在已被 torch.ops.aten.add.default 取代,這是一個函數式運算子。

我們還可以將這個匯出的程式進一步降低到僅包含 核心 ATen 運算子集合 (Core ATen Operator Set) 的運算子集合,這是僅包含約 180 個運算子的集合。 這種 IR 對於不想重新實作所有 ATen 運算子的後端來說是最理想的。

from torch.export import default_decompositions

core_aten_decomp_table = default_decompositions()
core_aten_ep = ep_for_training.run_decompositions(decomp_table=core_aten_decomp_table)
print(core_aten_ep.graph)
graph():
    %p_conv_weight : [num_users=1] = placeholder[target=p_conv_weight]
    %p_conv_bias : [num_users=1] = placeholder[target=p_conv_bias]
    %p_bn_weight : [num_users=1] = placeholder[target=p_bn_weight]
    %p_bn_bias : [num_users=1] = placeholder[target=p_bn_bias]
    %b_bn_running_mean : [num_users=1] = placeholder[target=b_bn_running_mean]
    %b_bn_running_var : [num_users=1] = placeholder[target=b_bn_running_var]
    %b_bn_num_batches_tracked : [num_users=1] = placeholder[target=b_bn_num_batches_tracked]
    %x : [num_users=1] = placeholder[target=x]
    %convolution : [num_users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%x, %p_conv_weight, %p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%b_bn_num_batches_tracked, 1), kwargs = {})
    %_native_batch_norm_legit_functional : [num_users=3] = call_function[target=torch.ops.aten._native_batch_norm_legit_functional.default](args = (%convolution, %p_bn_weight, %p_bn_bias, %b_bn_running_mean, %b_bn_running_var, True, 0.1, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 0), kwargs = {})
    %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 3), kwargs = {})
    %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 4), kwargs = {})
    return (getitem_3, getitem_4, add, getitem)

我們現在看到 torch.ops.aten.conv2d.default 已被分解為 torch.ops.aten.convolution.default。 這是因為 convolution 是一個更「核心」的運算子,因為像 conv1dconv2d 這樣的操作可以使用相同的運算子來實現。

我們也可以指定自己的分解行為

my_decomp_table = torch.export.default_decompositions()

def my_awesome_custom_conv2d_function(x, weight, bias, stride=[1, 1], padding=[0, 0], dilation=[1, 1], groups=1):
    return 2 * torch.ops.aten.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0], groups)

my_decomp_table[torch.ops.aten.conv2d.default] = my_awesome_custom_conv2d_function
my_ep = ep_for_training.run_decompositions(my_decomp_table)
print(my_ep.graph)
graph():
    %p_conv_weight : [num_users=1] = placeholder[target=p_conv_weight]
    %p_conv_bias : [num_users=1] = placeholder[target=p_conv_bias]
    %p_bn_weight : [num_users=1] = placeholder[target=p_bn_weight]
    %p_bn_bias : [num_users=1] = placeholder[target=p_bn_bias]
    %b_bn_running_mean : [num_users=1] = placeholder[target=b_bn_running_mean]
    %b_bn_running_var : [num_users=1] = placeholder[target=b_bn_running_var]
    %b_bn_num_batches_tracked : [num_users=1] = placeholder[target=b_bn_num_batches_tracked]
    %x : [num_users=1] = placeholder[target=x]
    %convolution : [num_users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%x, %p_conv_weight, %p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convolution, 2), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%b_bn_num_batches_tracked, 1), kwargs = {})
    %_native_batch_norm_legit_functional : [num_users=3] = call_function[target=torch.ops.aten._native_batch_norm_legit_functional.default](args = (%mul, %p_bn_weight, %p_bn_bias, %b_bn_running_mean, %b_bn_running_var, True, 0.1, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 0), kwargs = {})
    %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 3), kwargs = {})
    %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 4), kwargs = {})
    return (getitem_3, getitem_4, add, getitem)

請注意,torch.ops.aten.conv2d.default 沒有被分解成 torch.ops.aten.convolution.default,而是被分解成 torch.ops.aten.convolution.defaulttorch.ops.aten.mul.Tensor,這符合我們的自定義分解規則。

ExportDB

torch.export 將永遠只從 PyTorch 程式匯出單個計算圖。 由於此要求,將會有與 torch.export 不相容的 Python 或 PyTorch 功能,這將需要使用者重新編寫其模型程式碼的部分內容。 我們在本教學課程中已經看到了這方面的範例 - 例如,使用 cond 重寫 if 語句。

ExportDB 是一個標準參考文檔,記錄了 torch.export 支援和不支援的 Python/PyTorch 功能。它本質上是一個程式碼範例列表,每個範例代表一個特定 Python/PyTorch 功能的使用,以及它與 torch.export 的互動。範例也按類別標記,以便更容易搜尋。

例如,讓我們使用 ExportDB 來更好地理解 cond 運算符中 predicate 的工作方式。我們可以查看名為 cond_predicate 的範例,該範例具有 torch.cond 標籤。範例程式碼如下:

def cond_predicate(x):
    """
    The conditional statement (aka predicate) passed to ``cond()`` must be one of the following:
    - ``torch.Tensor`` with a single element
    - boolean expression
    NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
    """
    pred = x.dim() > 2 and x.shape[2] > 10
    return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x])

更廣泛地說,當發生以下情況之一時,ExportDB 可以用作參考:

  1. 在嘗試 torch.export 之前,您預先知道您的模型使用了一些棘手的 Python/PyTorch 功能,並且您想知道 torch.export 是否涵蓋了該功能。

  2. 當嘗試 torch.export 時,發生錯誤且不清楚如何解決。

ExportDB 並非詳盡無遺,但旨在涵蓋典型 PyTorch 程式碼中發現的所有使用案例。如果存在應該添加到 ExportDB 或由 torch.export 支援的重要 Python/PyTorch 功能,請隨時與我們聯繫。

執行導出的程式

由於 torch.export 僅是一種圖形捕獲機制,因此主動呼叫由 torch.export 產生的成品,將等同於執行主動模組。 為了優化導出程式的執行,我們可以將此導出成品傳遞給後端,例如 Inductor (透過 torch.compile, AOTInductor) 或 TensorRT

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(3, 3)

    def forward(self, x):
        x = self.linear(x)
        return x

inp = torch.randn(2, 3, device="cuda")
m = M().to(device="cuda")
ep = torch.export.export(m, (inp,))

# Run it eagerly
res = ep.module()(inp)
print(res)

# Run it with torch.compile
res = torch.compile(ep.module(), backend="inductor")(inp)
print(res)
I0203 16:55:12.677000 634 torch/fx/experimental/symbolic_shapes.py:3192] [23/0] create_env
I0203 16:55:12.692000 634 torch/fx/experimental/symbolic_shapes.py:4547] [23/0] produce_guards
V0203 16:55:12.692000 634 torch/fx/experimental/symbolic_shapes.py:4755] [23/0] track_symint L['x'].size()[0] 2 None
V0203 16:55:12.692000 634 torch/fx/experimental/symbolic_shapes.py:4755] [23/0] track_symint L['x'].size()[1] 3 None
V0203 16:55:12.693000 634 torch/fx/experimental/symbolic_shapes.py:4755] [23/0] track_symint L['x'].stride()[0] 3 None
V0203 16:55:12.693000 634 torch/fx/experimental/symbolic_shapes.py:4755] [23/0] track_symint L['x'].stride()[1] 1 None
V0203 16:55:12.693000 634 torch/fx/experimental/symbolic_shapes.py:4755] [23/0] track_symint L['x'].storage_offset() 0 None
tensor([[ 0.4830, -0.5149,  0.3888],
        [-0.9247,  0.8408, -0.2184]], device='cuda:0',
       grad_fn=<AddmmBackward0>)
I0203 16:55:12.720000 634 torch/fx/experimental/symbolic_shapes.py:3192] [24/0] create_env
I0203 16:55:13.317000 634 torch/fx/experimental/symbolic_shapes.py:4547] [24/0] produce_guards
I0203 16:55:13.341000 634 torch/fx/experimental/symbolic_shapes.py:4547] [24/0] produce_guards
V0203 16:55:13.342000 634 torch/fx/experimental/symbolic_shapes.py:4755] [24/0] track_symint L['x'].size()[0] 2 None
V0203 16:55:13.342000 634 torch/fx/experimental/symbolic_shapes.py:4755] [24/0] track_symint L['x'].size()[1] 3 None
V0203 16:55:13.342000 634 torch/fx/experimental/symbolic_shapes.py:4755] [24/0] track_symint L['x'].stride()[0] 3 None
V0203 16:55:13.343000 634 torch/fx/experimental/symbolic_shapes.py:4755] [24/0] track_symint L['x'].stride()[1] 1 None
V0203 16:55:13.343000 634 torch/fx/experimental/symbolic_shapes.py:4755] [24/0] track_symint L['x'].storage_offset() 0 None
V0203 16:55:13.343000 634 torch/fx/experimental/symbolic_shapes.py:4755] [24/0] track_symint L['self']._modules['linear']._parameters['weight'].size()[0] 3 None
V0203 16:55:13.343000 634 torch/fx/experimental/symbolic_shapes.py:4755] [24/0] track_symint L['self']._modules['linear']._parameters['weight'].size()[1] 3 None
V0203 16:55:13.344000 634 torch/fx/experimental/symbolic_shapes.py:4755] [24/0] track_symint L['self']._modules['linear']._parameters['weight'].stride()[0] 3 None
V0203 16:55:13.344000 634 torch/fx/experimental/symbolic_shapes.py:4755] [24/0] track_symint L['self']._modules['linear']._parameters['weight'].stride()[1] 1 None
V0203 16:55:13.344000 634 torch/fx/experimental/symbolic_shapes.py:4755] [24/0] track_symint L['self']._modules['linear']._parameters['weight'].storage_offset() 0 None
V0203 16:55:13.345000 634 torch/fx/experimental/symbolic_shapes.py:4755] [24/0] track_symint L['self']._modules['linear']._parameters['bias'].size()[0] 3 None
V0203 16:55:13.345000 634 torch/fx/experimental/symbolic_shapes.py:4755] [24/0] track_symint L['self']._modules['linear']._parameters['bias'].stride()[0] 1 None
V0203 16:55:13.345000 634 torch/fx/experimental/symbolic_shapes.py:4755] [24/0] track_symint L['self']._modules['linear']._parameters['bias'].storage_offset() 0 None
V0203 16:55:13.346000 634 torch/fx/experimental/symbolic_shapes.py:4958] [24/0] Skipping guard L['x'].size()[0] == 2
V0203 16:55:13.346000 634 torch/fx/experimental/symbolic_shapes.py:4958] [24/0] Skipping guard L['x'].size()[1] == 3
V0203 16:55:13.346000 634 torch/fx/experimental/symbolic_shapes.py:4958] [24/0] Skipping guard L['x'].stride()[0] == 3
V0203 16:55:13.347000 634 torch/fx/experimental/symbolic_shapes.py:4958] [24/0] Skipping guard L['x'].stride()[1] == 1
V0203 16:55:13.347000 634 torch/fx/experimental/symbolic_shapes.py:4958] [24/0] Skipping guard L['x'].storage_offset() == 0
V0203 16:55:13.347000 634 torch/fx/experimental/symbolic_shapes.py:4958] [24/0] Skipping guard L['self']._modules['linear']._parameters['weight'].size()[0] == 3
V0203 16:55:13.347000 634 torch/fx/experimental/symbolic_shapes.py:4958] [24/0] Skipping guard L['self']._modules['linear']._parameters['weight'].size()[1] == 3
V0203 16:55:13.348000 634 torch/fx/experimental/symbolic_shapes.py:4958] [24/0] Skipping guard L['self']._modules['linear']._parameters['weight'].stride()[0] == 3
V0203 16:55:13.348000 634 torch/fx/experimental/symbolic_shapes.py:4958] [24/0] Skipping guard L['self']._modules['linear']._parameters['weight'].stride()[1] == 1
V0203 16:55:13.348000 634 torch/fx/experimental/symbolic_shapes.py:4958] [24/0] Skipping guard L['self']._modules['linear']._parameters['weight'].storage_offset() == 0
V0203 16:55:13.349000 634 torch/fx/experimental/symbolic_shapes.py:4958] [24/0] Skipping guard L['self']._modules['linear']._parameters['bias'].size()[0] == 3
V0203 16:55:13.349000 634 torch/fx/experimental/symbolic_shapes.py:4958] [24/0] Skipping guard L['self']._modules['linear']._parameters['bias'].stride()[0] == 1
V0203 16:55:13.349000 634 torch/fx/experimental/symbolic_shapes.py:4958] [24/0] Skipping guard L['self']._modules['linear']._parameters['bias'].storage_offset() == 0
tensor([[ 0.4830, -0.5149,  0.3888],
        [-0.9247,  0.8408, -0.2184]], device='cuda:0',
       grad_fn=<CompiledFunctionBackward>)
import torch._inductor

# Note: these APIs are subject to change
# Compile the exported program to a PT2 archive using ``AOTInductor``
with torch.no_grad():
    pt2_path = torch._inductor.aoti_compile_and_package(ep)

# Load and run the .so file in Python.
# To load and run it in a C++ environment, see:
# https://pytorch.dev.org.tw/docs/main/torch.compiler_aot_inductor.html
aoti_compiled = torch._inductor.aoti_load_package(pt2_path)
res = aoti_compiled(inp)

結論

我們介紹了 torch.export,這是新的 PyTorch 2.X 方法,可從 PyTorch 程式匯出單個計算圖。 特別是,我們演示了幾個程式碼修改和考量(控制流程操作、約束等),需要進行這些修改和考量才能匯出圖形。

腳本總執行時間: (0 分鐘 2.787 秒)

由 Sphinx-Gallery 產生的圖庫

文件

訪問 PyTorch 的綜合開發者文檔

查看文檔

教學

獲取初學者和高級開發人員的深入教程

查看教程

資源

尋找開發資源並獲得問題解答

查看資源