修補批次正規化¶
發生什麼事?¶
批次正規化需要對與輸入大小相同的 running_mean 和 running_var 進行原地更新。Functorch 不支援對接收批次張量的常規張量進行原地更新 (即不允許 regular.add_(batched)
)。因此,當在單一模組的批次輸入上進行 vmapping 時,最終會出現此錯誤
如何修正¶
最佳支援方式之一是將 BatchNorm 切換為 GroupNorm。選項 1 和選項 2 支援此功能
所有這些選項都假設您不需要執行統計數據。如果您使用的是模組,這表示假設您不會在評估模式中使用批次正規化。如果您有需要在評估模式中使用 vmap 執行批次正規化的使用案例,請提交問題
選項 1:變更 BatchNorm¶
如果您想要變更為 GroupNorm,請在任何有 BatchNorm 的地方,將其替換為
BatchNorm2d(C, G, track_running_stats=False)
此處 C
與原始 BatchNorm 中的 C
相同。G
是將 C
分組的組數。因此,C % G == 0
,作為後備方案,您可以設定 C == G
,表示每個通道將被個別處理。
如果您必須使用 BatchNorm 且您已自行建置模組,您可以變更模組以不使用執行統計數據。換句話說,在任何有 BatchNorm 模組的地方,將 track_running_stats
旗標設定為 False
BatchNorm2d(64, track_running_stats=False)
選項 2:torchvision 參數¶
某些 torchvision 模型 (例如 resnet 和 regnet) 可以接收 norm_layer
參數。如果已預設,這些參數通常會預設為 BatchNorm2d。
您可以改為將其設定為 GroupNorm。
import torchvision
from functools import partial
torchvision.models.resnet18(norm_layer=lambda c: GroupNorm(num_groups=g, c))
此處再次強調,c % g == 0
,因此作為後備方案,請設定 g = c
。
如果您堅持使用 BatchNorm,請務必使用不使用執行統計數據的版本
import torchvision
from functools import partial
torchvision.models.resnet18(norm_layer=partial(BatchNorm2d, track_running_stats=False))
選項 3:functorch 的修補¶
functorch 新增了一些功能,允許快速、原地修補模組,使其不使用執行統計數據。變更正規化層更為脆弱,因此我們沒有提供該功能。如果您有一個網路,您希望 BatchNorm 不使用執行統計數據,您可以執行 replace_all_batch_norm_modules_
以原地更新模組,使其不使用執行統計數據
from torch.func import replace_all_batch_norm_modules_
replace_all_batch_norm_modules_(net)
選項 4:評估模式¶
在評估模式下執行時,running_mean 和 running_var 將不會更新。因此,vmap 可以支援此模式
model.eval()
vmap(model)(x)
model.train()