修補批次標準化¶
發生了什麼事?¶
批次標準化需要對與輸入大小相同的 running_mean 和 running_var 進行原地更新。 Functorch 不支援對採用批次張量的規則張量進行原地更新(即不允許 regular.add_(batched)
)。 因此,當對單個模組的輸入批次進行 vmaping 時,我們會遇到此錯誤
如何修復¶
所有這些選項都假設您不需要運行統計數據。 如果您正在使用模組,這意味著假設您不會在評估模式下使用批次標準化。 如果您的用例涉及在評估模式下使用 vmap 運行批次標準化,請提交問題
選項 1:變更批次標準化¶
如果您是自己構建模組的,則可以變更模組以不使用運行統計數據。 換句話說,在任何有批次標準化模組的地方,將 track_running_stats
旗標設置為 False
BatchNorm2d(64, track_running_stats=False)
選項 2:torchvision 參數¶
某些 torchvision 模型(如 resnet 和 regnet)可以採用 norm_layer
參數。 如果已將它們默認為 BatchNorm2d,則它們通常默認為 BatchNorm2d。 相反,您可以將其設置為不使用運行統計數據的批次標準化
import torchvision
from functools import partial
torchvision.models.resnet18(norm_layer=partial(BatchNorm2d, track_running_stats=False))
選項 3:functorch 的修補¶
functorch 添加了一些功能,允許快速、原地修補模組。 如果您有一個想要變更的網路,則可以運行 replace_all_batch_norm_modules_
以原地更新模組以不使用運行統計數據
from functorch.experimental import replace_all_batch_norm_modules_
replace_all_batch_norm_modules_(net)