torch.autograd.gradcheck.gradcheck¶
- torch.autograd.gradcheck.gradcheck(func, inputs, *, eps=1e-06, atol=1e-05, rtol=0.001, raise_exception=True, nondet_tol=0.0, check_undefined_grad=True, check_grad_dtypes=False, check_batched_grad=False, check_batched_forward_grad=False, check_forward_ad=False, check_backward_ad=True, fast_mode=False, masked=None)[原始碼][原始碼]¶
針對
inputs
中浮點數或複數型別且requires_grad=True
的 tensors,使用小幅有限差分法 (small finite differences) 驗證計算出的梯度與解析梯度是否一致。數值梯度和解析梯度之間的檢查使用
allclose()
。對於我們出於優化目的而考慮的大多數複數函數,不存在 Jacobian 的概念。相反,gradcheck 驗證 Wirtinger 導數和共軛 Wirtinger 導數的數值和解析值是否一致。由於梯度計算是在整體函數具有實數值輸出的假設下完成的,因此我們以特殊的方式處理具有複數輸出的函數。對於這些函數,gradcheck 應用於兩個實數值函數,第一個對應於取複數輸出的實數分量,第二個對應於取複數輸出的虛數分量。有關更多詳細資訊,請查看 複數的 Autograd。
注意
預設值是為雙精度
input
設計的。如果input
的精度較低,例如FloatTensor
,則此檢查很可能會失敗。注意
在不可微分點上評估時,Gradcheck 可能會失敗,因為通過有限差分法數值計算的梯度可能與解析計算的梯度不同(不一定是因為其中任何一個不正確)。有關更多上下文,請參閱 不可微分函數的梯度。
警告
如果
input
中任何經過檢查的 tensor 具有重疊的記憶體,即指向相同記憶體位址的不同索引(例如,來自torch.Tensor.expand()
),則此檢查很可能會失敗,因為通過在此類索引處的點擾動 (point perturbation) 計算的數值梯度將更改共享相同記憶體位址的所有其他索引的值。- 參數
func (function) – 接受 Tensor 輸入並返回 Tensor 或 Tensors 元組的 Python 函數
eps (float, optional) – 用於有限差分法的擾動值
atol (float, optional) – 絕對容差
rtol (float, optional) – 相對容差
raise_exception (bool, optional) – 指示檢查失敗時是否引發異常。 該異常提供有關失敗的確切性質的更多資訊。 這在調試 gradchecks 時很有用。
nondet_tol (float, optional) – 非確定性的容差。 通過微分運行相同的輸入時,結果必須完全匹配(預設值,0.0)或在此容差範圍內。
check_undefined_grad (bool, optional) – 如果
True
,則檢查是否支援未定義的輸出梯度並將其視為零,對於Tensor
輸出。check_batched_grad (bool, optional) – 如果
True
,檢查我們是否可以使用 prototype vmap 支援計算批次梯度。 預設為 False。check_batched_forward_grad (bool, optional) – 如果
True
,檢查我們是否可以使用 forward ad 和 prototype vmap 支援計算批次前向梯度。 預設為False
。check_forward_ad (bool, optional) – 若為
True
,則檢查使用前向模式 AD 計算的梯度是否與數值梯度匹配。預設值為False
。check_backward_ad (bool, optional) – 若為
False
,則不執行任何依賴反向模式 AD 實現的檢查。預設值為True
。fast_mode (bool, optional) – gradcheck 和 gradgradcheck 的快速模式目前僅適用於 R 到 R 的函數。如果沒有任何輸入和輸出是複數,則執行一個更快的 gradcheck 實現,不再計算整個雅可比矩陣;否則,我們將退回到較慢的實現。
masked (bool, optional) – 若為
True
,則忽略稀疏張量未指定元素的梯度。預設值為False
。
- 回傳
如果所有差異都滿足 allclose 條件,則回傳
True
- 回傳類型