捷徑

SyncBatchNorm

class torch.nn.SyncBatchNorm(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, process_group=None, device=None, dtype=None)[原始碼][原始碼]

對 N 維輸入套用批次正規化。

N-D 輸入是一個 [N-2]D 輸入的 mini-batch (小批量),帶有額外的通道維度,如論文 Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift 中所述。

y=xE[x]Var[x]+ϵγ+β

平均值和標準差是針對相同處理程序群組的所有 mini-batch,依據每個維度計算。γ\gammaβ\beta 是大小為 C (其中 C 是輸入大小) 的可學習參數向量。 預設情況下,γ\gamma 的元素取樣自 U(0,1)\mathcal{U}(0, 1),而 β\beta 的元素設定為 0。標準差透過有偏差的估算器計算,相當於 torch.var(input, unbiased=False)

同樣地,預設情況下,在訓練期間,此層會保留其計算出的平均值和變異數的 running estimates (滾動估計值),然後在評估期間使用這些值進行正規化。 滾動估計值會以預設的 momentum (動量) 0.1 維持。

如果 track_running_stats 設定為 False,則此層不會保留 running estimates (滾動估計值),而是在評估期間也改為使用批次統計資訊。

注意

momentum 參數與最佳化器類別中使用的動量參數以及傳統的動量概念不同。 在數學上,此處 running statistics 的更新規則為 x^new=(1momentum)×x^+momentum×xt\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t,其中 x^\hat{x} 是估計的統計量,而 xtx_t 是新的觀測值。

由於 Batch Normalization 是針對 C 維度中的每個通道完成的,因此在 (N, +) 切片上計算統計量,通常將其稱為 Volumetric Batch Normalization 或 Spatio-temporal Batch Normalization。

目前 SyncBatchNorm 僅支援每個進程單一 GPU 的 DistributedDataParallel (DDP)。 在使用 DDP 包裝網路之前,請使用 torch.nn.SyncBatchNorm.convert_sync_batchnorm()BatchNorm*D 層轉換為 SyncBatchNorm

參數
  • num_features (int) – 來自預期大小為 (N,C,+)(N, C, +) 的輸入中的 CC

  • eps (float) – 為數值穩定性而添加到分母的值。預設值: 1e-5

  • momentum (Optional[float]) – 用於 running_mean 和 running_var 計算的值。 可以設定為 None 以進行累積移動平均 (即簡單平均)。 預設值:0.1

  • affine (bool) – 一個布林值,當設定為 True 時,此模組具有可學習的 affine 參數。 預設值: True

  • track_running_stats ( bool ) – 一個布林值,當設定為 True 時,此模組會追蹤 running mean 和 variance;當設定為 False 時,此模組不會追蹤這些統計數據,並且會將統計緩衝區 running_meanrunning_var 初始化為 None。 當這些緩衝區為 None 時,此模組始終使用批次統計數據,無論是在訓練模式還是評估模式。 預設值:True

  • process_group ( Optional[Any] ) – 統計數據的同步發生在每個 process group 內。 預設行為是在整個 world 範圍內同步

形狀
  • 輸入: (N,C,+)(N, C, +)

  • 輸出: (N,C,+)(N, C, +) (與輸入相同的形狀)

注意

batchnorm 統計數據的同步僅在訓練時發生,也就是說,當設定 model.eval() 或如果 self.training 否則為 False 時,同步會被停用。

範例

>>> # With Learnable Parameters
>>> m = nn.SyncBatchNorm(100)
>>> # creating process group (optional)
>>> # ranks is a list of int identifying rank ids.
>>> ranks = list(range(8))
>>> r1, r2 = ranks[:4], ranks[4:]
>>> # Note: every rank calls into new_group for every
>>> # process group created, even if that rank is not
>>> # part of the group.
>>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]]
>>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1]
>>> # Without Learnable Parameters
>>> m = nn.BatchNorm3d(100, affine=False, process_group=process_group)
>>> input = torch.randn(20, 100, 35, 45, 10)
>>> output = m(input)

>>> # network is nn.BatchNorm layer
>>> sync_bn_network = nn.SyncBatchNorm.convert_sync_batchnorm(network, process_group)
>>> # only single gpu per process is currently supported
>>> ddp_sync_bn_network = torch.nn.parallel.DistributedDataParallel(
>>>                         sync_bn_network,
>>>                         device_ids=[args.local_rank],
>>>                         output_device=args.local_rank)
classmethod convert_sync_batchnorm(module, process_group=None)[source][source]

將模型中的所有 BatchNorm*D 層轉換為 torch.nn.SyncBatchNorm 層。

參數
  • module ( nn.Module ) – 包含一個或多個 BatchNorm*D 層的模組

  • process_group (optional) – 用於限定同步範圍的 process group,預設為整個 world

返回

原始 module,其中包含轉換後的 torch.nn.SyncBatchNorm 層。 如果原始 module 是一個 BatchNorm*D 層,則會返回一個新的 torch.nn.SyncBatchNorm 層物件。

範例

>>> # Network with nn.BatchNorm layer
>>> module = torch.nn.Sequential(
>>>            torch.nn.Linear(20, 100),
>>>            torch.nn.BatchNorm1d(100),
>>>          ).cuda()
>>> # creating process group (optional)
>>> # ranks is a list of int identifying rank ids.
>>> ranks = list(range(8))
>>> r1, r2 = ranks[:4], ranks[4:]
>>> # Note: every rank calls into new_group for every
>>> # process group created, even if that rank is not
>>> # part of the group.
>>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]]
>>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1]
>>> sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module, process_group)

文件

取得 PyTorch 的全面開發者文件

檢視文件

教學

取得針對初學者和進階開發者的深入教學

檢視教學

資源

尋找開發資源並取得問題解答

檢視資源