快捷方式

UX 限制

JAX 一樣,torch.func 對於可以轉換的內容存在限制。 一般來說,JAX 的限制是轉換僅適用於純函數:也就是說,輸出完全由輸入決定且不涉及副作用(例如突變)的函數。

我們有類似的保證:我們的轉換與純函數搭配良好。但是,我們確實支援某些原地(in-place)操作。一方面,編寫與函數轉換相容的程式碼可能需要改變您編寫 PyTorch 程式碼的方式,另一方面,您可能會發現我們的轉換讓您可以表達以前在 PyTorch 中難以表達的事情。

一般限制

所有 `torch.func` 轉換都有一個共同的限制,即函數不應賦值給全域變數。相反,函數的所有輸出都必須從函數中返回。這個限制來自 `torch.func` 的實作方式:每個轉換都會將 Tensor 輸入包裝在特殊的 `torch.func` Tensor 子類中,以方便轉換。

所以,不要像下面這樣做:

import torch
from torch.func import grad

# Don't do this
intermediate = None

def f(x):
  global intermediate
  intermediate = x.sin()
  z = intermediate.sin()
  return z

x = torch.randn([])
grad_x = grad(f)(x)

請將 f 重寫為返回 intermediate

def f(x):
  intermediate = x.sin()
  z = intermediate.sin()
  return z, intermediate

grad_x, intermediate = grad(f, has_aux=True)(x)

torch.autograd APIs

如果您嘗試在被 vmap() 轉換的函數中,或是在 `torch.func` 的自動微分(AD)轉換(vjp()jvp()jacrev()jacfwd())中,使用 torch.autograd API,例如 torch.autograd.gradtorch.autograd.backward,轉換可能無法對其進行轉換。如果無法做到,您將收到錯誤訊息。

這是 PyTorch 的自動微分支援實作方式的一個根本設計限制,也是我們設計 `torch.func` 函式庫的原因。請改為使用 `torch.autograd` API 的 `torch.func` 等效項: - torch.autograd.gradTensor.backward -> torch.func.vjptorch.func.grad - torch.autograd.functional.jvp -> torch.func.jvp - torch.autograd.functional.jacobian -> torch.func.jacrevtorch.func.jacfwd - torch.autograd.functional.hessian -> torch.func.hessian

vmap 限制

注意

vmap() 是我們限制最多的轉換。與梯度相關的轉換(grad()vjp()jvp())沒有這些限制。jacfwd() (以及 hessian(),它是使用 jacfwd() 實現的) 是 vmap()jvp() 的組合,因此它也具有這些限制。

vmap(func) 是一個轉換,它返回一個函數,該函數將 func 映射到每個輸入 Tensor 的一些新維度上。`vmap` 的心智模型(mental model)是它就像運行一個 `for` 迴圈:對於純函數(即,在沒有副作用的情況下),vmap(f)(x) 等同於

torch.stack([f(x_i) for x_i in x.unbind(0)])

變更:Python 資料結構的任意變更

在存在副作用的情況下,vmap() 不再像運行 `for` 迴圈那樣。例如,以下函數

def f(x, list):
  list.pop()
  print("hello!")
  return x.sum(0)

x = torch.randn(3, 1)
lst = [0, 1, 2, 3]

result = vmap(f, in_dims=(0, None))(x, lst)

將會列印 "hello!" 一次,並且只從 lst 中彈出一個元素。

vmap() 執行 f 一次,因此所有副作用只發生一次。

這是 `vmap` 實作方式的結果。`torch.func` 有一個特殊的內部 `BatchedTensor` 類別。vmap(f)(*inputs) 取得所有 Tensor 輸入,將它們轉換為 `BatchedTensor`,並呼叫 f(*batched_tensor_inputs)。`BatchedTensor` 覆寫(override)了 PyTorch API,以產生每個 PyTorch 運算子的批次(即向量化)行為。

變更:原地(in-place) PyTorch 操作

您可能因為收到關於 `vmap` 不相容的原地操作的錯誤而來到這裡。vmap() 如果遇到不受支援的 PyTorch 原地操作,將會引發錯誤,否則將會成功。不受支援的操作是指那些會導致將具有更多元素的 Tensor 寫入具有較少元素的 Tensor 的操作。以下是一個可能發生這種情況的例子:

def f(x, y):
  x.add_(y)
  return x

x = torch.randn(1)
y = torch.randn(3, 1)  # When vmapped over, looks like it has shape [1]

# Raises an error because `x` has fewer elements than `y`.
vmap(f, in_dims=(None, 0))(x, y)

x 是一個包含單一元素的 Tensor,y 是一個包含三個元素的 Tensor。x + y 由於廣播 (broadcasting) 會有三個元素,但嘗試將這三個元素寫回只有單一元素的 x 時,會因為試圖將三個元素寫入只有一個元素的 Tensor 而引發錯誤。

如果正在寫入的 Tensor 是在 vmap() 下進行批次處理的 (batched) (也就是說,它正在被 vmap over),則沒有問題。

def f(x, y):
  x.add_(y)
  return x

x = torch.randn(3, 1)
y = torch.randn(3, 1)
expected = x + y

# Does not raise an error because x is being vmapped over.
vmap(f, in_dims=(0, 0))(x, y)
assert torch.allclose(x, expected)

一個常見的解決方案是用它們的 "new_*" 等效函數來取代對 factory functions 的呼叫。例如:

為了理解這為什麼有幫助,請考慮以下情況。

def diag_embed(vec):
  assert vec.dim() == 1
  result = torch.zeros(vec.shape[0], vec.shape[0])
  result.diagonal().copy_(vec)
  return result

vecs = torch.tensor([[0., 1, 2], [3., 4, 5]])

# RuntimeError: vmap: inplace arithmetic(self, *extra_args) is not possible ...
vmap(diag_embed)(vecs)

vmap() 內部,result 是一個形狀為 [3, 3] 的 Tensor。 然而,雖然 vec 看起來形狀為 [3],但 vec 實際上具有底層形狀 [2, 3]。 無法將 vec 複製到形狀為 [3] 的 result.diagonal() 中,因為它有太多元素。

def diag_embed(vec):
  assert vec.dim() == 1
  result = vec.new_zeros(vec.shape[0], vec.shape[0])
  result.diagonal().copy_(vec)
  return result

vecs = torch.tensor([[0., 1, 2], [3., 4, 5]])
vmap(diag_embed)(vecs)

torch.zeros() 替換為 Tensor.new_zeros() 使得 result 具有底層形狀為 [2, 3, 3] 的 Tensor,因此現在可以將具有底層形狀 [2, 3] 的 vec 複製到 result.diagonal() 中。

Mutation: out= PyTorch Operations

vmap() 不支援 PyTorch 運算中的 out= 關鍵字引數。 如果在您的程式碼中遇到它,它會優雅地引發錯誤。

這不是一個根本的限制; 我們理論上可以在未來支援它,但我們目前選擇不支援。

Data-dependent Python control flow

我們目前還不支援在資料相依的控制流程上使用 vmap。 資料相依的控制流程是指 if 語句、while 迴圈或 for 迴圈的條件是一個正在被 vmap 處理的 Tensor。 例如,以下程式碼會引發錯誤訊息:

def relu(x):
  if x > 0:
    return x
  return 0

x = torch.randn(3)
vmap(relu)(x)

然而,任何不依賴於 vmap 處理的 tensors 中的值的控制流程都可以正常工作:

def custom_dot(x):
  if x.dim() == 1:
    return torch.dot(x, x)
  return (x * x).sum()

x = torch.randn(3)
vmap(custom_dot)(x)

JAX 支援使用特殊的控制流程運算符(例如 jax.lax.condjax.lax.while_loop)在 資料相依的控制流程上進行轉換。 我們正在研究將這些運算符的等效項新增到 PyTorch 中。

Data-dependent operations (.item())

我們不支援(也不會支援)在使用者定義的函數上使用 vmap,該函數對 Tensor 呼叫 .item()。 例如,以下程式碼會引發錯誤訊息:

def f(x):
  return x.item()

x = torch.randn(3)
vmap(f)(x)

請嘗試重寫您的程式碼,以避免使用 .item() 呼叫。

您也可能會遇到關於使用 .item() 的錯誤訊息,但您可能沒有使用它。 在這些情況下,PyTorch 內部可能正在呼叫 .item() - 請在 GitHub 上提交 issue,我們會修復 PyTorch 內部。

Dynamic shape operations (nonzero and friends)

vmap(f) 要求 f 應用於輸入中的每個 "example" 時,都會傳回一個具有相同形狀的 Tensor。 不支援 torch.nonzerotorch.is_nonzero 等操作,因此會產生錯誤。

為了理解原因,請考慮以下範例:

xs = torch.tensor([[0, 1, 2], [0, 0, 3]])
vmap(torch.nonzero)(xs)

torch.nonzero(xs[0]) 傳回一個形狀為 2 的 Tensor; 但是 torch.nonzero(xs[1]) 傳回一個形狀為 1 的 Tensor。 我們無法建構一個單一的 Tensor 作為輸出; 輸出將需要是一個參差不齊的 (ragged) Tensor(而 PyTorch 尚未有參差不齊的 Tensor 的概念)。

Randomness

使用者在呼叫隨機運算時的意圖可能不明確。 具體來說,某些使用者可能希望批次之間的隨機行為相同,而另一些使用者可能希望批次之間的隨機行為不同。 為了處理這個問題,vmap 採用一個隨機性標誌。

該標誌只能傳遞給 vmap,並且可以採用 3 個值:"error"、"different" 或 "same",預設為 error。 在 "error" 模式下,對隨機函數的任何呼叫都會產生一個錯誤,要求使用者根據其使用案例使用其他兩個標誌之一。

在 "different" 隨機性下,批次中的元素會產生不同的隨機值。 例如:

def add_noise(x):
  y = torch.randn(())  # y will be different across the batch
  return x + y

x = torch.ones(3)
result = vmap(add_noise, randomness="different")(x)  # we get 3 different values

在 "same" 隨機性下,批次中的元素會產生相同的隨機值。 例如:

def add_noise(x):
  y = torch.randn(())  # y will be the same across the batch
  return x + y

x = torch.ones(3)
result = vmap(add_noise, randomness="same")(x)  # we get the same value, repeated 3 times

Warning

我們的系統僅確定 PyTorch 運算符的隨機性行為,無法控制其他庫(如 numpy)的行為。 這與 JAX 在其解決方案上的限制類似。

注意

使用任一種支援的隨機性進行多次 vmap 呼叫不會產生相同的結果。 與標準 PyTorch 一樣,使用者可以透過在 vmap 外部使用 torch.manual_seed() 或使用 generators 來獲得隨機性的可重複性。

注意

最後,我們的隨機性與 JAX 不同,因為我們沒有使用無狀態 PRNG,部分原因是 PyTorch 不完全支援無狀態 PRNG。 相反,我們引入了一個標誌系統,以允許我們看到的最常見的隨機性形式。 如果您的使用案例不符合這些隨機性形式,請提交 issue。

文件

取得 PyTorch 的完整開發者文件

查看文件

教學

取得針對初學者與進階開發者的深入教學

查看教學

資源

尋找開發資源並取得您問題的解答

查看資源