捷徑

torch.jit.fork

torch.jit.fork(func, *args, **kwargs)[source][source]

建立執行 func 的非同步任務,以及對此執行結果值的參考。

fork 將立即返回,因此 func 的返回值可能尚未計算完成。若要強制完成任務並存取返回值,請在 Future 物件上調用 torch.jit.wait。使用返回 T 值的 func 調用 fork,其類型為 torch.jit.Future[T]fork 調用可以任意嵌套,並且可以使用位置參數和關鍵字參數來調用。非同步執行僅在 TorchScript 中運行時才會發生。如果在純 Python 中運行,fork 不會並行執行。fork 在追蹤 (tracing) 時調用也不會並行執行,但是 forkwait 的調用會被捕獲到導出的 IR Graph 中。

警告

fork 任務將以非確定性的方式執行。我們建議僅為不修改其輸入、模組屬性或全域狀態的純函數產生並行 fork 任務。

參數
  • func (callabletorch.nn.Module) – 將被調用的 Python 函數或 torch.nn.Module。 如果在 TorchScript 中執行,它將非同步執行,否則不會。 fork 的追蹤調用將被捕獲在 IR 中。

  • *args – 用於調用 func 的參數。

  • **kwargs – 用於調用 func 的參數。

返回值

func 執行的引用。 只有通過 torch.jit.wait 強制完成 func 後才能存取值 T

返回類型

torch.jit.Future[T]

範例 (fork 一個自由函數)

import torch
from torch import Tensor
def foo(a : Tensor, b : int) -> Tensor:
    return a + b
def bar(a):
    fut : torch.jit.Future[Tensor] = torch.jit.fork(foo, a, b=2)
    return torch.jit.wait(fut)
script_bar = torch.jit.script(bar)
input = torch.tensor(2)
# only the scripted version executes asynchronously
assert script_bar(input) == bar(input)
# trace is not run asynchronously, but fork is captured in IR
graph = torch.jit.trace(bar, (input,)).graph
assert "fork" in str(graph)

範例 (fork 一個模組方法)

import torch
from torch import Tensor
class AddMod(torch.nn.Module):
    def forward(self, a: Tensor, b : int):
        return a + b
class Mod(torch.nn.Module):
    def __init__(self) -> None:
        super(self).__init__()
        self.mod = AddMod()
    def forward(self, input):
        fut = torch.jit.fork(self.mod, a, b=2)
        return torch.jit.wait(fut)
input = torch.tensor(2)
mod = Mod()
assert mod(input) == torch.jit.script(mod).forward(input)

文件

存取 PyTorch 的綜合開發者文件

檢視文件

教學課程

取得針對初學者和進階開發者的深入教學課程

檢視教學課程

資源

尋找開發資源並取得您問題的解答

檢視資源