torch.jit.trace¶
- torch.jit.trace(func, example_inputs=None, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-05, strict=True, _force_outplace=False, _module_class=None, _compilation_unit=<torch.jit.CompilationUnit object>, example_kwarg_inputs=None, _store_inputs=True)[原始碼][原始碼]¶
追蹤一個函式,並回傳一個可執行檔或
ScriptFunction
,它將使用即時編譯進行最佳化。追蹤非常適合僅對
Tensor
及其列表、字典和元組進行操作的程式碼。使用 torch.jit.trace 和 torch.jit.trace_module,您可以將現有的模組或 Python 函式轉換為 TorchScript
ScriptFunction
或ScriptModule
。 您必須提供範例輸入,我們會執行該函式,並記錄在所有張量上執行的操作。單獨函式的結果記錄會產生 ScriptFunction。
nn.Module.forward 或 nn.Module 的結果記錄會產生 ScriptModule。
此模組也包含原始模組擁有的任何參數。
警告
追蹤僅正確記錄不依賴資料的函式和模組(例如,沒有在張量資料上的條件語句),並且沒有任何未追蹤的外部依賴項(例如,執行輸入/輸出或存取全域變數)。 追蹤僅記錄在給定函式在給定張量上執行時所做的操作。 因此,回傳的 ScriptModule 將始終在任何輸入上執行相同的追蹤圖。 當您的模組預期會根據輸入和/或模組狀態執行不同的操作集時,這有一些重要的含義。 例如,
追蹤不會記錄任何控制流程,例如 if 語句或迴圈。 當此控制流程在您的模組中是恆定的時,這很好,並且它通常會內聯控制流程決策。 但有時控制流程實際上是模型本身的一部分。 例如,循環神經網路是對輸入序列(可能動態)長度的迴圈。
在回傳的
ScriptModule
中,在training
和eval
模式下具有不同行為的操作將始終表現得好像它處於追蹤期間的模式中,無論 ScriptModule 處於哪種模式。
在這些情況下,追蹤是不適當的,而
scripting
是一個更好的選擇。 如果您追蹤此類模型,則可能會在後續的模型調用中靜默地獲得不正確的結果。 當執行可能導致產生不正確追蹤的操作時,追蹤器將嘗試發出警告。- 參數
func (可呼叫物件 或 torch.nn.Module) – 一個 Python 函式或 torch.nn.Module,將使用 example_inputs 執行。 func 的引數和回傳值必須是張量或(可能巢狀的)包含張量的元組。 當一個模組被傳遞給 torch.jit.trace 時,只會執行和追蹤
forward
方法(請參閱torch.jit.trace
了解詳細資訊)。- 關鍵字引數
example_inputs (tuple 或 torch.Tensor 或 None, 選擇性) – 一個範例輸入的元組,在追蹤時將傳遞給該函式。預設值:
None
。 應該指定此引數或example_kwarg_inputs
。假設追蹤的操作支援這些類型和形狀,則可以使用不同類型和形狀的輸入來執行結果追蹤。example_inputs 也可以是單個 Tensor,在這種情況下,它會自動封裝在元組中。 當值為 None 時,應指定example_kwarg_inputs
。check_trace (
bool
, 選擇性) – 檢查透過追蹤程式碼執行的相同輸入是否產生相同的輸出。預設值:True
。 例如,如果您的網路包含非決定性操作,或者您確定網路是正確的,儘管檢查器失敗,您可能需要禁用此選項。check_inputs (list of tuples, 選擇性) – 一個輸入引數的元組列表,應用於根據預期檢查追蹤。每個元組等同於一組輸入引數,這些引數將在
example_inputs
中指定。為了獲得最佳結果,請傳入一組檢查輸入,這些輸入代表您期望網路看到的輸入形狀和類型的空間。如果未指定,則原始的example_inputs
用於檢查。check_tolerance (float, optional) – 用於檢查程序中的浮點數比較容差。在已知原因(例如運算符融合)導致結果在數值上存在差異時,可以使用此參數來放寬檢查器的嚴格性。
strict (
bool
, optional) – 是否以嚴格模式運行追蹤器(預設值:True
)。只有當您想要追蹤器記錄您的可變容器類型(目前為list
/dict
),並且您確定您在問題中使用的容器是一個constant
結構且不會用作控制流程(if,for)條件時,才關閉此選項。example_kwarg_inputs (dict, optional) – 此參數是一個範例輸入的關鍵字參數包,將在追蹤時傳遞給函數。預設值:
None
。應指定此參數或example_inputs
。該字典將通過被追蹤函數的參數名稱進行解包。如果字典的鍵與被追蹤函數的參數名稱不匹配,則會引發運行時異常。
- 返回
如果 func 是 nn.Module 或 nn.Module 的
forward
,則 trace 返回一個ScriptModule
物件,其中包含一個包含追蹤程式碼的forward
方法。返回的 ScriptModule 將具有與原始nn.Module
相同的子模組和參數集。如果func
是一個獨立函數,則trace
返回 ScriptFunction。
範例 (追蹤函數)
import torch def foo(x, y): return 2 * x + y # Run `foo` with the provided inputs and record the tensor operations traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3))) # `traced_foo` can now be run with the TorchScript interpreter or saved # and loaded in a Python-free environment
範例 (追蹤現有模組)
import torch import torch.nn as nn class Net(nn.Module): def __init__(self) -> None: super().__init__() self.conv = nn.Conv2d(1, 1, 3) def forward(self, x): return self.conv(x) n = Net() example_weight = torch.rand(1, 1, 3, 3) example_forward_input = torch.rand(1, 1, 3, 3) # Trace a specific method and construct `ScriptModule` with # a single `forward` method module = torch.jit.trace(n.forward, example_forward_input) # Trace a module (implicitly traces `forward`) and construct a # `ScriptModule` with a single `forward` method module = torch.jit.trace(n, example_forward_input)