捷徑

修補批次標準化

發生了什麼事?

批次標準化需要對與輸入大小相同的 running_mean 和 running_var 進行原地更新。 Functorch 不支援對採用批次張量的規則張量進行原地更新(即不允許 regular.add_(batched))。 因此,當對單個模組的輸入批次進行 vmaping 時,我們會遇到此錯誤

如何修復

所有這些選項都假設您不需要運行統計數據。 如果您正在使用模組,這意味著假設您不會在評估模式下使用批次標準化。 如果您的用例涉及在評估模式下使用 vmap 運行批次標準化,請提交問題

選項 1:變更批次標準化

如果您是自己構建模組的,則可以變更模組以不使用運行統計數據。 換句話說,在任何有批次標準化模組的地方,將 track_running_stats 旗標設置為 False

BatchNorm2d(64, track_running_stats=False)

選項 2:torchvision 參數

某些 torchvision 模型(如 resnet 和 regnet)可以採用 norm_layer 參數。 如果已將它們默認為 BatchNorm2d,則它們通常默認為 BatchNorm2d。 相反,您可以將其設置為不使用運行統計數據的批次標準化

import torchvision
from functools import partial
torchvision.models.resnet18(norm_layer=partial(BatchNorm2d, track_running_stats=False))

選項 3:functorch 的修補

functorch 添加了一些功能,允許快速、原地修補模組。 如果您有一個想要變更的網路,則可以運行 replace_all_batch_norm_modules_ 以原地更新模組以不使用運行統計數據

from functorch.experimental import replace_all_batch_norm_modules_
replace_all_batch_norm_modules_(net)

文件

存取 PyTorch 的完整開發人員文件

查看文件

教學課程

取得適用於初學者和進階開發人員的深入教學課程

查看教學課程

資源

尋找開發資源並獲得問題解答

查看資源