torch.nn.utils.prune.random_unstructured¶
- torch.nn.utils.prune.random_unstructured(module, name, amount)[原始碼][原始碼]¶
透過移除隨機(目前未修剪的)單元來修剪張量。
透過移除指定數量的隨機選擇的(目前未修剪的)單元,修剪
module
中稱為name
的參數對應的張量。透過以下方式就地修改模組(並返回修改後的模組):新增一個名為
name+'_mask'
的命名緩衝區,對應於修剪方法應用於參數name
的二進位遮罩。將參數
name
替換為其修剪後的版本,同時將原始(未修剪)的參數儲存在一個名為name+'_orig'
的新參數中。
- 參數
- 回傳
輸入模組的修改(即修剪)版本
- 回傳類型
module (nn.Module)
範例
>>> m = prune.random_unstructured(nn.Linear(2, 3), 'weight', amount=1) >>> torch.sum(m.weight_mask == 0) tensor(1)