捷徑

自訂 Python 運算子

建立於:2024 年 6 月 18 日 | 最後更新:2025 年 1 月 6 日 | 最後驗證:2024 年 11 月 5 日

您將學到
  • 如何將以 Python 撰寫的自訂運算子與 PyTorch 整合

  • 如何使用 torch.library.opcheck 測試自訂運算子

先決條件
  • PyTorch 2.4 或更高版本

PyTorch 提供了一個龐大的運算子函式庫,可用於 Tensor (例如 torch.addtorch.sum 等)。 但是,您可能希望在 PyTorch 中使用新的自訂運算子,該運算子可能是由第三方函式庫撰寫的。 本教學課程展示如何包裝 Python 函式,使其行為類似於 PyTorch 原生運算子。 您可能希望在 PyTorch 中建立自訂運算子的原因包括

  • 將任意 Python 函式視為相對於 torch.compile 的不透明可呼叫對象 (也就是,防止 torch.compile 追蹤到該函式中)。

  • 為任意 Python 函式新增訓練支援

使用 torch.library.custom_op() 建立 Python 自訂運算子。 使用 C++ TORCH_LIBRARY API 建立 C++ 自訂運算子 (這些運算子在無 Python 環境中有效)。 請參閱自訂運算子著陸頁以取得更多詳細資訊。

請注意,如果您的運算可以表示為現有 PyTorch 運算子的組合,那麼通常不需要使用自訂運算子 API – 所有內容 (例如 torch.compile、訓練支援) 應該都能正常運作。

範例:將 PIL 的 crop 包裝到自訂運算子中

假設我們正在使用 PIL 的 crop 運算。

import torch
from torchvision.transforms.functional import to_pil_image, pil_to_tensor
import PIL
import IPython
import matplotlib.pyplot as plt

def crop(pic, box):
    img = to_pil_image(pic.cpu())
    cropped_img = img.crop(box)
    return pil_to_tensor(cropped_img).to(pic.device) / 255.

def display(img):
    plt.imshow(img.numpy().transpose((1, 2, 0)))

img = torch.ones(3, 64, 64)
img *= torch.linspace(0, 1, steps=64) * torch.linspace(0, 1, steps=64).unsqueeze(-1)
display(img)
python custom ops
cropped_img = crop(img, (10, 10, 50, 50))
display(cropped_img)
python custom ops

crop 無法有效地被 torch.compile 開箱即用處理:torch.compile 會對它無法處理的函式產生 "graph break",而 graph break 對於效能不利。 以下程式碼透過引發錯誤來示範這一點 (torch.compile 搭配 fullgraph=True 如果發生 graph break,則會引發錯誤)。

@torch.compile(fullgraph=True)
def f(img):
    return crop(img, (10, 10, 50, 50))

# The following raises an error. Uncomment the line to see it.
# cropped_img = f(img)

為了將 crop 封裝以與 torch.compile 搭配使用,我們需要做兩件事

  1. 將函式包裝到 PyTorch 自訂運算子中。

  2. 為運算子新增一個 "FakeTensor 核心" (也稱為 "meta 核心")。 給定一些 FakeTensors 輸入 (沒有儲存體的虛擬 Tensor),此函式應傳回您選擇的具有正確 Tensor 元資料 (形狀/步幅/dtype/裝置) 的虛擬 Tensor。

from typing import Sequence

# Use torch.library.custom_op to define a new custom operator.
# If your operator mutates any input Tensors, their names must be specified
# in the ``mutates_args`` argument.
@torch.library.custom_op("mylib::crop", mutates_args=())
def crop(pic: torch.Tensor, box: Sequence[int]) -> torch.Tensor:
    img = to_pil_image(pic.cpu())
    cropped_img = img.crop(box)
    return (pil_to_tensor(cropped_img) / 255.).to(pic.device, pic.dtype)

# Use register_fake to add a ``FakeTensor`` kernel for the operator
@crop.register_fake
def _(pic, box):
    channels = pic.shape[0]
    x0, y0, x1, y1 = box
    return pic.new_empty(channels, y1 - y0, x1 - x0)

在此之後,crop 現在可以在沒有 graph break 的情況下運作

@torch.compile(fullgraph=True)
def f(img):
    return crop(img, (10, 10, 50, 50))

cropped_img = f(img)
display(img)
python custom ops
display(cropped_img)
python custom ops

為 crop 新增訓練支援

使用 torch.library.register_autograd 為運算子新增訓練支援。 優先選擇此方法,而非直接使用 torch.autograd.Functionautograd.Function 與 PyTorch 運算子註冊 API 的某些組合可能會導致 (並且已經導致) 與 torch.compile 組合時出現靜默的錯誤。

如果您不需要訓練支援,則無需使用 torch.library.register_autograd。 如果您最終使用沒有 autograd 註冊的 custom_op 進行訓練,我們將引發錯誤訊息。

crop 的梯度公式本質上是 PIL.paste (我們將推導過程留給讀者練習)。 讓我們首先將 paste 包裝到自訂運算子中

@torch.library.custom_op("mylib::paste", mutates_args=())
def paste(im1: torch.Tensor, im2: torch.Tensor, coord: Sequence[int]) -> torch.Tensor:
    assert im1.device == im2.device
    assert im1.dtype == im2.dtype
    im1_pil = to_pil_image(im1.cpu())
    im2_pil = to_pil_image(im2.cpu())
    PIL.Image.Image.paste(im1_pil, im2_pil, coord)
    return (pil_to_tensor(im1_pil) / 255.).to(im1.device, im1.dtype)

@paste.register_fake
def _(im1, im2, coord):
    assert im1.device == im2.device
    assert im1.dtype == im2.dtype
    return torch.empty_like(im1)

現在讓我們使用 register_autograd 來指定 crop 的梯度公式

def backward(ctx, grad_output):
    grad_input = grad_output.new_zeros(ctx.pic_shape)
    grad_input = paste(grad_input, grad_output, ctx.coords)
    return grad_input, None

def setup_context(ctx, inputs, output):
    pic, box = inputs
    ctx.coords = box[:2]
    ctx.pic_shape = pic.shape

crop.register_autograd(backward, setup_context=setup_context)

請注意,backward 必須是 PyTorch 理解的運算子的組合,這就是我們將 paste 包裝到自訂運算子中,而不是直接使用 PIL 的 paste 的原因。

img = img.requires_grad_()
result = crop(img, (10, 10, 50, 50))
result.sum().backward()
display(img.grad)
python custom ops

這是正確的梯度,裁剪區域為 1 (白色),未使用區域為 0 (黑色)。

測試 Python 自訂運算子

使用 torch.library.opcheck 測試自訂運算子是否已正確註冊。 這不會測試梯度在數學上是否正確; 請為此編寫單獨的測試 (手動測試或 torch.autograd.gradcheck)。

要使用 opcheck,請傳入一組範例輸入來進行測試。如果您的運算子支援訓練,則範例應包含需要梯度 (grad) 的 Tensor。如果您的運算子支援多個裝置,則範例應包含來自每個裝置的 Tensor。

examples = [
    [torch.randn(3, 64, 64), [0, 0, 10, 10]],
    [torch.randn(3, 91, 91, requires_grad=True), [10, 0, 20, 10]],
    [torch.randn(3, 60, 60, dtype=torch.double), [3, 4, 32, 20]],
    [torch.randn(3, 512, 512, requires_grad=True, dtype=torch.double), [3, 4, 32, 45]],
]

for example in examples:
    torch.library.opcheck(crop, example)

可變動的 Python 自定義運算子

您也可以將一個會變動其輸入的 Python 函數包裝成一個自定義運算子。變動輸入的函數很常見,因為許多底層核心 (kernel) 都是這樣編寫的;例如,一個計算 sin 的核心可能會接收輸入和輸出 Tensor,並將 input.sin() 寫入到輸出 Tensor 中。

我們將使用 numpy.sin 來示範一個可變動的 Python 自定義運算子的範例。

import numpy as np

@torch.library.custom_op("mylib::numpy_sin", mutates_args={"output"}, device_types="cpu")
def numpy_sin(input: torch.Tensor, output: torch.Tensor) -> None:
    assert input.device == output.device
    assert input.device.type == "cpu"
    input_np = input.numpy()
    output_np = output.numpy()
    np.sin(input_np, out=output_np)

由於運算子沒有返回任何值,因此無需註冊 FakeTensor 核心(meta kernel)即可使其與 torch.compile 搭配使用。

@torch.compile(fullgraph=True)
def f(x):
    out = torch.empty(3)
    numpy_sin(x, out)
    return out

x = torch.randn(3)
y = f(x)
assert torch.allclose(y, x.sin())

這是 opcheck 的執行結果,告訴我們確實正確註冊了運算子。如果我們忘記將輸出添加到 mutates_argsopcheck 將會報錯。

example_inputs = [
    [torch.randn(3), torch.empty(3)],
    [torch.randn(0, 3), torch.empty(0, 3)],
    [torch.randn(1, 2, 3, 4, dtype=torch.double), torch.empty(1, 2, 3, 4, dtype=torch.double)],
]

for example in example_inputs:
    torch.library.opcheck(numpy_sin, example)

結論

在本教程中,我們學習了如何使用 torch.library.custom_op 在 Python 中創建一個可以與 PyTorch 子系統(例如 torch.compile 和 autograd)搭配使用的自定義運算子。

本教程提供了自定義運算子的基本介紹。有關更詳細的信息,請參閱

腳本總執行時間:(0 分鐘 4.502 秒)

由 Sphinx-Gallery 生成的圖庫

文件

獲取 PyTorch 的全面開發者文件

查看文件

教學文件

獲取針對初學者和高級開發者的深入教程

查看教程

資源

查找開發資源並獲得您的問題解答

查看資源