捷徑

functorch.functionalize

functorch.functionalize(func, *, remove='mutations')[原始碼]

functionalize 是一種轉換,可用於從函數中移除(中間)突變和別名,同時保留函數的語義。

functionalize(func) 會回傳一個與 func 語義相同的新函式,但所有中間的突變操作皆已移除。所有在中間張量上執行的原地操作:intermediate.foo_() 都會被其非原地等效操作取代:intermediate_updated = intermediate.foo()

當需要將 PyTorch 程式交付給不易表示突變或別名操作的後端或編譯器時,functionalize 非常有用。

參數
  • func (Callable) – 接受一個或多個引數的 Python 函式。

  • remove (str) – 一個可選的字串引數,其值可以是 ‘mutations’ 或 ‘mutations_and_views’。如果傳入 ‘mutations’,則所有突變操作將會被其非突變等效操作取代。如果傳入 ‘mutations_and_views’,則此外,所有別名操作也將被其非別名等效操作取代。預設值:’mutations’。

回傳值

回傳一個新的「函式化」函式。它接受與 func 相同的輸入,並且具有相同的行為,但函式中對中間張量執行的任何突變(以及可選的別名)都將被移除。

functionalize 也會移除對函式輸入執行的突變(和視圖)。然而,為了保持語義,functionalize 會在轉換完成運行後「修正」突變,方法是檢測是否有任何張量輸入「應該」被突變,並在必要時將新數據複製回輸入。

範例

>>> # xdoctest: +SKIP
>>> import torch
>>> from torch.fx.experimental.proxy_tensor import make_fx
>>> from torch.func import functionalize
>>>
>>> # A function that uses mutations and views, but only on intermediate tensors.
>>> def f(a):
...     b = a + 1
...     c = b.view(-1)
...     c.add_(1)
...     return b
...
>>> inpt = torch.randn(2)
>>>
>>> out1 = f(inpt)
>>> out2 = functionalize(f)(inpt)
>>>
>>> # semantics are the same (outputs are equivalent)
>>> print(torch.allclose(out1, out2))
True
>>>
>>> f_traced = make_fx(f)(inpt)
>>> f_no_mutations_traced = make_fx(functionalize(f))(inpt)
>>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt)
>>>
>>> print(f_traced.code)



def forward(self, a_1):
    add = torch.ops.aten.add(a_1, 1);  a_1 = None
    view = torch.ops.aten.view(add, [-1])
    add_ = torch.ops.aten.add_(view, 1);  view = None
    return add

>>> print(f_no_mutations_traced.code)



def forward(self, a_1):
    add = torch.ops.aten.add(a_1, 1);  a_1 = None
    view = torch.ops.aten.view(add, [-1]);  add = None
    add_1 = torch.ops.aten.add(view, 1);  view = None
    view_1 = torch.ops.aten.view(add_1, [2]);  add_1 = None
    return view_1

>>> print(f_no_mutations_and_views_traced.code)



def forward(self, a_1):
    add = torch.ops.aten.add(a_1, 1);  a_1 = None
    view_copy = torch.ops.aten.view_copy(add, [-1]);  add = None
    add_1 = torch.ops.aten.add(view_copy, 1);  view_copy = None
    view_copy_1 = torch.ops.aten.view_copy(add_1, [2]);  add_1 = None
    return view_copy_1


>>> # A function that mutates its input tensor
>>> def f(a):
...     b = a.view(-1)
...     b.add_(1)
...     return a
...
>>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt)
>>> #
>>> # All mutations and views have been removed,
>>> # but there is an extra copy_ in the graph to correctly apply the mutation to the input
>>> # after the function has completed.
>>> print(f_no_mutations_and_views_traced.code)



def forward(self, a_1):
    view_copy = torch.ops.aten.view_copy(a_1, [-1])
    add = torch.ops.aten.add(view_copy, 1);  view_copy = None
    view_copy_1 = torch.ops.aten.view_copy(add, [2]);  add = None
    copy_ = torch.ops.aten.copy_(a_1, view_copy_1);  a_1 = None
    return view_copy_1
functionalize 有一些值得一提的「失敗模式」
  1. 如同其他 torch.func 轉換,functionalize() 不適用於直接使用 .backward() 的函式。torch.autograd.grad 也是如此。如果您想使用 autograd,您可以直接使用 functionalize(grad(f)) 計算梯度。

  2. 如同其他 torch.func 轉換,functionalize() 不適用於全域狀態。如果您在使用非區域狀態的視圖/突變的函式上呼叫 functionalize(f),函式化將只會無效操作,並將視圖/突變呼叫直接傳遞給後端。一種解決方法是確保將任何非區域狀態建立包裝到一個更大的函式中,然後再對其呼叫 functionalize。

  3. resize_() 有一些限制:只要正在調整大小的張量不是視圖,functionalize 就適用於使用 resize_() 的程式。

  4. as_strided() 有一些限制:functionalize 不適用於導致張量記憶體重疊的 as_strided() 呼叫。

最後,理解函式化的一個有用的心智模型是,大多數使用者 PyTorch 程式都是使用公開的 torch API 編寫的。執行時,torch 運算子通常會分解成我們內部的 C++「ATen」API。函式化的邏輯完全發生在 ATen 層級。函式化知道如何將 ATen 中的每個別名運算子映射到其非別名等效運算子(例如 tensor.view({-1}) -> at::view_copy(tensor, {-1})),以及如何將 ATen 中的每個突變運算子映射到其非突變等效運算子(例如 tensor.add_(1) -> at::add(tensor, -1)),同時離線追蹤別名和突變,以了解何時需要修正。關於哪些 ATen 運算子是別名或突變的資訊都來自 https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml

警告

我們已將 functorch 整合到 PyTorch 中。作為整合的最後一步,functorch.functionalize 從 PyTorch 2.0 開始已被棄用,並將在 PyTorch 2.3 以後的版本中刪除。請改用 torch.func.functionalize;有關更多詳細資訊,請參閱 PyTorch 2.0 版本說明和/或 torch.func 遷移指南 https://pytorch.dev.org.tw/docs/master/func.migrating.html

文件

存取 PyTorch 的完整開發人員文件

查看文件

教學課程

取得適用於初學者和進階開發人員的深入教學課程

查看教學課程

資源

尋找開發資源並獲得問題解答

查看資源