捷徑

修補批次正規化

發生什麼事?

批次正規化需要對與輸入大小相同的 running_mean 和 running_var 進行原地更新。Functorch 不支援對接收批次張量的常規張量進行原地更新 (即不允許 regular.add_(batched))。因此,當在單一模組的批次輸入上進行 vmapping 時,最終會出現此錯誤

如何修正

最佳支援方式之一是將 BatchNorm 切換為 GroupNorm。選項 1 和選項 2 支援此功能

所有這些選項都假設您不需要執行統計數據。如果您使用的是模組,這表示假設您不會在評估模式中使用批次正規化。如果您有需要在評估模式中使用 vmap 執行批次正規化的使用案例,請提交問題

選項 1:變更 BatchNorm

如果您想要變更為 GroupNorm,請在任何有 BatchNorm 的地方,將其替換為

BatchNorm2d(C, G, track_running_stats=False)

此處 C 與原始 BatchNorm 中的 C 相同。G 是將 C 分組的組數。因此,C % G == 0,作為後備方案,您可以設定 C == G,表示每個通道將被個別處理。

如果您必須使用 BatchNorm 且您已自行建置模組,您可以變更模組以不使用執行統計數據。換句話說,在任何有 BatchNorm 模組的地方,將 track_running_stats 旗標設定為 False

BatchNorm2d(64, track_running_stats=False)

選項 2:torchvision 參數

某些 torchvision 模型 (例如 resnet 和 regnet) 可以接收 norm_layer 參數。如果已預設,這些參數通常會預設為 BatchNorm2d。

您可以改為將其設定為 GroupNorm。

import torchvision
from functools import partial
torchvision.models.resnet18(norm_layer=lambda c: GroupNorm(num_groups=g, c))

此處再次強調,c % g == 0,因此作為後備方案,請設定 g = c

如果您堅持使用 BatchNorm,請務必使用不使用執行統計數據的版本

import torchvision
from functools import partial
torchvision.models.resnet18(norm_layer=partial(BatchNorm2d, track_running_stats=False))

選項 3:functorch 的修補

functorch 新增了一些功能,允許快速、原地修補模組,使其不使用執行統計數據。變更正規化層更為脆弱,因此我們沒有提供該功能。如果您有一個網路,您希望 BatchNorm 不使用執行統計數據,您可以執行 replace_all_batch_norm_modules_ 以原地更新模組,使其不使用執行統計數據

from torch.func import replace_all_batch_norm_modules_
replace_all_batch_norm_modules_(net)

選項 4:評估模式

在評估模式下執行時,running_mean 和 running_var 將不會更新。因此,vmap 可以支援此模式

model.eval()
vmap(model)(x)
model.train()

文件

存取 PyTorch 的完整開發者文件

查看文件

教學

取得初學者和進階開發者的深入教學

查看教學

資源

尋找開發資源並獲得問題解答

查看資源