捷徑

torch.nn.utils.prune.custom_from_mask

torch.nn.utils.prune.custom_from_mask(module, name, mask)[source][source]

透過在 mask 中套用預先計算的遮罩,修剪 module 中稱為 name 的參數對應的張量。

透過以下方式就地修改模組 (並傳回修改後的模組):

  1. 新增一個名為 name+'_mask' 的命名緩衝區,對應於修剪方法套用至參數 name 的二元遮罩。

  2. 將參數 name 替換為其修剪後的版本,同時原始(未修剪)的參數會儲存在名為 name+'_orig' 的新參數中。

參數
  • module (nn.Module) – 包含要修剪之張量的模組

  • name (str) – module 內將進行修剪的參數名稱。

  • mask (Tensor) – 要應用於參數的二元遮罩。

回傳

輸入模組的修改後(即修剪後)版本

回傳類型

module (nn.Module)

範例

>>> from torch.nn.utils import prune
>>> m = prune.custom_from_mask(
...     nn.Linear(5, 3), name='bias', mask=torch.tensor([0, 1, 0])
... )
>>> print(m.bias_mask)
tensor([0., 1., 0.])

文件

取得 PyTorch 的完整開發者文件

檢視文件

教學課程

取得針對初學者和進階開發者的深入教學課程

檢視教學課程

資源

尋找開發資源並取得您的問題解答

檢視資源