捷徑

torch.nn.utils.weight_norm

torch.nn.utils.weight_norm(module, name='weight', dim=0)[來源][來源]

將權重正規化應用於給定模組中的參數。

w=gvv\mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}

權重正規化是一種重新參數化的方法,可將權重張量的大小與其方向解耦。這會將由 name (例如 'weight') 指定的參數,替換為兩個參數:一個指定大小 (例如 'weight_g'),另一個指定方向 (例如 'weight_v')。權重正規化是透過一個鉤子 (hook) 實現的,該鉤子會在每次 forward() 呼叫之前,根據大小和方向重新計算權重張量。

預設情況下,使用 dim=0 時,範數 (norm) 是針對每個輸出通道/平面獨立計算的。若要計算整個權重張量的範數,請使用 dim=None

請參閱 https://arxiv.org/abs/1602.07868

警告

此函數已被棄用。請使用 torch.nn.utils.parametrizations.weight_norm(),它使用現代化的參數化 API。新的 weight_norm 與舊的 weight_norm 產生的 state_dict 相容。

遷移指南

參數
  • module (Module) – 包含模組

  • name (str, optional) – 權重參數的名稱

  • dim (int, optional) – 計算範數的維度

傳回

具有權重正規化鉤子的原始模組

傳回類型

T_module

範例

>>> m = weight_norm(nn.Linear(20, 40), name='weight')
>>> m
Linear(in_features=20, out_features=40, bias=True)
>>> m.weight_g.size()
torch.Size([40, 1])
>>> m.weight_v.size()
torch.Size([40, 20])

文件

存取 PyTorch 的全面開發人員文件

檢視文件

教學課程

取得初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並取得問題的解答

檢視資源