捷徑

GroupNorm

class torch.nn.GroupNorm(num_groups, num_channels, eps=1e-05, affine=True, device=None, dtype=None)[source][source]

對一批輸入套用群組正規化。

此層實作的操作如 群組正規化 論文中所述

y=xE[x]Var[x]+ϵγ+βy = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

輸入通道會被分成 num_groups 個群組,每個群組包含 num_channels / num_groups 個通道。num_channels 必須可以被 num_groups 整除。平均值和標準差會分別針對每個群組計算。 γ\gammaβ\beta 是可學習的每通道仿射轉換參數向量,大小為 num_channels,如果 affineTrue。變異數是透過有偏差的估計量計算的,等同於 torch.var(input, unbiased=False)

此層在訓練和評估模式中都使用從輸入資料計算的統計數據。

參數
  • num_groups (int) – 將通道分成的群組數量

  • num_channels (int) – 預期輸入中的通道數量

  • eps (float) – 一個加到分母以實現數值穩定性的值。預設值:1e-5

  • affine (bool) – 一個布林值,當設定為 True 時,此模組具有可學習的每通道仿射參數,初始化為 1(對於權重)和 0(對於偏差)。預設值:True

形狀
  • 輸入: (N,C,)(N, C, *) 其中 C=num_channelsC=\text{num\_channels}

  • 輸出: (N,C,)(N, C, *) (與輸入相同的形狀)

範例

>>> input = torch.randn(20, 6, 10, 10)
>>> # Separate 6 channels into 3 groups
>>> m = nn.GroupNorm(3, 6)
>>> # Separate 6 channels into 6 groups (equivalent with InstanceNorm)
>>> m = nn.GroupNorm(6, 6)
>>> # Put all 6 channels into a single group (equivalent with LayerNorm)
>>> m = nn.GroupNorm(1, 6)
>>> # Activating the module
>>> output = m(input)

文件

存取 PyTorch 完整的開發者文件

檢視文件

教學

取得針對初學者和進階開發者的深度教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源