捷徑

torch.nn.utils.prune.global_unstructured

torch.nn.utils.prune.global_unstructured(parameters, pruning_method, importance_scores=None, **kwargs)[原始碼][原始碼]

透過應用指定的 pruning_method,全域地修剪與 parameters 中的所有參數相對應的張量。

透過以下方式就地修改模組

  1. 新增一個名為 name+'_mask' 的具名緩衝區,對應於由修剪方法應用於參數 name 的二元遮罩。

  2. 將參數 name 替換為其修剪後的版本,同時原始(未修剪)的參數儲存在一個名為 name+'_orig' 的新參數中。

參數
  • parameters (Iterable of (module, name) tuples) – 要以全域方式修剪的模型參數,也就是在決定要修剪哪些權重之前,先匯總所有權重。module 必須是 nn.Module 類型,name 必須是字串。

  • pruning_method (function) – 此模組中有效的修剪函式,或是使用者實作的自訂函式,該函式符合實作指南且具有 PRUNING_TYPE='unstructured'

  • importance_scores (dict) – 一個字典,將 (module, name) 元組映射到對應參數的重要性分數張量。該張量應與參數的形狀相同,並用於計算修剪的遮罩。如果未指定或為 None,則將使用參數代替其重要性分數。

  • kwargs – 其他關鍵字參數,例如:amount(int 或 float):要在指定的參數中修剪的參數數量。如果 float,則應介於 0.0 和 1.0 之間,並表示要修剪的參數比例。如果 int,則表示要修剪的參數的絕對數量。

引發

TypeError – 如果 PRUNING_TYPE != 'unstructured'

注意

由於除非範數按參數的大小進行正規化,否則全域結構化修剪沒有太大意義,因此我們現在將全域修剪的範圍限制為非結構化方法。

範例

>>> from torch.nn.utils import prune
>>> from collections import OrderedDict
>>> net = nn.Sequential(OrderedDict([
...     ('first', nn.Linear(10, 4)),
...     ('second', nn.Linear(4, 1)),
... ]))
>>> parameters_to_prune = (
...     (net.first, 'weight'),
...     (net.second, 'weight'),
... )
>>> prune.global_unstructured(
...     parameters_to_prune,
...     pruning_method=prune.L1Unstructured,
...     amount=10,
... )
>>> print(sum(torch.nn.utils.parameters_to_vector(net.buffers()) == 0))
tensor(10)

文件

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources