PruningContainer¶
- class torch.nn.utils.prune.PruningContainer(*args)[原始碼][原始碼]¶
用於迭代修剪的修剪方法序列的容器。
追蹤修剪方法的應用順序,並處理組合連續的修剪呼叫。
接受 BasePruningMethod 的實例或它們的可迭代物件作為參數。
- add_pruning_method(method)[原始碼][原始碼]¶
將子剪枝
method
加入到容器中。- 參數
method (subclass of BasePruningMethod) – 要加入到容器中的子剪枝方法。
- classmethod apply(module, name, *args, importance_scores=None, **kwargs)[source]¶
動態地增加剪枝並重新參數化張量。
加入前向 pre-hook,以啟用動態剪枝,並根據原始張量和剪枝遮罩重新參數化張量。
- 參數
module (nn.Module) – 包含要剪枝的張量的模組
name (str) –
module
內將進行剪枝的參數名稱。args – 傳遞給
BasePruningMethod
子類別的引數importance_scores (torch.Tensor) – 重要性分數的張量(與模組參數的形狀相同),用於計算剪枝的遮罩。此張量中的值指示正在剪枝的參數中相應元素的重要性。如果未指定或為 None,將使用該參數代替。
kwargs – 傳遞給
BasePruningMethod
子類別的關鍵字引數
- apply_mask(module)[source]¶
簡單地處理正在剪枝的參數和產生的遮罩之間的乘法。
從模組中取得遮罩和原始張量,並傳回張量的剪枝版本。
- 參數
module (nn.Module) – 包含要剪枝的張量的模組
- 傳回
輸入張量的剪枝版本
- 傳回類型
pruned_tensor (torch.Tensor)
- compute_mask(t, default_mask)[source][source]¶
透過計算新的部分遮罩並傳回其與
default_mask
的組合來套用最新的method
。新的部分遮罩應在未被
default_mask
歸零的條目或通道上計算。從張量t
計算新遮罩的部分取決於PRUNING_TYPE
(由類型處理程式處理)對於 'unstructured',遮罩將從未遮罩條目的平坦化清單中計算;
對於 'structured',遮罩將從張量中未遮罩的通道計算;
對於 'global',遮罩將在所有條目中計算。
- 參數
t (torch.Tensor) – 代表要剪枝的參數的張量(與
default_mask
的維度相同)。default_mask (torch.Tensor) – 來自先前剪枝迭代的遮罩。
- 傳回
結合了
default_mask
和來自當前剪枝method
的新遮罩效果的新遮罩(與default_mask
和t
的維度相同)。- 傳回類型
mask (torch.Tensor)
- prune(t, default_mask=None, importance_scores=None)[source]¶
計算並傳回輸入張量
t
的剪枝版本。根據
compute_mask()
中指定的剪枝規則。- 參數
t (torch.Tensor) – 要剪枝的張量(與
default_mask
的維度相同)。importance_scores (torch.Tensor) – 重要性分數的張量(與
t
的形狀相同),用於計算剪枝t
的遮罩。此張量中的值指示正在剪枝的t
中相應元素的重要性。如果未指定或為 None,將使用張量t
代替。default_mask (torch.Tensor, optional) – 來自先前剪枝迭代的遮罩(如果有的話)。在確定應該對張量的哪一部分進行剪枝時要考慮。如果為 None,則預設為全為 1 的遮罩。
- 傳回
張量
t
的剪枝版本。