捷徑

模組

標準 TorchRec 模組代表嵌入表的集合

  • EmbeddingBagCollectiontorch.nn.EmbeddingBag 的集合

  • EmbeddingCollectiontorch.nn.Embedding 的集合

這些模組透過標準化的設定類別建構

  • EmbeddingBagConfig 用於 EmbeddingBagCollection

  • EmbeddingConfig 用於 EmbeddingCollection

class torchrec.modules.embedding_configs.EmbeddingBagConfig(num_embeddings: int, embedding_dim: int, name: str = '', data_type: ~torchrec.types.DataType = DataType.FP32, feature_names: ~typing.List[str] = <factory>, weight_init_max: ~typing.Optional[float] = None, weight_init_min: ~typing.Optional[float] = None, num_embeddings_post_pruning: ~typing.Optional[int] = None, init_fn: ~typing.Optional[~typing.Callable[[~torch.Tensor], ~typing.Optional[~torch.Tensor]]] = None, need_pos: bool = False, pooling: ~torchrec.modules.embedding_configs.PoolingType = PoolingType.SUM)

基於:BaseEmbeddingConfig

EmbeddingBagConfig 是一個資料類別,代表單一嵌入表,其輸出預期會被合併。

參數:

pooling (PoolingType) – 合併類型。

class torchrec.modules.embedding_configs.EmbeddingConfig(num_embeddings: int, embedding_dim: int, name: str = '', data_type: ~torchrec.types.DataType = DataType.FP32, feature_names: ~typing.List[str] = <factory>, weight_init_max: ~typing.Optional[float] = None, weight_init_min: ~typing.Optional[float] = None, num_embeddings_post_pruning: ~typing.Optional[int] = None, init_fn: ~typing.Optional[~typing.Callable[[~torch.Tensor], ~typing.Optional[~torch.Tensor]]] = None, need_pos: bool = False)

基於:BaseEmbeddingConfig

EmbeddingConfig 是一個資料類別,代表單一嵌入表。

class torchrec.modules.embedding_configs.BaseEmbeddingConfig(num_embeddings: int, embedding_dim: int, name: str = '', data_type: ~torchrec.types.DataType = DataType.FP32, feature_names: ~typing.List[str] = <factory>, weight_init_max: ~typing.Optional[float] = None, weight_init_min: ~typing.Optional[float] = None, num_embeddings_post_pruning: ~typing.Optional[int] = None, init_fn: ~typing.Optional[~typing.Callable[[~torch.Tensor], ~typing.Optional[~torch.Tensor]]] = None, need_pos: bool = False)

嵌入配置的基本類別。

參數:
  • num_embeddings (int) – 嵌入的數量。

  • embedding_dim (int) – 嵌入維度。

  • name (str) – 嵌入表格的名稱。

  • data_type (DataType) – 嵌入表格的資料類型。

  • feature_names (List[str]) – 特徵名稱的列表。

  • weight_init_max (Optional[float]) – 權重初始化的最大值。

  • weight_init_min (Optional[float]) – 權重初始化的最小值。

  • num_embeddings_post_pruning (Optional[int]) – 推論時剪枝後的嵌入數量。如果為 None,則不應用剪枝。

  • init_fn (Optional[Callable[[torch.Tensor], Optional[torch.Tensor]]]) – 嵌入權重的初始化函數。

  • need_pos (bool) – 表格是否為位置加權。

class torchrec.modules.embedding_modules.EmbeddingBagCollection(tables: List[EmbeddingBagConfig], is_weighted: bool = False, device: Optional[device] = None)

EmbeddingBagCollection 代表池化嵌入 ( EmbeddingBags ) 的集合。

注意

EmbeddingBagCollection 是一個未分片的模組,並且未針對效能進行最佳化。對於效能敏感的場景,請考慮使用分片版本 ShardedEmbeddingBagCollection。

它可以被呼叫,傳入的引數代表稀疏資料,其形式為 KeyedJaggedTensor,數值的形狀為 (F, B, L[f][i]),其中

  • F: 特徵 (鍵) 的數量

  • B: 批次大小

  • L[f][i]: 稀疏特徵的長度 (對於每個特徵 f 和批次索引 i 而言,可能不同,也就是不規則的)

並輸出一個 KeyedTensor,數值的形狀為 (B, D),其中

  • B: 批次大小

  • D: 所有嵌入表格的嵌入維度總和,也就是 sum([config.embedding_dim for config in tables])

假設引數是一個 KeyedJaggedTensor J,具有 F 個特徵、批次大小 BL[f][i] 個稀疏長度,使得 J[f][i] 是特徵 f 和批次索引 i 的 bag,則輸出 KeyedTensor KT 定義如下: KT[i] = torch.cat([emb[f](J[f][i]) for f in J.keys()]),其中 emb[f] 是對應於特徵 fEmbeddingBag

請注意, J[f][i] 是一個可變長度的整數值列表 (一個 bag),而 emb[f](J[f][i]) 是通過使用 EmbeddingBag emb[f] 的模式 (預設為平均值) 減少 J[f][i] 中每個數值的嵌入而產生的池化嵌入。

參數:
  • tables (List[EmbeddingBagConfig]) – 嵌入表格的列表。

  • is_weighted (bool) – 輸入的 KeyedJaggedTensor 是否為加權。

  • device (Optional[torch.device]) – 預設的計算裝置。

範例

table_0 = EmbeddingBagConfig(
    name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"]
)
table_1 = EmbeddingBagConfig(
    name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"]
)

ebc = EmbeddingBagCollection(tables=[table_0, table_1])

#        i = 0     i = 1    i = 2  <-- batch indices
# "f1"   [0,1]     None      [2]
# "f2"   [3]       [4]     [5,6,7]
#  ^
# features

features = KeyedJaggedTensor(
    keys=["f1", "f2"],
    values=torch.tensor([0, 1,                  2,    # feature 'f1'
                            3,      4,    5, 6, 7]),  # feature 'f2'
                    #    i = 1    i = 2    i = 3   <--- batch indices
    offsets=torch.tensor([
            0, 2, 2,       # 'f1' bags are values[0:2], values[2:2], and values[2:3]
            3, 4, 5, 8]),  # 'f2' bags are values[3:4], values[4:5], and values[5:8]
)

pooled_embeddings = ebc(features)
print(pooled_embeddings.values())
tensor([
    #  f1 pooled embeddings              f2 pooled embeddings
    #     from bags (dim. 3)                from bags (dim. 4)
    [-0.8899, -0.1342, -1.9060,  -0.0905, -0.2814, -0.9369, -0.7783],  # i = 0
    [ 0.0000,  0.0000,  0.0000,   0.1598,  0.0695,  1.3265, -0.1011],  # i = 1
    [-0.4256, -1.1846, -2.1648,  -1.0893,  0.3590, -1.9784, -0.7681]],  # i = 2
    grad_fn=<CatBackward0>)
print(pooled_embeddings.keys())
['f1', 'f2']
print(pooled_embeddings.offset_per_key())
tensor([0, 3, 7])  # embeddings have dimensions 3 and 4, so embeddings are at [0, 3) and [3, 7).
property device: device

返回: torch.device:計算裝置。

embedding_bag_configs() List[EmbeddingBagConfig]
返回:

嵌入 bag 配置。

返回類型:

List[EmbeddingBagConfig]

forward(features: KeyedJaggedTensor) KeyedTensor

執行 EmbeddingBagCollection 的前向傳遞。此方法接收 KeyedJaggedTensor 並返回 KeyedTensor,這是每個特徵的 embeddings pooling 結果。

參數:

features (KeyedJaggedTensor) – 輸入 KJT

返回:

KeyedTensor

is_weighted() bool
返回:

EmbeddingBagCollection 是否加權。

返回類型:

bool

reset_parameters() None

重置 EmbeddingBagCollection 的參數。參數值基於每個 EmbeddingBagConfig 的 init_fn(如果存在)進行初始化。

class torchrec.modules.embedding_modules.EmbeddingCollection(tables: List[EmbeddingConfig], device: Optional[device] = None, need_indices: bool = False)

EmbeddingCollection 表示非 pooling embeddings 的集合。

注意

EmbeddingCollection 是一個未分片的模組,並且沒有針對效能進行優化。對於效能敏感的場景,請考慮使用分片版本 ShardedEmbeddingCollection。

它可以被呼叫,傳入的引數代表稀疏資料,其形式為 KeyedJaggedTensor,數值的形狀為 (F, B, L[f][i]),其中

  • F: 特徵 (鍵) 的數量

  • B: 批次大小

  • L[f][i]: 稀疏特徵的長度 (對於每個特徵 f 和批次索引 i 而言,可能不同,也就是不規則的)

並輸出一個類型為 Dict[Feature, JaggedTensor]result,其中 result[f] 是一個 JaggedTensor,其形狀為 (EB[f], D[f]),其中

  • EB[f]: 特徵 f 的「擴展批次大小」,等於其 bag 值的長度總和,即 sum([len(J[f][i]) for i in range(B)])

  • D[f]: 是特徵 f 的 embedding 維度。

參數:
  • tables (List[EmbeddingConfig]) – embedding 表的列表。

  • device (Optional[torch.device]) – 預設的計算裝置。

  • need_indices (bool) – 是否需要將 indices 傳遞到最終的查找字典。

範例

e1_config = EmbeddingConfig(
    name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"]
)
e2_config = EmbeddingConfig(
    name="t2", embedding_dim=3, num_embeddings=10, feature_names=["f2"]
)

ec = EmbeddingCollection(tables=[e1_config, e2_config])

#     0       1        2  <-- batch
# 0   [0,1] None    [2]
# 1   [3]    [4]    [5,6,7]
# ^
# feature

features = KeyedJaggedTensor.from_offsets_sync(
    keys=["f1", "f2"],
    values=torch.tensor([0, 1,                  2,    # feature 'f1'
                            3,      4,    5, 6, 7]),  # feature 'f2'
                    #    i = 1    i = 2    i = 3   <--- batch indices
    offsets=torch.tensor([
            0, 2, 2,       # 'f1' bags are values[0:2], values[2:2], and values[2:3]
            3, 4, 5, 8]),  # 'f2' bags are values[3:4], values[4:5], and values[5:8]
)

feature_embeddings = ec(features)
print(feature_embeddings['f2'].values())
tensor([
    # embedding for value 3 in f2 bag values[3:4]:
    [-0.2050,  0.5478,  0.6054],

    # embedding for value 4 in f2 bag values[4:5]:
    [ 0.7352,  0.3210, -3.0399],

    # embedding for values 5, 6, 7 in f2 bag values[5:8]:
    [ 0.1279, -0.1756, -0.4130],
    [ 0.7519, -0.4341, -0.0499],
    [ 0.9329, -1.0697, -0.8095],

], grad_fn=<EmbeddingBackward>)
property device: device

返回: torch.device:計算裝置。

embedding_configs() List[EmbeddingConfig]
返回:

embedding 設定。

返回類型:

List[EmbeddingConfig]

embedding_dim() int
返回:

embedding 維度。

返回類型:

int

embedding_names_by_table() List[List[str]]
返回:

按表排列的 embedding 名稱。

返回類型:

List[List[str]]

forward(features: KeyedJaggedTensor) Dict[str, JaggedTensor]

執行 EmbeddingBagCollection 的前向傳遞。此方法接收 KeyedJaggedTensor 並返回 Dict[str, JaggedTensor],這是每個特徵的個別 embeddings 的結果。

參數:

features (KeyedJaggedTensor) – 形式為 [F X B X L] 的 KJT。

返回:

Dict[str, JaggedTensor]

need_indices() bool
返回:

判斷 EmbeddingCollection 是否需要索引。

返回類型:

bool

reset_parameters() None

重置 EmbeddingCollection 的參數。 參數值會根據每個 EmbeddingConfig 的 init_fn (如果存在) 進行初始化。

文件

取得 PyTorch 的完整開發者文件

查看文件

教學

取得適合初學者和進階開發者的深入教學課程

查看教學

資源

尋找開發資源並獲得您的問題解答

查看資源