torch.nn.utils.prune.custom_from_mask¶
- torch.nn.utils.prune.custom_from_mask(module, name, mask)[source][source]¶
透過在
mask
中套用預先計算的遮罩,修剪module
中稱為name
的參數對應的張量。透過以下方式就地修改模組 (並傳回修改後的模組):
新增一個名為
name+'_mask'
的命名緩衝區,對應於修剪方法套用至參數name
的二元遮罩。將參數
name
替換為其修剪後的版本,同時原始(未修剪)的參數會儲存在名為name+'_orig'
的新參數中。
- 參數
- 回傳
輸入模組的修改後(即修剪後)版本
- 回傳類型
module (nn.Module)
範例
>>> from torch.nn.utils import prune >>> m = prune.custom_from_mask( ... nn.Linear(5, 3), name='bias', mask=torch.tensor([0, 1, 0]) ... ) >>> print(m.bias_mask) tensor([0., 1., 0.])