torch.jit.fork¶
- torch.jit.fork(func, *args, **kwargs)[source][source]¶
建立執行 func 的非同步任務,以及對此執行結果值的參考。
fork 將立即返回,因此 func 的返回值可能尚未計算完成。若要強制完成任務並存取返回值,請在 Future 物件上調用 torch.jit.wait。使用返回 T 值的 func 調用 fork,其類型為 torch.jit.Future[T]。fork 調用可以任意嵌套,並且可以使用位置參數和關鍵字參數來調用。非同步執行僅在 TorchScript 中運行時才會發生。如果在純 Python 中運行,fork 不會並行執行。fork 在追蹤 (tracing) 時調用也不會並行執行,但是 fork 和 wait 的調用會被捕獲到導出的 IR Graph 中。
警告
fork 任務將以非確定性的方式執行。我們建議僅為不修改其輸入、模組屬性或全域狀態的純函數產生並行 fork 任務。
- 參數
func (callable 或 torch.nn.Module) – 將被調用的 Python 函數或 torch.nn.Module。 如果在 TorchScript 中執行,它將非同步執行,否則不會。 fork 的追蹤調用將被捕獲在 IR 中。
*args – 用於調用 func 的參數。
**kwargs – 用於調用 func 的參數。
- 返回值
對 func 執行的引用。 只有通過 torch.jit.wait 強制完成 func 後才能存取值 T。
- 返回類型
torch.jit.Future[T]
範例 (fork 一個自由函數)
import torch from torch import Tensor def foo(a : Tensor, b : int) -> Tensor: return a + b def bar(a): fut : torch.jit.Future[Tensor] = torch.jit.fork(foo, a, b=2) return torch.jit.wait(fut) script_bar = torch.jit.script(bar) input = torch.tensor(2) # only the scripted version executes asynchronously assert script_bar(input) == bar(input) # trace is not run asynchronously, but fork is captured in IR graph = torch.jit.trace(bar, (input,)).graph assert "fork" in str(graph)
範例 (fork 一個模組方法)
import torch from torch import Tensor class AddMod(torch.nn.Module): def forward(self, a: Tensor, b : int): return a + b class Mod(torch.nn.Module): def __init__(self) -> None: super(self).__init__() self.mod = AddMod() def forward(self, input): fut = torch.jit.fork(self.mod, a, b=2) return torch.jit.wait(fut) input = torch.tensor(2) mod = Mod() assert mod(input) == torch.jit.script(mod).forward(input)