torch.func¶
torch.func,先前稱為「functorch」,是類似 JAX 的 PyTorch 可組合函式轉換。
注意
這個函式庫目前處於 beta 階段。 這表示功能通常可以運作(除非另有說明),並且我們(PyTorch 團隊)致力於推進這個函式庫。 但是,API 可能會根據使用者回饋進行變更,並且我們沒有涵蓋 PyTorch 的所有操作。
如果您對 API 有任何建議,或想要涵蓋的使用案例,請開啟 GitHub issue 或與我們聯繫。 我們很樂意了解您如何使用這個函式庫。
什麼是可組合函式轉換?¶
「函式轉換」是一個高階函式,它接受一個數值函式並傳回一個新的函式,該函式計算不同的量。
torch.func
具有自動微分轉換(grad(f)
傳回一個計算f
梯度的函式)、向量化/批次處理轉換(vmap(f)
傳回一個計算f
在輸入批次上的函式)等等。這些函數轉換可以任意組合。例如,組合
vmap(grad(f))
會計算一個稱為「每個樣本梯度 (per-sample-gradients)」的量,這是目前標準 PyTorch 無法有效率計算的。