捷徑

PruningContainer

class torch.nn.utils.prune.PruningContainer(*args)[原始碼][原始碼]

用於迭代修剪的修剪方法序列的容器。

追蹤修剪方法的應用順序,並處理組合連續的修剪呼叫。

接受 BasePruningMethod 的實例或它們的可迭代物件作為參數。

add_pruning_method(method)[原始碼][原始碼]

將子剪枝 method 加入到容器中。

參數

method (subclass of BasePruningMethod) – 要加入到容器中的子剪枝方法。

classmethod apply(module, name, *args, importance_scores=None, **kwargs)[source]

動態地增加剪枝並重新參數化張量。

加入前向 pre-hook,以啟用動態剪枝,並根據原始張量和剪枝遮罩重新參數化張量。

參數
  • module (nn.Module) – 包含要剪枝的張量的模組

  • name (str) – module 內將進行剪枝的參數名稱。

  • args – 傳遞給 BasePruningMethod 子類別的引數

  • importance_scores (torch.Tensor) – 重要性分數的張量(與模組參數的形狀相同),用於計算剪枝的遮罩。此張量中的值指示正在剪枝的參數中相應元素的重要性。如果未指定或為 None,將使用該參數代替。

  • kwargs – 傳遞給 BasePruningMethod 子類別的關鍵字引數

apply_mask(module)[source]

簡單地處理正在剪枝的參數和產生的遮罩之間的乘法。

從模組中取得遮罩和原始張量,並傳回張量的剪枝版本。

參數

module (nn.Module) – 包含要剪枝的張量的模組

傳回

輸入張量的剪枝版本

傳回類型

pruned_tensor (torch.Tensor)

compute_mask(t, default_mask)[source][source]

透過計算新的部分遮罩並傳回其與 default_mask 的組合來套用最新的 method

新的部分遮罩應在未被 default_mask 歸零的條目或通道上計算。從張量 t 計算新遮罩的部分取決於 PRUNING_TYPE(由類型處理程式處理)

  • 對於 'unstructured',遮罩將從未遮罩條目的平坦化清單中計算;

  • 對於 'structured',遮罩將從張量中未遮罩的通道計算;

  • 對於 'global',遮罩將在所有條目中計算。

參數
  • t (torch.Tensor) – 代表要剪枝的參數的張量(與 default_mask 的維度相同)。

  • default_mask (torch.Tensor) – 來自先前剪枝迭代的遮罩。

傳回

結合了 default_mask 和來自當前剪枝 method 的新遮罩效果的新遮罩(與 default_maskt 的維度相同)。

傳回類型

mask (torch.Tensor)

prune(t, default_mask=None, importance_scores=None)[source]

計算並傳回輸入張量 t 的剪枝版本。

根據 compute_mask() 中指定的剪枝規則。

參數
  • t (torch.Tensor) – 要剪枝的張量(與 default_mask 的維度相同)。

  • importance_scores (torch.Tensor) – 重要性分數的張量(與 t 的形狀相同),用於計算剪枝 t 的遮罩。此張量中的值指示正在剪枝的 t 中相應元素的重要性。如果未指定或為 None,將使用張量 t 代替。

  • default_mask (torch.Tensor, optional) – 來自先前剪枝迭代的遮罩(如果有的話)。在確定應該對張量的哪一部分進行剪枝時要考慮。如果為 None,則預設為全為 1 的遮罩。

傳回

張量 t 的剪枝版本。

remove(module)[原始碼]

從模組中移除剪枝重新參數化。

名為 name 的已剪枝參數會永久保持剪枝狀態,且名為 name+'_orig' 的參數會從參數列表中移除。 類似地,名為 name+'_mask' 的緩衝區也會從緩衝區中移除。

注意

剪枝本身不會被取消或還原!

文件

存取 PyTorch 的完整開發人員文件

檢視文件

教學

取得針對初學者和進階開發人員的深入教學

檢視教學

資源

尋找開發資源並取得您的問題解答

檢視資源