torch.nn.utils.prune.identity¶
- torch.nn.utils.prune.identity(module, name)[source][source]¶
套用修剪重新參數化,而不修剪任何單元。
將修剪重新參數化應用於與
module
中名為name
的參數相對應的張量,而實際上不修剪任何單元。透過以下方式就地修改模組(並傳回修改後的模組):新增一個名為
name+'_mask'
的具名緩衝區,對應於修剪方法應用於參數name
的二元遮罩。將參數
name
替換為其修剪後的版本,同時將原始(未修剪)參數儲存在一個名為name+'_orig'
的新參數中。
注意
遮罩是一個全為 1 的張量。
- 參數
- 回傳
輸入模組的修改(即修剪後)版本
- 回傳類型
module (nn.Module)
範例
>>> m = prune.identity(nn.Linear(2, 3), 'bias') >>> print(m.bias_mask) tensor([1., 1., 1.])