快捷方式

Autograd 機制

本說明將概述 autograd 的工作方式和記錄操作的方式。雖然不一定需要完全理解這些,但我們建議您熟悉它,因為它將幫助您編寫更有效率、更簡潔的程式,並且可以幫助您進行除錯。

autograd 如何編碼歷史記錄

Autograd 是一個反向自動微分系統。從概念上講,autograd 會記錄一個圖,紀錄你在執行運算時,所有建立資料的運算,產生一個有向無環圖,其葉節點是輸入張量,而根節點是輸出張量。透過從根節點到葉節點追蹤這個圖,你可以使用鏈式法則自動計算梯度。

在內部,autograd 將這個圖表示為 Function 物件(實際上是表達式)的圖,這些物件可以被 apply() 以計算評估該圖的結果。當計算前向傳遞時,autograd 同時執行請求的計算並建立一個圖,該圖表示計算梯度的函數(每個 torch.Tensor.grad_fn 屬性是進入該圖的入口點)。當前向傳遞完成時,我們在反向傳遞中評估這個圖以計算梯度。

需要注意的重要一點是,該圖在每次迭代時都會從頭開始重建,而這正是允許使用任意 Python 控制流程語句的原因,這些語句可以在每次迭代時更改圖的整體形狀和大小。你無需在啟動訓練之前對所有可能的路徑進行編碼 - 你所運行的是你所微分的。

儲存的張量

某些運算需要在前向傳遞期間儲存中間結果,以便執行反向傳遞。例如,函數 xx2x\mapsto x^2 會儲存輸入 xx 以計算梯度。

在定義自定義 Python Function 時,你可以使用 save_for_backward() 在前向傳遞期間儲存張量,並使用 saved_tensors 在反向傳遞期間檢索它們。 有關更多信息,請參閱 擴展 PyTorch

對於 PyTorch 定義的運算(例如 torch.pow()),張量會根據需要自動儲存。 你可以探索(為了教育或除錯目的)哪些張量被特定的 grad_fn 儲存,方法是查找以字首 _saved 開頭的屬性。

x = torch.randn(5, requires_grad=True)
y = x.pow(2)
print(x.equal(y.grad_fn._saved_self))  # True
print(x is y.grad_fn._saved_self)  # True

在先前的程式碼中,y.grad_fn._saved_self 指的是與 x 相同的 Tensor 物件。 但情況可能並非總是如此。 例如

x = torch.randn(5, requires_grad=True)
y = x.exp()
print(y.equal(y.grad_fn._saved_result))  # True
print(y is y.grad_fn._saved_result)  # False

在底層,為了防止參考循環,PyTorch 在儲存時封裝了張量,並在讀取時將其解封裝到不同的張量中。 在此,你從存取 y.grad_fn._saved_result 獲得的張量是一個與 y 不同的張量物件(但它們仍然共享相同的儲存空間)。

張量是否會被封裝到不同的張量物件中取決於它是否是其自身 grad_fn 的輸出,這是一個實現細節,可能會發生變化,使用者不應依賴它。

你可以使用 儲存的張量的 Hooks 來控制 PyTorch 如何進行封裝/解封裝。

不可微分函數的梯度

僅當使用的每個基本函數都可微分時,使用自動微分計算的梯度才有效。 不幸的是,我們在實踐中使用的許多函數不具有此屬性(例如,relusqrt0 處)。 為了嘗試減少不可微分函數的影響,我們透過按順序應用以下規則來定義基本運算的梯度

  1. 如果函數可微分,因此在目前點存在梯度,請使用它。

  2. 如果函數是凸函數(至少在局部),請使用最小範數的次梯度(它是最速下降方向)。

  3. 如果函數是凹函數(至少在局部),請使用最小範數的超梯度(考慮 -f(x) 並應用前一點)。

  4. 如果定義了函數,請透過連續性定義目前點的梯度(請注意,此處可能出現 inf,例如對於 sqrt(0))。 如果有多個可能的值,請任意選擇一個。

  5. 如果未定義函數(例如 sqrt(-1)log(-1) 或輸入為 NaN 時的大多數函數),則用作梯度的值是任意的(我們也可能會引發錯誤,但不保證)。 大多數函數將使用 NaN 作為梯度,但出於效能原因,某些函數將使用其他值(例如 log(-1))。

  6. 如果函數不是確定性映射(即,它不是一個 數學函數),它將被標記為不可微分。 如果在需要 grad 的張量上,在 no_grad 環境之外使用它,這將導致反向傳遞中出錯。

局部禁用梯度計算

有多種機制可以從 Python 局部禁用梯度計算

要跨整個程式碼區塊禁用梯度,可以使用上下文管理器,例如 no-grad 模式和推理模式。 對於從梯度計算中更精細地排除子圖,可以設定張量的 requires_grad 欄位。

除了討論上述機制之外,以下我們也會說明評估模式 (evaluation mode)(nn.Module.eval()),這個方法並非用來停用梯度計算,但因為其名稱,常常會與上述三種方法混淆。

設定 requires_grad

requires_grad 是一個旗標,預設為 false,*除非被包裝在* nn.Parameter 中,它允許從梯度計算中精細地排除子圖。它在正向和反向傳播中都會生效

在正向傳播期間,只有當至少有一個輸入張量需要梯度時,操作才會被記錄在反向圖中。在反向傳播(.backward())期間,只有 requires_grad=True 的葉節點張量,其梯度才會累積到它們的 .grad 欄位中。

很重要的一點是,即使每個張量都有這個旗標,*設定* 它只對葉節點張量(沒有 grad_fn 的張量,例如 nn.Module 的參數)才有意義。非葉節點張量(有 grad_fn 的張量)是與之關聯的反向圖的張量。因此,它們的梯度將被需要作為中間結果,來計算需要梯度的葉節點張量的梯度。從這個定義可以清楚地看出,所有非葉節點張量將自動具有 require_grad=True

設定 requires_grad 應該是你控制模型哪些部分參與梯度計算的主要方式,例如,如果你需要在模型微調期間凍結預訓練模型的部分。

要凍結模型的部分,只需將 .requires_grad_(False) 應用到你不希望更新的參數。如上所述,由於使用這些參數作為輸入的計算不會記錄在正向傳播中,因此它們的 .grad 欄位也不會在反向傳播中更新,因為它們從一開始就不會是反向圖的一部分,正如預期的那樣。

因為這是一個非常常見的模式,所以 requires_grad 也可以在模組層級使用 nn.Module.requires_grad_() 來設定。當應用於模組時,.requires_grad_() 對模組的所有參數生效(這些參數預設具有 requires_grad=True)。

梯度模式 (Grad Modes)

除了設定 requires_grad 之外,還有三種可以從 Python 中選擇的梯度模式,這些模式會影響 PyTorch 在內部如何透過 autograd 處理計算:預設模式 (grad mode)、no-grad 模式和 inference 模式,所有這些模式都可以透過上下文管理器和裝飾器進行切換。

模式

將操作從反向圖的記錄中排除

跳過額外的 autograd 追蹤開銷

在啟用模式時建立的張量可以在之後的 grad-mode 中使用

範例

預設

正向傳播

no-grad

優化器更新

inference

資料處理、模型評估

預設模式 (梯度模式) (Default Mode (Grad Mode))

「預設模式」是指當沒有啟用其他模式(如 no-grad 和 inference 模式)時,我們隱含處於的模式。為了與「no-grad 模式」形成對比,預設模式有時也稱為「梯度模式」。

關於預設模式,最重要的是要知道它是 requires_grad 生效的唯一模式。在其他兩種模式中,requires_grad 總是會被覆寫為 False

No-grad 模式

No-grad 模式中的計算行為就像沒有任何輸入需要梯度一樣。換句話說,即使存在具有 require_grad=True 的輸入,no-grad 模式中的計算也永遠不會記錄在反向圖中。

當你需要執行不應被 autograd 記錄的操作,但你仍然希望稍後在梯度模式中使用這些計算的輸出時,請啟用 no-grad 模式。這個上下文管理器可以方便地停用程式碼塊或函式的梯度,而無需暫時將張量設定為 requires_grad=False,然後再設定回 True

例如,no-grad 模式可能在編寫優化器時很有用:當執行訓練更新時,你希望就地更新參數,而無需讓 autograd 記錄更新。你也打算在下一個正向傳播中將更新後的參數用於梯度模式中的計算。

torch.nn.init 中的實作也依賴於 no-grad 模式,以便在初始化參數時避免 autograd 追蹤,從而在原地更新初始化的參數。

Inference 模式

Inference 模式是 no-grad 模式的極端版本。就像在 no-grad 模式中一樣,inference 模式中的計算不會記錄在反向圖中,但啟用 inference 模式將使 PyTorch 能夠更快地加速你的模型。這個更好的運行時有一個缺點:在 inference 模式中建立的張量將無法在退出 inference 模式後用於將被 autograd 記錄的計算中。

當你執行不與 autograd 互動的計算,*且* 你不打算在任何稍後將被 autograd 記錄的計算中使用在 inference 模式中建立的張量時,請啟用 inference 模式。

建議您在程式碼中不需要自動微分追蹤的部分(例如:資料處理和模型評估)嘗試使用推論模式。如果它能直接適用於您的使用情境,這將會是一個免費的效能提升。如果在啟用推論模式後遇到錯誤,請檢查您是否在退出推論模式後,於自動微分紀錄的計算中使用在推論模式中建立的 tensors。如果無法避免這種使用方式,您可以隨時切換回 no-grad 模式。

有關推論模式的詳細資訊,請參閱 推論模式

有關推論模式的實作細節,請參閱 RFC-0011-InferenceMode

評估模式(nn.Module.eval()

評估模式並非用於在本機停用梯度計算的機制。這裡包含它只是因為有時它會被誤認為是這樣的機制。

從功能上來說,module.eval()(或等效的 module.train(False))與 no-grad 模式和推論模式完全正交。model.eval() 如何影響您的模型完全取決於模型中使用的特定模組,以及它們是否定義了任何訓練模式特定的行為。

如果您的模型依賴於諸如 torch.nn.Dropouttorch.nn.BatchNorm2d 之類的模組,這些模組可能會根據訓練模式而有不同的行為,例如避免在驗證資料上更新 BatchNorm 的 running statistics,那麼您有責任呼叫 model.eval()model.train()

建議您在訓練時始終使用 model.train(),在評估模型(驗證/測試)時使用 model.eval(),即使您不確定您的模型是否具有訓練模式特定的行為,因為您使用的模組可能會更新為在訓練和評估模式下有不同的行為。

使用 autograd 的 In-place 操作

在 autograd 中支援 in-place 操作是一件困難的事情,我們不鼓勵在大多數情況下使用它們。Autograd 的積極緩衝區釋放和重複使用使其非常有效率,並且很少有 in-place 操作可以顯著降低記憶體使用量。除非您在嚴重的記憶體壓力下運作,否則您可能永遠不需要使用它們。

有兩個主要原因限制了 in-place 操作的適用性

  1. In-place 操作可能會覆寫計算梯度所需的值。

  2. 每個 in-place 操作都需要實作來重寫計算圖。Out-of-place 版本只是分配新的物件並保留對舊圖的引用,而 in-place 操作則需要變更所有輸入的建立者到表示此操作的 Function。這可能很棘手,特別是如果有很多 Tensors 引用相同的儲存空間(例如,通過索引或轉置建立),如果修改後的輸入的儲存空間被任何其他 Tensor 引用,in-place 函數將會引發錯誤。

In-place 正確性檢查

每個 tensor 都會保留一個版本計數器,每次在任何操作中將其標記為 dirty 時,該計數器都會遞增。當 Function 為 backward 儲存任何 tensors 時,也會儲存其包含 Tensor 的版本計數器。一旦您存取 self.saved_tensors,就會進行檢查,如果它大於儲存的值,則會引發錯誤。這確保了如果您正在使用 in-place 函數並且沒有看到任何錯誤,您可以確定計算出的梯度是正確的。

多執行緒 Autograd

autograd 引擎負責執行所有必要的 backward 操作,以計算 backward pass。本節將描述所有細節,以幫助您在多執行緒環境中充分利用它。(這僅適用於 PyTorch 1.6+,因為先前版本的行為不同。)

使用者可以使用多執行緒程式碼(例如 Hogwild 訓練)訓練其模型,並且不會阻塞並行的 backward 計算,範例程式碼可能如下所示:

# Define a train function to be used in different threads
def train_fn():
    x = torch.ones(5, 5, requires_grad=True)
    # forward
    y = (x + 3) * (x + 4) * 0.5
    # backward
    y.sum().backward()
    # potential optimizer update


# User write their own threading code to drive the train_fn
threads = []
for _ in range(10):
    p = threading.Thread(target=train_fn, args=())
    p.start()
    threads.append(p)

for p in threads:
    p.join()

請注意,使用者應注意以下一些行為:

CPU 上的並行性

當您在 CPU 上的多個執行緒中通過 python 或 C++ API 執行 backward()grad() 時,您期望看到額外的並行性,而不是在執行期間以特定順序序列化所有 backward 呼叫(PyTorch 1.6 之前的行為)。

非確定性

如果您從多個執行緒並行呼叫 backward() 並且具有共享輸入(即 Hogwild CPU 訓練),則應預期存在非確定性。發生這種情況的原因是參數會自動在執行緒之間共享,因此,多個執行緒可能會在梯度累積期間存取並嘗試累積相同的 .grad 屬性。從技術上講,這是不安全的,並且可能導致競爭條件,並且結果可能無效。

開發具有共享參數的多執行緒模型的使用者應牢記執行緒模型,並應了解上述問題。

可以使用 functional API torch.autograd.grad() 來計算梯度,而不是使用 backward() 以避免非確定性。

圖形保留

如果 autograd 圖形的一部分在執行緒之間共享,即首先以單執行緒執行 forward 的第一部分,然後在多個執行緒中執行第二部分,則共享圖形的第一部分。在這種情況下,不同的執行緒在同一圖形上執行 grad()backward() 可能會導致一個執行緒即時破壞圖形的問題,並且另一個執行緒會因此崩潰。Autograd 將會向使用者報告錯誤,類似於在沒有 retain_graph=True 的情況下呼叫 backward() 兩次,並告知使用者應使用 retain_graph=True

Autograd 節點上的執行緒安全性

由於 Autograd 允許呼叫者執行緒驅動其 backward 執行以實現潛在的平行性,因此對於共享部分/全部 GraphTask 的平行 backward() 呼叫,我們確保 CPU 上的執行緒安全性非常重要。

由於 GIL,自定義 Python autograd.Functions 會自動具有執行緒安全性。對於內建 C++ Autograd 節點(例如 AccumulateGrad、CopySlices)和自定義 autograd::Functions,Autograd 引擎使用執行緒互斥鎖定來確保可能具有狀態寫入/讀取的 autograd 節點上的執行緒安全性。

C++ Hooks 不具備執行緒安全性

Autograd 仰賴使用者撰寫具備執行緒安全性的 C++ Hooks。如果您希望 Hook 能夠在多執行緒環境中正確運作,您需要撰寫適當的執行緒鎖定程式碼,以確保 Hooks 具備執行緒安全性。

複數的 Autograd

簡短版本

  • 當您使用 PyTorch 對任何函數 f(z)f(z) 進行微分,而該函數具有複數的定義域和/或值域時,梯度的計算基於一個假設,即該函數是一個更大的實數值損失函數 g(input)=Lg(input)=L 的一部分。計算出的梯度為 Lz\frac{\partial L}{\partial z^*} (請注意 z 的共軛),其負值正是梯度下降演算法中使用的最速下降方向。因此,現有的優化器有機會可以開箱即用,直接與複數參數搭配使用。

  • 此慣例與 TensorFlow 的複數微分慣例相符,但與 JAX 不同(JAX 計算 Lz\frac{\partial L}{\partial z})。

  • 如果您有一個從實數到實數的函數,但在內部使用了複數運算,那麼此慣例無關緊要:您將始終獲得與僅使用實數運算實現時相同的結果。

如果您對數學細節感到好奇,或者想知道如何在 PyTorch 中定義複數導數,請繼續閱讀。

什麼是複數導數?

複數可微性的數學定義採用導數的極限定義,並將其推廣到複數。考慮函數 f:CCf: ℂ → ℂ,

f(z=x+yj)=u(x,y)+v(x,y)jf(z=x+yj) = u(x, y) + v(x, y)j

其中 uuvv 是兩個變數的實值函數,而 jj 是虛數單位。

使用導數的定義,我們可以寫成

f(z)=limh0,hCf(z+h)f(z)hf'(z) = \lim_{h \to 0, h \in C} \frac{f(z+h) - f(z)}{h}

為了使此極限存在,不僅 uuvv 必須是實可微的,而且 ff 還必須滿足柯西-黎曼 方程式。 換句話說:對於實數和虛數的步長 (hh) 計算出的極限必須相等。 這是一個更嚴格的條件。

複數可微函數通常被稱為全純函數。 它們表現良好,具有您從實數可微函數中所看到的所有良好特性,但實際上在優化領域中沒有用處。 對於優化問題,研究社群中只使用實值目標函數,因為複數不屬於任何有序域,因此具有複數值的損失沒有多大意義。

此外,沒有有趣的實值目標函數滿足柯西-黎曼方程式。 因此,全純函數的理論不能用於優化,因此大多數人使用 Wirtinger 微積分。

Wirtinger 微積分開始發揮作用…

所以,我們有非常棒的複可微性 (complex differentiability) 和全純函數 (holomorphic functions) 理論,但我們根本沒辦法使用它,因為許多常用的函數都不是全純函數。可憐的數學家該怎麼辦呢?嗯,魏廷格 (Wirtinger) 觀察到,即使 f(z)f(z) 不是全純的,仍然可以將它重寫成一個雙變數函數 f(z,z)f(z, z*),它總是全純的。這是因為 zz 的實部和虛部可以用 zzzz^* 來表示,如下:

Re(z)=z+z2Im(z)=zz2j\begin{aligned} \mathrm{Re}(z) &= \frac {z + z^*}{2} \\ \mathrm{Im}(z) &= \frac {z - z^*}{2j} \end{aligned}

Wirtinger 微積分建議研究 f(z,z)f(z, z^*),如果 ff 是實可微的,則保證是全純的 (另一種思考方式是將其視為座標系統的變換,從 f(x,y)f(x, y)f(z,z)f(z, z^*))。此函數具有偏導數 z\frac{\partial }{\partial z}z\frac{\partial}{\partial z^{*}}。 我們可以使用連鎖律來建立這些偏導數與 zz 的實部和虛部的偏導數之間的關係。

x=zxz+zxz=z+zy=zyz+zyz=1j(zz)\begin{aligned} \frac{\partial }{\partial x} &= \frac{\partial z}{\partial x} * \frac{\partial }{\partial z} + \frac{\partial z^*}{\partial x} * \frac{\partial }{\partial z^*} \\ &= \frac{\partial }{\partial z} + \frac{\partial }{\partial z^*} \\ \\ \frac{\partial }{\partial y} &= \frac{\partial z}{\partial y} * \frac{\partial }{\partial z} + \frac{\partial z^*}{\partial y} * \frac{\partial }{\partial z^*} \\ &= 1j * \left(\frac{\partial }{\partial z} - \frac{\partial }{\partial z^*}\right) \end{aligned}

從上面的方程式,我們得到

z=1/2(x1jy)z=1/2(x+1jy)\begin{aligned} \frac{\partial }{\partial z} &= 1/2 * \left(\frac{\partial }{\partial x} - 1j * \frac{\partial }{\partial y}\right) \\ \frac{\partial }{\partial z^*} &= 1/2 * \left(\frac{\partial }{\partial x} + 1j * \frac{\partial }{\partial y}\right) \end{aligned}

這就是你在維基百科上會找到的 Wirtinger 微積分的經典定義。

這種改變有很多美好的結果。

  • 首先,柯西-黎曼方程式可以簡化為 fz=0\frac{\partial f}{\partial z^*} = 0 (也就是說,函數 ff 可以完全用 zz 表示,而不需要參考 zz^*)。

  • 另一個重要的(且有些違反直覺的)結果,我們稍後會看到,當我們對實數值的損失函數進行優化時,在更新變數時應該採取的步驟是由 Lossz\frac{\partial Loss}{\partial z^*} 給出的 (不是 Lossz\frac{\partial Loss}{\partial z})。

更多閱讀資料請參考:https://arxiv.org/pdf/0906.4835.pdf

Wirtinger Calculus 在最佳化中有什麼用處?

音訊和其他領域的研究人員更常使用梯度下降來優化具有複數變數的實數值損失函數。 通常,這些人將實數值和虛數值視為可以更新的單獨通道。 對於步長 α/2\alpha/2 和損失 LL,我們可以在 R2ℝ^2 中寫出以下方程式

xn+1=xn(α/2)Lxyn+1=yn(α/2)Ly\begin{aligned} x_{n+1} &= x_n - (\alpha/2) * \frac{\partial L}{\partial x} \\ y_{n+1} &= y_n - (\alpha/2) * \frac{\partial L}{\partial y} \end{aligned}

這些方程式如何轉換到複數空間 C 裡?

zn+1=xn(α/2)Lx+1j(yn(α/2)Ly)=znα1/2(Lx+jLy)=znαLz\begin{aligned} z_{n+1} &= x_n - (\alpha/2) * \frac{\partial L}{\partial x} + 1j * (y_n - (\alpha/2) * \frac{\partial L}{\partial y}) \\ &= z_n - \alpha * 1/2 * \left(\frac{\partial L}{\partial x} + j \frac{\partial L}{\partial y}\right) \\ &= z_n - \alpha * \frac{\partial L}{\partial z^*} \end{aligned}

發生了一些非常有趣的事情:Wirtinger 微積分告訴我們,我們可以簡化上面的複數變數更新公式,使其僅參考共軛 Wirtinger 導數 Lz\frac{\partial L}{\partial z^*},這給了我們在最佳化中採取的步驟。

由於共軛 Wirtinger 導數給了我們實數值損失函數的正確步驟,因此當您對具有實數值損失的函數進行微分時,PyTorch 會給您這個導數。

PyTorch 如何計算共軛 Wirtinger 導數?

通常,我們的導數公式會將 grad_output 作為輸入,代表我們已經計算過的傳入向量-雅可比行列式乘積,也就是 Ls\frac{\partial L}{\partial s^*},其中 LL 是整個計算過程的損失(產生一個實數損失),而 ss 是我們函數的輸出。這裡的目標是計算 Lz\frac{\partial L}{\partial z^*},其中 zz 是函數的輸入。事實證明,在實數損失的情況下,我們可以*僅*計算 Ls\frac{\partial L}{\partial s^*} 就可以,即使鏈式法則意味著我們也需要存取 Ls\frac{\partial L}{\partial s}。如果您想跳過這個推導,請查看本節中的最後一個方程式,然後跳到下一節。

讓我們繼續使用定義為 f:CCf: ℂ → ℂf(z)=f(x+yj)=u(x,y)+v(x,y)jf(z) = f(x+yj) = u(x, y) + v(x, y)j。 如上所述,autograd 的梯度慣例以實數值損失函數的優化為中心,因此假設 ff 是較大的實數值損失函數 gg 的一部分。 使用鏈式法則,我們可以寫成

(1)Lz=Luuz+Lvvz\frac{\partial L}{\partial z^*} = \frac{\partial L}{\partial u} * \frac{\partial u}{\partial z^*} + \frac{\partial L}{\partial v} * \frac{\partial v}{\partial z^*}

現在使用 Wirtinger 導數的定義,我們可以寫成

Ls=1/2(LuLvj)Ls=1/2(Lu+Lvj)\begin{aligned} \frac{\partial L}{\partial s} = 1/2 * \left(\frac{\partial L}{\partial u} - \frac{\partial L}{\partial v} j\right) \\ \frac{\partial L}{\partial s^*} = 1/2 * \left(\frac{\partial L}{\partial u} + \frac{\partial L}{\partial v} j\right) \end{aligned}

這裡應該注意到,由於 uuvv 是實函數,且根據我們假設 ff 是一個實值函數的一部分,LL 是實數,所以我們有:

(2)(Ls)=Ls\left( \frac{\partial L}{\partial s} \right)^* = \frac{\partial L}{\partial s^*}

也就是說,Ls\frac{\partial L}{\partial s} 等於 grad_outputgrad\_output^*

解以上方程式以求得 Lu\frac{\partial L}{\partial u}Lv\frac{\partial L}{\partial v},我們可以得到

(3)Lu=Ls+LsLv=1j(LsLs)\begin{aligned} \frac{\partial L}{\partial u} = \frac{\partial L}{\partial s} + \frac{\partial L}{\partial s^*} \\ \frac{\partial L}{\partial v} = 1j * \left(\frac{\partial L}{\partial s} - \frac{\partial L}{\partial s^*}\right) \end{aligned}

(3) 代入 (1),我們得到

Lz=(Ls+Ls)uz+1j(LsLs)vz=Ls(uz+vzj)+Ls(uzvzj)=Ls(u+vj)z+Ls(u+vj)z=Lssz+Lssz\begin{aligned} \frac{\partial L}{\partial z^*} &= \left(\frac{\partial L}{\partial s} + \frac{\partial L}{\partial s^*}\right) * \frac{\partial u}{\partial z^*} + 1j * \left(\frac{\partial L}{\partial s} - \frac{\partial L}{\partial s^*}\right) * \frac{\partial v}{\partial z^*} \\ &= \frac{\partial L}{\partial s} * \left(\frac{\partial u}{\partial z^*} + \frac{\partial v}{\partial z^*} j\right) + \frac{\partial L}{\partial s^*} * \left(\frac{\partial u}{\partial z^*} - \frac{\partial v}{\partial z^*} j\right) \\ &= \frac{\partial L}{\partial s} * \frac{\partial (u + vj)}{\partial z^*} + \frac{\partial L}{\partial s^*} * \frac{\partial (u + vj)^*}{\partial z^*} \\ &= \frac{\partial L}{\partial s} * \frac{\partial s}{\partial z^*} + \frac{\partial L}{\partial s^*} * \frac{\partial s^*}{\partial z^*} \\ \end{aligned}

使用 (2),我們得到

(4)Lz=(Ls)sz+Ls(sz)=(grad_output)sz+grad_output(sz)\begin{aligned} \frac{\partial L}{\partial z^*} &= \left(\frac{\partial L}{\partial s^*}\right)^* * \frac{\partial s}{\partial z^*} + \frac{\partial L}{\partial s^*} * \left(\frac{\partial s}{\partial z}\right)^* \\ &= \boxed{ (grad\_output)^* * \frac{\partial s}{\partial z^*} + grad\_output * \left(\frac{\partial s}{\partial z}\right)^* } \\ \end{aligned}

最後這個方程式對於編寫你自己的梯度非常重要,因為它將我們的導數公式分解成一個更簡單的公式,可以很容易地手動計算。

我該如何為複雜的函數編寫自己的導數公式?

以上方框中的方程式給出了複數函數上所有導數的通用公式。但是,我們仍然需要計算 sz\frac{\partial s}{\partial z}sz\frac{\partial s}{\partial z^*}。 你可以用兩種方法來做到這一點

  • 第一種方法是直接使用 Wirtinger 導數的定義來計算 sz\frac{\partial s}{\partial z}sz\frac{\partial s}{\partial z^*},透過使用 sx\frac{\partial s}{\partial x}sy\frac{\partial s}{\partial y} (你可以用正常的方式計算)。

  • 第二種方式是使用變數變換技巧,將 f(z)f(z) 重寫為雙變數函數 f(z,z)f(z, z^*),並透過將 zzzz^* 視為獨立變數來計算共軛 Wirtinger 導數。 這通常更容易;例如,如果所討論的函數是全純函數,則只會使用 zz(且 sz\frac{\partial s}{\partial z^*} 將為零)。

讓我們考慮函數 f(z=x+yj)=cz=c(x+yj)f(z = x + yj) = c * z = c * (x+yj) 作為範例,其中 cRc \in ℝ

使用第一種方法計算 Wirtinger 導數,我們得到。

sz=1/2(sxsyj)=1/2(c(c1j)1j)=csz=1/2(sx+syj)=1/2(c+(c1j)1j)=0\begin{aligned} \frac{\partial s}{\partial z} &= 1/2 * \left(\frac{\partial s}{\partial x} - \frac{\partial s}{\partial y} j\right) \\ &= 1/2 * (c - (c * 1j) * 1j) \\ &= c \\ \\ \\ \frac{\partial s}{\partial z^*} &= 1/2 * \left(\frac{\partial s}{\partial x} + \frac{\partial s}{\partial y} j\right) \\ &= 1/2 * (c + (c * 1j) * 1j) \\ &= 0 \\ \end{aligned}

使用 (4)grad_output = 1.0 (這是當在 PyTorch 中對純量輸出調用 backward() 時使用的預設 grad 輸出值),我們得到

Lz=10+1c=c\frac{\partial L}{\partial z^*} = 1 * 0 + 1 * c = c

使用第二種方法計算 Wirtinger 導數,我們可以立即得到

sz&=(cz)z&=csz&=(cz)z&=0\begin{aligned} \frac{\partial s}{\partial z} &= \frac{\partial (c*z)}{\partial z} \\ &= c \\ \frac{\partial s}{\partial z^*} &= \frac{\partial (c*z)}{\partial z^*} \\ &= 0 \end{aligned}

再次使用(4),我們得到 Lz=c\frac{\partial L}{\partial z^*} = c。 如你所見,第二種方法涉及較少的計算,並且對於更快的計算來說更方便。

跨域函數呢?

有些函數會將複數輸入映射到實數輸出,反之亦然。 這些函數構成(4)的一個特例,我們可以利用鏈鎖律推導出來。

  • 對於 f:CRf: ℂ → ℝ, 我們得到

    Lz=2grad_outputsz\frac{\partial L}{\partial z^*} = 2 * grad\_output * \frac{\partial s}{\partial z^{*}}
  • 對於 f:RCf: ℝ → ℂ, 我們得到

    Lz=2Re(grad_outputsz)\frac{\partial L}{\partial z^*} = 2 * \mathrm{Re}(grad\_output^* * \frac{\partial s}{\partial z^{*}})

儲存張量的 Hook

您可以透過定義一組 pack_hook / unpack_hook hook 來控制如何封裝/解封裝已儲存的張量pack_hook 函式應該只接受一個張量作為其單一引數,但可以傳回任何 Python 物件 (例如,另一個張量、一個 Tuple,甚至是一個包含檔名的字串)。unpack_hook 函式只接受 pack_hook 的輸出作為其單一引數,並且應該傳回一個將在反向傳播中使用的張量。unpack_hook 傳回的張量只需要與傳遞給 pack_hook 作為輸入的張量具有相同的內容。特別是,任何與 autograd 相關的中繼資料都可以忽略,因為它們將在解封裝期間被覆蓋。

這樣的一對 Hook 的範例是

class SelfDeletingTempFile():
    def __init__(self):
        self.name = os.path.join(tmp_dir, str(uuid.uuid4()))

    def __del__(self):
        os.remove(self.name)

def pack_hook(tensor):
    temp_file = SelfDeletingTempFile()
    torch.save(tensor, temp_file.name)
    return temp_file

def unpack_hook(temp_file):
    return torch.load(temp_file.name)

請注意,unpack_hook 不應刪除暫存檔,因為它可能會被多次呼叫:只要傳回的 SelfDeletingTempFile 物件還存在,暫存檔就應該存在。在上面的範例中,我們透過在不再需要時 (在刪除 SelfDeletingTempFile 物件時) 關閉暫存檔來防止洩漏暫存檔。

注意

我們保證 pack_hook 只會被呼叫一次,但 unpack_hook 可以根據反向傳播的需要被多次呼叫,並且我們期望它每次都傳回相同的資料。

警告

禁止對任何函式的輸入執行原地 (inplace) 操作,因為它們可能導致意想不到的副作用。如果對封裝 Hook 的輸入進行了原地修改,PyTorch 將會拋出錯誤,但不會捕捉到對解封裝 Hook 的輸入進行原地修改的情況。

註冊已儲存張量的 Hook

您可以透過在 SavedTensor 物件上呼叫 register_hooks() 方法來註冊一對已儲存張量的 Hook。這些物件會以 grad_fn 的屬性公開,並以 _raw_saved_ 字首開頭。

x = torch.randn(5, requires_grad=True)
y = x.pow(2)
y.grad_fn._raw_saved_self.register_hooks(pack_hook, unpack_hook)

pack_hook 方法會在註冊該對 Hook 後立即被呼叫。unpack_hook 方法會在每次需要存取已儲存張量時被呼叫,可以透過 y.grad_fn._saved_self 或在反向傳播期間進行存取。

警告

如果在已儲存的張量被釋放後 (即在呼叫反向傳播之後) 仍然維持對 SavedTensor 的引用,則禁止呼叫其 register_hooks()。PyTorch 大多數時候會拋出錯誤,但在某些情況下可能無法做到,並且可能會出現未定義的行為。

註冊已儲存張量的預設 Hook

或者,您可以使用 context-manager saved_tensors_hooks 來註冊一對 Hook,它們將被應用於在該 context 中建立的*所有*已儲存的張量。

範例

# Only save on disk tensors that have size >= 1000
SAVE_ON_DISK_THRESHOLD = 1000

def pack_hook(x):
    if x.numel() < SAVE_ON_DISK_THRESHOLD:
        return x
    temp_file = SelfDeletingTempFile()
    torch.save(tensor, temp_file.name)
    return temp_file

def unpack_hook(tensor_or_sctf):
    if isinstance(tensor_or_sctf, torch.Tensor):
        return tensor_or_sctf
    return torch.load(tensor_or_sctf.name)

class Model(nn.Module):
    def forward(self, x):
        with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
          # ... compute output
          output = x
        return output

model = Model()
net = nn.DataParallel(model)

使用此 context manager 定義的 Hook 是 thread-local 的。因此,以下程式碼不會產生預期的效果,因為 Hook 不會經過 DataParallel

# Example what NOT to do

net = nn.DataParallel(model)
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
    output = net(input)

請注意,使用這些 Hook 會停用所有用於減少張量物件建立的最佳化。例如

with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
    x = torch.randn(5, requires_grad=True)
    y = x * x

如果沒有 Hook,xy.grad_fn._saved_selfy.grad_fn._saved_other 都會引用相同的張量物件。有了 Hook,PyTorch 會將 x 封裝和解封裝到兩個與原始 x 共享相同儲存體的新張量物件中 (不執行複製)。

反向 Hook 執行

本節將討論不同 hooks 觸發或不觸發的時機。然後,將討論它們觸發的順序。將涵蓋的 hooks 包括:透過 torch.Tensor.register_hook() 註冊到 Tensor 的 backward hooks、透過 torch.Tensor.register_post_accumulate_grad_hook() 註冊到 Tensor 的 post-accumulate-grad hooks、透過 torch.autograd.graph.Node.register_hook() 註冊到 Node 的 post-hooks,以及透過 torch.autograd.graph.Node.register_prehook() 註冊到 Node 的 pre-hooks。

特定 hook 是否會被觸發

透過 torch.Tensor.register_hook() 註冊到 Tensor 的 hooks 在計算該 Tensor 的梯度時執行。(請注意,這不需要執行 Tensor 的 grad_fn。例如,如果 Tensor 作為 inputs 參數傳遞給 torch.autograd.grad(),則可能不會執行 Tensor 的 grad_fn,但註冊到該 Tensor 的 hook 將始終被執行。)

透過 torch.Tensor.register_post_accumulate_grad_hook() 註冊到 Tensor 的 hooks 在該 Tensor 的梯度被累加之後執行,這意味著 Tensor 的 grad 欄位已被設定。然而,透過 torch.Tensor.register_hook() 註冊的 hooks 在計算梯度時執行,而透過 torch.Tensor.register_post_accumulate_grad_hook() 註冊的 hooks 僅在反向傳播結束時由 autograd 更新 Tensor 的 grad 欄位後觸發。因此,post-accumulate-grad hooks 只能為 leaf Tensors 註冊。即使您呼叫 backward(retain_graph=True),在非 leaf Tensor 上透過 torch.Tensor.register_post_accumulate_grad_hook() 註冊 hook 也會出錯。

使用 torch.autograd.graph.Node.register_hook()torch.autograd.graph.Node.register_prehook() 註冊到 torch.autograd.graph.Node 的 hooks 僅在它所註冊的 Node 被執行時觸發。

特定 Node 是否被執行可能取決於是否使用 torch.autograd.grad()torch.autograd.backward() 呼叫了反向傳播。具體來說,當您在對應於您要作為 inputs 參數傳遞給 torch.autograd.grad()torch.autograd.backward() 的 Tensor 的 Node 上註冊 hook 時,您應該注意這些差異。

如果您正在使用 torch.autograd.backward(),則上述所有 hooks 都將被執行,無論您是否指定了 inputs 參數。這是因為 .backward() 執行所有 Nodes,即使它們對應於指定為輸入的 Tensor。(請注意,執行對應於作為 inputs 傳遞的 Tensors 的這個額外 Node 通常是不必要的,但仍然會執行。此行為可能會更改;您不應依賴它。)

另一方面,如果您正在使用 torch.autograd.grad(),則註冊到對應於傳遞給 input 的 Tensors 的 Nodes 的 backward hooks 可能不會被執行,因為除非有另一個輸入依賴於此 Node 的梯度結果,否則這些 Nodes 將不會被執行。

不同 hooks 觸發的順序

事件發生的順序是

  1. 執行註冊到 Tensor 的 hooks

  2. 執行註冊到 Node 的 pre-hooks(如果 Node 被執行)。

  3. 為 retain_grad 的 Tensors 更新 .grad 欄位

  4. 執行 Node(取決於上述規則)

  5. 對於累積了 .grad 的 leaf Tensors,執行 post-accumulate-grad hooks

  6. 執行註冊到 Node 的 post-hooks(如果 Node 被執行)

如果同一 Tensor 或 Node 上註冊了多個相同類型的 hooks,則它們會按照註冊的順序執行。稍後執行的 hooks 可以觀察到先前 hooks 對梯度所做的修改。

特殊 hooks

torch.autograd.graph.register_multi_grad_hook() 是使用註冊到 Tensors 的 hooks 實現的。每個單獨的 Tensor hook 都遵循上面定義的 Tensor hook 順序觸發,並且當計算出最後一個 Tensor 梯度時,會呼叫已註冊的 multi-grad hook。

torch.nn.modules.module.register_module_full_backward_hook() 是使用註冊到 Node 的 hooks 實作的。在計算 forward 時,hooks 會註冊到與 module 的輸入和輸出對應的 grad_fn。由於一個 module 可能會有多個輸入並傳回多個輸出,因此在 forward 之前,會先將一個虛擬的自定義 autograd Function 應用於 module 的輸入,並在 forward 的輸出傳回之前應用於 module 的輸出,以確保這些 Tensors 共享一個 grad_fn,然後我們可以將我們的 hooks 附加到該 grad_fn 上。

當 Tensor 被就地 (in-place) 修改時,Tensor hooks 的行為

通常,註冊到 Tensor 的 hooks 接收輸出相對於該 Tensor 的梯度,其中 Tensor 的值被認為是計算 backward 時的值。

然而,如果您將 hooks 註冊到一個 Tensor,然後就地修改該 Tensor,則在就地修改之前註冊的 hooks 也會類似地接收輸出相對於該 Tensor 的梯度,但 Tensor 的值被認為是就地修改之前的值。

如果您更喜歡前一種情況的行為,則應在對 Tensor 進行所有就地修改後再將它們註冊到 Tensor。 例如

t = torch.tensor(1., requires_grad=True).sin()
t.cos_()
t.register_hook(fn)
t.backward()

此外,了解底層運作方式可能會有所幫助,當 hooks 註冊到 Tensor 時,它們實際上會永久綁定到該 Tensor 的 grad_fn,因此,如果該 Tensor 隨後被就地修改,即使 Tensor 現在有了一個新的 grad_fn,但在修改之前註冊的 hooks 將繼續與舊的 grad_fn 相關聯,例如,當 autograd 引擎在圖表中到達該 Tensor 的舊 grad_fn 時,它們將會觸發。

文件

存取 PyTorch 的完整開發者文件

查看文件

教學

取得初學者和進階開發人員的深入教學

查看教學

資源

尋找開發資源並獲得您的問題解答

查看資源