快捷方式

如何撰寫您自己的 TVTensor 類別

注意

Colab 上試用,或前往結尾下載完整範例程式碼。

本指南適用於進階使用者和下游程式庫維護者。我們將說明如何撰寫您自己的 TVTensor 類別,以及如何使其與內建的 Torchvision v2 轉換相容。在繼續之前,請確保您已閱讀TVTensors 常見問題

import torch
from torchvision import tv_tensors
from torchvision.transforms import v2

我們將建立一個非常簡單的類別,僅繼承自基礎 TVTensor 類別。這足以涵蓋您需要了解的內容,以實作更精細的使用案例。如果您需要建立一個攜帶中繼資料的類別,請參考 BoundingBoxes 類別的實作方式

class MyTVTensor(tv_tensors.TVTensor):
    pass


my_dp = MyTVTensor([1, 2, 3])
my_dp
MyTVTensor([1., 2., 3.])

現在我們已經定義了自訂的 TVTensor 類別,我們希望它與內建的 torchvision 轉換和功能性 API 相容。為此,我們需要實作一個執行轉換核心的 kernel,然後透過 register_kernel() 將其「hook」到我們想要支援的功能。

我們在下面說明此過程:我們為 MyTVTensor 類別的「水平翻轉」操作建立一個 kernel,並將其註冊到功能性 API。

from torchvision.transforms.v2 import functional as F


@F.register_kernel(functional="hflip", tv_tensor_cls=MyTVTensor)
def hflip_my_tv_tensor(my_dp, *args, **kwargs):
    print("Flipping!")
    out = my_dp.flip(-1)
    return tv_tensors.wrap(out, like=my_dp)

若要了解為何使用 wrap(),請參閱 我有一個 TVTensor,但現在我有一個 Tensor。救命!。暫時忽略 *args, **kwargs,我們將在 參數轉發,以及確保您的 kernel 未來的相容性 中說明。

注意

在我們對 register_kernel 的呼叫中,我們使用字串 functional="hflip" 來參照我們要 hook 的功能。我們也可以使用功能本身,即 @register_kernel(functional=F.hflip, ...)

現在我們已經註冊了 kernel,我們可以對 MyTVTensor 實例呼叫功能性 API

my_dp = MyTVTensor(torch.rand(3, 256, 256))
_ = F.hflip(my_dp)
Flipping!

我們也可以使用 RandomHorizontalFlip 轉換,因為它在內部依賴 hflip()

t = v2.RandomHorizontalFlip(p=1)
_ = t(my_dp)
Flipping!

注意

我們無法為轉換類別註冊 kernel,我們只能為功能註冊 kernel。我們無法註冊轉換類別的原因是,一個轉換可能在內部依賴多個功能,因此一般來說,我們無法為給定類別註冊單一 kernel。

參數轉發,以及確保您的 kernel 未來的相容性

您要 hook 的功能性 API 是公開的,因此具有向後相容性:我們保證這些功能的參數不會在沒有適當的棄用週期下被移除或重新命名。但是,我們不保證向前相容性,我們可能會在未來新增參數。

假設在未來的版本中,Torchvision 為其 hflip() 功能新增了一個新的 inplace 參數。如果您已經定義並註冊了自己的 kernel,如下所示

def hflip_my_tv_tensor(my_dp):  # noqa
    print("Flipping!")
    out = my_dp.flip(-1)
    return tv_tensors.wrap(out, like=my_dp)

那麼呼叫 F.hflip(my_dp) 將會失敗,因為 hflip 會嘗試將新的 inplace 參數傳遞給您的 kernel,但您的 kernel 不接受它。

因此,我們建議始終在 kernel 的簽名中使用 *args, **kwargs 來定義您的 kernel,如上所示。這樣,您的 kernel 將能夠接受我們未來可能新增的任何新參數。(技術上來說,僅新增 **kwargs 就足夠了)。

腳本總執行時間: (0 分鐘 0.003 秒)

由 Sphinx-Gallery 產生圖庫

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

取得針對初學者和進階開發者的深入教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源