• 文件 >
  • 表格批次嵌入 (TBE) 推論模組
捷徑

表格批次嵌入 (TBE) 推論模組

穩定 API

class fbgemm_gpu.split_table_batched_embeddings_ops_inference.IntNBitTableBatchedEmbeddingBagsCodegen(embedding_specs: List[Tuple[str, int, int, SparseType, EmbeddingLocation]], feature_table_map: List[int] | None = None, index_remapping: List[Tensor] | None = None, pooling_mode: PoolingMode = PoolingMode.SUM, device: str | device | int | None = None, bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING, weight_lists: List[Tuple[Tensor, Tensor | None]] | None = None, pruning_hash_load_factor: float = 0.5, use_array_for_index_remapping: bool = True, output_dtype: SparseType = SparseType.FP16, cache_algorithm: CacheAlgorithm = CacheAlgorithm.LRU, cache_load_factor: float = 0.2, cache_sets: int = 0, cache_reserved_memory: float = 0.0, enforce_hbm: bool = False, record_cache_metrics: RecordCacheMetrics | None = None, gather_uvm_cache_stats: bool | None = False, row_alignment: int | None = None, fp8_exponent_bits: int | None = None, fp8_exponent_bias: int | None = None, cache_assoc: int = 32, scale_bias_size_in_bytes: int = 4, cacheline_alignment: bool = True, uvm_host_mapped: bool = False, reverse_qparam: bool = False, feature_names_per_table: List[List[str]] | None = None, indices_dtype: dtype = torch.int32)[原始碼]

nn.EmbeddingBag(sparse=False) 的表格批次推論版本,支援 FP32/FP16/FP8/INT8/INT4/INT2 權重

參數:
  • embedding_specs (List[Tuple[int, int, EmbeddingLocation, ComputeDevice]]) –

    嵌入規格列表。每個規格描述一個實體嵌入表格的規格。每一個都是一個元組,包含嵌入列數、嵌入維度(必須是 4 的倍數)、表格放置位置 (EmbeddingLocation) 以及運算裝置 (ComputeDevice)。

    可用的 EmbeddingLocation 選項為

    1. DEVICE = 將嵌入表格放置於 GPU 全域記憶體 (HBM) 中

    2. MANAGED = 將嵌入放置於統一虛擬記憶體中(可從 GPU 和 CPU 存取)

    3. MANAGED_CACHING = 將嵌入表格放置於統一虛擬記憶體中,並使用 GPU 全域記憶體 (HBM) 作為快取

    4. HOST = 將嵌入表格放置於 CPU 記憶體 (DRAM) 中

    5. MTIA = 將嵌入表格放置於 MTIA 記憶體中

    可用的 ComputeDevice 選項為

    1. CPU = 在 CPU 上執行表格查詢

    2. CUDA = 在 GPU 上執行表格查詢

    3. MTIA = 在 MTIA 上執行表格查詢

  • feature_table_map (Optional[List[int]] = None) – 一個可選的列表,指定特徵到表格的映射。feature_table_map[i] 表示特徵 i 映射到的實體嵌入表格。

  • index_remapping (Optional[List[Tensor]] = None) – 用於剪枝的索引重新映射

  • pooling_mode (PoolingMode = PoolingMode.SUM) –

    池化模式。可用的 PoolingMode 選項為

    1. SUM = 總和池化

    2. MEAN = 平均池化

    3. NONE = 無池化(序列嵌入)

  • device (Optional[Union[str, int, torch.device]] = None) – 目前的裝置,用於放置張量

  • bounds_check_mode (BoundsCheckMode = BoundsCheckMode.WARNING) –

    輸入檢查模式。可用的 BoundsCheckMode 選項為

    1. NONE = 跳過邊界檢查

    2. FATAL = 當遇到無效的索引/偏移量時拋出錯誤

    3. WARNING = 當遇到無效的索引/偏移量時印出警告訊息並修正(將無效的索引設為零,並調整無效的偏移量使其在邊界內)

    4. IGNORE = 靜默地修正無效的索引/偏移量(將無效的索引設為零,並調整無效的偏移量使其在邊界內)

  • weight_lists (Optional[List[Tuple[Tensor, Optional[Tensor]]]] = None) – [T]

  • pruning_hash_load_factor (float = 0.5) – 剪枝雜湊的載入因子

  • use_array_for_index_remapping (bool = True) – 若為 True,則使用陣列進行索引重新映射。否則,使用雜湊表。

  • output_dtype (SparseType = SparseType.FP16) – 輸出張量的資料類型。

  • cache_algorithm (CacheAlgorithm = CacheAlgorithm.LRU) –

    快取演算法(當 EmbeddingLocation 設定為 MANAGED_CACHING 時使用)。選項為

    1. LRU = 最近最少使用

    2. LFU = 最不常使用

  • cache_load_factor (float = 0.2) – 用於決定快取容量的因子,當使用 EmbeddingLocation.MANAGED_CACHING 時。快取容量為 cache_load_factor * 所有嵌入表格中的總列數

  • cache_sets (int = 0) – 快取集合的數量(當 EmbeddingLocation 設定為 MANAGED_CACHING 時使用)

  • cache_reserved_memory (float = 0.0) – 在 HBM 中為非快取目的保留的記憶體量(當 EmbeddingLocation 設定為 MANAGED_CACHING 時使用)。

  • enforce_hbm (bool = False) – 若為 True,當使用 EmbeddingLocation.MANAGED_CACHING 時,將所有權重/動量放置於 HBM 中

  • record_cache_metrics (Optional[RecordCacheMetrics] = None) – 若 RecordCacheMetrics.record_cache_miss_counter 為 True,則記錄命中次數、請求次數等;若 RecordCacheMetrics.record_tablewise_cache_miss 為 True,則以表格方式記錄相似的指標

  • gather_uvm_cache_stats (Optional[bool] = False) – 若為 True,當 EmbeddingLocation 設定為 MANAGED_CACHING 時,收集快取統計資訊

  • row_alignment (Optional[int] = None) – 列對齊

  • fp8_exponent_bits (Optional[int] = None) – 使用 FP8 時的指數位元

  • fp8_exponent_bias (Optional[int] = None) – 使用 FP8 時的指數偏差

  • cache_assoc (int = 32) – 快取的路數

  • scale_bias_size_in_bytes (int = DEFAULT_SCALE_BIAS_SIZE_IN_BYTES) – 縮放和偏差的大小,以位元組為單位

  • cacheline_alignment (bool = True) – 若為 True,將每個表格對齊 128 位元組快取行邊界

  • uvm_host_mapped (bool = False) – 若為 True,則使用 malloc + cudaHostRegister 分配每個 UVM 張量。否則使用 cudaMallocManaged

  • reverse_qparam (bool = False) – 若為 True,則在每列末尾載入 qparams。否則,在每列開頭載入 qparams

  • feature_names_per_table (Optional[List[List[str]]] = None) – 一個可選的列表,指定每個表格的特徵名稱。feature_names_per_table[t] 表示表格 t 的特徵名稱。

  • indices_dtype (torch.dtype = torch.int32) – 預期傳遞給 forward() 呼叫的索引張量的 dtype。此資訊將用於建構 remap_indices 陣列/雜湊。選項為 torch.int32torch.int64

assign_embedding_weights(q_weight_list: List[Tuple[Tensor, Tensor | None]]) None[原始碼]

使用輸入的權重和 scale_shifts 列表中的值,指派 self.split_embedding_weights()。

fill_random_weights() None[原始碼]

逐表格使用隨機權重填滿緩衝區

forward(indices: Tensor, offsets: Tensor, per_sample_weights: Tensor | None = None) Tensor[source]

定義每次呼叫時執行的計算。

應由所有子類別覆寫。

Note

雖然前向傳遞的配方需要在這個函數中定義,但應該在之後呼叫 Module 實例,而不是這個函數,因為前者會處理已註冊的 hooks,而後者會靜默地忽略它們。

recompute_module_buffers() None[source]

計算位於 meta 裝置上且未在 reset_weights_placements_and_offsets() 中實現的模組緩衝區。目前這些緩衝區為 weights_tysrows_per_tableD_offsetsbounds_check_warning。目前不計算與剪枝相關或與 uvm 相關的緩衝區。

split_embedding_weights(split_scale_shifts: bool = True) List[Tuple[Tensor, Tensor | None]][source]

傳回依表格分割的權重列表

split_embedding_weights_with_scale_bias(split_scale_bias_mode: int = 1) List[Tuple[Tensor, Tensor | None, Tensor | None]][source]

傳回依表格 `split_scale_bias_mode` 分割的權重列表

0:傳回單列;1:傳回權重 + scale_bias;2:傳回權重、scale、bias。

其他 API

文件

存取 PyTorch 的完整開發人員文件

檢視文件

教學課程

取得初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並取得您的問題解答

檢視資源