RMSNorm¶
- class torch.nn.RMSNorm(normalized_shape, eps=None, elementwise_affine=True, device=None, dtype=None)[source][source]¶
對輸入的小批量資料應用均方根層歸一化 (Root Mean Square Layer Normalization)。
此層實作了論文 Root Mean Square Layer Normalization 中描述的運算。
RMS 是在最後
D
個維度上計算的,其中D
是normalized_shape
的維度。 例如,如果normalized_shape
是(3, 5)
(一個二維形狀),則 RMS 是在輸入的最後 2 個維度上計算的。- 參數
- Shape (形狀)
Input (輸入):
Output (輸出): (與輸入相同的形狀)
Examples (範例)
>>> rms_norm = nn.RMSNorm([2, 3]) >>> input = torch.randn(2, 2, 3) >>> rms_norm(input)