表格批次嵌入 (TBE) 推論模組¶
穩定 API¶
- class fbgemm_gpu.split_table_batched_embeddings_ops_inference.IntNBitTableBatchedEmbeddingBagsCodegen(embedding_specs: List[Tuple[str, int, int, SparseType, EmbeddingLocation]], feature_table_map: List[int] | None = None, index_remapping: List[Tensor] | None = None, pooling_mode: PoolingMode = PoolingMode.SUM, device: str | device | int | None = None, bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING, weight_lists: List[Tuple[Tensor, Tensor | None]] | None = None, pruning_hash_load_factor: float = 0.5, use_array_for_index_remapping: bool = True, output_dtype: SparseType = SparseType.FP16, cache_algorithm: CacheAlgorithm = CacheAlgorithm.LRU, cache_load_factor: float = 0.2, cache_sets: int = 0, cache_reserved_memory: float = 0.0, enforce_hbm: bool = False, record_cache_metrics: RecordCacheMetrics | None = None, gather_uvm_cache_stats: bool | None = False, row_alignment: int | None = None, fp8_exponent_bits: int | None = None, fp8_exponent_bias: int | None = None, cache_assoc: int = 32, scale_bias_size_in_bytes: int = 4, cacheline_alignment: bool = True, uvm_host_mapped: bool = False, reverse_qparam: bool = False, feature_names_per_table: List[List[str]] | None = None, indices_dtype: dtype = torch.int32)[原始碼]¶
nn.EmbeddingBag(sparse=False) 的表格批次推論版本,支援 FP32/FP16/FP8/INT8/INT4/INT2 權重
- 參數:
embedding_specs (List[Tuple[int, int, EmbeddingLocation, ComputeDevice]]) –
嵌入規格列表。每個規格描述一個實體嵌入表格的規格。每一個都是一個元組,包含嵌入列數、嵌入維度(必須是 4 的倍數)、表格放置位置 (EmbeddingLocation) 以及運算裝置 (ComputeDevice)。
可用的 EmbeddingLocation 選項為
DEVICE = 將嵌入表格放置於 GPU 全域記憶體 (HBM) 中
MANAGED = 將嵌入放置於統一虛擬記憶體中(可從 GPU 和 CPU 存取)
MANAGED_CACHING = 將嵌入表格放置於統一虛擬記憶體中,並使用 GPU 全域記憶體 (HBM) 作為快取
HOST = 將嵌入表格放置於 CPU 記憶體 (DRAM) 中
MTIA = 將嵌入表格放置於 MTIA 記憶體中
可用的 ComputeDevice 選項為
CPU = 在 CPU 上執行表格查詢
CUDA = 在 GPU 上執行表格查詢
MTIA = 在 MTIA 上執行表格查詢
feature_table_map (Optional[List[int]] = None) – 一個可選的列表,指定特徵到表格的映射。feature_table_map[i] 表示特徵 i 映射到的實體嵌入表格。
index_remapping (Optional[List[Tensor]] = None) – 用於剪枝的索引重新映射
pooling_mode (PoolingMode = PoolingMode.SUM) –
池化模式。可用的 PoolingMode 選項為
SUM = 總和池化
MEAN = 平均池化
NONE = 無池化(序列嵌入)
device (Optional[Union[str, int, torch.device]] = None) – 目前的裝置,用於放置張量
bounds_check_mode (BoundsCheckMode = BoundsCheckMode.WARNING) –
輸入檢查模式。可用的 BoundsCheckMode 選項為
NONE = 跳過邊界檢查
FATAL = 當遇到無效的索引/偏移量時拋出錯誤
WARNING = 當遇到無效的索引/偏移量時印出警告訊息並修正(將無效的索引設為零,並調整無效的偏移量使其在邊界內)
IGNORE = 靜默地修正無效的索引/偏移量(將無效的索引設為零,並調整無效的偏移量使其在邊界內)
weight_lists (Optional[List[Tuple[Tensor, Optional[Tensor]]]] = None) – [T]
pruning_hash_load_factor (float = 0.5) – 剪枝雜湊的載入因子
use_array_for_index_remapping (bool = True) – 若為 True,則使用陣列進行索引重新映射。否則,使用雜湊表。
output_dtype (SparseType = SparseType.FP16) – 輸出張量的資料類型。
cache_algorithm (CacheAlgorithm = CacheAlgorithm.LRU) –
快取演算法(當 EmbeddingLocation 設定為 MANAGED_CACHING 時使用)。選項為
LRU = 最近最少使用
LFU = 最不常使用
cache_load_factor (float = 0.2) – 用於決定快取容量的因子,當使用 EmbeddingLocation.MANAGED_CACHING 時。快取容量為 cache_load_factor * 所有嵌入表格中的總列數
cache_sets (int = 0) – 快取集合的數量(當 EmbeddingLocation 設定為 MANAGED_CACHING 時使用)
cache_reserved_memory (float = 0.0) – 在 HBM 中為非快取目的保留的記憶體量(當 EmbeddingLocation 設定為 MANAGED_CACHING 時使用)。
enforce_hbm (bool = False) – 若為 True,當使用 EmbeddingLocation.MANAGED_CACHING 時,將所有權重/動量放置於 HBM 中
record_cache_metrics (Optional[RecordCacheMetrics] = None) – 若 RecordCacheMetrics.record_cache_miss_counter 為 True,則記錄命中次數、請求次數等;若 RecordCacheMetrics.record_tablewise_cache_miss 為 True,則以表格方式記錄相似的指標
gather_uvm_cache_stats (Optional[bool] = False) – 若為 True,當 EmbeddingLocation 設定為 MANAGED_CACHING 時,收集快取統計資訊
row_alignment (Optional[int] = None) – 列對齊
fp8_exponent_bits (Optional[int] = None) – 使用 FP8 時的指數位元
fp8_exponent_bias (Optional[int] = None) – 使用 FP8 時的指數偏差
cache_assoc (int = 32) – 快取的路數
scale_bias_size_in_bytes (int = DEFAULT_SCALE_BIAS_SIZE_IN_BYTES) – 縮放和偏差的大小,以位元組為單位
cacheline_alignment (bool = True) – 若為 True,將每個表格對齊 128 位元組快取行邊界
uvm_host_mapped (bool = False) – 若為 True,則使用 malloc + cudaHostRegister 分配每個 UVM 張量。否則使用 cudaMallocManaged
reverse_qparam (bool = False) – 若為 True,則在每列末尾載入 qparams。否則,在每列開頭載入 qparams。
feature_names_per_table (Optional[List[List[str]]] = None) – 一個可選的列表,指定每個表格的特徵名稱。feature_names_per_table[t] 表示表格 t 的特徵名稱。
indices_dtype (torch.dtype = torch.int32) – 預期傳遞給 forward() 呼叫的索引張量的 dtype。此資訊將用於建構 remap_indices 陣列/雜湊。選項為 torch.int32 和 torch.int64。
- assign_embedding_weights(q_weight_list: List[Tuple[Tensor, Tensor | None]]) None [原始碼]¶
使用輸入的權重和 scale_shifts 列表中的值,指派 self.split_embedding_weights()。
- forward(indices: Tensor, offsets: Tensor, per_sample_weights: Tensor | None = None) Tensor [source]¶
定義每次呼叫時執行的計算。
應由所有子類別覆寫。
Note
雖然前向傳遞的配方需要在這個函數中定義,但應該在之後呼叫
Module
實例,而不是這個函數,因為前者會處理已註冊的 hooks,而後者會靜默地忽略它們。
- recompute_module_buffers() None [source]¶
計算位於 meta 裝置上且未在 reset_weights_placements_and_offsets() 中實現的模組緩衝區。目前這些緩衝區為 weights_tys、rows_per_table、D_offsets 和 bounds_check_warning。目前不計算與剪枝相關或與 uvm 相關的緩衝區。