torch.nn.utils.weight_norm¶
- torch.nn.utils.weight_norm(module, name='weight', dim=0)[來源][來源]¶
將權重正規化應用於給定模組中的參數。
權重正規化是一種重新參數化的方法,可將權重張量的大小與其方向解耦。這會將由
name
(例如'weight'
) 指定的參數,替換為兩個參數:一個指定大小 (例如'weight_g'
),另一個指定方向 (例如'weight_v'
)。權重正規化是透過一個鉤子 (hook) 實現的,該鉤子會在每次forward()
呼叫之前,根據大小和方向重新計算權重張量。預設情況下,使用
dim=0
時,範數 (norm) 是針對每個輸出通道/平面獨立計算的。若要計算整個權重張量的範數,請使用dim=None
。請參閱 https://arxiv.org/abs/1602.07868
警告
此函數已被棄用。請使用
torch.nn.utils.parametrizations.weight_norm()
,它使用現代化的參數化 API。新的weight_norm
與舊的weight_norm
產生的state_dict
相容。遷移指南
大小 (
weight_g
) 和方向 (weight_v
) 現在分別表示為parametrizations.weight.original0
和parametrizations.weight.original1
。 如果這讓您感到困擾,請在 https://github.com/pytorch/pytorch/issues/102999 上發表評論。若要移除權重正規化重新參數化,請使用
torch.nn.utils.parametrize.remove_parametrizations()
。權重不再於模組正向傳播時重新計算一次; 相反,它會在每次存取時重新計算。 若要恢復舊的行為,請在使用相關模組之前,使用
torch.nn.utils.parametrize.cached()
。
- 參數
- 傳回
具有權重正規化鉤子的原始模組
- 傳回類型
T_module
範例
>>> m = weight_norm(nn.Linear(20, 40), name='weight') >>> m Linear(in_features=20, out_features=40, bias=True) >>> m.weight_g.size() torch.Size([40, 1]) >>> m.weight_v.size() torch.Size([40, 20])