注意
點擊這裡下載完整的範例程式碼
剪枝教學¶
建立於:2019 年 7 月 22 日 | 最後更新:2023 年 11 月 02 日 | 最後驗證:2024 年 11 月 05 日
作者: Michela Paganini
最先進的深度學習技術依賴於過度參數化的模型,這些模型難以部署。 相反,眾所周知,生物神經網路使用高效的稀疏連接。 識別壓縮模型(透過減少其中的參數數量)的最佳技術非常重要,以便減少記憶體、電池和硬體消耗,同時又不犧牲準確性。 這反過來使您能夠在設備上部署輕量級模型,並透過私有設備上運算來保證隱私。 在研究方面,剪枝用於研究過度參數化和欠參數化網路之間學習動態的差異,研究幸運稀疏子網路和初始化(“彩票”)作為一種破壞性神經架構搜尋技術的角色,等等。
在本教學中,您將學習如何使用 torch.nn.utils.prune
來稀疏化您的神經網路,以及如何擴展它以實作您自己的自訂剪枝技術。
需求¶
"torch>=1.4.0a0+8e8a5e0"
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
建立模型¶
在本教學中,我們使用來自 LeCun et al., 1998 的 LeNet 架構。
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
# 1 input image channel, 6 output channels, 5x5 square conv kernel
self.conv1 = nn.Conv2d(1, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5x5 image dimension
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, int(x.nelement() / x.shape[0]))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = LeNet().to(device=device)
檢查模組¶
讓我們檢查 LeNet 模型中的(未剪枝的)conv1
層。 它將包含兩個參數 weight
和 bias
,目前沒有緩衝區。
module = model.conv1
print(list(module.named_parameters()))
[('weight', Parameter containing:
tensor([[[[ 0.1529, 0.1660, -0.0469, 0.1837, -0.0438],
[ 0.0404, -0.0974, 0.1175, 0.1763, -0.1467],
[ 0.1738, 0.0374, 0.1478, 0.0271, 0.0964],
[-0.0282, 0.1542, 0.0296, -0.0934, 0.0510],
[-0.0921, -0.0235, -0.0812, 0.1327, -0.1579]]],
[[[-0.0922, -0.0565, -0.1203, 0.0189, -0.1975],
[ 0.1806, -0.1699, 0.1544, 0.0333, -0.0649],
[ 0.1236, 0.0312, 0.1616, 0.0219, -0.0631],
[ 0.0537, -0.0542, 0.0842, 0.1786, 0.1156],
[-0.0874, 0.1155, 0.0358, 0.1016, -0.1219]]],
[[[-0.1980, -0.0773, -0.1534, 0.1641, 0.0576],
[ 0.0828, 0.0633, -0.0035, 0.1565, -0.1421],
[ 0.0126, -0.1365, 0.0617, -0.0689, 0.0613],
[-0.0417, 0.1659, -0.1185, -0.1193, -0.1193],
[ 0.1799, 0.0667, 0.1925, -0.1651, -0.1984]]],
[[[-0.1565, -0.1345, 0.0810, 0.0716, 0.1662],
[-0.1033, -0.1363, 0.1061, -0.0808, 0.1214],
[-0.0475, 0.1144, -0.1554, -0.1009, 0.0610],
[ 0.0423, -0.0510, 0.1192, 0.1360, -0.1450],
[-0.1068, 0.1831, -0.0675, -0.0709, -0.1935]]],
[[[-0.1145, 0.0500, -0.0264, -0.1452, 0.0047],
[-0.1366, -0.1697, -0.1101, -0.1750, -0.1273],
[ 0.1999, 0.0378, 0.0616, -0.1865, -0.1314],
[-0.0666, 0.0313, -0.1760, -0.0862, -0.1197],
[ 0.0006, -0.0744, -0.0139, -0.1355, -0.1373]]],
[[[-0.1167, -0.0685, -0.1579, 0.1677, -0.0397],
[ 0.1721, 0.0623, -0.1694, 0.1384, -0.0550],
[-0.0767, -0.1660, -0.1988, 0.0572, -0.0437],
[ 0.0779, -0.1641, 0.1485, -0.1468, -0.0345],
[ 0.0418, 0.1033, 0.1615, 0.1822, -0.1586]]]], device='cuda:0',
requires_grad=True)), ('bias', Parameter containing:
tensor([ 0.0503, -0.0860, -0.0219, -0.1497, 0.1822, -0.1468], device='cuda:0',
requires_grad=True))]
print(list(module.named_buffers()))
[]
剪枝模組¶
要剪枝模組(在本範例中,是 LeNet 架構的 conv1
層),首先從 torch.nn.utils.prune
中可用的剪枝技術中選擇一種(或透過繼承 BasePruningMethod
來 實作您自己的)。 然後,指定模組以及要剪枝的參數在該模組中的名稱。 最後,使用所選剪枝技術所需的適當關鍵字引數,指定剪枝參數。
在本範例中,我們將隨機剪枝 conv1
層中名為 weight
的參數中 30% 的連接。 該模組作為函數的第一個引數傳遞; name
使用其字串識別碼識別該模組中的參數; amount
指示要剪枝的連接的百分比(如果是介於 0. 和 1. 之間的浮點數),或要剪枝的連接的絕對數量(如果是非負整數)。
prune.random_unstructured(module, name="weight", amount=0.3)
Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
剪枝透過從參數中移除 weight
並將其替換為名為 weight_orig
的新參數(即,將 "_orig"
附加到初始參數 name
)來執行。 weight_orig
儲存張量的未剪枝版本。 bias
未被剪枝,因此它將保持完整。
print(list(module.named_parameters()))
[('bias', Parameter containing:
tensor([ 0.0503, -0.0860, -0.0219, -0.1497, 0.1822, -0.1468], device='cuda:0',
requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[ 0.1529, 0.1660, -0.0469, 0.1837, -0.0438],
[ 0.0404, -0.0974, 0.1175, 0.1763, -0.1467],
[ 0.1738, 0.0374, 0.1478, 0.0271, 0.0964],
[-0.0282, 0.1542, 0.0296, -0.0934, 0.0510],
[-0.0921, -0.0235, -0.0812, 0.1327, -0.1579]]],
[[[-0.0922, -0.0565, -0.1203, 0.0189, -0.1975],
[ 0.1806, -0.1699, 0.1544, 0.0333, -0.0649],
[ 0.1236, 0.0312, 0.1616, 0.0219, -0.0631],
[ 0.0537, -0.0542, 0.0842, 0.1786, 0.1156],
[-0.0874, 0.1155, 0.0358, 0.1016, -0.1219]]],
[[[-0.1980, -0.0773, -0.1534, 0.1641, 0.0576],
[ 0.0828, 0.0633, -0.0035, 0.1565, -0.1421],
[ 0.0126, -0.1365, 0.0617, -0.0689, 0.0613],
[-0.0417, 0.1659, -0.1185, -0.1193, -0.1193],
[ 0.1799, 0.0667, 0.1925, -0.1651, -0.1984]]],
[[[-0.1565, -0.1345, 0.0810, 0.0716, 0.1662],
[-0.1033, -0.1363, 0.1061, -0.0808, 0.1214],
[-0.0475, 0.1144, -0.1554, -0.1009, 0.0610],
[ 0.0423, -0.0510, 0.1192, 0.1360, -0.1450],
[-0.1068, 0.1831, -0.0675, -0.0709, -0.1935]]],
[[[-0.1145, 0.0500, -0.0264, -0.1452, 0.0047],
[-0.1366, -0.1697, -0.1101, -0.1750, -0.1273],
[ 0.1999, 0.0378, 0.0616, -0.1865, -0.1314],
[-0.0666, 0.0313, -0.1760, -0.0862, -0.1197],
[ 0.0006, -0.0744, -0.0139, -0.1355, -0.1373]]],
[[[-0.1167, -0.0685, -0.1579, 0.1677, -0.0397],
[ 0.1721, 0.0623, -0.1694, 0.1384, -0.0550],
[-0.0767, -0.1660, -0.1988, 0.0572, -0.0437],
[ 0.0779, -0.1641, 0.1485, -0.1468, -0.0345],
[ 0.0418, 0.1033, 0.1615, 0.1822, -0.1586]]]], device='cuda:0',
requires_grad=True))]
由上述選擇的剪枝技術產生的剪枝遮罩儲存為名為 weight_mask
的模組緩衝區(即,將 "_mask"
附加到初始參數 name
)。
print(list(module.named_buffers()))
[('weight_mask', tensor([[[[1., 1., 1., 1., 1.],
[1., 0., 1., 1., 1.],
[1., 0., 0., 1., 1.],
[1., 0., 1., 1., 1.],
[1., 0., 0., 1., 1.]]],
[[[1., 1., 1., 0., 1.],
[1., 1., 1., 1., 1.],
[0., 1., 1., 1., 0.],
[1., 1., 0., 1., 0.],
[0., 1., 0., 1., 1.]]],
[[[1., 0., 0., 0., 1.],
[1., 0., 1., 1., 0.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 0., 1., 1., 0.]]],
[[[1., 1., 1., 1., 1.],
[0., 1., 1., 1., 0.],
[1., 1., 1., 0., 1.],
[0., 0., 1., 1., 1.],
[1., 1., 0., 1., 1.]]],
[[[1., 0., 1., 1., 1.],
[1., 1., 0., 0., 0.],
[1., 1., 0., 0., 0.],
[0., 1., 1., 0., 1.],
[1., 0., 0., 0., 1.]]],
[[[1., 0., 1., 0., 1.],
[0., 1., 1., 1., 1.],
[1., 1., 0., 1., 0.],
[1., 1., 1., 1., 1.],
[1., 0., 0., 1., 1.]]]], device='cuda:0'))]
為了使前向傳遞在不修改的情況下運作,weight
屬性需要存在。torch.nn.utils.prune
中實作的剪枝技術會計算權重的剪枝版本(通過將遮罩與原始參數結合),並將其儲存在 weight
屬性中。請注意,這不再是 module
的參數,它現在只是一個屬性。
print(module.weight)
tensor([[[[ 0.1529, 0.1660, -0.0469, 0.1837, -0.0438],
[ 0.0404, -0.0000, 0.1175, 0.1763, -0.1467],
[ 0.1738, 0.0000, 0.0000, 0.0271, 0.0964],
[-0.0282, 0.0000, 0.0296, -0.0934, 0.0510],
[-0.0921, -0.0000, -0.0000, 0.1327, -0.1579]]],
[[[-0.0922, -0.0565, -0.1203, 0.0000, -0.1975],
[ 0.1806, -0.1699, 0.1544, 0.0333, -0.0649],
[ 0.0000, 0.0312, 0.1616, 0.0219, -0.0000],
[ 0.0537, -0.0542, 0.0000, 0.1786, 0.0000],
[-0.0000, 0.1155, 0.0000, 0.1016, -0.1219]]],
[[[-0.1980, -0.0000, -0.0000, 0.0000, 0.0576],
[ 0.0828, 0.0000, -0.0035, 0.1565, -0.0000],
[ 0.0126, -0.1365, 0.0617, -0.0689, 0.0613],
[-0.0417, 0.1659, -0.1185, -0.1193, -0.1193],
[ 0.1799, 0.0000, 0.1925, -0.1651, -0.0000]]],
[[[-0.1565, -0.1345, 0.0810, 0.0716, 0.1662],
[-0.0000, -0.1363, 0.1061, -0.0808, 0.0000],
[-0.0475, 0.1144, -0.1554, -0.0000, 0.0610],
[ 0.0000, -0.0000, 0.1192, 0.1360, -0.1450],
[-0.1068, 0.1831, -0.0000, -0.0709, -0.1935]]],
[[[-0.1145, 0.0000, -0.0264, -0.1452, 0.0047],
[-0.1366, -0.1697, -0.0000, -0.0000, -0.0000],
[ 0.1999, 0.0378, 0.0000, -0.0000, -0.0000],
[-0.0000, 0.0313, -0.1760, -0.0000, -0.1197],
[ 0.0006, -0.0000, -0.0000, -0.0000, -0.1373]]],
[[[-0.1167, -0.0000, -0.1579, 0.0000, -0.0397],
[ 0.0000, 0.0623, -0.1694, 0.1384, -0.0550],
[-0.0767, -0.1660, -0.0000, 0.0572, -0.0000],
[ 0.0779, -0.1641, 0.1485, -0.1468, -0.0345],
[ 0.0418, 0.0000, 0.0000, 0.1822, -0.1586]]]], device='cuda:0',
grad_fn=<MulBackward0>)
最後,在每次前向傳遞之前,使用 PyTorch 的 forward_pre_hooks
應用剪枝。具體來說,當 module
被剪枝時(就像我們在這裡所做的那樣),它會為每個與其關聯的、被剪枝的參數獲取一個 forward_pre_hook
。在這種情況下,由於到目前為止我們只剪枝了名為 weight
的原始參數,因此只會存在一個 hook。
print(module._forward_pre_hooks)
OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7fd608603490>)])
為了完整性,我們現在也可以剪枝 bias
,以查看 module
的參數、緩衝區、hook 和屬性如何變化。為了嘗試另一種剪枝技術,這裡我們使用 l1_unstructured
剪枝函數中實作的 L1 範數,剪枝 bias 中最小的 3 個條目。
prune.l1_unstructured(module, name="bias", amount=3)
Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
現在我們期望具名參數包含 weight_orig
(來自之前)和 bias_orig
。緩衝區將包含 weight_mask
和 bias_mask
。兩個張量的剪枝版本將作為 module 屬性存在,並且 module 現在將有兩個 forward_pre_hooks
。
print(list(module.named_parameters()))
[('weight_orig', Parameter containing:
tensor([[[[ 0.1529, 0.1660, -0.0469, 0.1837, -0.0438],
[ 0.0404, -0.0974, 0.1175, 0.1763, -0.1467],
[ 0.1738, 0.0374, 0.1478, 0.0271, 0.0964],
[-0.0282, 0.1542, 0.0296, -0.0934, 0.0510],
[-0.0921, -0.0235, -0.0812, 0.1327, -0.1579]]],
[[[-0.0922, -0.0565, -0.1203, 0.0189, -0.1975],
[ 0.1806, -0.1699, 0.1544, 0.0333, -0.0649],
[ 0.1236, 0.0312, 0.1616, 0.0219, -0.0631],
[ 0.0537, -0.0542, 0.0842, 0.1786, 0.1156],
[-0.0874, 0.1155, 0.0358, 0.1016, -0.1219]]],
[[[-0.1980, -0.0773, -0.1534, 0.1641, 0.0576],
[ 0.0828, 0.0633, -0.0035, 0.1565, -0.1421],
[ 0.0126, -0.1365, 0.0617, -0.0689, 0.0613],
[-0.0417, 0.1659, -0.1185, -0.1193, -0.1193],
[ 0.1799, 0.0667, 0.1925, -0.1651, -0.1984]]],
[[[-0.1565, -0.1345, 0.0810, 0.0716, 0.1662],
[-0.1033, -0.1363, 0.1061, -0.0808, 0.1214],
[-0.0475, 0.1144, -0.1554, -0.1009, 0.0610],
[ 0.0423, -0.0510, 0.1192, 0.1360, -0.1450],
[-0.1068, 0.1831, -0.0675, -0.0709, -0.1935]]],
[[[-0.1145, 0.0500, -0.0264, -0.1452, 0.0047],
[-0.1366, -0.1697, -0.1101, -0.1750, -0.1273],
[ 0.1999, 0.0378, 0.0616, -0.1865, -0.1314],
[-0.0666, 0.0313, -0.1760, -0.0862, -0.1197],
[ 0.0006, -0.0744, -0.0139, -0.1355, -0.1373]]],
[[[-0.1167, -0.0685, -0.1579, 0.1677, -0.0397],
[ 0.1721, 0.0623, -0.1694, 0.1384, -0.0550],
[-0.0767, -0.1660, -0.1988, 0.0572, -0.0437],
[ 0.0779, -0.1641, 0.1485, -0.1468, -0.0345],
[ 0.0418, 0.1033, 0.1615, 0.1822, -0.1586]]]], device='cuda:0',
requires_grad=True)), ('bias_orig', Parameter containing:
tensor([ 0.0503, -0.0860, -0.0219, -0.1497, 0.1822, -0.1468], device='cuda:0',
requires_grad=True))]
print(list(module.named_buffers()))
[('weight_mask', tensor([[[[1., 1., 1., 1., 1.],
[1., 0., 1., 1., 1.],
[1., 0., 0., 1., 1.],
[1., 0., 1., 1., 1.],
[1., 0., 0., 1., 1.]]],
[[[1., 1., 1., 0., 1.],
[1., 1., 1., 1., 1.],
[0., 1., 1., 1., 0.],
[1., 1., 0., 1., 0.],
[0., 1., 0., 1., 1.]]],
[[[1., 0., 0., 0., 1.],
[1., 0., 1., 1., 0.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 0., 1., 1., 0.]]],
[[[1., 1., 1., 1., 1.],
[0., 1., 1., 1., 0.],
[1., 1., 1., 0., 1.],
[0., 0., 1., 1., 1.],
[1., 1., 0., 1., 1.]]],
[[[1., 0., 1., 1., 1.],
[1., 1., 0., 0., 0.],
[1., 1., 0., 0., 0.],
[0., 1., 1., 0., 1.],
[1., 0., 0., 0., 1.]]],
[[[1., 0., 1., 0., 1.],
[0., 1., 1., 1., 1.],
[1., 1., 0., 1., 0.],
[1., 1., 1., 1., 1.],
[1., 0., 0., 1., 1.]]]], device='cuda:0')), ('bias_mask', tensor([0., 0., 0., 1., 1., 1.], device='cuda:0'))]
print(module.bias)
tensor([ 0.0000, -0.0000, -0.0000, -0.1497, 0.1822, -0.1468], device='cuda:0',
grad_fn=<MulBackward0>)
print(module._forward_pre_hooks)
OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7fd608603490>), (1, <torch.nn.utils.prune.L1Unstructured object at 0x7fd608602d70>)])
迭代剪枝¶
可以對 module 中的相同參數進行多次剪枝,各種剪枝調用的效果等於串聯應用的各種遮罩的組合。新遮罩與舊遮罩的組合由 PruningContainer
的 compute_mask
方法處理。
例如,假設我們現在要進一步剪枝 module.weight
,這次使用沿張量的第 0 軸(第 0 軸對應於卷積層的輸出通道,對於 conv1
而言,維度為 6)的結構化剪枝,基於通道的 L2 範數。這可以使用 ln_structured
函數來實現,其中 n=2
且 dim=0
。
prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)
# As we can verify, this will zero out all the connections corresponding to
# 50% (3 out of 6) of the channels, while preserving the action of the
# previous mask.
print(module.weight)
tensor([[[[ 0.0000, 0.0000, -0.0000, 0.0000, -0.0000],
[ 0.0000, -0.0000, 0.0000, 0.0000, -0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[-0.0000, 0.0000, 0.0000, -0.0000, 0.0000],
[-0.0000, -0.0000, -0.0000, 0.0000, -0.0000]]],
[[[-0.0000, -0.0000, -0.0000, 0.0000, -0.0000],
[ 0.0000, -0.0000, 0.0000, 0.0000, -0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, -0.0000],
[ 0.0000, -0.0000, 0.0000, 0.0000, 0.0000],
[-0.0000, 0.0000, 0.0000, 0.0000, -0.0000]]],
[[[-0.1980, -0.0000, -0.0000, 0.0000, 0.0576],
[ 0.0828, 0.0000, -0.0035, 0.1565, -0.0000],
[ 0.0126, -0.1365, 0.0617, -0.0689, 0.0613],
[-0.0417, 0.1659, -0.1185, -0.1193, -0.1193],
[ 0.1799, 0.0000, 0.1925, -0.1651, -0.0000]]],
[[[-0.1565, -0.1345, 0.0810, 0.0716, 0.1662],
[-0.0000, -0.1363, 0.1061, -0.0808, 0.0000],
[-0.0475, 0.1144, -0.1554, -0.0000, 0.0610],
[ 0.0000, -0.0000, 0.1192, 0.1360, -0.1450],
[-0.1068, 0.1831, -0.0000, -0.0709, -0.1935]]],
[[[-0.0000, 0.0000, -0.0000, -0.0000, 0.0000],
[-0.0000, -0.0000, -0.0000, -0.0000, -0.0000],
[ 0.0000, 0.0000, 0.0000, -0.0000, -0.0000],
[-0.0000, 0.0000, -0.0000, -0.0000, -0.0000],
[ 0.0000, -0.0000, -0.0000, -0.0000, -0.0000]]],
[[[-0.1167, -0.0000, -0.1579, 0.0000, -0.0397],
[ 0.0000, 0.0623, -0.1694, 0.1384, -0.0550],
[-0.0767, -0.1660, -0.0000, 0.0572, -0.0000],
[ 0.0779, -0.1641, 0.1485, -0.1468, -0.0345],
[ 0.0418, 0.0000, 0.0000, 0.1822, -0.1586]]]], device='cuda:0',
grad_fn=<MulBackward0>)
相應的 hook 現在將屬於 torch.nn.utils.prune.PruningContainer
類型,並將儲存應用於 weight
參數的剪枝歷史記錄。
[<torch.nn.utils.prune.RandomUnstructured object at 0x7fd608603490>, <torch.nn.utils.prune.LnStructured object at 0x7fd608602cb0>]
序列化剪枝後的模型¶
所有相關的張量,包括遮罩緩衝區和用於計算剪枝張量的原始參數,都儲存在模型的 state_dict
中,因此可以輕鬆地序列化和儲存,如果需要的話。
print(model.state_dict().keys())
odict_keys(['conv1.weight_orig', 'conv1.bias_orig', 'conv1.weight_mask', 'conv1.bias_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])
移除剪枝的重新參數化¶
為了使剪枝永久化,移除以 weight_orig
和 weight_mask
表示的重新參數化,並移除 forward_pre_hook
,我們可以使用 torch.nn.utils.prune
中的 remove
功能。請注意,這不會撤銷剪枝,就好像它從未發生過一樣。它只是透過將參數 weight
重新分配給模型的參數(以其剪枝版本),使其永久化。
在移除重新參數化之前
print(list(module.named_parameters()))
[('weight_orig', Parameter containing:
tensor([[[[ 0.1529, 0.1660, -0.0469, 0.1837, -0.0438],
[ 0.0404, -0.0974, 0.1175, 0.1763, -0.1467],
[ 0.1738, 0.0374, 0.1478, 0.0271, 0.0964],
[-0.0282, 0.1542, 0.0296, -0.0934, 0.0510],
[-0.0921, -0.0235, -0.0812, 0.1327, -0.1579]]],
[[[-0.0922, -0.0565, -0.1203, 0.0189, -0.1975],
[ 0.1806, -0.1699, 0.1544, 0.0333, -0.0649],
[ 0.1236, 0.0312, 0.1616, 0.0219, -0.0631],
[ 0.0537, -0.0542, 0.0842, 0.1786, 0.1156],
[-0.0874, 0.1155, 0.0358, 0.1016, -0.1219]]],
[[[-0.1980, -0.0773, -0.1534, 0.1641, 0.0576],
[ 0.0828, 0.0633, -0.0035, 0.1565, -0.1421],
[ 0.0126, -0.1365, 0.0617, -0.0689, 0.0613],
[-0.0417, 0.1659, -0.1185, -0.1193, -0.1193],
[ 0.1799, 0.0667, 0.1925, -0.1651, -0.1984]]],
[[[-0.1565, -0.1345, 0.0810, 0.0716, 0.1662],
[-0.1033, -0.1363, 0.1061, -0.0808, 0.1214],
[-0.0475, 0.1144, -0.1554, -0.1009, 0.0610],
[ 0.0423, -0.0510, 0.1192, 0.1360, -0.1450],
[-0.1068, 0.1831, -0.0675, -0.0709, -0.1935]]],
[[[-0.1145, 0.0500, -0.0264, -0.1452, 0.0047],
[-0.1366, -0.1697, -0.1101, -0.1750, -0.1273],
[ 0.1999, 0.0378, 0.0616, -0.1865, -0.1314],
[-0.0666, 0.0313, -0.1760, -0.0862, -0.1197],
[ 0.0006, -0.0744, -0.0139, -0.1355, -0.1373]]],
[[[-0.1167, -0.0685, -0.1579, 0.1677, -0.0397],
[ 0.1721, 0.0623, -0.1694, 0.1384, -0.0550],
[-0.0767, -0.1660, -0.1988, 0.0572, -0.0437],
[ 0.0779, -0.1641, 0.1485, -0.1468, -0.0345],
[ 0.0418, 0.1033, 0.1615, 0.1822, -0.1586]]]], device='cuda:0',
requires_grad=True)), ('bias_orig', Parameter containing:
tensor([ 0.0503, -0.0860, -0.0219, -0.1497, 0.1822, -0.1468], device='cuda:0',
requires_grad=True))]
print(list(module.named_buffers()))
[('weight_mask', tensor([[[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]],
[[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]],
[[[1., 0., 0., 0., 1.],
[1., 0., 1., 1., 0.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 0., 1., 1., 0.]]],
[[[1., 1., 1., 1., 1.],
[0., 1., 1., 1., 0.],
[1., 1., 1., 0., 1.],
[0., 0., 1., 1., 1.],
[1., 1., 0., 1., 1.]]],
[[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]],
[[[1., 0., 1., 0., 1.],
[0., 1., 1., 1., 1.],
[1., 1., 0., 1., 0.],
[1., 1., 1., 1., 1.],
[1., 0., 0., 1., 1.]]]], device='cuda:0')), ('bias_mask', tensor([0., 0., 0., 1., 1., 1.], device='cuda:0'))]
print(module.weight)
tensor([[[[ 0.0000, 0.0000, -0.0000, 0.0000, -0.0000],
[ 0.0000, -0.0000, 0.0000, 0.0000, -0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[-0.0000, 0.0000, 0.0000, -0.0000, 0.0000],
[-0.0000, -0.0000, -0.0000, 0.0000, -0.0000]]],
[[[-0.0000, -0.0000, -0.0000, 0.0000, -0.0000],
[ 0.0000, -0.0000, 0.0000, 0.0000, -0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, -0.0000],
[ 0.0000, -0.0000, 0.0000, 0.0000, 0.0000],
[-0.0000, 0.0000, 0.0000, 0.0000, -0.0000]]],
[[[-0.1980, -0.0000, -0.0000, 0.0000, 0.0576],
[ 0.0828, 0.0000, -0.0035, 0.1565, -0.0000],
[ 0.0126, -0.1365, 0.0617, -0.0689, 0.0613],
[-0.0417, 0.1659, -0.1185, -0.1193, -0.1193],
[ 0.1799, 0.0000, 0.1925, -0.1651, -0.0000]]],
[[[-0.1565, -0.1345, 0.0810, 0.0716, 0.1662],
[-0.0000, -0.1363, 0.1061, -0.0808, 0.0000],
[-0.0475, 0.1144, -0.1554, -0.0000, 0.0610],
[ 0.0000, -0.0000, 0.1192, 0.1360, -0.1450],
[-0.1068, 0.1831, -0.0000, -0.0709, -0.1935]]],
[[[-0.0000, 0.0000, -0.0000, -0.0000, 0.0000],
[-0.0000, -0.0000, -0.0000, -0.0000, -0.0000],
[ 0.0000, 0.0000, 0.0000, -0.0000, -0.0000],
[-0.0000, 0.0000, -0.0000, -0.0000, -0.0000],
[ 0.0000, -0.0000, -0.0000, -0.0000, -0.0000]]],
[[[-0.1167, -0.0000, -0.1579, 0.0000, -0.0397],
[ 0.0000, 0.0623, -0.1694, 0.1384, -0.0550],
[-0.0767, -0.1660, -0.0000, 0.0572, -0.0000],
[ 0.0779, -0.1641, 0.1485, -0.1468, -0.0345],
[ 0.0418, 0.0000, 0.0000, 0.1822, -0.1586]]]], device='cuda:0',
grad_fn=<MulBackward0>)
在移除重新參數化之後
prune.remove(module, 'weight')
print(list(module.named_parameters()))
[('bias_orig', Parameter containing:
tensor([ 0.0503, -0.0860, -0.0219, -0.1497, 0.1822, -0.1468], device='cuda:0',
requires_grad=True)), ('weight', Parameter containing:
tensor([[[[ 0.0000, 0.0000, -0.0000, 0.0000, -0.0000],
[ 0.0000, -0.0000, 0.0000, 0.0000, -0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[-0.0000, 0.0000, 0.0000, -0.0000, 0.0000],
[-0.0000, -0.0000, -0.0000, 0.0000, -0.0000]]],
[[[-0.0000, -0.0000, -0.0000, 0.0000, -0.0000],
[ 0.0000, -0.0000, 0.0000, 0.0000, -0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, -0.0000],
[ 0.0000, -0.0000, 0.0000, 0.0000, 0.0000],
[-0.0000, 0.0000, 0.0000, 0.0000, -0.0000]]],
[[[-0.1980, -0.0000, -0.0000, 0.0000, 0.0576],
[ 0.0828, 0.0000, -0.0035, 0.1565, -0.0000],
[ 0.0126, -0.1365, 0.0617, -0.0689, 0.0613],
[-0.0417, 0.1659, -0.1185, -0.1193, -0.1193],
[ 0.1799, 0.0000, 0.1925, -0.1651, -0.0000]]],
[[[-0.1565, -0.1345, 0.0810, 0.0716, 0.1662],
[-0.0000, -0.1363, 0.1061, -0.0808, 0.0000],
[-0.0475, 0.1144, -0.1554, -0.0000, 0.0610],
[ 0.0000, -0.0000, 0.1192, 0.1360, -0.1450],
[-0.1068, 0.1831, -0.0000, -0.0709, -0.1935]]],
[[[-0.0000, 0.0000, -0.0000, -0.0000, 0.0000],
[-0.0000, -0.0000, -0.0000, -0.0000, -0.0000],
[ 0.0000, 0.0000, 0.0000, -0.0000, -0.0000],
[-0.0000, 0.0000, -0.0000, -0.0000, -0.0000],
[ 0.0000, -0.0000, -0.0000, -0.0000, -0.0000]]],
[[[-0.1167, -0.0000, -0.1579, 0.0000, -0.0397],
[ 0.0000, 0.0623, -0.1694, 0.1384, -0.0550],
[-0.0767, -0.1660, -0.0000, 0.0572, -0.0000],
[ 0.0779, -0.1641, 0.1485, -0.1468, -0.0345],
[ 0.0418, 0.0000, 0.0000, 0.1822, -0.1586]]]], device='cuda:0',
requires_grad=True))]
print(list(module.named_buffers()))
[('bias_mask', tensor([0., 0., 0., 1., 1., 1.], device='cuda:0'))]
剪枝模型中的多個參數¶
透過指定所需的剪枝技術和參數,我們可以輕鬆地剪枝網路中的多個張量,也許可以根據它們的類型,正如我們將在本範例中看到的那樣。
new_model = LeNet()
for name, module in new_model.named_modules():
# prune 20% of connections in all 2D-conv layers
if isinstance(module, torch.nn.Conv2d):
prune.l1_unstructured(module, name='weight', amount=0.2)
# prune 40% of connections in all linear layers
elif isinstance(module, torch.nn.Linear):
prune.l1_unstructured(module, name='weight', amount=0.4)
print(dict(new_model.named_buffers()).keys()) # to verify that all masks exist
dict_keys(['conv1.weight_mask', 'conv2.weight_mask', 'fc1.weight_mask', 'fc2.weight_mask', 'fc3.weight_mask'])
全域剪枝¶
到目前為止,我們只關注通常所說的「局部」剪枝,也就是說,透過將每個條目的統計數據(權重幅度、激活、梯度等)僅與該張量中的其他條目進行比較,來逐一剪枝模型中的張量。然而,一個常見且可能更強大的技術是同時剪枝整個模型,例如,移除整個模型中最低 20% 的連接,而不是移除每一層中最低 20% 的連接。這可能會導致每層不同的剪枝百分比。讓我們看看如何使用 torch.nn.utils.prune
中的 global_unstructured
來實現這一點。
model = LeNet()
parameters_to_prune = (
(model.conv1, 'weight'),
(model.conv2, 'weight'),
(model.fc1, 'weight'),
(model.fc2, 'weight'),
(model.fc3, 'weight'),
)
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.2,
)
現在我們可以檢查每個剪枝參數中產生的稀疏性,它們在每一層中都不會等於 20%。但是,全域稀疏性將(近似)為 20%。
print(
"Sparsity in conv1.weight: {:.2f}%".format(
100. * float(torch.sum(model.conv1.weight == 0))
/ float(model.conv1.weight.nelement())
)
)
print(
"Sparsity in conv2.weight: {:.2f}%".format(
100. * float(torch.sum(model.conv2.weight == 0))
/ float(model.conv2.weight.nelement())
)
)
print(
"Sparsity in fc1.weight: {:.2f}%".format(
100. * float(torch.sum(model.fc1.weight == 0))
/ float(model.fc1.weight.nelement())
)
)
print(
"Sparsity in fc2.weight: {:.2f}%".format(
100. * float(torch.sum(model.fc2.weight == 0))
/ float(model.fc2.weight.nelement())
)
)
print(
"Sparsity in fc3.weight: {:.2f}%".format(
100. * float(torch.sum(model.fc3.weight == 0))
/ float(model.fc3.weight.nelement())
)
)
print(
"Global sparsity: {:.2f}%".format(
100. * float(
torch.sum(model.conv1.weight == 0)
+ torch.sum(model.conv2.weight == 0)
+ torch.sum(model.fc1.weight == 0)
+ torch.sum(model.fc2.weight == 0)
+ torch.sum(model.fc3.weight == 0)
)
/ float(
model.conv1.weight.nelement()
+ model.conv2.weight.nelement()
+ model.fc1.weight.nelement()
+ model.fc2.weight.nelement()
+ model.fc3.weight.nelement()
)
)
)
Sparsity in conv1.weight: 4.67%
Sparsity in conv2.weight: 13.92%
Sparsity in fc1.weight: 22.16%
Sparsity in fc2.weight: 12.10%
Sparsity in fc3.weight: 11.31%
Global sparsity: 20.00%
使用自訂剪枝函數擴充 torch.nn.utils.prune
¶
若要實作您自己的剪枝函式,您可以擴充 nn.utils.prune
模組,方法是繼承 BasePruningMethod
基礎類別,這也是所有其他剪枝方法所用的方式。基礎類別會為您實作以下方法: __call__
、 apply_mask
、 apply
、 prune
和 remove
。除了某些特殊情況,您不應該需要為您的新剪枝技術重新實作這些方法。然而,您將必須實作 __init__
(建構函式) 和 compute_mask
(根據您的剪枝技術邏輯,關於如何計算給定張量遮罩的指示)。此外,您還必須指定此技術實作的剪枝類型 (支援的選項為 global
、 structured
和 unstructured
)。這是為了確定在迭代應用剪枝的情況下,如何組合遮罩。換句話說,當剪枝一個預先剪枝過的參數時,目前的剪枝技術預期會作用於該參數的未剪枝部分。指定 PRUNING_TYPE
將使 PruningContainer
(處理剪枝遮罩的迭代應用) 能夠正確識別要剪枝的參數切片。
舉例來說,假設您想要實作一種剪枝技術,剪掉張量中的每隔一個條目 (或者,如果張量先前已被剪枝,則剪掉張量中剩餘的未剪枝部分)。這將屬於 PRUNING_TYPE='unstructured'
,因為它作用於層中的個別連接,而不是作用於整個單元/通道 ('structured'
) 或跨不同參數 ('global'
)。
class FooBarPruningMethod(prune.BasePruningMethod):
"""Prune every other entry in a tensor
"""
PRUNING_TYPE = 'unstructured'
def compute_mask(self, t, default_mask):
mask = default_mask.clone()
mask.view(-1)[::2] = 0
return mask
現在,若要將此應用於 nn.Module
中的參數,您還應該提供一個簡單的函式來實例化該方法並應用它。
def foobar_unstructured(module, name):
"""Prunes tensor corresponding to parameter called `name` in `module`
by removing every other entry in the tensors.
Modifies module in place (and also return the modified module)
by:
1) adding a named buffer called `name+'_mask'` corresponding to the
binary mask applied to the parameter `name` by the pruning method.
The parameter `name` is replaced by its pruned version, while the
original (unpruned) parameter is stored in a new parameter named
`name+'_orig'`.
Args:
module (nn.Module): module containing the tensor to prune
name (string): parameter name within `module` on which pruning
will act.
Returns:
module (nn.Module): modified (i.e. pruned) version of the input
module
Examples:
>>> m = nn.Linear(3, 4)
>>> foobar_unstructured(m, name='bias')
"""
FooBarPruningMethod.apply(module, name)
return module
讓我們試試看!
model = LeNet()
foobar_unstructured(model.fc3, name='bias')
print(model.fc3.bias_mask)
tensor([0., 1., 0., 1., 0., 1., 0., 1., 0., 1.])
腳本總執行時間: (0 分鐘 0.295 秒)