torch.cuda.jiterator._create_multi_output_jit_fn¶
- torch.cuda.jiterator._create_multi_output_jit_fn(code_string, num_outputs, **kwargs)[source][source]¶
為 elementwise op 建立一個 jiterator 產生的 cuda 核心,該核心支援返回一個或多個輸出。
- 參數
- 回傳類型
範例
code_string = "template <typename T> void my_kernel(T x, T y, T alpha, T& out) { out = -x + alpha * y; }" jitted_fn = create_jit_fn(code_string, alpha=1.0) a = torch.rand(3, device='cuda') b = torch.rand(3, device='cuda') # invoke jitted function like a regular python function result = jitted_fn(a, b, alpha=3.14)
警告
此 API 處於 Beta 階段,未來版本可能會變更。
警告
此 API 僅支援最多 8 個輸入和 8 個輸出