SSD 嵌入運算子¶
CUDA 運算子¶
-
enum RocksdbWriteMode¶
rocksdb 寫入模式
在 SSD 卸載中,每個訓練迭代有 3 次寫入 FWD_ROCKSDB_READ:快取查找將未快取的資料從 rocksdb 移動到 fwd 路徑上的 L2 快取
FWD_L1_EVICTION:L1 快取逐出將在 fwd 路徑上將資料逐出到 L2 快取
BWD_L1_CNFLCT_MISS_WRITE_BACK:L1 衝突未命中將在 bwd 路徑上插入到 L2 以進行嵌入更新
一旦 L2 快取已滿,以上所有 L2 快取填充都可能觸發 rocksdb 寫入
此外,我們將在 L2 清空時執行 ssd io
值
-
enumerator FWD_ROCKSDB_READ¶
-
enumerator FWD_L1_EVICTION¶
-
enumerator BWD_L1_CNFLCT_MISS_WRITE_BACK¶
-
enumerator FLUSH¶
-
enumerator FWD_ROCKSDB_READ¶
-
inline size_t hash_shard(int64_t id, size_t num_shards)¶
用於 SSD L2 快取和 rocksdb 分片演算法的雜湊函數
- 參數:
id – 分片鍵
num_shards – 分片範圍
- 回傳:
分片 id 範圍從 [0, num_shards)
-
void cuda_callback_func(cudaStream_t stream, cudaError_t status, void *functor)¶
用於
cudaStreamAddCallback
的回呼函數用於
cudaStreamAddCallback
的通用回呼函數,即cudaStreamCallback_t callback
。此函數將functor
轉換為 void 函數,調用它,然後刪除它(刪除發生在另一個執行緒中)- 參數:
stream –
cudaStreamAddCallback
在其上運作的 CUDA 串流status – CUDA 狀態
functor – 將被呼叫的 functor
- 回傳:
無
-
Tensor masked_index_put_cuda(Tensor self, Tensor indices, Tensor values, Tensor count, const bool use_pipeline, const int64_t preferred_sms)¶
類似於
torch.Tensor.index_put
但忽略indices < 0
masked_index_put_cuda
僅支援 2D 輸入values
。它使用indices
中 >= 0 的列索引,將values
中的count
列放入self
中。# Equivalent PyTorch Python code indices = indices[:count] filter_ = indices >= 0 indices_ = indices[filter_] self[indices_] = values[filter_.nonzero().flatten()]
- 參數:
self – 2D 輸出張量(被索引的張量)
indices – 1D 索引張量
values – 2D 輸入張量
count – 包含要處理的
indices
長度的張量use_pipeline – 指示此核心將與其他核心重疊的標誌。如果為 true,則使用一部分 SM 以減少資源競爭
preferred_sms – 當 use_pipeline=true 時,核心偏好使用的 SM 數量。當 use_pipeline=false 時,此值將被忽略。
- 回傳:
self
張量
-
Tensor masked_index_select_cuda(Tensor self, Tensor indices, Tensor values, Tensor count, const bool use_pipeline, const int64_t preferred_sms)¶
類似於
torch.index_select
但忽略indices < 0
masked_index_select_cuda
僅支援 2D 輸入values
。它將indices
中指定的count
列(其中indices
>= 0)從values
放入self
中# Equivalent PyTorch Python code indices = indices[:count] filter_ = indices >= 0 indices_ = indices[filter_] self[filter_.nonzero().flatten()] = values[indices_]
- 參數:
self – 2D 輸出張量
indices – 1D 索引張量
values – 2D 輸入張量(被索引的張量)
count – 包含要處理的
indices
長度的張量use_pipeline – 指示此核心將與其他核心重疊的標誌。如果為 true,則使用一部分 SM 以減少資源競爭
preferred_sms – 當 use_pipeline=true 時,核心偏好使用的 SM 數量。當 use_pipeline=false 時,此值將被忽略。
- 回傳:
self
張量
-
std::tuple<Tensor, Tensor> ssd_generate_row_addrs_cuda(const Tensor &lxu_cache_locations, const Tensor &assigned_cache_slots, const Tensor &linear_index_inverse_indices, const Tensor &unique_indices_count_cumsum, const Tensor &cache_set_inverse_indices, const Tensor &lxu_cache_weights, const Tensor &inserted_ssd_weights, const Tensor &unique_indices_length, const Tensor &cache_set_sorted_unique_indices)¶
為 SSD TBE 資料產生記憶體位址。
從 SSD 檢索的資料可以儲存在暫存記憶體 (HBM) 或 LXU 快取(也是 HBM)中。
lxu_cache_locations
用於指定資料的位置。如果位置為 -1,則關聯索引的資料位於暫存記憶體中;否則,它位於快取中。為了使 TBE 核心能夠方便地存取資料,此運算子會為每個索引產生第一個位元組的記憶體位址。當存取資料時,TBE 核心僅需要將位址轉換為指標。此外,此運算子還產生後向傳播逐出索引的列表,這些索引基本上是資料位於暫存記憶體中的索引。
- 參數:
lxu_cache_locations – 包含為完整索引列表儲存資料的快取槽的張量。-1 是一個 sentinel 值,表示資料不在快取中。
assigned_cache_slots – 包含唯一索引列表的快取槽的張量。-1 表示資料不在快取中
linear_index_inverse_indices – 包含線性索引在排序前的原始位置的張量
unique_indices_count_cumsum – 包含唯一索引計數的互斥前綴和結果的張量
cache_set_inverse_indices – 包含快取集合在排序前的原始位置的張量
lxu_cache_weights – LXU 快取張量
inserted_ssd_weights – 暫存記憶體張量
unique_indices_length – 包含唯一索引數量的張量(GPU 張量)
cache_set_sorted_unique_indices – 包含排序後的唯一快取集合的相關唯一索引的張量
- 回傳:
張量組成的元組(SSD 列位址張量和後向傳播逐出索引張量)
-
void ssd_update_row_addrs_cuda(const Tensor &ssd_row_addrs_curr, const Tensor &inserted_ssd_weights_curr_next_map, const Tensor &lxu_cache_locations_curr, const Tensor &linear_index_inverse_indices_curr, const Tensor &unique_indices_count_cumsum_curr, const Tensor &cache_set_inverse_indices_curr, const Tensor &lxu_cache_weights, const Tensor &inserted_ssd_weights_next, const Tensor &unique_indices_length_curr)¶
更新 SSD TBE 資料的記憶體位址。
當啟用管線預取時,當前迭代的暫存記憶體中的資料可以在預取步驟期間移動到 L1 或下一次迭代的暫存記憶體。此運算子會更新重新定位到正確位置的資料的記憶體位址。
- 參數:
ssd_row_addrs_curr – 包含當前迭代的列位址的張量
inserted_ssd_weights_curr_next_map – 包含當前迭代中每個索引的位置與下一次迭代的暫存記憶體中的位置之間的映射的張量(-1 = 資料尚未移動)。inserted_ssd_weights_curr_next_map[i] 是位置
lxu_cache_locations_curr – 包含為完整索引列表儲存資料的快取槽的張量,用於當前迭代。-1 是一個 sentinel 值,表示資料不在快取中。
linear_index_inverse_indices_curr – 包含線性索引在排序前的原始位置的張量,用於當前迭代
unique_indices_count_cumsum_curr – 包含唯一索引計數的互斥前綴和結果的張量,用於當前迭代
cache_set_inverse_indices_curr – 包含快取集合在排序前的原始位置的張量,用於當前迭代
lxu_cache_weights – LXU 快取張量
inserted_ssd_weights_next – 下一次迭代的暫存記憶體張量
unique_indices_length_curr – 包含唯一索引數量的張量(GPU 張量),用於當前迭代
- 回傳:
無
-
void compact_indices_cuda(std::vector<Tensor> compact_indices, Tensor compact_count, std::vector<Tensor> indices, Tensor masks, Tensor count)¶
壓縮給定的索引列表。
此運算子根據給定的遮罩(包含 0 或 1 的張量)壓縮給定的索引列表。運算子會移除遮罩為 0 的索引。它僅對
count
數量的元素進行運算(而不是整個張量)。範例
indices = [[0, 3, -1, 3, -1, -1, 7], [0, 2, 2, 3, -1, 9, 7]] masks = [1, 1, 0, 1, 0, 0, 1] count = 5 # x represents an arbitrary value compact_indices = [[0, 3, 3, x, x, x, x], [0, 2, 3, x, x, x, x]] compact_count = 3
- 參數:
compact_indices – 壓縮索引列表(輸出索引)。
compact_count – 包含壓縮後元素數量的張量
indices – 要壓縮的輸入索引列表
masks – 包含 0 或 1 以指示是否移除/保留元素的張量。0 = 移除對應的索引。1 = 保留對應的索引。@count count 包含要壓縮的元素數量的張量
-
class CacheLibCache¶
- #include <cachelib_cache.h>
用於 Cachlib 互動的 Cachelib 包裝類別。
它用於維護所有快取相關操作,包括初始化、插入、查找和逐出。它對於逐出邏輯是有狀態的,呼叫者必須明確獲取和重置與逐出相關的狀態。Cachelib 相關的優化將在此類別內部捕獲,例如獲取和延遲 markUseful 以提高 get 效能
注意
此類別僅處理單個 Cachelib 讀取/更新。並行處理在呼叫者端完成
-
class EmbeddingParameterServer : public EmbeddingKVDB¶
- #include <ps_table_batched_embeddings.h>
Training Parameter Service (TPS) 用戶端的 EmbeddingKVDB 實作。
-
class CacheContext¶
- #include <kv_db_table_batched_embeddings.h>
它保存 l2 快取查找結果。
num_misses 是 l2 快取查找中的未命中數 cached_addr_list 已預先分配了查找數量,以實現更好的並行性,並且無效位置(快取未命中)將保持為 sentinel 值
-
struct QueueItem¶
- #include <kv_db_table_batched_embeddings.h>
用於背景 L2/rocksdb 更新的佇列項目
indices/weights/count 是對應的 set() 參數
read_handles 是 cachelib 抽象化的索引/嵌入對元資料,稍後將用於更新 cachelib LRU 佇列,因為我們將其與 EmbeddingKVDB::get_cache() 分開
mode 用於監控 rocksdb 寫入,請查看 RocksdbWriteMode 以獲得詳細說明
-
class EmbeddingKVDB : public std::enable_shared_from_this<EmbeddingKVDB>¶
- #include <kv_db_table_batched_embeddings.h>
用於與不同快取層和儲存層互動的類別,公開呼叫在 cuda 串流上執行。
目前 TBE 使用它將 Key(嵌入索引)Value(嵌入)卸載到 DRAM、SSD 或遠端儲存,以提供更好的可擴展性,而不會耗盡 HBM 資源
-
class EmbeddingRocksDB : public EmbeddingKVDB¶
- #include <ssd_table_batched_embeddings.h>
EmbeddingKVDB 在 RocksDB 中的實作。