捷徑

torch.nn.utils.prune.random_unstructured

torch.nn.utils.prune.random_unstructured(module, name, amount)[原始碼][原始碼]

透過移除隨機(目前未修剪的)單元來修剪張量。

透過移除指定數量的隨機選擇的(目前未修剪的)單元,修剪 module 中稱為 name 的參數對應的張量。透過以下方式就地修改模組(並返回修改後的模組):

  1. 新增一個名為 name+'_mask' 的命名緩衝區,對應於修剪方法應用於參數 name 的二進位遮罩。

  2. 將參數 name 替換為其修剪後的版本,同時將原始(未修剪)的參數儲存在一個名為 name+'_orig' 的新參數中。

參數
  • module (nn.Module) – 包含要修剪之張量的模組

  • name (str) – 將在其上進行修剪的 module 內的參數名稱。

  • amount (intfloat) – 要修剪的參數數量。 如果是 float,應該介於 0.0 和 1.0 之間,並表示要修剪的參數的比例。 如果是 int,則表示要修剪的參數的絕對數量。

回傳

輸入模組的修改(即修剪)版本

回傳類型

module (nn.Module)

範例

>>> m = prune.random_unstructured(nn.Linear(2, 3), 'weight', amount=1)
>>> torch.sum(m.weight_mask == 0)
tensor(1)

文件

訪問 PyTorch 的綜合開發者文檔

查看文檔

教學

獲取針對初學者和高級開發者的深度教學

查看教學

資源

尋找開發資源並獲得您的問題的解答

查看資源