GroupNorm¶
- class torch.nn.GroupNorm(num_groups, num_channels, eps=1e-05, affine=True, device=None, dtype=None)[source][source]¶
對一批輸入套用群組正規化。
此層實作的操作如 群組正規化 論文中所述
輸入通道會被分成
num_groups
個群組,每個群組包含num_channels / num_groups
個通道。num_channels
必須可以被num_groups
整除。平均值和標準差會分別針對每個群組計算。 和 是可學習的每通道仿射轉換參數向量,大小為num_channels
,如果affine
為True
。變異數是透過有偏差的估計量計算的,等同於 torch.var(input, unbiased=False)。此層在訓練和評估模式中都使用從輸入資料計算的統計數據。
- 參數
- 形狀
輸入: 其中
輸出: (與輸入相同的形狀)
範例
>>> input = torch.randn(20, 6, 10, 10) >>> # Separate 6 channels into 3 groups >>> m = nn.GroupNorm(3, 6) >>> # Separate 6 channels into 6 groups (equivalent with InstanceNorm) >>> m = nn.GroupNorm(6, 6) >>> # Put all 6 channels into a single group (equivalent with LayerNorm) >>> m = nn.GroupNorm(1, 6) >>> # Activating the module >>> output = m(input)