如何撰寫您自己的 TVTensor 類別¶
本指南適用於進階使用者和下游程式庫維護者。我們將說明如何撰寫您自己的 TVTensor 類別,以及如何使其與內建的 Torchvision v2 轉換相容。在繼續之前,請確保您已閱讀TVTensors 常見問題。
import torch
from torchvision import tv_tensors
from torchvision.transforms import v2
我們將建立一個非常簡單的類別,僅繼承自基礎 TVTensor
類別。這足以涵蓋您需要了解的內容,以實作更精細的使用案例。如果您需要建立一個攜帶中繼資料的類別,請參考 BoundingBoxes
類別的實作方式。
class MyTVTensor(tv_tensors.TVTensor):
pass
my_dp = MyTVTensor([1, 2, 3])
my_dp
MyTVTensor([1., 2., 3.])
現在我們已經定義了自訂的 TVTensor 類別,我們希望它與內建的 torchvision 轉換和功能性 API 相容。為此,我們需要實作一個執行轉換核心的 kernel,然後透過 register_kernel()
將其「hook」到我們想要支援的功能。
我們在下面說明此過程:我們為 MyTVTensor 類別的「水平翻轉」操作建立一個 kernel,並將其註冊到功能性 API。
from torchvision.transforms.v2 import functional as F
@F.register_kernel(functional="hflip", tv_tensor_cls=MyTVTensor)
def hflip_my_tv_tensor(my_dp, *args, **kwargs):
print("Flipping!")
out = my_dp.flip(-1)
return tv_tensors.wrap(out, like=my_dp)
若要了解為何使用 wrap()
,請參閱 我有一個 TVTensor,但現在我有一個 Tensor。救命!。暫時忽略 *args, **kwargs
,我們將在 參數轉發,以及確保您的 kernel 未來的相容性 中說明。
注意
在我們對 register_kernel
的呼叫中,我們使用字串 functional="hflip"
來參照我們要 hook 的功能。我們也可以使用功能本身,即 @register_kernel(functional=F.hflip, ...)
。
現在我們已經註冊了 kernel,我們可以對 MyTVTensor
實例呼叫功能性 API
my_dp = MyTVTensor(torch.rand(3, 256, 256))
_ = F.hflip(my_dp)
Flipping!
我們也可以使用 RandomHorizontalFlip
轉換,因為它在內部依賴 hflip()
t = v2.RandomHorizontalFlip(p=1)
_ = t(my_dp)
Flipping!
注意
我們無法為轉換類別註冊 kernel,我們只能為功能註冊 kernel。我們無法註冊轉換類別的原因是,一個轉換可能在內部依賴多個功能,因此一般來說,我們無法為給定類別註冊單一 kernel。
參數轉發,以及確保您的 kernel 未來的相容性¶
您要 hook 的功能性 API 是公開的,因此具有向後相容性:我們保證這些功能的參數不會在沒有適當的棄用週期下被移除或重新命名。但是,我們不保證向前相容性,我們可能會在未來新增參數。
假設在未來的版本中,Torchvision 為其 hflip()
功能新增了一個新的 inplace
參數。如果您已經定義並註冊了自己的 kernel,如下所示
def hflip_my_tv_tensor(my_dp): # noqa
print("Flipping!")
out = my_dp.flip(-1)
return tv_tensors.wrap(out, like=my_dp)
那麼呼叫 F.hflip(my_dp)
將會失敗,因為 hflip
會嘗試將新的 inplace
參數傳遞給您的 kernel,但您的 kernel 不接受它。
因此,我們建議始終在 kernel 的簽名中使用 *args, **kwargs
來定義您的 kernel,如上所示。這樣,您的 kernel 將能夠接受我們未來可能新增的任何新參數。(技術上來說,僅新增 **kwargs 就足夠了)。
腳本總執行時間: (0 分鐘 0.003 秒)