注意
點擊這裡下載完整的範例程式碼
TorchRec 簡介¶
建立於:2024 年 10 月 02 日 | 最後更新:2024 年 10 月 10 日 | 最後驗證:2024 年 10 月 02 日
TorchRec 是一個 PyTorch 函式庫,專為使用嵌入 (embeddings) 建立可擴展且有效率的推薦系統而設計。本教學將引導您完成安裝過程,介紹嵌入的概念,並強調其在推薦系統中的重要性。它提供了在 PyTorch 和 TorchRec 中實作嵌入的實用示範,重點在於透過分散式訓練和進階最佳化來處理大型嵌入表。
嵌入的基本原理及其在推薦系統中的作用
如何設定 TorchRec 以在 PyTorch 環境中管理和實作嵌入
探索跨多個 GPU 分散大型嵌入表的進階技術
PyTorch v2.5 或更新版本,搭配 CUDA 11.8 或更新版本
Python 3.9 或更新版本
安裝相依性¶
在 Google Colab 或其他環境中執行本教學之前,請安裝以下相依性
!pip3 install --pre torch --index-url https://download.pytorch.org/whl/cu121 -U
!pip3 install fbgemm_gpu --index-url https://download.pytorch.org/whl/cu121
!pip3 install torchmetrics==1.0.3
!pip3 install torchrec --index-url https://download.pytorch.org/whl/cu121
注意
如果您在 Google Colab 中執行此程式碼,請確保切換到 GPU 執行階段類型。如需更多資訊,請參閱啟用 CUDA
嵌入 (Embeddings)¶
在建立推薦系統時,類別特徵通常具有大量的基數,例如貼文、使用者、廣告等等。
為了表示這些實體並對這些關係進行建模,會使用嵌入。在機器學習中,嵌入是高維空間中的實數向量,用於表示複雜資料中的意義,例如文字、圖像或使用者。
RecSys 中的嵌入¶
現在您可能想知道,這些嵌入最初是如何產生的? 嗯,嵌入表示為嵌入表中的個別列,也稱為嵌入權重。這樣做的原因是,嵌入或嵌入表權重就像模型的所有其他權重一樣,透過梯度下降進行訓練!
嵌入表只是一個用於儲存嵌入的大型矩陣,具有兩個維度 (B, N),其中
B 是表儲存的嵌入數量
N 是每個嵌入的維度數量(N 維嵌入)。
嵌入表的輸入表示嵌入查詢,用於檢索特定索引或列的嵌入。在推薦系統中,例如許多大型系統中使用的那些,唯一的 ID 不僅用於特定使用者,還用於貼文和廣告等實體,以作為各個嵌入表的查詢索引!
嵌入透過以下流程在 RecSys 中進行訓練
輸入/查詢索引會作為唯一 ID 輸入到模型中。ID 會雜湊到嵌入表的總大小,以防止 ID > 列數時出現問題
然後檢索嵌入並進行池化,例如取嵌入的總和或平均值。這是必需的,因為每個範例的嵌入數量可能不同,而模型期望一致的形狀。
嵌入會與模型的其餘部分結合使用以產生預測,例如廣告的點擊率 (CTR)。
損失是根據預測和範例的標籤計算的,並且模型的所有權重都會透過梯度下降和反向傳播進行更新,包括與範例相關聯的嵌入權重。
這些嵌入對於表示類別特徵(例如使用者、貼文和廣告)至關重要,以便捕獲關係並做出良好的推薦。深度學習推薦模型 (DLRM) 論文更詳細地討論了在 RecSys 中使用嵌入表的技術細節。
本教學介紹了嵌入的概念,展示了 TorchRec 特定的模組和資料類型,並描述了分散式訓練如何與 TorchRec 一起運作。
import torch
PyTorch 中的嵌入¶
在 PyTorch 中,我們有以下類型的嵌入
torch.nn.Embedding
:一個嵌入表,其中正向傳遞會按原樣傳回嵌入本身。torch.nn.EmbeddingBag
:一個嵌入表,其中正向傳遞會傳回然後池化的嵌入,例如總和或平均值,也稱為池化嵌入。
在本節中,我們將簡要介紹透過將索引傳遞到表中來執行嵌入查詢。
num_embeddings, embedding_dim = 10, 4
# Initialize our embedding table
weights = torch.rand(num_embeddings, embedding_dim)
print("Weights:", weights)
# Pass in pre-generated weights just for example, typically weights are randomly initialized
embedding_collection = torch.nn.Embedding(
num_embeddings, embedding_dim, _weight=weights
)
embedding_bag_collection = torch.nn.EmbeddingBag(
num_embeddings, embedding_dim, _weight=weights
)
# Print out the tables, we should see the same weights as above
print("Embedding Collection Table: ", embedding_collection.weight)
print("Embedding Bag Collection Table: ", embedding_bag_collection.weight)
# Lookup rows (ids for embedding ids) from the embedding tables
# 2D tensor with shape (batch_size, ids for each batch)
ids = torch.tensor([[1, 3]])
print("Input row IDS: ", ids)
embeddings = embedding_collection(ids)
# Print out the embedding lookups
# You should see the specific embeddings be the same as the rows (ids) of the embedding tables above
print("Embedding Collection Results: ")
print(embeddings)
print("Shape: ", embeddings.shape)
# ``nn.EmbeddingBag`` default pooling is mean, so should be mean of batch dimension of values above
pooled_embeddings = embedding_bag_collection(ids)
print("Embedding Bag Collection Results: ")
print(pooled_embeddings)
print("Shape: ", pooled_embeddings.shape)
# ``nn.EmbeddingBag`` is the same as ``nn.Embedding`` but just with pooling (mean, sum, and so on)
# We can see that the mean of the embeddings of embedding_collection is the same as the output of the embedding_bag_collection
print("Mean: ", torch.mean(embedding_collection(ids), dim=1))
Weights: tensor([[0.8823, 0.9150, 0.3829, 0.9593],
[0.3904, 0.6009, 0.2566, 0.7936],
[0.9408, 0.1332, 0.9346, 0.5936],
[0.8694, 0.5677, 0.7411, 0.4294],
[0.8854, 0.5739, 0.2666, 0.6274],
[0.2696, 0.4414, 0.2969, 0.8317],
[0.1053, 0.2695, 0.3588, 0.1994],
[0.5472, 0.0062, 0.9516, 0.0753],
[0.8860, 0.5832, 0.3376, 0.8090],
[0.5779, 0.9040, 0.5547, 0.3423]])
Embedding Collection Table: Parameter containing:
tensor([[0.8823, 0.9150, 0.3829, 0.9593],
[0.3904, 0.6009, 0.2566, 0.7936],
[0.9408, 0.1332, 0.9346, 0.5936],
[0.8694, 0.5677, 0.7411, 0.4294],
[0.8854, 0.5739, 0.2666, 0.6274],
[0.2696, 0.4414, 0.2969, 0.8317],
[0.1053, 0.2695, 0.3588, 0.1994],
[0.5472, 0.0062, 0.9516, 0.0753],
[0.8860, 0.5832, 0.3376, 0.8090],
[0.5779, 0.9040, 0.5547, 0.3423]], requires_grad=True)
Embedding Bag Collection Table: Parameter containing:
tensor([[0.8823, 0.9150, 0.3829, 0.9593],
[0.3904, 0.6009, 0.2566, 0.7936],
[0.9408, 0.1332, 0.9346, 0.5936],
[0.8694, 0.5677, 0.7411, 0.4294],
[0.8854, 0.5739, 0.2666, 0.6274],
[0.2696, 0.4414, 0.2969, 0.8317],
[0.1053, 0.2695, 0.3588, 0.1994],
[0.5472, 0.0062, 0.9516, 0.0753],
[0.8860, 0.5832, 0.3376, 0.8090],
[0.5779, 0.9040, 0.5547, 0.3423]], requires_grad=True)
Input row IDS: tensor([[1, 3]])
Embedding Collection Results:
tensor([[[0.3904, 0.6009, 0.2566, 0.7936],
[0.8694, 0.5677, 0.7411, 0.4294]]], grad_fn=<EmbeddingBackward0>)
Shape: torch.Size([1, 2, 4])
Embedding Bag Collection Results:
tensor([[0.6299, 0.5843, 0.4988, 0.6115]], grad_fn=<EmbeddingBagBackward0>)
Shape: torch.Size([1, 4])
Mean: tensor([[0.6299, 0.5843, 0.4988, 0.6115]], grad_fn=<MeanBackward1>)
恭喜! 現在您對如何使用嵌入表有了基本的了解 — 這是現代推薦系統的基礎之一! 這些表表示實體及其關係。 例如,給定使用者與他們喜歡的頁面和貼文之間的關係。
TorchRec 功能總覽¶
在以上章節中,我們學習了如何使用嵌入表,這是現代推薦系統的基礎之一!這些表代表實體和關係,例如使用者、頁面、貼文等等。 由於這些實體不斷增加,因此通常會應用雜湊 (hash) 函數,以確保 ID 在特定嵌入表的範圍內。 然而,為了表示大量的實體並減少雜湊碰撞,這些表可能會變得非常龐大 (想想廣告的數量)。 實際上,這些表可能會變得如此龐大,以至於即使有 80G 記憶體也無法容納在單一 GPU 上。
為了訓練具有龐大嵌入表的模型,需要在 GPU 之間對這些表進行分片 (sharding),這會引入一整套新的並行性和最佳化問題和機會。 幸運的是,我們擁有 TorchRec 函式庫,它已經遇到、整合和解決了許多這些問題。 TorchRec 是一個函式庫,提供大規模分散式嵌入的基本元件。
接下來,我們將探索 TorchRec 函式庫的主要功能。 我們將從 torch.nn.Embedding
開始,並將其擴展到自訂 TorchRec 模組,探索使用產生嵌入分片計畫的分散式訓練環境,查看 TorchRec 固有的最佳化,並擴展模型以準備在 C++ 中進行推論。 以下是本節的快速概述
TorchRec 模組和資料類型
分散式訓練、分片和最佳化
推論
讓我們從導入 TorchRec 開始
import torchrec
本節介紹 TorchRec 模組和資料類型,包括 EmbeddingCollection
和 EmbeddingBagCollection
、JaggedTensor
、KeyedJaggedTensor
、KeyedTensor
等實體。
從 EmbeddingBag
到 EmbeddingBagCollection
¶
我們已經探索了 torch.nn.Embedding
和 torch.nn.EmbeddingBag
。 TorchRec 通過建立嵌入集合來擴展這些模組,換句話說,模組可以具有多個嵌入表,使用 EmbeddingCollection
和 EmbeddingBagCollection
。 我們將使用 EmbeddingBagCollection
來表示一組嵌入包 (embedding bags)。
在下面的範例程式碼中,我們建立了一個具有兩個嵌入包的 EmbeddingBagCollection
(EBC),一個代表產品,另一個代表使用者。 每個表,product_table
和 user_table
,都由大小為 4096 的 64 維嵌入表示。
ebc = torchrec.EmbeddingBagCollection(
device="cpu",
tables=[
torchrec.EmbeddingBagConfig(
name="product_table",
embedding_dim=64,
num_embeddings=4096,
feature_names=["product"],
pooling=torchrec.PoolingType.SUM,
),
torchrec.EmbeddingBagConfig(
name="user_table",
embedding_dim=64,
num_embeddings=4096,
feature_names=["user"],
pooling=torchrec.PoolingType.SUM,
)
]
)
print(ebc.embedding_bags)
ModuleDict(
(product_table): EmbeddingBag(4096, 64, mode='sum')
(user_table): EmbeddingBag(4096, 64, mode='sum')
)
讓我們檢查 EmbeddingBagCollection
的前向 (forward) 方法,以及模組的輸入和輸出
import inspect
# Let's look at the ``EmbeddingBagCollection`` forward method
# What is a ``KeyedJaggedTensor`` and ``KeyedTensor``?
print(inspect.getsource(ebc.forward))
def forward(self, features: KeyedJaggedTensor) -> KeyedTensor:
"""
Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor`
and returns a `KeyedTensor`, which is the result of pooling the embeddings for each feature.
Args:
features (KeyedJaggedTensor): Input KJT
Returns:
KeyedTensor
"""
flat_feature_names: List[str] = []
for names in self._feature_names:
flat_feature_names.extend(names)
inverse_indices = reorder_inverse_indices(
inverse_indices=features.inverse_indices_or_none(),
feature_names=flat_feature_names,
)
pooled_embeddings: List[torch.Tensor] = []
feature_dict = features.to_dict()
for i, embedding_bag in enumerate(self.embedding_bags.values()):
for feature_name in self._feature_names[i]:
f = feature_dict[feature_name]
res = embedding_bag(
input=f.values(),
offsets=f.offsets(),
per_sample_weights=f.weights() if self._is_weighted else None,
).float()
pooled_embeddings.append(res)
return KeyedTensor(
keys=self._embedding_names,
values=process_pooled_embeddings(
pooled_embeddings=pooled_embeddings,
inverse_indices=inverse_indices,
),
length_per_key=self._lengths_per_embedding,
)
TorchRec 輸入/輸出資料類型¶
TorchRec 具有用於其模組輸入和輸出的不同資料類型:JaggedTensor
、KeyedJaggedTensor
和 KeyedTensor
。 現在您可能會問,為什麼要建立新的資料類型來表示稀疏特徵? 為了回答這個問題,我們必須了解如何在程式碼中表示稀疏特徵。
稀疏特徵也稱為 id_list_feature
和 id_score_list_feature
,它們是將用作嵌入表索引的 ID,以檢索該 ID 的嵌入。 舉一個非常簡單的例子,想像一下一個單一的稀疏特徵是使用者互動過的廣告。 輸入本身將是使用者互動過的一組廣告 ID,而檢索到的嵌入將是這些廣告的語義表示。 在程式碼中表示這些特徵的棘手之處在於,在每個輸入範例中,ID 的數量是可變的。 有一天,使用者可能只與一個廣告互動,而第二天他們可能與三個廣告互動。
下面顯示了一個簡單的表示,我們有一個 lengths
張量,表示批次中一個範例中有多少個索引,以及一個包含索引本身的 values
張量。
# Batch Size 2
# 1 ID in example 1, 2 IDs in example 2
id_list_feature_lengths = torch.tensor([1, 2])
# Values (IDs) tensor: ID 5 is in example 1, ID 7, 1 is in example 2
id_list_feature_values = torch.tensor([5, 7, 1])
接下來,讓我們看看偏移量 (offsets),以及每個批次中包含的内容
# Lengths can be converted to offsets for easy indexing of values
id_list_feature_offsets = torch.cumsum(id_list_feature_lengths, dim=0)
print("Offsets: ", id_list_feature_offsets)
print("First Batch: ", id_list_feature_values[: id_list_feature_offsets[0]])
print(
"Second Batch: ",
id_list_feature_values[id_list_feature_offsets[0] : id_list_feature_offsets[1]],
)
from torchrec import JaggedTensor
# ``JaggedTensor`` is just a wrapper around lengths/offsets and values tensors!
jt = JaggedTensor(values=id_list_feature_values, lengths=id_list_feature_lengths)
# Automatically compute offsets from lengths
print("Offsets: ", jt.offsets())
# Convert to list of values
print("List of Values: ", jt.to_dense())
# ``__str__`` representation
print(jt)
from torchrec import KeyedJaggedTensor
# ``JaggedTensor`` represents IDs for 1 feature, but we have multiple features in an ``EmbeddingBagCollection``
# That's where ``KeyedJaggedTensor`` comes in! ``KeyedJaggedTensor`` is just multiple ``JaggedTensors`` for multiple id_list_feature_offsets
# From before, we have our two features "product" and "user". Let's create ``JaggedTensors`` for both!
product_jt = JaggedTensor(
values=torch.tensor([1, 2, 1, 5]), lengths=torch.tensor([3, 1])
)
user_jt = JaggedTensor(values=torch.tensor([2, 3, 4, 1]), lengths=torch.tensor([2, 2]))
# Q1: How many batches are there, and which values are in the first batch for ``product_jt`` and ``user_jt``?
kjt = KeyedJaggedTensor.from_jt_dict({"product": product_jt, "user": user_jt})
# Look at our feature keys for the ``KeyedJaggedTensor``
print("Keys: ", kjt.keys())
# Look at the overall lengths for the ``KeyedJaggedTensor``
print("Lengths: ", kjt.lengths())
# Look at all values for ``KeyedJaggedTensor``
print("Values: ", kjt.values())
# Can convert ``KeyedJaggedTensor`` to dictionary representation
print("to_dict: ", kjt.to_dict())
# ``KeyedJaggedTensor`` string representation
print(kjt)
# Q2: What are the offsets for the ``KeyedJaggedTensor``?
# Now we can run a forward pass on our ``EmbeddingBagCollection`` from before
result = ebc(kjt)
result
# Result is a ``KeyedTensor``, which contains a list of the feature names and the embedding results
print(result.keys())
# The results shape is [2, 128], as batch size of 2. Reread previous section if you need a refresher on how the batch size is determined
# 128 for dimension of embedding. If you look at where we initialized the ``EmbeddingBagCollection``, we have two tables "product" and "user" of dimension 64 each
# meaning embeddings for both features are of size 64. 64 + 64 = 128
print(result.values().shape)
# Nice to_dict method to determine the embeddings that belong to each feature
result_dict = result.to_dict()
for key, embedding in result_dict.items():
print(key, embedding.shape)
Offsets: tensor([1, 3])
First Batch: tensor([5])
Second Batch: tensor([7, 1])
Offsets: tensor([0, 1, 3])
List of Values: [tensor([5]), tensor([7, 1])]
JaggedTensor({
[[5], [7, 1]]
})
Keys: ['product', 'user']
Lengths: tensor([3, 1, 2, 2])
Values: tensor([1, 2, 1, 5, 2, 3, 4, 1])
to_dict: {'product': <torchrec.sparse.jagged_tensor.JaggedTensor object at 0x7f377446ee00>, 'user': <torchrec.sparse.jagged_tensor.JaggedTensor object at 0x7f36958db5e0>}
KeyedJaggedTensor({
"product": [[1, 2, 1], [5]],
"user": [[2, 3], [4, 1]]
})
['product', 'user']
torch.Size([2, 128])
product torch.Size([2, 64])
user torch.Size([2, 64])
恭喜! 您現在了解 TorchRec 模組和資料類型。 恭喜您成功完成了這一步。 接下來,我們將學習分散式訓練和分片。
分散式訓練和分片¶
既然我們已經掌握了 TorchRec 模組和資料類型,現在是時候將其提升到一個新的層次了。
請記住,TorchRec 的主要目的是為分散式嵌入提供基本元件。 到目前為止,我們只在單個設備上使用嵌入表。 由於嵌入表一直很小,這才成為可能,但在生產環境中,通常情況並非如此。 嵌入表通常變得非常龐大,以至於一個表無法容納在單個 GPU 上,這就產生了對多個設備和分散式環境的需求。
在本節中,我們將探索如何設定分散式環境,以及實際的生產訓練是如何完成的,以及探索如何使用 TorchRec 對嵌入表進行分片。
本節也只會使用 1 個 GPU,儘管它將以分散式方式處理。 這僅僅是訓練的限制,因為訓練每個 GPU 都有一個進程 (process)。 推論沒有遇到這種要求
在下面的範例程式碼中,我們設定了 PyTorch 分散式環境。
警告
如果您在 Google Colab 中執行此程式碼,您只能呼叫此儲存格 (cell) 一次,再次呼叫將導致錯誤,因為您只能初始化進程組一次。
import os
import torch.distributed as dist
# Set up environment variables for distributed training
# RANK is which GPU we are on, default 0
os.environ["RANK"] = "0"
# How many devices in our "world", colab notebook can only handle 1 process
os.environ["WORLD_SIZE"] = "1"
# Localhost as we are training locally
os.environ["MASTER_ADDR"] = "localhost"
# Port for distributed training
os.environ["MASTER_PORT"] = "29500"
# nccl backend is for GPUs, gloo is for CPUs
dist.init_process_group(backend="gloo")
print(f"Distributed environment initialized: {dist}")
Distributed environment initialized: <module 'torch.distributed' from '/usr/local/lib/python3.10/dist-packages/torch/distributed/__init__.py'>
分散式嵌入¶
我們已經使用了主要的 TorchRec 模組:EmbeddingBagCollection
。 我們已經檢查了它的工作原理以及資料如何在 TorchRec 中表示。 但是,我們尚未探索 TorchRec 的主要部分之一,即分散式嵌入。
到目前為止,GPU 是 ML 工作負載最受歡迎的選擇,因為它們能夠比 CPU 執行更多數量的浮點運算/秒 (FLOPs)。 然而,GPU 存在快速記憶體 (HBM,類似於 CPU 的 RAM) 稀缺的限制,通常為約 10 多 GB。
RecSys 模型可能包含遠遠超出 1 個 GPU 記憶體限制的嵌入表,因此需要跨多個 GPU 分散嵌入表,也稱為模型並行。 另一方面,資料並行是指在每個 GPU 上複製整個模型,每個 GPU 採用不同的資料批次進行訓練,並在反向傳播 (backwards pass) 上同步梯度。
模型中需要較少計算資源但需要較多記憶體(嵌入層)的部分採用模型並行(model parallel)分佈,而需要較多計算資源但需要較少記憶體(密集層、MLP 等)的部分採用資料並行(data parallel)分佈。
分片 (Sharding)¶
為了分佈嵌入表(embedding table),我們將嵌入表分割成多個部分,並將這些部分放置在不同的裝置上,這也稱為「分片」。
有很多種分片嵌入表的方法。最常見的方法是:
表級分片(Table-Wise):整個表都放置在一個裝置上
列級分片(Column-Wise):嵌入表的列被分片
行級分片(Row-Wise):嵌入表的行被分片
分片模組 (Sharded Modules)¶
雖然這一切看起來需要處理和實作很多東西,但您很幸運。TorchRec 提供了所有用於輕鬆進行分散式訓練和推論的基本元件!事實上,TorchRec 模組有兩個對應的類別,可用於在分散式環境中使用任何 TorchRec 模組
模組分片器 (module sharder):此類別公開了一個
shard
API,用於處理 TorchRec 模組的分片,產生一個分片模組。* 對於EmbeddingBagCollection
,分片器是 EmbeddingBagCollectionSharder分片模組 (sharded module):此類別是 TorchRec 模組的分片變體。它具有與常規 TorchRec 模組相同的輸入/輸出,但經過更多最佳化,並且可以在分散式環境中工作。* 對於
EmbeddingBagCollection
,分片變體是 ShardedEmbeddingBagCollection
每個 TorchRec 模組都有一個未分片和一個分片變體。
未分片版本旨在用於原型設計和實驗。
分片版本旨在用於分散式環境中進行分散式訓練和推論。
TorchRec 模組的分片版本,例如 EmbeddingBagCollection
,將處理模型並行所需的一切,例如 GPU 之間的通訊,以便將嵌入層分佈到正確的 GPU。
復習一下我們的 EmbeddingBagCollection
模組
ebc
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.types import ShardingEnv
# Corresponding sharder for ``EmbeddingBagCollection`` module
sharder = EmbeddingBagCollectionSharder()
# ``ProcessGroup`` from torch.distributed initialized 2 cells above
pg = dist.GroupMember.WORLD
assert pg is not None, "Process group is not initialized"
print(f"Process Group: {pg}")
Process Group: <torch.distributed.distributed_c10d.ProcessGroup object at 0x7f370ccd1130>
規劃器 (Planner)¶
在我們展示分片如何工作之前,我們必須了解規劃器,它有助於我們確定最佳的分片配置。
給定多個嵌入表和多個 rank,可能存在許多不同的分片配置。例如,給定 2 個嵌入表和 2 個 GPU,您可以:
在每個 GPU 上放置 1 個表
將兩個表都放置在單個 GPU 上,而在另一個 GPU 上不放置任何表
將某些行和列放置在每個 GPU 上
考慮到所有這些可能性,我們通常希望獲得一種對效能而言最佳的分片配置。
這就是規劃器的作用。規劃器能夠在給定嵌入表的數量和 GPU 的數量的情況下,確定最佳配置。事實證明,手動執行此操作非常困難,工程師必須考慮大量因素才能確保最佳的分片計劃。幸運的是,當使用規劃器時,TorchRec 提供了自動規劃器。
TorchRec 規劃器:
評估硬體的記憶體限制
根據記憶體提取(作為嵌入層查找)來估算計算量
解決特定於資料的因素
考慮其他硬體細節(例如頻寬)以產生最佳分片計劃
為了考慮所有這些變數,TorchRec 規劃器可以接收 各種數量的資料,包括嵌入表、約束、硬體資訊和拓樸結構,以協助為模型產生最佳分片計劃,這些計劃通常在各個堆疊層中提供。
要了解有關分片的更多資訊,請參閱我們的 分片教學。
# In our case, 1 GPU and compute on CUDA device
planner = EmbeddingShardingPlanner(
topology=Topology(
world_size=1,
compute_device="cuda",
)
)
# Run planner to get plan for sharding
plan = planner.collective_plan(ebc, [sharder], pg)
print(f"Sharding Plan generated: {plan}")
Sharding Plan generated: module:
param | sharding type | compute kernel | ranks
------------- | ------------- | -------------- | -----
product_table | table_wise | fused | [0]
user_table | table_wise | fused | [0]
param | shard offsets | shard sizes | placement
------------- | ------------- | ----------- | -------------
product_table | [0, 0] | [4096, 64] | rank:0/cuda:0
user_table | [0, 0] | [4096, 64] | rank:0/cuda:0
規劃器結果 (Planner Result)¶
如您在上面看到的,在執行規劃器時會產生大量的輸出。我們可以查看許多正在計算的統計資料,以及我們的表最終放置的位置。
執行規劃器的結果是一個靜態計劃,可以重複用於分片!這使得分片對於生產模型可以是靜態的,而不是每次都確定一個新的分片計劃。在下面,我們使用分片計劃來最終產生我們的 ShardedEmbeddingBagCollection
。
# The static plan that was generated
plan
env = ShardingEnv.from_process_group(pg)
# Shard the ``EmbeddingBagCollection`` module using the ``EmbeddingBagCollectionSharder``
sharded_ebc = sharder.shard(ebc, plan.plan[""], env, torch.device("cuda"))
print(f"Sharded EBC Module: {sharded_ebc}")
Sharded EBC Module: ShardedEmbeddingBagCollection(
(lookups):
GroupedPooledEmbeddingsLookup(
(_emb_modules): ModuleList(
(0): BatchedFusedEmbeddingBag(
(_emb_module): SplitTableBatchedEmbeddingBagsCodegen()
)
)
)
(_output_dists):
TwPooledEmbeddingDist()
(embedding_bags): ModuleDict(
(product_table): Module()
(user_table): Module()
)
)
使用 LazyAwaitable
進行 GPU 訓練¶
請記住,TorchRec 是一個高度最佳化的分散式嵌入層函式庫。 TorchRec 引入了一個概念,可以為 GPU 上的訓練提供更高的效能,即 LazyAwaitable。您將看到 LazyAwaitable
類型作為各種分片 TorchRec 模組的輸出。 LazyAwaitable
類型所做的只是盡可能延遲計算某些結果,並且它通過充當異步類型來實現這一點。
from typing import List
from torchrec.distributed.types import LazyAwaitable
# Demonstrate a ``LazyAwaitable`` type:
class ExampleAwaitable(LazyAwaitable[torch.Tensor]):
def __init__(self, size: List[int]) -> None:
super().__init__()
self._size = size
def _wait_impl(self) -> torch.Tensor:
return torch.ones(self._size)
awaitable = ExampleAwaitable([3, 2])
awaitable.wait()
kjt = kjt.to("cuda")
output = sharded_ebc(kjt)
# The output of our sharded ``EmbeddingBagCollection`` module is an `Awaitable`?
print(output)
kt = output.wait()
# Now we have our ``KeyedTensor`` after calling ``.wait()``
# If you are confused as to why we have a ``KeyedTensor ``output,
# give yourself a refresher on the unsharded ``EmbeddingBagCollection`` module
print(type(kt))
print(kt.keys())
print(kt.values().shape)
# Same output format as unsharded ``EmbeddingBagCollection``
result_dict = kt.to_dict()
for key, embedding in result_dict.items():
print(key, embedding.shape)
<torchrec.distributed.embeddingbag.EmbeddingBagCollectionAwaitable object at 0x7f36956370d0>
<class 'torchrec.sparse.jagged_tensor.KeyedTensor'>
['product', 'user']
torch.Size([2, 128])
product torch.Size([2, 64])
user torch.Size([2, 64])
分片 TorchRec 模組的結構¶
我們現在已根據我們產生的分片計劃成功地分片了一個 EmbeddingBagCollection
!分片模組具有來自 TorchRec 的常見 API,這些 API 抽象了多個 GPU 之間的分散式通訊/計算。事實上,這些 API 經過高度最佳化,可在訓練和推論中實現高效能。 以下是 TorchRec 提供的三個用於分散式訓練/推論的常見 API
input_dist
:處理將輸入從 GPU 分佈到 GPU。lookups
:使用 FBGEMM TBE 以最佳化的批次方式執行實際的嵌入層查找(稍後會詳細介紹)。output_dist
:處理將輸出從 GPU 分佈到 GPU。
輸入和輸出的分佈通過 NCCL Collectives 完成,尤其是 All-to-Alls,這是所有 GPU 彼此發送和接收資料的地方。 TorchRec 與 PyTorch 分散式介面連接以進行 collectives,並為最終使用者提供清晰的抽象,從而消除了對較低層級細節的關注。
反向傳播會執行所有這些集合運算,但為了分配梯度,會以相反的順序執行。 input_dist
、 lookup
和 output_dist
都取決於分片方案。由於我們以表為單位進行分片,因此這些 API 是由 TwPooledEmbeddingSharding 構建的模組。
sharded_ebc
# Distribute input KJTs to all other GPUs and receive KJTs
sharded_ebc._input_dists
# Distribute output embeddings to all other GPUs and receive embeddings
sharded_ebc._output_dists
[TwPooledEmbeddingDist(
(_dist): PooledEmbeddingsAllToAll()
)]
優化嵌入查詢¶
在對嵌入表的集合執行查詢時,一個簡單的解決方案是迭代所有 nn.EmbeddingBags
並對每個表進行查詢。這正是標準的、未分片的 EmbeddingBagCollection
所做的事情。然而,雖然這個解決方案很簡單,但它非常慢。
FBGEMM 是一個提供 GPU 運算子(也稱為核心)的函式庫,這些運算子經過高度優化。其中一個運算子稱為 Table Batched Embedding (TBE),它提供兩個主要的優化
表批次處理,允許您使用一個核心呼叫查詢多個嵌入。
優化器融合,允許模組根據規範的 PyTorch 優化器和參數自行更新。
ShardedEmbeddingBagCollection
使用 FBGEMM TBE 作為查詢,而不是傳統的 nn.EmbeddingBags
,以實現優化的嵌入查詢。
sharded_ebc._lookups
[GroupedPooledEmbeddingsLookup(
(_emb_modules): ModuleList(
(0): BatchedFusedEmbeddingBag(
(_emb_module): SplitTableBatchedEmbeddingBagsCodegen()
)
)
)]
DistributedModelParallel
¶
我們現在已經探索了如何分片單個 EmbeddingBagCollection
!我們能夠使用 EmbeddingBagCollectionSharder
並使用未分片的 EmbeddingBagCollection
來產生 ShardedEmbeddingBagCollection
模組。這種工作流程沒有問題,但通常在實作模型並行時,DistributedModelParallel (DMP) 被用作標準介面。當使用 DMP 包裝您的模型(在我們的例子中是 ebc
)時,將會發生以下情況
決定如何分片模型。 DMP 將收集可用的分片器,並提出一個最佳的分片嵌入表的方法(例如,
EmbeddingBagCollection
)實際分片模型。這包括在適當的裝置上為每個嵌入表分配記憶體。
DMP 接受我們剛剛實驗過的一切,例如靜態分片計畫、分片器列表等。但是,它也有一些不錯的預設值,可以無縫地分片 TorchRec 模型。在這個玩具例子中,由於我們有兩個嵌入表和一個 GPU,TorchRec 會將它們都放在單個 GPU 上。
ebc
model = torchrec.distributed.DistributedModelParallel(ebc, device=torch.device("cuda"))
out = model(kjt)
out.wait()
model
DistributedModelParallel(
(_dmp_wrapped_module): ShardedEmbeddingBagCollection(
(lookups):
GroupedPooledEmbeddingsLookup(
(_emb_modules): ModuleList(
(0): BatchedFusedEmbeddingBag(
(_emb_module): SplitTableBatchedEmbeddingBagsCodegen()
)
)
)
(_input_dists):
TwSparseFeaturesDist(
(_dist): KJTAllToAll()
)
(_output_dists):
TwPooledEmbeddingDist(
(_dist): PooledEmbeddingsAllToAll()
)
(embedding_bags): ModuleDict(
(product_table): Module()
(user_table): Module()
)
)
)
分片最佳實踐¶
目前,我們的配置僅在 1 個 GPU(或 rank)上進行分片,這很簡單:只需將所有表放在 1 個 GPU 記憶體中即可。然而,在實際的生產用例中,嵌入表通常在數百個 GPU 上進行分片,具有不同的分片方法,例如以表為單位、以行為單位和以列為單位。確定正確的分片配置非常重要(以防止記憶體不足的問題),同時保持記憶體和計算的平衡,以實現最佳效能。
加入優化器¶
請記住,TorchRec 模組針對大規模分散式訓練進行了高度優化。一個重要的優化與優化器有關。
TorchRec 模組提供了一個無縫的 API,可以在訓練中融合反向傳播和優化步驟,從而顯著優化效能並減少使用的記憶體,同時可以細粒度地將不同的優化器分配給不同的模型參數。
優化器類別¶
TorchRec 使用 CombinedOptimizer
,其中包含 KeyedOptimizers
的集合。CombinedOptimizer
有效地使處理模型中各種子群的多個優化器變得容易。KeyedOptimizer
擴展了 torch.optim.Optimizer
並且通過參數字典初始化,從而公開參數。 EmbeddingBagCollection
中的每個 TBE
模組都有自己的 KeyedOptimizer
,它組合到一個 CombinedOptimizer
中。
TorchRec 中的融合優化器¶
使用 DistributedModelParallel
,優化器被融合,這意味著優化器更新是在反向傳播中完成的。這是 TorchRec 和 FBGEMM 中的一項優化,其中優化器嵌入梯度不會被具體化,而是直接應用於參數。這帶來了顯著的記憶體節省,因為嵌入梯度通常與參數本身的大小相同。
但是,您可以選擇使優化器 dense
,它不應用此優化,並且讓您檢查嵌入梯度或根據需要對其應用計算。在這種情況下,密集優化器將是您的 具有 optimizer 的規範 PyTorch 模型訓練迴圈。
一旦通過 DistributedModelParallel
創建了優化器,您仍然需要管理未與 TorchRec 嵌入模組相關聯的其他參數的優化器。要查找其他參數,請使用 in_backward_optimizer_filter(model.named_parameters())
。像對待普通的 Torch 優化器一樣,將優化器應用於這些參數,並將此優化器和 model.fused_optimizer
組合到一個 CombinedOptimizer
中,您可以在訓練迴圈中使用它來 zero_grad
並 step
完成訓練。
向 EmbeddingBagCollection
添加優化器¶
我們將通過兩種方式來做到這一點,它們是等效的,但會根據您的偏好為您提供選擇
通過分片器中的
fused_params
傳遞優化器 kwargs。通過
apply_optimizer_in_backward
,它將優化器參數轉換為fused_params
以傳遞給EmbeddingBagCollection
或EmbeddingCollection
中的TBE
。
# Option 1: Passing optimizer kwargs through fused parameters
from torchrec.optim.optimizers import in_backward_optimizer_filter
from fbgemm_gpu.split_embedding_configs import EmbOptimType
# We initialize the sharder with
fused_params = {
"optimizer": EmbOptimType.EXACT_ROWWISE_ADAGRAD,
"learning_rate": 0.02,
"eps": 0.002,
}
# Initialize sharder with ``fused_params``
sharder_with_fused_params = EmbeddingBagCollectionSharder(fused_params=fused_params)
# We'll use same plan and unsharded EBC as before but this time with our new sharder
sharded_ebc_fused_params = sharder_with_fused_params.shard(ebc, plan.plan[""], env, torch.device("cuda"))
# Looking at the optimizer of each, we can see that the learning rate changed, which indicates our optimizer has been applied correctly.
# If seen, we can also look at the TBE logs of the cell to see that our new optimizer is indeed being applied
print(f"Original Sharded EBC fused optimizer: {sharded_ebc.fused_optimizer}")
print(f"Sharded EBC with fused parameters fused optimizer: {sharded_ebc_fused_params.fused_optimizer}")
print(f"Type of optimizer: {type(sharded_ebc_fused_params.fused_optimizer)}")
from torch.distributed.optim import _apply_optimizer_in_backward as apply_optimizer_in_backward
import copy
# Option 2: Applying optimizer through apply_optimizer_in_backward
# Note: we need to call apply_optimizer_in_backward on unsharded model first and then shard it
# We can achieve the same result as we did in the previous
ebc_apply_opt = copy.deepcopy(ebc)
optimizer_kwargs = {"lr": 0.5}
for name, param in ebc_apply_opt.named_parameters():
print(f"{name=}")
apply_optimizer_in_backward(torch.optim.SGD, [param], optimizer_kwargs)
sharded_ebc_apply_opt = sharder.shard(ebc_apply_opt, plan.plan[""], env, torch.device("cuda"))
# Now when we print the optimizer, we will see our new learning rate, you can verify momentum through the TBE logs as well if outputted
print(sharded_ebc_apply_opt.fused_optimizer)
print(type(sharded_ebc_apply_opt.fused_optimizer))
# We can also check through the filter other parameters that aren't associated with the "fused" optimizer(s)
# Practically, just non TorchRec module parameters. Since our module is just a TorchRec EBC
# there are no other parameters that aren't associated with TorchRec
print("Non Fused Model Parameters:")
print(dict(in_backward_optimizer_filter(sharded_ebc_fused_params.named_parameters())).keys())
# Here we do a dummy backwards call and see that parameter updates for fused
# optimizers happen as a result of the backward pass
ebc_output = sharded_ebc_fused_params(kjt).wait().values()
loss = torch.sum(torch.ones_like(ebc_output) - ebc_output)
print(f"First Iteration Loss: {loss}")
loss.backward()
ebc_output = sharded_ebc_fused_params(kjt).wait().values()
loss = torch.sum(torch.ones_like(ebc_output) - ebc_output)
# We don't call an optimizer.step(), so for the loss to have changed here,
# that means that the gradients were somehow updated, which is what the
# fused optimizer automatically handles for us
print(f"Second Iteration Loss: {loss}")
Original Sharded EBC fused optimizer: : EmbeddingFusedOptimizer (
Parameter Group 0
lr: 0.01
)
Sharded EBC with fused parameters fused optimizer: : EmbeddingFusedOptimizer (
Parameter Group 0
lr: 0.02
)
Type of optimizer: <class 'torchrec.optim.keyed.CombinedOptimizer'>
name='embedding_bags.product_table.weight'
name='embedding_bags.user_table.weight'
: EmbeddingFusedOptimizer (
Parameter Group 0
lr: 0.5
)
<class 'torchrec.optim.keyed.CombinedOptimizer'>
Non Fused Model Parameters:
dict_keys([])
First Iteration Loss: 255.66006469726562
Second Iteration Loss: 245.43795776367188
推論¶
現在我們能夠訓練分散式嵌入,我們如何才能採用訓練好的模型並對其進行優化以進行推論?推論通常對模型的效能和大小非常敏感。僅在 Python 環境中運行訓練好的模型效率極低。推論環境和訓練環境之間有兩個關鍵區別
量化 (Quantization):推論模型通常會經過量化,模型參數會降低精度,以降低預測的延遲並縮小模型大小。例如,將訓練模型中的 FP32 (4 位元組) 轉換為每個嵌入權重的 INT8 (1 位元組)。由於嵌入表規模龐大,因此這也是必要的,因為我們希望使用盡可能少的設備進行推論,以最大限度地減少延遲。
C++ 環境:推論延遲非常重要,因此為了確保充足的效能,模型通常在 C++ 環境中執行,以及在沒有 Python 執行環境的情況下 (例如在裝置上)。
TorchRec 提供將 TorchRec 模型轉換為可進行推論的原始元件,包含:
用於量化模型的 API,並透過 FBGEMM TBE 自動引入最佳化
用於分散式推論的分片嵌入 (Sharding embeddings)
將模型編譯為 TorchScript (與 C++ 相容)
在本節中,我們將介紹以下完整工作流程:
量化模型
分片量化模型
將分片量化模型編譯為 TorchScript
ebc
class InferenceModule(torch.nn.Module):
def __init__(self, ebc: torchrec.EmbeddingBagCollection):
super().__init__()
self.ebc_ = ebc
def forward(self, kjt: KeyedJaggedTensor):
return self.ebc_(kjt)
module = InferenceModule(ebc)
for name, param in module.named_parameters():
# Here, the parameters should still be FP32, as we are using a standard EBC
# FP32 is default, regularly used for training
print(name, param.shape, param.dtype)
ebc_.embedding_bags.product_table.weight torch.Size([4096, 64]) torch.float32
ebc_.embedding_bags.user_table.weight torch.Size([4096, 64]) torch.float32
量化¶
如上所示,正常的 EBC 包含具有 FP32 精度 (每個權重 32 位元) 的嵌入表權重。在這裡,我們將使用 TorchRec 推論函式庫將模型的嵌入權重量化為 INT8。
from torch import quantization as quant
from torchrec.modules.embedding_configs import QuantConfig
from torchrec.quant.embedding_modules import (
EmbeddingBagCollection as QuantEmbeddingBagCollection,
)
quant_dtype = torch.int8
qconfig = QuantConfig(
# dtype of the result of the embedding lookup, post activation
# torch.float generally for compatibility with rest of the model
# as rest of the model here usually isn't quantized
activation=quant.PlaceholderObserver.with_args(dtype=torch.float),
# quantized type for embedding weights, aka parameters to actually quantize
weight=quant.PlaceholderObserver.with_args(dtype=quant_dtype),
)
qconfig_spec = {
# Map of module type to qconfig
torchrec.EmbeddingBagCollection: qconfig,
}
mapping = {
# Map of module type to quantized module type
torchrec.EmbeddingBagCollection: QuantEmbeddingBagCollection,
}
module = InferenceModule(ebc)
# Quantize the module
qebc = quant.quantize_dynamic(
module,
qconfig_spec=qconfig_spec,
mapping=mapping,
inplace=False,
)
print(f"Quantized EBC: {qebc}")
kjt = kjt.to("cpu")
qebc(kjt)
# Once quantized, goes from parameters -> buffers, as no longer trainable
for name, buffer in qebc.named_buffers():
# The shapes of the tables should be the same but the dtype should be int8 now
# post quantization
print(name, buffer.shape, buffer.dtype)
Quantized EBC: InferenceModule(
(ebc_): QuantizedEmbeddingBagCollection(
(_kjt_to_jt_dict): ComputeKJTToJTDict()
(embedding_bags): ModuleDict(
(product_table): Module()
(user_table): Module()
)
)
)
ebc_.embedding_bags.product_table.weight torch.Size([4096, 80]) torch.uint8
ebc_.embedding_bags.user_table.weight torch.Size([4096, 80]) torch.uint8
分片¶
在這裡,我們執行 TorchRec 量化模型的分片。這是為了確保我們透過 FBGEMM TBE 使用高效能的模組。這裡我們使用一個裝置,以與訓練保持一致 (1 個 TBE)。
from torchrec import distributed as trec_dist
from torchrec.distributed.shard import _shard_modules
sharded_qebc = _shard_modules(
module=qebc,
device=torch.device("cpu"),
env=trec_dist.ShardingEnv.from_local(
1,
0,
),
)
print(f"Sharded Quantized EBC: {sharded_qebc}")
sharded_qebc(kjt)
Sharded Quantized EBC: InferenceModule(
(ebc_): ShardedQuantEmbeddingBagCollection(
(lookups):
InferGroupedPooledEmbeddingsLookup()
(_output_dists): ModuleList()
(embedding_bags): ModuleDict(
(product_table): Module()
(user_table): Module()
)
(_input_dist_module): ShardedQuantEbcInputDist()
)
)
<torchrec.sparse.jagged_tensor.KeyedTensor object at 0x7f3695878310>
編譯¶
現在我們有了最佳化的 eager TorchRec 推論模型。下一步是確保此模型可以在 C++ 中載入,因為目前它只能在 Python 執行環境中執行。
Meta 建議的編譯方法有兩個:torch.fx tracing (產生模型的中間表示) 並將結果轉換為 TorchScript,其中 TorchScript 與 C++ 相容。
from torchrec.fx import Tracer
tracer = Tracer(leaf_modules=["IntNBitTableBatchedEmbeddingBagsCodegen"])
graph = tracer.trace(sharded_qebc)
gm = torch.fx.GraphModule(sharded_qebc, graph)
print("Graph Module Created!")
print(gm.code)
scripted_gm = torch.jit.script(gm)
print("Scripted Graph Module Created!")
print(scripted_gm.code)
Graph Module Created!
torch.fx._symbolic_trace.wrap("torchrec_distributed_quant_embeddingbag_flatten_feature_lengths")
torch.fx._symbolic_trace.wrap("torchrec_fx_utils__fx_marker")
torch.fx._symbolic_trace.wrap("torchrec_distributed_quant_embedding_kernel__unwrap_kjt")
torch.fx._symbolic_trace.wrap("fbgemm_gpu_split_table_batched_embeddings_ops_inference_inputs_to_device")
torch.fx._symbolic_trace.wrap("torchrec_distributed_embedding_lookup_embeddings_cat_empty_rank_handle_inference")
def forward(self, kjt : torchrec_sparse_jagged_tensor_KeyedJaggedTensor):
flatten_feature_lengths = torchrec_distributed_quant_embeddingbag_flatten_feature_lengths(kjt); kjt = None
_fx_marker = torchrec_fx_utils__fx_marker('KJT_ONE_TO_ALL_FORWARD_BEGIN', flatten_feature_lengths); _fx_marker = None
split = flatten_feature_lengths.split([2])
getitem = split[0]; split = None
to = getitem.to(device(type='cuda', index=0), non_blocking = True); getitem = None
_fx_marker_1 = torchrec_fx_utils__fx_marker('KJT_ONE_TO_ALL_FORWARD_END', flatten_feature_lengths); flatten_feature_lengths = _fx_marker_1 = None
_unwrap_kjt = torchrec_distributed_quant_embedding_kernel__unwrap_kjt(to); to = None
getitem_1 = _unwrap_kjt[0]
getitem_2 = _unwrap_kjt[1]
getitem_3 = _unwrap_kjt[2]; _unwrap_kjt = getitem_3 = None
inputs_to_device = fbgemm_gpu_split_table_batched_embeddings_ops_inference_inputs_to_device(getitem_1, getitem_2, None, device(type='cuda', index=0)); getitem_1 = getitem_2 = None
getitem_4 = inputs_to_device[0]
getitem_5 = inputs_to_device[1]
getitem_6 = inputs_to_device[2]; inputs_to_device = None
_tensor_constant0 = self._tensor_constant0
_tensor_constant1 = self._tensor_constant1
bounds_check_indices = torch.ops.fbgemm.bounds_check_indices(_tensor_constant0, getitem_4, getitem_5, 1, _tensor_constant1, getitem_6); _tensor_constant0 = _tensor_constant1 = bounds_check_indices = None
_tensor_constant2 = self._tensor_constant2
_tensor_constant3 = self._tensor_constant3
_tensor_constant4 = self._tensor_constant4
_tensor_constant5 = self._tensor_constant5
_tensor_constant6 = self._tensor_constant6
_tensor_constant7 = self._tensor_constant7
_tensor_constant8 = self._tensor_constant8
_tensor_constant9 = self._tensor_constant9
int_nbit_split_embedding_codegen_lookup_function = torch.ops.fbgemm.int_nbit_split_embedding_codegen_lookup_function(dev_weights = _tensor_constant2, uvm_weights = _tensor_constant3, weights_placements = _tensor_constant4, weights_offsets = _tensor_constant5, weights_tys = _tensor_constant6, D_offsets = _tensor_constant7, total_D = 128, max_int2_D = 0, max_int4_D = 0, max_int8_D = 64, max_float16_D = 0, max_float32_D = 0, indices = getitem_4, offsets = getitem_5, pooling_mode = 0, indice_weights = getitem_6, output_dtype = 0, lxu_cache_weights = _tensor_constant8, lxu_cache_locations = _tensor_constant9, row_alignment = 16, max_float8_D = 0, fp8_exponent_bits = -1, fp8_exponent_bias = -1); _tensor_constant2 = _tensor_constant3 = _tensor_constant4 = _tensor_constant5 = _tensor_constant6 = _tensor_constant7 = getitem_4 = getitem_5 = getitem_6 = _tensor_constant8 = _tensor_constant9 = None
embeddings_cat_empty_rank_handle_inference = torchrec_distributed_embedding_lookup_embeddings_cat_empty_rank_handle_inference([int_nbit_split_embedding_codegen_lookup_function], dim = 1, device = 'cuda:0', dtype = torch.float32); int_nbit_split_embedding_codegen_lookup_function = None
to_1 = embeddings_cat_empty_rank_handle_inference.to(device(type='cpu')); embeddings_cat_empty_rank_handle_inference = None
keyed_tensor = torchrec_sparse_jagged_tensor_KeyedTensor(keys = ['product', 'user'], length_per_key = [64, 64], values = to_1, key_dim = 1); to_1 = None
return keyed_tensor
/usr/local/lib/python3.10/dist-packages/torch/jit/_check.py:178: UserWarning:
The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
Scripted Graph Module Created!
def forward(self,
kjt: __torch__.torchrec.sparse.jagged_tensor.KeyedJaggedTensor) -> __torch__.torchrec.sparse.jagged_tensor.KeyedTensor:
_0 = __torch__.torchrec.distributed.quant_embeddingbag.flatten_feature_lengths
_1 = __torch__.torchrec.fx.utils._fx_marker
_2 = __torch__.torchrec.distributed.quant_embedding_kernel._unwrap_kjt
_3 = __torch__.fbgemm_gpu.split_table_batched_embeddings_ops_inference.inputs_to_device
_4 = __torch__.torchrec.distributed.embedding_lookup.embeddings_cat_empty_rank_handle_inference
flatten_feature_lengths = _0(kjt, )
_fx_marker = _1("KJT_ONE_TO_ALL_FORWARD_BEGIN", flatten_feature_lengths, )
split = (flatten_feature_lengths).split([2], )
getitem = split[0]
to = (getitem).to(torch.device("cuda", 0), True, None, )
_fx_marker_1 = _1("KJT_ONE_TO_ALL_FORWARD_END", flatten_feature_lengths, )
_unwrap_kjt = _2(to, )
getitem_1 = (_unwrap_kjt)[0]
getitem_2 = (_unwrap_kjt)[1]
inputs_to_device = _3(getitem_1, getitem_2, None, torch.device("cuda", 0), )
getitem_4 = (inputs_to_device)[0]
getitem_5 = (inputs_to_device)[1]
getitem_6 = (inputs_to_device)[2]
_tensor_constant0 = self._tensor_constant0
_tensor_constant1 = self._tensor_constant1
ops.fbgemm.bounds_check_indices(_tensor_constant0, getitem_4, getitem_5, 1, _tensor_constant1, getitem_6)
_tensor_constant2 = self._tensor_constant2
_tensor_constant3 = self._tensor_constant3
_tensor_constant4 = self._tensor_constant4
_tensor_constant5 = self._tensor_constant5
_tensor_constant6 = self._tensor_constant6
_tensor_constant7 = self._tensor_constant7
_tensor_constant8 = self._tensor_constant8
_tensor_constant9 = self._tensor_constant9
int_nbit_split_embedding_codegen_lookup_function = ops.fbgemm.int_nbit_split_embedding_codegen_lookup_function(_tensor_constant2, _tensor_constant3, _tensor_constant4, _tensor_constant5, _tensor_constant6, _tensor_constant7, 128, 0, 0, 64, 0, 0, getitem_4, getitem_5, 0, getitem_6, 0, _tensor_constant8, _tensor_constant9, 16)
_5 = [int_nbit_split_embedding_codegen_lookup_function]
embeddings_cat_empty_rank_handle_inference = _4(_5, 1, "cuda:0", 6, )
to_1 = torch.to(embeddings_cat_empty_rank_handle_inference, torch.device("cpu"))
_6 = ["product", "user"]
_7 = [64, 64]
keyed_tensor = __torch__.torchrec.sparse.jagged_tensor.KeyedTensor.__new__(__torch__.torchrec.sparse.jagged_tensor.KeyedTensor)
_8 = (keyed_tensor).__init__(_6, _7, to_1, 1, None, None, )
return keyed_tensor
結論¶
在本教學課程中,您已從訓練分散式 RecSys 模型到使其準備好進行推論。TorchRec repo 提供了一個完整範例,說明如何將 TorchRec TorchScript 模型載入到 C++ 中以進行推論。
如需更多資訊,請參閱我們的 dlrm 範例,其中包括使用 Deep Learning Recommendation Model for Personalization and Recommendation Systems 中描述的方法在 Criteo 1TB 資料集上進行多節點訓練。
腳本的總運行時間:( 0 分鐘 0.820 秒)