捷徑

模型整合

建立於:2023 年 3 月 15 日 | 上次更新:2024 年 1 月 16 日 | 上次驗證:2024 年 11 月 05 日

本教學說明如何使用 torch.vmap 向量化模型整合。

什麼是模型整合?

模型整合將多個模型的預測結果結合在一起。傳統上,這是透過分別在一些輸入上執行每個模型,然後結合預測來完成的。但是,如果您正在運行具有相同架構的模型,則可以使用 torch.vmap 將它們組合在一起。vmap 是一個函數轉換,它將函數映射到輸入張量的維度上。它的一個用例是透過向量化消除 for 迴圈並加速它們。

讓我們演示如何使用簡單的 MLP 集合來做到這一點。

注意

本教學需要 PyTorch 2.0.0 或更高版本。

import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(0)

# Here's a simple MLP
class SimpleMLP(nn.Module):
    def __init__(self):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.flatten(1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        return x

讓我們產生一批虛擬數據,並假裝我們正在使用 MNIST 資料集。因此,虛擬影像為 28x28,並且我們有一個大小為 64 的小批量。此外,假設我們要結合來自 10 個不同模型的預測。

device = 'cuda'
num_models = 10

data = torch.randn(100, 64, 1, 28, 28, device=device)
targets = torch.randint(10, (6400,), device=device)

models = [SimpleMLP().to(device) for _ in range(num_models)]

我們有幾種產生預測的選擇。也許我們想給每個模型一個不同的隨機小批量資料。或者,也許我們想透過每個模型運行相同的小批量資料(例如,如果我們正在測試不同模型初始化的效果)。

選項 1:每個模型使用不同的小批量

minibatches = data[:num_models]
predictions_diff_minibatch_loop = [model(minibatch) for model, minibatch in zip(models, minibatches)]

選項 2:使用相同的小批量

minibatch = data[0]
predictions2 = [model(minibatch) for model in models]

使用 vmap 向量化整合

讓我們使用 vmap 來加速 for 迴圈。我們必須首先準備好模型以與 vmap 一起使用。

首先,透過堆疊每個參數來將模型的狀態組合在一起。例如,model[i].fc1.weight 的形狀為 [784, 128];我們將堆疊每個 10 個模型的 .fc1.weight 以產生形狀為 [10, 784, 128] 的大權重。

PyTorch 提供 torch.func.stack_module_state 方便函數來執行此操作。

from torch.func import stack_module_state

params, buffers = stack_module_state(models)

接下來,我們需要定義一個函數來透過 vmap 進行映射。該函數應該在給定參數、緩衝區和輸入的情況下,使用這些參數、緩衝區和輸入來運行模型。我們將使用 torch.func.functional_call 來提供幫助

from torch.func import functional_call
import copy

# Construct a "stateless" version of one of the models. It is "stateless" in
# the sense that the parameters are meta Tensors and do not have storage.
base_model = copy.deepcopy(models[0])
base_model = base_model.to('meta')

def fmodel(params, buffers, x):
    return functional_call(base_model, (params, buffers), (x,))

選項 1:使用每個模型不同的 minibatch 取得預測。

預設情況下,vmap 會將函數映射到傳遞給函數的所有輸入的第一個維度。在使用 stack_module_state 之後,每個 params 和緩衝區在前面都有一個大小為 'num_models' 的額外維度,並且 minibatches 有一個大小為 'num_models' 的維度。

print([p.size(0) for p in params.values()]) # show the leading 'num_models' dimension

assert minibatches.shape == (num_models, 64, 1, 28, 28) # verify minibatch has leading dimension of size 'num_models'

from torch import vmap

predictions1_vmap = vmap(fmodel)(params, buffers, minibatches)

# verify the ``vmap`` predictions match the
assert torch.allclose(predictions1_vmap, torch.stack(predictions_diff_minibatch_loop), atol=1e-3, rtol=1e-5)
[10, 10, 10, 10, 10, 10]

選項 2:使用相同的資料 minibatch 取得預測。

vmap 有一個 in_dims 參數,用於指定要映射的維度。透過使用 None,我們告訴 vmap 我們希望將相同的小批量應用於所有 10 個模型。

predictions2_vmap = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch)

assert torch.allclose(predictions2_vmap, torch.stack(predictions2), atol=1e-3, rtol=1e-5)

快速提示:關於可以透過 vmap 轉換的函數類型存在限制。最好轉換的函數是純函數:輸出僅由沒有副作用的輸入決定的函數(例如,突變)。 vmap 無法處理任意 Python 資料結構的突變,但它可以處理許多就地 PyTorch 操作。

效能

想知道效能數據嗎?以下是相關的數字。

from torch.utils.benchmark import Timer
without_vmap = Timer(
    stmt="[model(minibatch) for model, minibatch in zip(models, minibatches)]",
    globals=globals())
with_vmap = Timer(
    stmt="vmap(fmodel)(params, buffers, minibatches)",
    globals=globals())
print(f'Predictions without vmap {without_vmap.timeit(100)}')
print(f'Predictions with vmap {with_vmap.timeit(100)}')
Predictions without vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7fc7149c5930>
[model(minibatch) for model, minibatch in zip(models, minibatches)]
  2.88 ms
  1 measurement, 100 runs , 1 thread
Predictions with vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7fc7149c5cf0>
vmap(fmodel)(params, buffers, minibatches)
  905.20 us
  1 measurement, 100 runs , 1 thread

使用 vmap 可以大幅加速!

通常來說,使用 vmap 進行向量化應該比在 for 迴圈中執行函式更快,並且能與手動批次處理相媲美。但也有一些例外情況,例如我們尚未為特定操作實作 vmap 規則,或者底層的 Kernel 沒有針對較舊的硬體(GPU)進行最佳化。如果您發現任何這些情況,請在 GitHub 上開啟 Issue 讓我們知道。

腳本總運行時間: ( 0 分鐘 0.939 秒)

由 Sphinx-Gallery 產生

文件

獲取 PyTorch 的完整開發人員文檔

查看文檔

教學課程

獲取適合初學者和進階開發人員的深入教學課程

查看教學課程

資源

尋找開發資源並獲得您的問題解答

查看資源