SyncBatchNorm¶
- class torch.nn.SyncBatchNorm(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, process_group=None, device=None, dtype=None)[原始碼][原始碼]¶
對 N 維輸入套用批次正規化。
N-D 輸入是一個 [N-2]D 輸入的 mini-batch (小批量),帶有額外的通道維度,如論文 Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift 中所述。
平均值和標準差是針對相同處理程序群組的所有 mini-batch,依據每個維度計算。 和 是大小為 C (其中 C 是輸入大小) 的可學習參數向量。 預設情況下, 的元素取樣自 ,而 的元素設定為 0。標準差透過有偏差的估算器計算,相當於 torch.var(input, unbiased=False)。
同樣地,預設情況下,在訓練期間,此層會保留其計算出的平均值和變異數的 running estimates (滾動估計值),然後在評估期間使用這些值進行正規化。 滾動估計值會以預設的
momentum
(動量) 0.1 維持。如果
track_running_stats
設定為False
,則此層不會保留 running estimates (滾動估計值),而是在評估期間也改為使用批次統計資訊。注意
此
momentum
參數與最佳化器類別中使用的動量參數以及傳統的動量概念不同。 在數學上,此處 running statistics 的更新規則為 ,其中 是估計的統計量,而 是新的觀測值。由於 Batch Normalization 是針對
C
維度中的每個通道完成的,因此在(N, +)
切片上計算統計量,通常將其稱為 Volumetric Batch Normalization 或 Spatio-temporal Batch Normalization。目前
SyncBatchNorm
僅支援每個進程單一 GPU 的DistributedDataParallel
(DDP)。 在使用 DDP 包裝網路之前,請使用torch.nn.SyncBatchNorm.convert_sync_batchnorm()
將BatchNorm*D
層轉換為SyncBatchNorm
。- 參數
num_features (int) – 來自預期大小為 的輸入中的
eps (float) – 為數值穩定性而添加到分母的值。預設值:
1e-5
momentum (Optional[float]) – 用於 running_mean 和 running_var 計算的值。 可以設定為
None
以進行累積移動平均 (即簡單平均)。 預設值:0.1affine (bool) – 一個布林值,當設定為
True
時,此模組具有可學習的 affine 參數。 預設值:True
track_running_stats ( bool ) – 一個布林值,當設定為
True
時,此模組會追蹤 running mean 和 variance;當設定為False
時,此模組不會追蹤這些統計數據,並且會將統計緩衝區running_mean
和running_var
初始化為None
。 當這些緩衝區為None
時,此模組始終使用批次統計數據,無論是在訓練模式還是評估模式。 預設值:True
process_group ( Optional[Any] ) – 統計數據的同步發生在每個 process group 內。 預設行為是在整個 world 範圍內同步
- 形狀
輸入:
輸出: (與輸入相同的形狀)
注意
batchnorm 統計數據的同步僅在訓練時發生,也就是說,當設定
model.eval()
或如果self.training
否則為False
時,同步會被停用。範例
>>> # With Learnable Parameters >>> m = nn.SyncBatchNorm(100) >>> # creating process group (optional) >>> # ranks is a list of int identifying rank ids. >>> ranks = list(range(8)) >>> r1, r2 = ranks[:4], ranks[4:] >>> # Note: every rank calls into new_group for every >>> # process group created, even if that rank is not >>> # part of the group. >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]] >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1] >>> # Without Learnable Parameters >>> m = nn.BatchNorm3d(100, affine=False, process_group=process_group) >>> input = torch.randn(20, 100, 35, 45, 10) >>> output = m(input) >>> # network is nn.BatchNorm layer >>> sync_bn_network = nn.SyncBatchNorm.convert_sync_batchnorm(network, process_group) >>> # only single gpu per process is currently supported >>> ddp_sync_bn_network = torch.nn.parallel.DistributedDataParallel( >>> sync_bn_network, >>> device_ids=[args.local_rank], >>> output_device=args.local_rank)
- classmethod convert_sync_batchnorm(module, process_group=None)[source][source]¶
將模型中的所有
BatchNorm*D
層轉換為torch.nn.SyncBatchNorm
層。- 參數
module ( nn.Module ) – 包含一個或多個
BatchNorm*D
層的模組process_group (optional) – 用於限定同步範圍的 process group,預設為整個 world
- 返回
原始
module
,其中包含轉換後的torch.nn.SyncBatchNorm
層。 如果原始module
是一個BatchNorm*D
層,則會返回一個新的torch.nn.SyncBatchNorm
層物件。
範例
>>> # Network with nn.BatchNorm layer >>> module = torch.nn.Sequential( >>> torch.nn.Linear(20, 100), >>> torch.nn.BatchNorm1d(100), >>> ).cuda() >>> # creating process group (optional) >>> # ranks is a list of int identifying rank ids. >>> ranks = list(range(8)) >>> r1, r2 = ranks[:4], ranks[4:] >>> # Note: every rank calls into new_group for every >>> # process group created, even if that rank is not >>> # part of the group. >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]] >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1] >>> sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module, process_group)