捷徑

torch.jit.trace_module

torch.jit.trace_module(mod, inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-05, strict=True, _force_outplace=False, _module_class=None, _compilation_unit=<torch.jit.CompilationUnit object>, example_inputs_is_kwarg=False, _store_inputs=True)[來源][來源]

追蹤模組並傳回一個可執行的 ScriptModule,它將使用即時編譯進行最佳化。

當一個模組傳遞給 torch.jit.trace 時,只會執行和追蹤 forward 方法。使用 trace_module,您可以指定一個方法名稱和範例輸入的字典來進行追蹤 (請參閱下面的 inputs 參數)。

有關追蹤的更多資訊,請參閱 torch.jit.trace

參數
  • mod (torch.nn.Module) – 一個 torch.nn.Module,包含其名稱在 inputs 中指定的方法。給定的方法將會被編譯為單個 ScriptModule 的一部分。

  • inputs (dict) – 一個字典,包含由 mod 中的方法名稱索引的範例輸入。輸入將會被傳遞給方法,其名稱對應於輸入的鍵。 例如: { 'forward' : example_forward_input, 'method2': example_method2_input}

關鍵字參數
  • check_trace (bool, optional) – 檢查透過追蹤程式碼執行的相同輸入是否產生相同的輸出。預設值:True。 如果您的網路包含非決定性的運算,或者您確定網路即使檢查器失敗也是正確的,您可能需要停用此功能。

  • check_inputs (list of dicts, optional) – 一個字典列表,包含應該用來檢查追蹤結果是否符合預期的輸入參數。每個元組等同於一組將在 inputs 中指定的輸入參數。 為了獲得最佳效果,請傳入一組具有代表性的檢查輸入,這些輸入代表您期望網路看到的輸入的形狀和類型空間。 如果未指定,則原始 inputs 將用於檢查。

  • check_tolerance (float, optional) – 在檢查程序中使用的浮點數比較容差。 這可以用於在已知原因(例如運算子融合)導致結果在數值上發散的情況下,放寬檢查器的嚴格性。

  • example_inputs_is_kwarg (bool, optional) – 此參數表示範例輸入是否為關鍵字參數的封包。預設值:False

返回值

一個 ScriptModule 物件,具有包含追蹤程式碼的單個 forward 方法。 當 func 是一個 torch.nn.Module 時,返回的 ScriptModule 將具有與 func 相同的子模組和參數集。

範例 (追蹤具有多個方法的模組)

import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = nn.Conv2d(1, 1, 3)

    def forward(self, x):
        return self.conv(x)

    def weighted_kernel_sum(self, weight):
        return weight * self.conv.weight


n = Net()
example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)

# Trace a specific method and construct `ScriptModule` with
# a single `forward` method
module = torch.jit.trace(n.forward, example_forward_input)

# Trace a module (implicitly traces `forward`) and construct a
# `ScriptModule` with a single `forward` method
module = torch.jit.trace(n, example_forward_input)

# Trace specific methods on a module (specified in `inputs`), constructs
# a `ScriptModule` with `forward` and `weighted_kernel_sum` methods
inputs = {'forward' : example_forward_input, 'weighted_kernel_sum' : example_weight}
module = torch.jit.trace_module(n, inputs)

文件

訪問 PyTorch 的綜合開發者文件

查看文檔

教程

獲取針對初學者和高級開發人員的深入教程

查看教程

資源

尋找開發資源並獲得您問題的解答

查看資源