捷徑

torch.overrides

此模組公開了各種用於 __torch_function__ 協議的輔助函數。 有關 __torch_function__ 協議的更多詳細資訊,請參閱擴展 torch Python API

函數

torch.overrides.get_ignored_functions()[原始碼][原始碼]

傳回無法被 __torch_function__ 覆寫的公開函數。

傳回

一個函數的元組,這些函數在 torch API 中是公開可用的,但不能使用 __torch_function__ 覆寫。主要是因為這些函數的參數都不是張量或類似張量的東西。

回傳類型

Set[Callable]

範例

>>> torch.Tensor.as_subclass in torch.overrides.get_ignored_functions()
True
>>> torch.add in torch.overrides.get_ignored_functions()
False
torch.overrides.get_overridable_functions()[原始碼][原始碼]

列出可透過 __torch_function__ 覆寫的函數

傳回

一個字典,將包含可覆寫函數的命名空間對應到該命名空間中可以被覆寫的函數。

回傳類型

Dict[Any, List[Callable]]

torch.overrides.resolve_name(f)[原始碼][原始碼]

取得傳遞給 __torch_function__ 的函數的可讀字串名稱

參數

f (Callable) – 要解析名稱的函數。

傳回

函數的名稱;如果經過 eval 運算,應該會回傳輸入的函數。

回傳類型

str

torch.overrides.get_testing_overrides()[原始碼][原始碼]

回傳一個包含所有可覆寫函數的虛擬覆寫的字典

傳回

一個字典,將 PyTorch API 中可覆寫的函數對應到具有與真實函數相同簽名且無條件回傳 -1 的 lambda 函數。 這些 lambda 函數可用於測試定義 __torch_function__ 的類型的 API 涵蓋率。

回傳類型

Dict[Callable, Callable]

範例

>>> import inspect
>>> my_add = torch.overrides.get_testing_overrides()[torch.add]
>>> inspect.signature(my_add)
<Signature (input, other, out=None)>
torch.overrides.handle_torch_function(public_api, relevant_args, *args, **kwargs)[原始碼][原始碼]

實作一個函數,並檢查 __torch_function__ 覆寫。

請參閱 torch::autograd::handle_torch_function,以了解此函數在 C++ 實作中的等效函數。

參數
  • public_api (function) – 最初像 public_api(*args, **kwargs) 這樣呼叫的公開 torch API 公開的函數,現在正在檢查其引數。

  • relevant_args (iterable) – 要檢查 __torch_function__ 方法的引數的可迭代物件。

  • args (tuple) – 最初傳遞給 public_api 的任意位置引數。

  • kwargs (tuple) – 最初傳遞給 public_api 的任意關鍵字引數。

傳回

呼叫 implementation__torch_function__ 方法的結果,視情況而定。

回傳類型

object

:raises TypeError : 如果找不到任何實作。

範例

>>> def func(a):
...     if has_torch_function_unary(a):
...         return handle_torch_function(func, (a,), a)
...     return a + 0
torch.overrides.has_torch_function()

檢查可迭代物件的元素中是否存在 __torch_function__ 實作,或是否已啟用 __torch_function__ 模式。 考慮精確的 TensorParameter 作為不可調度的。 使用此方法來保護對 handle_torch_function() 的呼叫; 請勿使用它來測試某個東西是否為類似張量,請改用 is_tensor_like()。 :param relevant_args: 要檢查 __torch_function__ 方法的引數的可迭代物件。 :type relevant_args: iterable

傳回

如果 relevant_args 的任何元素具有 __torch_function__ 實作,則為 True,否則為 False。

回傳類型

bool

另請參閱

torch.is_tensor_like

檢查某個東西是否為類似張量,包括精確的 Tensor

torch.overrides.is_tensor_like(inp)[原始碼][原始碼]

如果傳入的輸入是類似張量,則回傳 True

目前,只要輸入的類型上存在 __torch_function__ 屬性,就會發生這種情況。

範例

張量的子類通常是類似張量。

>>> class SubTensor(torch.Tensor): ...
>>> is_tensor_like(SubTensor([0]))
True

內建或使用者類型通常不是類似張量。

>>> is_tensor_like(6)
False
>>> is_tensor_like(None)
False
>>> class NotATensor: ...
>>> is_tensor_like(NotATensor())
False

但是,它們可以透過實作 __torch_function__ 來變成類似張量。

>>> class TensorLike:
...     @classmethod
...     def __torch_function__(cls, func, types, args, kwargs):
...         return -1
>>> is_tensor_like(TensorLike())
True
torch.overrides.is_tensor_method_or_property(func)[原始碼][原始碼]

如果傳入的函式是屬於 torch.Tensor 的方法或屬性的處理器,則傳回 True,如同傳入 __torch_function__

注意

對於屬性,必須傳入其 __get__ 方法。

特別是,由於以下原因,可能需要這樣做:

  1. 方法/屬性有時不包含 __module__ 插槽。

  2. 它們要求第一個傳入的參數是 torch.Tensor 的實例。

範例

>>> is_tensor_method_or_property(torch.Tensor.add)
True
>>> is_tensor_method_or_property(torch.add)
False
回傳類型

bool

torch.overrides.wrap_torch_function(dispatcher)[原始碼][原始碼]

使用與 __torch_function__ 相關的功能封裝給定的函式。

參數

dispatcher (Callable) – 一個可呼叫物件,它傳回傳遞到函式的 Tensor-like 物件的可迭代物件。

注意

此裝飾器可能會降低程式碼的效能。 一般來說,將您的程式碼表示為一系列本身支援 __torch_function__ 的函式就足夠了。 如果您發現自己處於罕見的情況,例如,如果您正在封裝一個底層函式庫,並且您也需要它適用於 Tensor-like 物件,則可以使用此函式。

範例

>>> def dispatcher(a):  # Must have the same signature as func
...     return (a,)
>>> @torch.overrides.wrap_torch_function(dispatcher)
>>> def func(a):  # This will make func dispatchable by __torch_function__
...     return a + 0

文件

存取 PyTorch 的完整開發人員文件

檢視文件

教學課程

取得初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得問題解答

檢視資源