torch.nn.utils.parametrizations.weight_norm¶
- torch.nn.utils.parametrizations.weight_norm(module, name='weight', dim=0)[原始碼][原始碼]¶
將權重正規化應用於指定模組中的參數。
權重正規化是一種重新參數化的方法,它將權重張量的大小與其方向解耦。這會將由
name
指定的參數替換為兩個參數:一個指定大小,另一個指定方向。預設情況下,使用
dim=0
,範數是針對每個輸出通道/平面獨立計算的。要計算整個權重張量的範數,請使用dim=None
。請參閱 https://arxiv.org/abs/1602.07868
- 參數
- 回傳
具有權重正規化掛鉤的原始模組
範例
>>> m = weight_norm(nn.Linear(20, 40), name='weight') >>> m ParametrizedLinear( in_features=20, out_features=40, bias=True (parametrizations): ModuleDict( (weight): ParametrizationList( (0): _WeightNorm() ) ) ) >>> m.parametrizations.weight.original0.size() torch.Size([40, 1]) >>> m.parametrizations.weight.original1.size() torch.Size([40, 20])