快捷鍵

torch.cuda.jiterator._create_multi_output_jit_fn

torch.cuda.jiterator._create_multi_output_jit_fn(code_string, num_outputs, **kwargs)[source][source]

為 elementwise op 建立一個 jiterator 產生的 cuda 核心,該核心支援返回一個或多個輸出。

參數
  • code_string (str) – 由 jiterator 編譯的 CUDA 程式碼字串。 進入函式物件必須通過引用傳回值。

  • num_outputs (int) – 核心返回的輸出數量

  • kwargs (Dict, optional) – 生成函數的關鍵字參數

回傳類型

可呼叫物件 (Callable)

範例

code_string = "template <typename T> void my_kernel(T x, T y, T alpha, T& out) { out = -x + alpha * y; }"
jitted_fn = create_jit_fn(code_string, alpha=1.0)
a = torch.rand(3, device='cuda')
b = torch.rand(3, device='cuda')
# invoke jitted function like a regular python function
result = jitted_fn(a, b, alpha=3.14)

警告

此 API 處於 Beta 階段,未來版本可能會變更。

警告

此 API 僅支援最多 8 個輸入和 8 個輸出

文件

取得 PyTorch 的完整開發者文件

查看文件

教學

取得針對初學者和進階開發者的深入教學

查看教學

資源

尋找開發資源並獲得您問題的解答

查看資源