快捷方式

torch.nn.utils.parametrizations.weight_norm

torch.nn.utils.parametrizations.weight_norm(module, name='weight', dim=0)[原始碼][原始碼]

將權重正規化應用於指定模組中的參數。

w=gvv\mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}

權重正規化是一種重新參數化的方法,它將權重張量的大小與其方向解耦。這會將由 name 指定的參數替換為兩個參數:一個指定大小,另一個指定方向。

預設情況下,使用 dim=0,範數是針對每個輸出通道/平面獨立計算的。要計算整個權重張量的範數,請使用 dim=None

請參閱 https://arxiv.org/abs/1602.07868

參數
  • module (Module) – 包含模組

  • name (str, optional) – 權重參數的名稱

  • dim (int, optional) – 計算範數的維度

回傳

具有權重正規化掛鉤的原始模組

範例

>>> m = weight_norm(nn.Linear(20, 40), name='weight')
>>> m
ParametrizedLinear(
  in_features=20, out_features=40, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): _WeightNorm()
    )
  )
)
>>> m.parametrizations.weight.original0.size()
torch.Size([40, 1])
>>> m.parametrizations.weight.original1.size()
torch.Size([40, 20])

文件

存取 PyTorch 的完整開發人員文件

查看文件

教學課程

取得初學者和高級開發人員的深入教學課程

查看教學課程

資源

尋找開發資源並獲得問題解答

查看資源