略過模組參數初始化¶
建立於:2021 年 6 月 17 日 | 最後更新:2021 年 6 月 17 日 | 最後驗證:未驗證
簡介¶
當建立模組時,其可學習的參數會根據與模組類型關聯的預設初始化方案進行初始化。例如,torch.nn.Linear
模組的 weight 參數是從 uniform(-1/sqrt(in_features), 1/sqrt(in_features)) 分佈初始化的。如果需要其他初始化方案,傳統上需要在模組實例化後重新初始化參數
from torch import nn
# Initializes weight from the default distribution: uniform(-1/sqrt(10), 1/sqrt(10)).
m = nn.Linear(10, 5)
# Re-initialize weight from a different distribution.
nn.init.orthogonal_(m.weight)
在這種情況下,建構期間完成的初始化是浪費的計算,如果 weight 參數很大,則可能並非微不足道。
略過初始化¶
現在可以略過模組建構期間的參數初始化,避免浪費的計算。 這可以使用 torch.nn.utils.skip_init()
函數輕鬆完成
from torch import nn
from torch.nn.utils import skip_init
m = skip_init(nn.Linear, 10, 5)
# Example: Do custom, non-default parameter initialization.
nn.init.orthogonal_(m.weight)
這可以應用於滿足下面更新模組以支援略過初始化章節中描述的條件的任何模組。 請注意,torch.nn 提供的所有模組都滿足這些條件,因此支援略過初始化。
更新模組以支援略過初始化¶
由於 torch.nn.utils.skip_init()
的實作方式 (請參閱 實作細節),模組必須滿足兩個要求才能與該函數相容。 您只需遵守這些要求即可為您的自訂模組選擇加入參數初始化略過功能
1. 模組必須在其建構子中接受一個 device kwarg,該 kwarg 會傳遞給建構期間建立的任何參數或緩衝區。
2. 模組不得在其建構子中對參數或緩衝區執行任何計算,除了初始化 (即來自 torch.nn.init 的函數)。
以下範例示範了透過將 device kwarg 傳遞給任何建立的參數、緩衝區或子模組來更新以支援 device kwarg 的模組
import torch
from torch import nn
class MyModule(torch.nn.Module):
def __init__(self, foo, bar, device=None):
super().__init__()
# ==== Case 1: Module creates parameters directly. ====
# Pass device along to any created parameters.
self.param1 = nn.Parameter(torch.empty((foo, bar), device=device))
self.register_parameter('param2', nn.Parameter(torch.empty(bar, device=device)))
# To ensure support for the meta device, avoid using ops except those in
# torch.nn.init on parameters in your module's constructor.
with torch.no_grad():
nn.init.kaiming_uniform_(self.param1)
nn.init.uniform_(self.param2)
# ==== Case 2: Module creates submodules. ====
# Pass device along recursively. All submodules will need to support
# them as well; this is the case for all torch.nn provided modules.
self.fc = nn.Linear(bar, 5, device=device)
# This also works with containers.
self.linears = nn.Sequential(
nn.Linear(5, 5, device=device),
nn.Linear(5, 1, device=device)
)
# ==== Case 3: Module creates buffers. ====
# Pass device along during buffer tensor creation.
self.register_buffer('some_buffer', torch.ones(7, device=device))
...
實作細節¶
在幕後,torch.nn.utils.skip_init()
函數是根據兩步驟模式實作的
# 1. Initialize module on the meta device; all torch.nn.init ops have
# no-op behavior on the meta device.
m = nn.Linear(10, 5, device='meta')
# 2. Materialize an uninitialized (empty) form of the module on the CPU device.
# The result of this is a module instance with uninitialized parameters.
m.to_empty(device='cpu')
它的運作方式是將模組實例化到「meta」裝置上,該裝置具有張量形狀資訊,但不分配任何儲存空間。torch.nn.init 運算針對此 meta 裝置進行特殊實作,使其具有無運算行為。 這導致參數初始化邏輯基本上被略過。
請注意,此模式僅適用於在建構期間正確支援 device kwarg 的模組,如 更新模組以支援略過初始化中所述。