快捷方式

擴展 PyTorch

在本筆記中,我們將涵蓋擴展 torch.nntorch.autogradtorch 的方法,以及撰寫自訂 C++ 擴展。

新增運算子

PyTorch 提供了一個龐大的運算子庫,這些運算子可以處理 Tensor(例如 torch.add()torch.sum() 等)。但是,您可能希望將新的自訂運算引入 PyTorch,並使其行為與 PyTorch 的內建運算子類似。為此,您必須透過 Python torch.library 或 C++ TORCH_LIBRARY API 向 PyTorch 註冊自訂運算。

請參閱 PyTorch 自訂運算子登陸頁面 以取得更多詳細資訊。

擴展 torch.autograd

autograd 新增運算需要為每個運算實作一個新的 Function 子類別。回想一下,Function 是 autograd 用於編碼運算歷史記錄和計算梯度的東西。

本文檔的第一部分重點介紹反向模式 AD,因為它是最廣泛使用的功能。最後一部分討論了前向模式 AD 的擴展。

何時使用

通常,如果您想在模型中執行不可微分或依賴非 PyTorch 函式庫(例如 NumPy)的計算,但仍然希望您的運算與其他運算鏈接並與 autograd 引擎一起工作,請實作一個自訂函數。

在某些情況下,自訂函數也可以用於提高效能和記憶體使用量:如果您使用 C++ 擴展 實作了前向和後向傳遞,您可以將它們包裝在 Function 中以與 autograd 引擎連接。如果您想減少為後向傳遞保存的緩衝區數量,可以使用自訂函數將運算組合在一起。

何時不使用

如果您已經可以使用 PyTorch 的內建運算來編寫函數,那麼 autograd(很可能)已經能夠記錄其後向圖。在這種情況下,您不需要自己實作後向函數。考慮使用普通的 Python 函數。

如果您需要維護狀態,即,可訓練的參數,您應該(也)使用自訂模組。請參閱下面的章節以取得有關擴展 torch.nn 的更多資訊。

如果您想在後向傳遞期間更改梯度或執行副作用,請考慮註冊一個 tensorModule 鉤子。

如何使用

請按照以下步驟操作: 1. 繼承 Function 並實作 forward()、 (可選)setup_context()backward() 方法。 2. 在 ctx 引數上呼叫正確的方法。 3. 宣告您的函數是否支援 double backward。 4. 使用 gradcheck 驗證您的梯度是否正確。

步驟 1: 在繼承 Function 之後,您需要定義 3 個方法

  • forward() 是執行運算的程式碼。它可以接受任意數量的引數,如果您指定預設值,其中一些引數可以是可選的。這裡接受各種 Python 物件。追蹤歷史記錄(即,具有 requires_grad=True)的 Tensor 引數將在呼叫之前轉換為不追蹤歷史記錄的引數,並且它們的使用將在圖表中註冊。請注意,此邏輯不會遍歷列表/字典/任何其他資料結構,並且只會考慮作為呼叫直接引數的張量。您可以傳回單個 Tensor 輸出,或者如果有多個輸出,則傳回 tuple 的張量。此外,請參閱 Function 的文件,以尋找只能從 forward() 呼叫的有用方法的描述。

  • setup_context() (選用)。您可以撰寫一個「組合式」的 forward(),它接受一個 ctx 物件,或者(從 PyTorch 2.0 開始)一個獨立的 forward(),它不接受 ctx 和一個 setup_context() 方法,其中 ctx 的修改發生在這裡。forward() 應該負責計算,而 setup_context() 應該只負責修改 ctx(而不進行任何計算)。一般來說,分離的 forward()setup_context() 更接近 PyTorch 原生運算的操作方式,因此更容易與各種 PyTorch 子系統組合。有關更多詳細資訊,請參閱 組合或分離的 forward() 和 setup_context()

  • backward() (或 vjp()) 定義梯度公式。它會收到與輸出數量相同的 Tensor 引數,每個引數都代表相對於該輸出的梯度。絕對不要就地修改這些引數。它應該回傳與輸入數量相同的 tensors,每個 tensor 包含相對於其對應輸入的梯度。如果您的輸入不需要梯度(needs_input_grad 是一個布林值的元組,指示每個輸入是否需要梯度計算),或者不是 Tensor 物件,您可以回傳 python:None。此外,如果您有 forward() 的可選引數,您可以回傳比輸入更多的梯度,只要它們都是 None 即可。

步驟 2:您有責任正確使用 ctx 中的函數,以確保新的 Function 能夠與 autograd 引擎正常運作。

  • save_for_backward() 必須用於儲存任何要在反向傳播中使用的 tensors。非 tensors 應該直接儲存在 ctx 上。如果既不是輸入也不是輸出的 tensors 被儲存用於反向傳播,您的 Function 可能不支援雙重反向傳播(請參閱步驟 3)。

  • mark_dirty() 必須用於標記任何被前向函數就地修改的輸入。

  • mark_non_differentiable() 必須用於告知引擎,如果輸出不可微分。預設情況下,所有可微分類型的輸出 tensors 都會被設定為需要梯度。不可微分類型(即,整數類型)的 tensors 永遠不會被標記為需要梯度。

  • set_materialize_grads() 可以用於告知 autograd 引擎,在輸出不依賴於輸入的情況下,通過不實例化給 backward 函數的 grad tensors 來優化梯度計算。也就是說,如果設定為 False,則在 Python 中的 None 物件或在 C++ 中的「未定義 tensor」(tensor x,其中 x.defined() 為 False)在調用 backward 之前不會被轉換為填充零的 tensor,因此您的程式碼需要像處理填充零的 tensor 一樣處理這些物件。此設定的預設值為 True。

步驟 3:如果您的 Function 不支援雙重反向傳播,您應該通過使用 once_differentiable() 修飾 backward 來顯式聲明這一點。使用此修飾符,嘗試通過您的函數執行雙重反向傳播將產生錯誤。有關雙重反向傳播的更多資訊,請參閱我們的雙重反向傳播教程。

步驟 4:建議您使用 torch.autograd.gradcheck() 檢查您的 backward 函數是否通過使用您的 backward 函數計算 Jacobian 矩陣,並將該值與使用有限差分法以數值方式計算出的 Jacobian 進行逐元素比較,從而正確計算前向的梯度。

範例

您可以在下面找到 Linear 函式的程式碼,以及額外的註解

# Inherit from Function
class LinearFunction(Function):

    # Note that forward, setup_context, and backward are @staticmethods
    @staticmethod
    def forward(input, weight, bias):
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    @staticmethod
    # inputs is a Tuple of all of the inputs passed to forward.
    # output is the output of the forward().
    def setup_context(ctx, inputs, output):
        input, weight, bias = inputs
        ctx.save_for_backward(input, weight, bias)

    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_output):
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None

        # These needs_input_grad checks are optional and there only to
        # improve efficiency. If you want to make your code simpler, you can
        # skip them. Returning gradients for inputs that don't require it is
        # not an error.
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)

        return grad_input, grad_weight, grad_bias

現在,為了更容易使用這些自訂運算 (custom ops),我們建議您為它們建立別名 (aliasing) 或是將它們包裝在函式中。將它們包裝在函式中可以讓我們支援預設參數 (default arguments) 和關鍵字參數 (keyword arguments)。

# Option 1: alias
linear = LinearFunction.apply

# Option 2: wrap in a function, to support default args and keyword args.
def linear(input, weight, bias=None):
    return LinearFunction.apply(input, weight, bias)

在這裡,我們提供一個額外的範例,說明如何建立一個由非 Tensor 參數化的函式。

class MulConstant(Function):
    @staticmethod
    def forward(tensor, constant):
        return tensor * constant

    @staticmethod
    def setup_context(ctx, inputs, output):
        # ctx is a context object that can be used to stash information
        # for backward computation
        tensor, constant = inputs
        ctx.constant = constant

    @staticmethod
    def backward(ctx, grad_output):
        # We return as many input gradients as there were arguments.
        # Gradients of non-Tensor arguments to forward must be None.
        return grad_output * ctx.constant, None

在這裡,我們透過呼叫 set_materialize_grads(False) 來最佳化上面的範例。

class MulConstant(Function):
    @staticmethod
    def forward(tensor, constant):
        return tensor * constant

    @staticmethod
    def setup_context(ctx, inputs, output):
        tensor, constant = inputs
        ctx.set_materialize_grads(False)
        ctx.constant = constant

    @staticmethod
    def backward(ctx, grad_output):
        # Here we must handle None grad_output tensor. In this case we
        # can skip unnecessary computations and just return None.
        if grad_output is None:
            return None, None

        # We return as many input gradients as there were arguments.
        # Gradients of non-Tensor arguments to forward must be None.
        return grad_output * ctx.constant, None

如果您需要在 forward() 中計算的任何「中間」Tensor 被儲存,它們必須作為輸出返回,或者結合 forwardsetup_context() (請參閱 組合或分離 forward() 和 setup_context())。請注意,這表示如果您希望梯度流經這些中間值,則需要為它們定義梯度公式(另請參閱 double backward 教學

class MyCube(torch.autograd.Function):
    @staticmethod
    def forward(x):
        # We wish to save dx for backward. In order to do so, it must
        # be returned as an output.
        dx = 3 * x ** 2
        result = x ** 3
        return result, dx

    @staticmethod
    def setup_context(ctx, inputs, output):
        x, = inputs
        result, dx = output
        ctx.save_for_backward(x, dx)

    @staticmethod
    def backward(ctx, grad_output, grad_dx):
        x, dx = ctx.saved_tensors
        # In order for the autograd.Function to work with higher-order
        # gradients, we must add the gradient contribution of `dx`,
        # which is grad_dx * 6 * x.
        result = grad_output * dx + grad_dx * 6 * x
        return result

# Wrap MyCube in a function so that it is clearer what the output is
def my_cube(x):
    result, dx = MyCube.apply(x)
    return result

注意

backward 的輸入,例如 grad_output,也可以是追蹤歷史記錄的張量。 因此,如果 backward 是使用可微分運算(例如,調用另一個自訂 Function)來實作,則可以使用更高階的導數。 在這種情況下,使用 save_for_backward 儲存的張量也可以在 backward 中使用,並且具有回流的梯度,但是儲存在 ctx 中的張量不會有回流的梯度。 如果您需要儲存在 ctx 中的 Tensor 的梯度回流,則應將其設為自訂 Function 的輸出,並使用 save_for_backward 儲存它。

您可能需要檢查您實作的 backward 方法是否實際計算出您函式的導數。 可以透過使用小的有限差分 (finite differences) 與數值近似值比較來實現。

from torch.autograd import gradcheck

# gradcheck takes a tuple of tensors as input, check if your gradient
# evaluated with these tensors are close enough to numerical
# approximations and returns True if they all verify this condition.
input = (torch.randn(20,20,dtype=torch.double,requires_grad=True), torch.randn(30,20,dtype=torch.double,requires_grad=True))
test = gradcheck(linear, input, eps=1e-6, atol=1e-4)
print(test)

有關有限差分梯度比較的更多詳細資訊,請參閱 數值梯度檢查。 如果您的函式用於更高階的導數(微分 backward 傳遞),則可以使用同一個套件中的 gradgradcheck 函式來檢查更高階的導數。

組合或分離 forward()setup_context()

定義 Function 主要有兩種方式。 可以

  • 定義一個 forward(),它將前向計算邏輯 (forward compute logic) 與 setup_context() 結合起來。

  • (從 PyTorch 2.0 開始)定義一個單獨的 forward()setup_context()

我們建議使用第二種選項(單獨的 forward()setup_context()),因為它更接近 PyTorch 原生運算的實作方式,並且與 torch.func 轉換相容。 但是,我們計劃繼續支援這兩種方法;將 forward()setup_context() 結合:可以帶來更大的靈活性,因為您可以儲存中間值,而無需將它們作為輸出返回。

有關如何使用單獨的 forward()setup_context() 定義 Function,請參閱上一節。

這是一個範例,說明如何使用組合的 forward()setup_context() 定義 Function

class LinearFunction(Function):
    @staticmethod
    # ctx is the first argument to forward
    def forward(ctx, input, weight, bias=None):
        # The forward pass can use ctx.
        ctx.save_for_backward(input, weight, bias)
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None

        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)

        return grad_input, grad_weight, grad_bias

前向模式 AD

覆寫前向模式 AD 公式具有非常相似的 API,但有一些不同的細微之處。您可以實作 jvp() 函式。

它將會獲得與輸入相同數量的 Tensor 參數,每個參數代表相對於該輸入的梯度。 它應該傳回與輸出相同數量的張量,每個張量包含相對於其相應輸出的梯度。 jvp() 將在 forward() 方法之後,在 apply() 傳回之前立即被呼叫。

jvp()backward() 函式有一些細微的差異

  • 你可以使用 ctx 將任何資料從 forward() 傳遞到 jvp() 函式。如果 backward() 不需要該狀態,你可以在 jvp() 函式的末尾執行 del ctx.foo 來顯式釋放它。

  • jvp() 的實作必須是可反向微分的,或明確檢查是否沒有給定的前向模式梯度設置了 requires_grad

  • jvp() 函式必須符合 forward() 的視圖/原地 (view/inplace) 行為。例如,如果第 i 個輸入被原地修改,則第 i 個梯度必須被原地更新。同樣地,如果第 j 個輸出是第 k 個輸入的視圖,那麼返回的第 j 個輸出梯度必須是給定的第 k 個輸入的梯度視圖。

  • 由於使用者無法指定需要計算哪個梯度,因此 jvp() 函式應始終計算所有輸出的梯度。

  • 前向模式梯度確實會遵守由 set_materialize_grads() 設置的標誌,並且當此功能被禁用時,你可以獲得 None 輸入梯度。

torch.func 轉換和/或 torch.vmap()

有關詳細信息,請參閱 使用 autograd.Function 擴展 torch.func

擴展 torch.nn

nn 導出兩種接口 - 模組及其函數式版本。你可以通過兩種方式擴展它,但我們建議對所有種類的層使用模組,這些層包含任何參數或緩衝區,並且建議對像激活函數、池化等無參數操作使用函數式形式。

在上面的章節中,已經完全涵蓋了添加操作的函數式版本。

新增 Module

由於 nn 大量利用 autograd,因此新增一個 Module 需要實作一個執行操作並能計算梯度的 Function。從現在開始,我們假設我們要實現一個 Linear 模組,並且我們已經按照上面的列表實現了該函式。新增這個模組所需的程式碼非常少。現在,需要實作兩個函式

  • __init__ (可選) - 接收諸如 Kernel size、特徵數量等的參數,並初始化參數和緩衝區。

  • forward() - 實例化一個 Function 並使用它來執行操作。它非常類似於上面顯示的函數式包裝器。

這是一個 Linear 模組的實作方式

class Linear(nn.Module):
    def __init__(self, input_features, output_features, bias=True):
        super().__init__()
        self.input_features = input_features
        self.output_features = output_features

        # nn.Parameter is a special kind of Tensor, that will get
        # automatically registered as Module's parameter once it's assigned
        # as an attribute. Parameters and buffers need to be registered, or
        # they won't appear in .parameters() (doesn't apply to buffers), and
        # won't be converted when e.g. .cuda() is called. You can use
        # .register_buffer() to register buffers.
        # nn.Parameters require gradients by default.
        self.weight = nn.Parameter(torch.empty(output_features, input_features))
        if bias:
            self.bias = nn.Parameter(torch.empty(output_features))
        else:
            # You should always register all possible parameters, but the
            # optional ones can be None if you want.
            self.register_parameter('bias', None)

        # Not a very smart way to initialize weights
        nn.init.uniform_(self.weight, -0.1, 0.1)
        if self.bias is not None:
            nn.init.uniform_(self.bias, -0.1, 0.1)

    def forward(self, input):
        # See the autograd section for explanation of what happens here.
        return LinearFunction.apply(input, self.weight, self.bias)

    def extra_repr(self):
        # (Optional)Set the extra information about this module. You can test
        # it by printing an object of this class.
        return 'input_features={}, output_features={}, bias={}'.format(
            self.input_features, self.output_features, self.bias is not None
        )

擴展 torch Python API

你可以通過定義一個具有與 Tensor 相符的方法的自定義類別來建立模擬 Tensor 的自定義類型。但是,如果你希望能夠將這些類型傳遞給頂級 torch 命名空間中接受 Tensor 運算元的 torch.add() 等函式,該怎麼辦?

如果您的自訂 Python 類型定義了一個名為 __torch_function__ 的方法,當您的自訂類別的實例傳遞到 torch 命名空間中的函數時,PyTorch 將會調用您的 __torch_function__ 實作。這使得您可以為 torch 命名空間中的任何函數定義自訂實作,您的 __torch_function__ 實作可以呼叫這些函數,從而允許您的使用者將您的自訂類型與他們已經為 Tensor 撰寫的現有 PyTorch 工作流程一起使用。這適用於與 Tensor 無關的「duck」類型,以及使用者定義的 Tensor 子類別。

使用類似 Tensor 的類型擴展 torch

注意

此功能靈感來自 NumPy __array_function__ 協定。 有關更多詳細資訊,請參閱NumPy 文件NEP-0018

為了更具體,讓我們先從一個簡單的例子開始,來說明 API 調度機制。我們將建立一個自訂類型,表示一個 2D 純量張量,由階數 N 和沿對角線條目的值 value 參數化

class ScalarTensor(object):
   def __init__(self, N, value):
       self._N = N
       self._value = value

   def __repr__(self):
       return "ScalarTensor(N={}, value={})".format(self._N, self._value)

   def tensor(self):
       return self._value * torch.eye(self._N)

這個設計的第一個迭代並不是很實用。 ScalarTensor 的主要功能是提供比基本張量類別更緊湊的純量張量字串表示

>>> d = ScalarTensor(5, 2)
>>> d
ScalarTensor(N=5, value=2)
>>> d.tensor()
tensor([[2., 0., 0., 0., 0.],
        [0., 2., 0., 0., 0.],
        [0., 0., 2., 0., 0.],
        [0., 0., 0., 2., 0.],
        [0., 0., 0., 0., 2.]])

如果我們嘗試將此物件與 torch API 一起使用,我們會遇到問題

>>> import torch
>>> torch.mean(d)
TypeError: mean(): argument 'input' (position 1) must be Tensor, not ScalarTensor

ScalarTensor 添加 __torch_function__ 實作,使得上述操作可以成功。 讓我們重新進行我們的實作,這次添加 __torch_function__ 實作

HANDLED_FUNCTIONS = {}
class ScalarTensor(object):
    def __init__(self, N, value):
        self._N = N
        self._value = value

    def __repr__(self):
        return "ScalarTensor(N={}, value={})".format(self._N, self._value)

    def tensor(self):
        return self._value * torch.eye(self._N)

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        if func not in HANDLED_FUNCTIONS or not all(
            issubclass(t, (torch.Tensor, ScalarTensor))
            for t in types
        ):
            return NotImplemented
        return HANDLED_FUNCTIONS[func](*args, **kwargs)

__torch_function__ 方法採用四個參數:func,對正在覆蓋的 torch API 函數的引用;types,實作 __torch_function__ 的類 Tensor 類型列表;args,傳遞給函數的參數元組;以及 kwargs,傳遞給函數的關鍵字參數字典。它使用一個名為 HANDLED_FUNCTIONS 的全域調度表來儲存自訂實作。此字典的鍵是 torch 命名空間中的函數,值是 ScalarTensor 的實作。

注意

使用全域調度表不是 __torch_function__ API 的強制部分,它只是一個用於結構化覆蓋實作的有用設計模式。

這個類別定義還不足以讓 torch.mean 在我們傳遞給它 ScalarTensor 時做正確的事情——我們還需要為 ScalarTensor 運算元定義 torch.mean 的實作,並將該實作添加到 HANDLED_FUNCTIONS 調度表字典中。 一種方法是定義一個裝飾器

import functools
def implements(torch_function):
    """Register a torch function override for ScalarTensor"""
    def decorator(func):
        functools.update_wrapper(func, torch_function)
        HANDLED_FUNCTIONS[torch_function] = func
        return func
    return decorator

可以應用於我們覆蓋的實作

@implements(torch.mean)
def mean(input):
    return float(input._value) / input._N

透過此更改,我們現在可以使用 torch.meanScalarTensor

>>> d = ScalarTensor(5, 2)
>>> torch.mean(d)
0.4

當然,torch.mean 是一個最簡單的覆蓋函數的範例,因為它只採用一個運算元。 我們可以使用相同的機制來覆蓋採用多個運算元的函數,其中任何一個運算元都可能是定義 __torch_function__ 的張量或類張量,例如對於 torch.add()

def ensure_tensor(data):
    if isinstance(data, ScalarTensor):
        return data.tensor()
    return torch.as_tensor(data)

@implements(torch.add)
def add(input, other):
   try:
       if input._N == other._N:
           return ScalarTensor(input._N, input._value + other._value)
       else:
           raise ValueError("Shape mismatch!")
   except AttributeError:
       return torch.add(ensure_tensor(input), ensure_tensor(other))

此版本有一個快速路徑,適用於當兩個運算元都是 ScalarTensor 實例時,還有一個較慢的路徑,當任一運算元不是 ScalarTensor 時,會降級為將資料轉換為張量。 這使得覆蓋函數在任一運算元是 ScalarTensor 或常規 Tensor 時都能正確執行

>>> s = ScalarTensor(2, 2)
>>> torch.add(s, s)
ScalarTensor(N=2, value=4)
>>> t = torch.tensor([[1, 1,], [1, 1]])
>>> torch.add(s, t)
tensor([[3., 1.],
        [1., 3.]])

請注意,我們的 add 實作沒有像 torch.add() 那樣採用 alphaout 作為關鍵字參數

>>> torch.add(s, s, alpha=2)
TypeError: add() got an unexpected keyword argument 'alpha'

為了速度和靈活性,__torch_function__ 調度機制不會檢查覆蓋函數的簽章是否與 torch API 中要覆蓋的函數的簽章相符。 對於某些應用程式來說,忽略可選參數可能沒問題,但為了確保與 Tensor 的完全相容性,torch API 函數的使用者實作應注意完全模擬要覆蓋的函數的 API。

torch API 中,如果函式沒有明確的覆寫 (override),則會從 __torch_function__ 回傳 NotImplemented。如果所有定義了 __torch_function__ 的運算元都回傳 NotImplemented,PyTorch 會拋出 TypeError。這表示大多數情況下,如果某個型別沒有明確的覆寫操作,當傳入該型別的實例時,會拋出 TypeError

>>> torch.mul(s, 3)
TypeError: no implementation found for 'torch.mul' on types that
implement __torch_function__: [ScalarTensor]

實際上,這表示如果您想要使用類似以下的 __torch_function__ 實作來實作您的覆寫,您需要明確地實作完整的 torch API,或者您關心的 API 的整個子集。這可能是一項艱鉅的任務,因為完整的 torch API 相當廣泛。

另一種選擇是,對於未處理的操作,不要回傳 NotImplemented,而是當沒有覆寫可用時,將 Tensor 傳遞給原始的 torch 函式。例如,如果我們將 ScalarTensor__torch_function__ 實作更改為以下內容:

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
    if kwargs is None:
        kwargs = {}
    if func not in HANDLED_FUNCTIONS or not all(
            issubclass(t, (torch.Tensor, ScalarTensor))
            for t in types
        ):
        args = [a.tensor() if hasattr(a, 'tensor') else a for a in args]
        return func(*args, **kwargs)
    return HANDLED_FUNCTIONS[func](*args, **kwargs)

那麼 torch.mul() 將正確運作,儘管回傳型別始終是 Tensor,而不是 ScalarTensor,即使兩個運算元都是 ScalarTensor 實例。

>>> s = ScalarTensor(2, 2)
>>> torch.mul(s, s)
tensor([[4., 0.],
        [0., 4.]])

另請參閱下面的 MetadataTensor 範例,了解此模式的另一種變體,但始終回傳 MetadataTensor,以便在 torch API 的操作中傳播元數據 (metadata)。

__torch_function__ 協定旨在完全覆蓋 API,部分覆蓋可能會導致不良結果,特別是,某些函式會拋出 TypeError。對於子類別來說尤其如此,即使它們回傳完全相同的結果,也必須覆蓋 torch.addtorch.Tensor.__add__torch.Tensor.add 這三個函式。未能做到這一點也可能導致無限遞迴。如果需要實作來自 torch.Tensor 子類別的函式,則必須在其內部實作中使用 super().__torch_function__

子類別化 torch.Tensor

從 1.7.0 版開始,應用於 torch.Tensor 子類別的 torch.Tensor 方法和公共 torch.* 命名空間中的函式,將回傳子類別實例,而不是 torch.Tensor 實例。

>>> class SubTensor(torch.Tensor):
...     pass
>>> type(torch.add(SubTensor([0]), SubTensor([1]))).__name__
'SubTensor'
>>> type(torch.add(SubTensor([0]), torch.tensor([1]))).__name__
'SubTensor'

如果存在多個子類別,則預設會選擇層次結構中最低的那個。如果沒有唯一的方法可以確定這種情況,則會引發 TypeError

>>> type(torch.add(SubTensor2([0]), SubTensor([1]))).__name__
'SubTensor2'
>>> type(torch.add(SubTensor2([0]), torch.tensor([1]))).__name__
'SubTensor2'
>>> torch.add(SubTensor([0]), OtherSubTensor([1]))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: no implementation found for 'torch.add' on types that implement __torch_function__: [SubTensor, OtherSubTensor]

如果希望全域覆寫所有張量 (tensor) 方法,可以使用 __torch_function__。這是一個記錄所有函式/方法調用的範例:

class LoggingTensor(torch.Tensor):
    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        # NOTE: Logging calls Tensor.__repr__, so we can't log __repr__ without infinite recursion
        if func is not torch.Tensor.__repr__:
            logging.info(f"func: {func.__name__}, args: {args!r}, kwargs: {kwargs!r}")
        if kwargs is None:
            kwargs = {}
        return super().__torch_function__(func, types, args, kwargs)

但是,如果希望覆寫 Tensor 子類別上的方法,可以通過直接覆寫該方法(通過為子類別定義它),或者使用 __torch_function__ 並與 func 匹配來實現。

在子類別的 __torch_function__ 中應該小心,始終調用 super().__torch_function__(func, ...),而不是像 1.7.0 版之前那樣直接調用 func。否則可能導致 func 遞迴回到 __torch_function__,從而導致無限遞迴。

使用 Tensor 包裝器型別擴展 torch

另一個有用的案例是一種包裝 Tensor 的型別,可以作為屬性或通過子類別化。下面我們實作了這種型別的一個特例,一個 MetadataTensor,它將元數據字典附加到 Tensor,該元數據通過 torch 操作傳播。由於這是對整個 torch API 的通用包裝,因此我們不需要單獨實作每個覆寫,因此我們可以使 __torch_function__ 實作對允許的操作更加寬容。

class MetadataTensor(object):
    def __init__(self, data, metadata=None, **kwargs):
        self._t = torch.as_tensor(data, **kwargs)
        self._metadata = metadata

    def __repr__(self):
        return "Metadata:\n{}\n\ndata:\n{}".format(self._metadata, self._t)

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        metadatas = tuple(a._metadata for a in args if hasattr(a, '_metadata'))
        args = [getattr(a, '_t', a) for a in args]
        assert len(metadatas) > 0
        ret = func(*args, **kwargs)
        return MetadataTensor(ret, metadata=metadatas[0])

這個簡單的實作不一定適用於 torch API 中的每個函式,但它足以捕獲大多數常見操作。

>>> metadata = {'owner': 'Ministry of Silly Walks'}
>>> m = MetadataTensor([[1, 2], [3, 4]], metadata=metadata)
>>> t = torch.tensor([[1, 2], [1, 2]])
>>> torch.add(t, m)
Metadata:
{'owner': 'Ministry of Silly Walks'}

data:
tensor([[2, 4],
        [4, 6]])
>>> torch.mul(t, m)
Metadata:
{'owner': 'Ministry of Silly Walks'}

data:
tensor([[1, 4],
        [3, 8]])

對定義了 __torch_function__ 的多個型別進行操作

可以使用 torch API 處理多個不同的型別,每個型別都有一個 __torch_function__ 實作,但必須特別小心。在這種情況下,規則是:

  • 分派 (dispatch) 操作會收集每個運算元的所有不同的 __torch_function__ 實作,並按順序調用它們:子類別優先於父類別,否則按照運算式中的從左到右的順序。

  • 如果回傳的值不是 NotImplemented,則該值將作為結果回傳。實作可以通過回傳 NotImplemented 來註冊它們未實作某個操作。

  • 如果所有 __torch_function__ 實作都回傳 NotImplemented,則 PyTorch 會拋出 TypeError

測試 PyTorch API 的覆寫範圍

實作 __torch_function__ 的一個麻煩之處在於,如果某些操作有覆寫 (override),而其他操作沒有,使用者可能會遇到不一致的體驗,最糟的情況是在運行時使用沒有覆寫的函數時會引發錯誤。為了簡化這個過程,PyTorch 提供了一個面向開發者的 API,以確保對 __torch_function__ 覆寫的完整支援。這個 API 是私有的,未來可能會在沒有警告的情況下進行變更。

首先,要取得所有可覆寫函數的清單,請使用 torch.overrides._get_overridable_functions。 這會返回一個字典,其鍵是 PyTorch Python API 中的命名空間,其值是該命名空間中可以覆寫的函數清單。 例如,讓我們印出 torch.nn.functional 中前 5 個可以覆寫的函數名稱。

>>> from torch.overrides import get_overridable_functions
>>> func_dict = get_overridable_functions()
>>> nn_funcs = func_dict[torch.nn.functional]
>>> print([f.__name__ for f in nn_funcs[:5])
['adaptive_avg_pool1d', 'adaptive_avg_pool2d', 'adaptive_avg_pool3d',
 'adaptive_max_pool1d', 'adaptive_max_pool1d_with_indices']

這個函數清單讓您可以迭代所有可覆寫的函數,但是在實務上,這還不足以為所有這些函數編寫測試,而無需費力地手動複製每個函數的簽名以進行每個測試。 為了簡化這個過程,torch.overrides._get_testing_overrides 函數會返回一個字典,該字典將 PyTorch API 中的可覆寫函數映射到虛擬的 lambda 函數,這些函數具有與原始函數相同的簽名,但無條件地返回 -1。 這些函數最適合與 inspect 一起使用,以分析原始 PyTorch 函數的函數簽名。

>>> import inspect
>>> from torch.overrides import get_testing_overrides
>>> override_dict = get_testing_overrides()
>>> dummy_add = override_dict[torch.add]
>>> inspect.signature(dummy_add)
<Signature (input, other, out=None)>

最後,torch.overrides.get_ignored_functions 會返回一個函數元組,這些函數明確地不能被 __torch_function__ 覆寫。 此清單可用於確認 get_overridable_functions 返回的字典中不存在的函數是否無法被覆寫。

擴展 torch 原生 API

雖然 __torch_function__ 允許有效擴展 PyTorch 純 Python 組件的行為,但它不允許擴展以 C++ 實現的 PyTorch 部分。 為此,Tensor 子類別也可以定義 __torch_dispatch__,它將能夠在 C++ 層級覆寫行為。

要有效使用此功能,重要的是要知道 PyTorch 原生部分的實作方式。 其中最重要的組件是我們稱之為「dispatcher」的東西(最好的描述可以在這篇 部落格文章 中找到,即使它稍微過時了)。 顧名思義,它負責為特定函數呼叫呼叫正確的後端函數。 例如,當呼叫 torch.add(a, b) 時,dispatcher 將檢查這兩個參數,找出哪個「功能」(autograd、autocast、functionalization 等)和哪個「後端」(CPU、CUDA、MPS 等)應該用於此特定呼叫,最後呼叫所有正確的 kernel。 kernel 所做的一個非常常見的事情是「redispatch」。 例如,當在 GPU 上使用 autocast 運行您的神經網路時,第一次呼叫將是 autocast kernel,它將處理任何潛在的 autocast 邏輯並向下 redispatch。 下一個功能將是 autograd,它將正確地建立 autograd 圖形,然後向下 redispatch。 最後,我們到達 CUDA 的後端 kernel,它將啟動正確的 CUDA kernel 並返回最終結果。 在退出的過程中,autograd 會將圖形附加到輸出,最後,autocast 將有機會在退出時進行任何更新。

dispatcher 的一個配置是呼叫所有這些功能和後端鍵的順序。最新的清單及其順序可以在 DispatchKey.h 裡面的 DispatchKey 列舉中找到。為了擴展 torch 的目的,本次討論中重要的訂單子集是

vmap -> Autocast -> Autograd -> ZeroTensor -> Neg/Conj -> Functionalize -> Python -> Backends

對於本次討論而言,最重要的鍵是 Python,因為每個定義了 __torch_dispatch__ 方法的 Tensor 子類別都會呼叫此功能。 正是從那裡呼叫了使用者定義的方法,並且可以在那裡任意覆寫行為。 從那裡,再次呼叫提供的 func 將執行「redispatch」。

此實作的一些重要含義是

  • 此程式碼在「所有功能之下」運行。 因此,它僅像常規後端一樣,負責產生每個 Tensor 的輸出值(並且可以而且應該忽略所有高級功能,如 autograd、autocast 等)。

  • 如果任何高階功能在沒有 redispatch 的情況下實作了給定的函數,它將永遠不會到達 Python 鍵,因此永遠不會觸發 __torch_dispatch__ 回呼。 這尤其發生在 CompositeImplicitAutograd 函數中,這些函數在 Autograd 層級評估而不進行 redispatch。 這是因為 CompositeImplicitAutograd 函數通過隱式呼叫其他原生操作來指定其 autograd 公式,因此在 Autograd 層級,該函數被分解為其原生操作,並評估這些操作。

  • 當回呼到 Python 並包裝結果時,使用與常規 PyTorch Python/C++ 綁定相同的轉換。 特別是,某些物件無法在 Python 中表示,需要特殊處理(例如,未定義的 Tensor 變成 None)。

  • 我們的原生函數會被延遲填充為 torch.ops.{namespace}.{func_name}.{overload_name} 作為可呼叫的 Python 物件,以便於從 Python 與它們互動。 給 __torch_dispatch__func 物件始終是來自此命名空間的條目。 此命名空間可用於直接呼叫原生操作並繞過常用的 Python API 和綁定程式碼。

如同 __torch_function__ 能夠介入所有 torch 的 Python API 和 Tensor 方法,__torch_dispatch__ 也能攔截所有進入 aten native API 的呼叫。請注意,Tensor 上的所有方法在進入 dispatcher 之前都會轉換為函式呼叫,因此在這裡會顯示為函式呼叫:torch.add(a, 2)a + 2 將會導致完全相同的 aten 呼叫。這些函式大多數定義在 native_functions.yaml 中,其中指定了這些函式的屬性以及它們的後端實作。它們的實作以及指定的特性會透過程式碼自動產生來註冊。一些更特殊的函式或特性也會在 C++ 程式碼庫的其他地方或使用者定義的 C++ 擴充中註冊。

也可以使用 torch.library 來新增 新的 native 函式。這個 Python 功能允許定義和/或新增新的實作到 native 函式。這可以用於新增缺少的 kernel、替換現有的 kernel 或定義全新的 native 函式。

您可以在 subclass zoo repo 中找到許多基於 __torch_dispatch__ 的子類別範例。

__torch_dispatch__ 呼叫慣例

@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
    pass

當使用者呼叫一個帶有具有 __torch_dispatch__ 的輸入的運算符時,該呼叫可能會轉發到 __torch_dispatch__。 args 和 kwargs 在呼叫 __torch_dispatch__ 之前會進行正規化,也就是說

  • kwargs 包含運算符架構中的僅限關鍵字的參數。如果一個 kwarg 等於它的預設值(在架構中),它將不會被傳遞。

  • args 包含所有其他參數,無論它們是如何傳遞給運算符的(位置 vs 關鍵字)。如果一個 arg 等於它的預設值,並且它是最右邊的位置參數,或者它右邊的所有參數都沒有被傳遞,它將不會被傳遞。

使用模式 (Modes) 擴充所有 torch API

不幸的是,有些函式不接受 Tensor 輸入。這意味著上述的子類別方法不能用於覆蓋所有 PyTorch 函式的行為。此外,如果使用案例需要攔截每個函式呼叫,則將每個 Tensor 更改為子類別可能會過於侵入性。

為了處理這種使用案例,我們引入了“模式 (Mode)”的概念。這些模式存在於 __torch_function____torch_dispatch__ 覆蓋中,分別透過子類別化 torch.overrides.TorchFunctionModetorch.utils._python_dispatch.TorchDispatchMode 來建立,並且用作上下文管理器。

為了簡化它與子類別和其他模式互動的描述,每當進入模式的上下文管理器時,每個函式的行為都好像參數列表的開頭有一個額外的 Tensor 參數,並且該模式作為子類別。 這尤其意味著所有模式處理器將在任何子類別處理器之前被呼叫,並且對應於內部上下文管理器的模式將始終首先執行。

同樣重要的是要注意,在給定的模式處理器中,此特定模式已停用,並且可以透過執行 with self: 手動重新啟用。

這是一個顯示每種類型記錄模式的範例

import torch
from torch.overrides import TorchFunctionMode, resolve_name
from torch.utils._python_dispatch import TorchDispatchMode

class FunctionLog(TorchFunctionMode):
    def __torch_function__(self, func, types, args, kwargs=None):
        print(f"Function Log: {resolve_name(func)}(*{args}, **{kwargs})")
        return func(*args, **(kwargs or {}))

class DispatchLog(TorchDispatchMode):
    def __torch_dispatch__(self, func, types, args, kwargs=None):
        print(f"Dispatch Log: {func}(*{args}, **{kwargs})")
        return func(*args, **(kwargs or {}))

def f():
    a = torch.rand(10, requires_grad=True)
    b = a * 2
    b.sum().backward()

print("TorchFunctionMode logging:")
with FunctionLog():
    f()

print("TorchDispatchMode logging:")
with DispatchLog():
    f()

它會印出以下內容,帶有額外的註解

TorchFunctionMode logging:
Function Log: torch.rand(*(10,), **{'requires_grad': True})
Function Log: torch.Tensor.mul(*(tensor([0.7164, 0.9897, 0.1745, 0.9336, 0.4287, 0.7989, 0.2169, 0.7474, 0.5624,
        0.5970], requires_grad=True), 2), **None)
Function Log: torch.Tensor.sum(*(tensor([1.4328, 1.9794, 0.3490, 1.8671, 0.8573, 1.5977, 0.4338, 1.4948, 1.1249,
        1.1939], grad_fn=<MulBackward0>),), **None)
# Note that at the python level, we only see the call to backward but not what happens in the autograd engine.
Function Log: torch.Tensor.backward(*(tensor(12.3307, grad_fn=<SumBackward0>),), **{'gradient': None, 'retain_graph': None, 'create_graph': False, 'inputs': None})

TorchDispatchMode logging:
# Here the requires_grad flag from autograd is removed while default arguments were populated.
Dispatch Log: aten.rand.default(*([10],), **{'device': device(type='cpu'), 'pin_memory': False})
Dispatch Log: aten.mul.Tensor(*(tensor([0.2151, 0.6018, 0.8415, 0.9060, 0.2974, 0.7708, 0.6668, 0.0352, 0.7948,
        0.6023], requires_grad=True), 2), **{})
Dispatch Log: aten.sum.default(*(tensor([0.4303, 1.2036, 1.6831, 1.8120, 0.5949, 1.5416, 1.3335, 0.0705, 1.5897,
        1.2046], grad_fn=<MulBackward0>),), **{})
# Here we don't see the call to backward itself, but its constituents. Starting here with the factory function that creates the initial gradient.
Dispatch Log: aten.ones_like.default(*(tensor(11.4637, grad_fn=<SumBackward0>),), **{'pin_memory': False, 'memory_format': torch.preserve_format})
# This is the backward of the sum
Dispatch Log: aten.expand.default(*(tensor(1.), [10]), **{})
Dispatch Log: aten.mul.Tensor(*(tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]), 2), **{})
Dispatch Log: aten.detach.default(*(tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]),), **{})
Dispatch Log: aten.detach.default(*(tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]),), **{})

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

取得初學者和進階開發人員的深入教學課程

檢視教學

資源

尋找開發資源並取得您的問題解答

檢視資源