torch.nn.utils.skip_init¶
- torch.nn.utils.skip_init(module_cls, *args, **kwargs)[source][source]¶
給定一個模組類別物件和 args / kwargs,實例化該模組而不初始化參數/緩衝區。
如果初始化過程較慢,或者需要執行自定義初始化,而不需要預設初始化,這會很有用。但由於此函數的實作方式,有一些需要注意的地方。
1. 模組的建構子必須接受一個 device 參數,並將其傳遞給建構過程中建立的任何參數或緩衝區。
2. 除了初始化(即來自
torch.nn.init
的函數)之外,模組的建構子不得對參數執行任何計算。如果滿足這些條件,則可以使用未初始化的參數/緩衝區值來實例化模組,就像使用
torch.empty()
建立的一樣。- 參數
module_cls – 類別物件;應該是
torch.nn.Module
的子類別args – 傳遞給模組建構子的 args
kwargs – 傳遞給模組建構子的 kwargs
- 返回
具有未初始化參數/緩衝區的實例化模組
範例
>>> import torch >>> m = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1) >>> m.weight Parameter containing: tensor([[0.0000e+00, 1.5846e+29, 7.8307e+00, 2.5250e-29, 1.1210e-44]], requires_grad=True) >>> m2 = torch.nn.utils.skip_init(torch.nn.Linear, in_features=6, out_features=1) >>> m2.weight Parameter containing: tensor([[-1.4677e+24, 4.5915e-41, 1.4013e-45, 0.0000e+00, -1.4677e+24, 4.5915e-41]], requires_grad=True)