torch.cuda.jiterator._create_jit_fn¶
- torch.cuda.jiterator._create_jit_fn(code_string, **kwargs)[來源][來源]¶
為元素級運算建立 jiterator 產生的 CUDA 核心。
程式碼字串必須是有效的 CUDA 函式,用於描述單一元素的計算。程式碼字串必須遵循 C++ 範本模式,如下範例所示。此函式將內嵌到元素級核心範本中,並即時編譯。編譯後的核心將快取在記憶體中,以及本機暫存目錄中。
Jiterator 產生的核心接受非連續張量,並支援廣播和類型提升。
- 參數
code_string (str) – 要由 jiterator 編譯的 CUDA 程式碼字串。輸入函式物件必須依值傳回。
kwargs (Dict, optional) – 產生函式的關鍵字引數
- 傳回類型
範例
code_string = "template <typename T> T my_kernel(T x, T y, T alpha) { return -x + alpha * y; }" jitted_fn = create_jit_fn(code_string, alpha=1.0) a = torch.rand(3, device='cuda') b = torch.rand(3, device='cuda') # invoke jitted function like a regular python function result = jitted_fn(a, b, alpha=3.14)
code_string 也允許多個函數定義,並且最後一個函數會被視為入口函數。
範例
code_string = "template <typename T> T util_fn(T x, T y) { return ::sin(x) + ::cos(y); }" code_string += "template <typename T> T my_kernel(T x, T y, T val) { return ::min(val, util_fn(x, y)); }" jitted_fn = create_jit_fn(code_string, val=0.0) a = torch.rand(3, device='cuda') b = torch.rand(3, device='cuda') # invoke jitted function like a regular python function result = jitted_fn(a, b) # using default val=0.0
Jiterator 可以與 Python 註冊一起使用,以覆寫運算子的 CUDA 核心。以下範例是使用 ReLU 覆寫 Gelu 的 CUDA 核心。
範例
code_string = "template <typename T> T my_gelu(T a) { return a > 0 ? a : 0; }" my_gelu = create_jit_fn(code_string) my_lib = torch.library.Library("aten", "IMPL") my_lib.impl('aten::gelu', my_gelu, "CUDA") # torch.nn.GELU and torch.nn.function.gelu are now overridden a = torch.rand(3, device='cuda') torch.allclose(torch.nn.functional.gelu(a), torch.nn.functional.relu(a))
警告
此 API 處於 Beta 階段,未來版本可能會變更。
警告
此 API 僅支援最多 8 個輸入和 1 個輸出
警告
所有輸入張量必須位於 CUDA 裝置中