UX 限制¶
functorch,如同 JAX,對於可以轉換的內容有限制。一般來說,JAX 的限制是轉換只能用於純函式:也就是說,輸出完全由輸入決定且不涉及副作用(如變異)的函式。
我們有類似的保證:我們的轉換適用於純函式。但是,我們確實支援某些原地操作。一方面,編寫與 functorch 轉換相容的程式碼可能需要更改您編寫 PyTorch 程式碼的方式,另一方面,您可能會發現我們的轉換讓您可以表達以前在 PyTorch 中難以表達的內容。
一般限制¶
所有 functorch 轉換都有一個限制,即函式不應指派給全域變數。相反地,函式的所有輸出都必須從函式返回。這個限制來自於 functorch 的實現方式:每個轉換都會將張量輸入包裝在特殊的 functorch 張量子類別中,以方便轉換。
因此,請不要使用以下程式碼
import torch
from functorch import grad
# Don't do this
intermediate = None
def f(x):
global intermediate
intermediate = x.sin()
z = intermediate.sin()
return z
x = torch.randn([])
grad_x = grad(f)(x)
請重寫 f
以返回 intermediate
def f(x):
intermediate = x.sin()
z = intermediate.sin()
return z, intermediate
grad_x, intermediate = grad(f, has_aux=True)(x)
torch.autograd API¶
如果您嘗試在被 vmap()
或其中一個 functorch 的 AD 轉換(vjp()
、jvp()
、jacrev()
、jacfwd()
)轉換的函式內使用 torch.autograd
API,例如 torch.autograd.grad
或 torch.autograd.backward
,則轉換可能無法對其進行轉換。如果無法這樣做,您將收到錯誤訊息。
這是 PyTorch 的 AD 支援實現方式的基本設計限制,也是我們設計 functorch 函式庫的原因。請改用 torch.autograd
API 的 functorch 等效項:- torch.autograd.grad
、Tensor.backward
-> functorch.vjp
或 functorch.grad
- torch.autograd.functional.jvp
-> functorch.jvp
- torch.autograd.functional.jacobian
-> functorch.jacrev
或 functorch.jacfwd
- torch.autograd.functional.hessian
-> functorch.hessian
vmap 限制¶
注意
vmap()
是我們限制最多的轉換。與梯度相關的轉換(grad()
、vjp()
、jvp()
)沒有這些限制。jacfwd()
(以及使用 jacfwd()
實現的 hessian()
)是 vmap()
和 jvp()
的組合,因此它也有這些限制。
vmap(func)
是一種轉換,它返回一個函式,該函式將 func
對映到每個輸入張量的新維度上。vmap 的心智模型就像是在執行一個 for 迴圈:對於純函式(即在沒有副作用的情況下),vmap(f)(x)
等價於
torch.stack([f(x_i) for x_i in x.unbind(0)])
變異:Python 資料結構的任意變異¶
在存在副作用的情況下,vmap()
不再像是在執行 for 迴圈。例如,以下函式
def f(x, list):
list.pop()
print("hello!")
return x.sum(0)
x = torch.randn(3, 1)
lst = [0, 1, 2, 3]
result = vmap(f, in_dims=(0, None))(x, lst)
將只列印一次“hello!”,並且只從 lst
中彈出一個元素。
vmap()
只執行一次 f,因此所有副作用只會發生一次。
這是 vmap 實現方式的結果。functorch 有一個特殊的內部 BatchedTensor 類別。vmap(f)(*inputs)
接收所有張量輸入,將它們轉換為 BatchedTensors,並呼叫 f(*batched_tensor_inputs)
。BatchedTensor 覆寫 PyTorch API 以為每個 PyTorch 運算符產生批次(即向量化)行為。
變異:原地 PyTorch 操作¶
如果 vmap()
遇到不支援的 PyTorch 原地操作,它將引發錯誤,否則它將成功。不支援的操作是那些會導致元素較多的張量被寫入元素較少的張量的操作。以下是如何發生這種情況的範例
def f(x, y):
x.add_(y)
return x
x = torch.randn(1)
y = torch.randn(3)
# Raises an error because `y` has fewer elements than `x`.
vmap(f, in_dims=(None, 0))(x, y)
x
是一個具有一個元素的張量,y
是一個具有三個元素的張量。x + y
具有三個元素(由於廣播),但嘗試將三個元素寫回只有一個元素的 x
會引發錯誤,因為嘗試將三個元素寫入只有一個元素的張量。
如果要寫入的張量具有相同數量(或更多)的元素,則沒有問題
def f(x, y):
x.add_(y)
return x
x = torch.randn(3)
y = torch.randn(3)
expected = x + y
# Does not raise an error because x and y have the same number of elements.
vmap(f, in_dims=(0, 0))(x, y)
assert torch.allclose(x, expected)
變異:out= PyTorch 操作¶
vmap()
不支援 PyTorch 操作中的 out=
關鍵字參數。如果在您的程式碼中遇到它,它將會正常地出錯。
這不是一個根本的限制;理論上我們可以在未來支援它,但我們現在選擇不支援。
資料相依的 Python 控制流程¶
我們尚不支援對資料相依的控制流程使用 vmap
。資料相依的控制流程是指 if 語句、while 迴圈或 for 迴圈的條件是一個正在被 vmap
對映的張量。例如,以下程式碼將引發錯誤訊息
def relu(x):
if x > 0:
return x
return 0
x = torch.randn(3)
vmap(relu)(x)
但是,任何不依賴於 vmap
中張量值的控制流程都可以正常運作
def custom_dot(x):
if x.dim() == 1:
return torch.dot(x, x)
return (x * x).sum()
x = torch.randn(3)
vmap(custom_dot)(x)
JAX 支援使用特殊的控制流程運算符(例如 jax.lax.cond
、jax.lax.while_loop
)對 資料相依的控制流程 進行轉換。我們正在研究將這些功能的等效功能添加到 functorch 中(在 GitHub 上建立一個議題以表達您的支援!)。
資料相依操作(.item())¶
我們不支援(未來也不會支援)在呼叫了張量的 .item()
的使用者自訂函數上使用 vmap。例如,以下程式碼將會產生錯誤訊息
def f(x):
return x.item()
x = torch.randn(3)
vmap(f)(x)
請嘗試改寫您的程式碼,避免使用 .item()
呼叫。
您可能也會遇到有關使用 .item()
的錯誤訊息,但您可能沒有使用它。在這些情況下,可能是 PyTorch 內部呼叫了 .item()
– 請在 GitHub 上提交 issue,我們會修復 PyTorch 內部程式碼。
動態形狀操作(nonzero 等)¶
vmap(f)
要求將 f
套用至輸入中的每個「範例」時,傳回的張量形狀必須相同。不支援 torch.nonzero
、torch.is_nonzero
等操作,因此會產生錯誤。
以下範例說明了原因
xs = torch.tensor([[0, 1, 2], [0, 0, 3]])
vmap(torch.nonzero)(xs)
torch.nonzero(xs[0])
傳回形狀為 2 的張量;但 torch.nonzero(xs[1])
傳回形狀為 1 的張量。我們無法建構單一形狀的張量作為輸出;輸出需要是不規則張量(而 PyTorch 尚不支援不規則張量的概念)。
隨機性¶
使用者呼叫隨機操作的意圖可能不清楚。具體來說,有些使用者可能希望隨機行為在批次之間保持一致,而其他使用者則可能希望隨機行為在批次之間有所不同。為了處理這個問題,vmap
採用了一個隨機性旗標。
該旗標只能傳遞給 vmap,並且可以採用 3 個值:「error」、「different」或「same」,預設值為 error。在「error」模式下,任何對隨機函數的呼叫都會產生錯誤,要求使用者根據其使用情況使用其他兩個旗標之一。
在「different」隨機性下,批次中的元素會產生不同的隨機值。例如:
def add_noise(x):
y = torch.randn(()) # y will be different across the batch
return x + y
x = torch.ones(3)
result = vmap(add_noise, randomness="different")(x) # we get 3 different values
在「same」隨機性下,批次中的元素會產生相同的隨機值。例如:
def add_noise(x):
y = torch.randn(()) # y will be the same across the batch
return x + y
x = torch.ones(3)
result = vmap(add_noise, randomness="same")(x) # we get the same value, repeated 3 times
警告
我們的系統只會判斷 PyTorch 運算子的隨機性行為,無法控制 numpy 等其他函式庫的行為。這與 JAX 解決方案的限制類似
注意
使用任何一種支援的隨機性多次呼叫 vmap 不會產生相同的結果。與標準 PyTorch 一樣,使用者可以透過在 vmap 之外使用 torch.manual_seed()
或使用產生器來獲得隨機性重現性。
注意
最後,我們的隨機性與 JAX 不同,因為我們沒有使用無狀態 PRNG,部分原因是 PyTorch 不完全支援無狀態 PRNG。相反的,我們引入了一個旗標系統,允許我們看到最常見的隨機性形式。如果您的使用情況不適合這些形式的隨機性,請提交 issue。