LnStructured¶
- class torch.nn.utils.prune.LnStructured(amount, n, dim=-1)[來源][來源]¶
根據張量的 L
n
範數,修剪整個(目前未修剪的)通道。- 參數
amount (int 或 float) – 要修剪的通道數量。如果
float
,則應介於 0.0 和 1.0 之間,並表示要修剪的參數比例。如果int
,則表示要修剪的參數的絕對數量。n (int, float, inf, -inf, 'fro', 'nuc') – 請參閱
torch.norm()
中參數p
的有效條目說明文件。dim (int, optional) – 定義要剪枝之通道的維度索引。預設值:-1。
- classmethod apply(module, name, amount, n, dim, importance_scores=None)[原始碼][原始碼]¶
動態新增剪枝和張量的重新參數化。
新增 forward pre-hook,使其能夠動態剪枝,並根據原始張量和剪枝遮罩重新參數化張量。
- 參數
module (nn.Module) – 包含要剪枝之張量的模組
name (str) –
module
中將執行剪枝的參數名稱。amount (int 或 float) – 要剪枝的參數數量。如果
float
,則應介於 0.0 和 1.0 之間,並表示要剪枝的參數比例。如果int
,則表示要剪枝的參數絕對數量。n (int, float, inf, -inf, 'fro', 'nuc') – 請參閱
torch.norm()
中參數p
的有效條目說明文件。dim (int) – 定義要剪枝之通道的維度索引。
importance_scores (torch.Tensor) – 重要性分數的張量(與模組參數形狀相同),用於計算剪枝的遮罩。此張量中的值表示要剪枝之參數中對應元素的重要性。如果未指定或為 None,則將使用模組參數代替。
- apply_mask(module)[原始碼]¶
僅處理正在剪枝的參數與產生的遮罩之間的乘法運算。
從模組中提取遮罩和原始張量,並傳回張量的剪枝版本。
- 參數
module (nn.Module) – 包含要剪枝之張量的模組
- 傳回
輸入張量的剪枝版本
- 傳回類型
pruned_tensor (torch.Tensor)
- compute_mask(t, default_mask)[原始碼][原始碼]¶
計算並傳回輸入張量
t
的遮罩。從基本
default_mask
開始(如果張量尚未被剪枝,則應為一個由 1 組成的遮罩),生成一個遮罩以應用在default_mask
之上,方法是將指定維度上具有最低 Ln
-norm 的通道歸零。- 參數
t (torch.Tensor) – 表示要剪枝之參數的張量
default_mask (torch.Tensor) – 來自先前剪枝疊代的基礎遮罩,在新遮罩應用後需要遵守。與
t
具有相同的維度。
- 傳回
要應用於
t
的遮罩,與t
具有相同的維度- 傳回類型
mask (torch.Tensor)
- 引發
IndexError – 如果
self.dim >= len(t.shape)
- prune(t, default_mask=None, importance_scores=None)[原始碼]¶
計算並傳回輸入張量
t
的剪枝版本。根據
compute_mask()
中指定的剪枝規則。- 參數
t (torch.Tensor) – 要剪枝的張量(與
default_mask
具有相同的維度)。importance_scores (torch.Tensor) – 重要性分數的張量(與
t
的形狀相同),用於計算修剪t
的遮罩。此張量中的值表示t
中對應元素的重要性,t
正在被修剪。 如果未指定或為 None,則將使用張量t
代替。default_mask (torch.Tensor, optional) – 先前修剪迭代的遮罩(如果有的話)。 在確定修剪應作用於張量的哪個部分時,應考慮此遮罩。 如果為 None,則預設為全 1 的遮罩。
- 傳回
張量
t
的修剪版本。