快捷方式

設定 TorchRec

在本節中,我們將:

  • 了解使用 TorchRec 的需求

  • 設定整合 TorchRec 的環境

  • 執行基本的 TorchRec 程式碼

系統需求

TorchRec 通常僅在 AWS Linux 上進行測試,應可在類似環境中運作。 以下展示了目前測試的相容性矩陣

Python 版本

3.9, 3.10, 3.11, 3.12

運算平台

CPU、CUDA 11.8、CUDA 12.1、CUDA 12.4

除了這些需求之外,TorchRec 的核心相依性是 PyTorch 和 FBGEMM。 如果您的系統通常與這兩個函式庫相容,那麼它應該足以支援 TorchRec。

版本相容性

TorchRec 和 FBGEMM 具有匹配的版本號碼,這些版本號碼會在發布時一起測試

  • TorchRec 1.0 與 FBGEMM 1.0 相容

  • TorchRec 0.8 與 FBGEMM 0.8 相容

  • TorchRec 0.8 可能與 FBGEMM 0.7 不相容

此外,TorchRec 和 FBGEMM 僅在新的 PyTorch 發布時才會發布。 因此,特定版本的 TorchRec 和 FBGEMM 應對應於特定的 PyTorch 版本

  • TorchRec 1.0 與 PyTorch 2.5 相容

  • TorchRec 0.8 與 PyTorch 2.4 相容

  • TorchRec 0.8 可能與 PyTorch 2.3 不相容

安裝

以下我們展示了 CUDA 12.1 的安裝範例。 對於 CPU、CUDA 11.8 或 CUDA 12.4,請將 cu121 替換為 cpucu118cu124

pip install torch --index-url https://download.pytorch.org/whl/cu121
pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/cu121
pip install torchmetrics==1.0.3
pip install torchrec --index-url https://download.pytorch.org/whl/cu121
pip install torch
pip install fbgemm-gpu
pip install torchrec
pip install torch --index-url https://download.pytorch.org/whl/nightly/cu121
pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/cu121
pip install torchmetrics==1.0.3
pip install torchrec --index-url https://download.pytorch.org/whl/nightly/cu121

您也可以從原始碼建置 TorchRec,以使用 TorchRec 的最新變更進行開發。 若要從原始碼建置,請查看此參考

執行簡單的 TorchRec 範例

現在我們已經正確設定 TorchRec,讓我們執行一些 TorchRec 程式碼! 在下面,我們將使用 TorchRec 資料類型執行一個簡單的前向傳遞:KeyedJaggedTensorEmbeddingBagCollection

import torch

import torchrec
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor

ebc = torchrec.EmbeddingBagCollection(
    device="cpu",
    tables=[
        torchrec.EmbeddingBagConfig(
            name="product_table",
            embedding_dim=16,
            num_embeddings=4096,
            feature_names=["product"],
            pooling=torchrec.PoolingType.SUM,
        ),
        torchrec.EmbeddingBagConfig(
            name="user_table",
            embedding_dim=16,
            num_embeddings=4096,
            feature_names=["user"],
            pooling=torchrec.PoolingType.SUM,
        )
    ]
)

product_jt = JaggedTensor(
    values=torch.tensor([1, 2, 1, 5]), lengths=torch.tensor([3, 1])
)
user_jt = JaggedTensor(values=torch.tensor([2, 3, 4, 1]), lengths=torch.tensor([2, 2]))

# Q1: How many batches are there, and which values are in the first batch for product_jt and user_jt?
kjt = KeyedJaggedTensor.from_jt_dict({"product": product_jt, "user": user_jt})

print("Call EmbeddingBagCollection Forward: ", ebc(kjt))

將上述程式碼儲存到名為 torchrec_example.py 的檔案。 然後,您應該能夠從您的終端機使用以下方式執行它

python torchrec_example.py

您應該會看到輸出 KeyedTensor 具有產生的嵌入。 恭喜! 您已正確安裝並執行了您的第一個 TorchRec 程式!

文件

存取 PyTorch 的完整開發人員文件

查看文件

教學

取得初學者和進階開發人員的深入教學課程

查看教學

資源

尋找開發資源並獲得問題解答

查看資源