捷徑

實驗性運算子

Attention 運算子

std::tuple<at::Tensor, at::Tensor, at::Tensor> gqa_attn_splitk(const at::Tensor &XQ, const at::Tensor &cache_K, const at::Tensor &cache_V, const at::Tensor &seq_positions, const double qk_scale, const int64_t num_split_ks, const int64_t kv_cache_quant_num_groups, const bool use_tensor_cores, const int64_t cache_logical_dtype_int)

解碼群組查詢注意力 Split-K (BF16/INT4 KV)。

解碼群組查詢注意力 (GQA) 的 CUDA 實作,支援 BF16 和 INT4 KV 快取以及 BF16 輸入查詢。目前僅支援最大上下文長度為 16384、固定磁頭尺寸為 128,且僅有一個 KV 快取磁頭。它支援任意數量的查詢磁頭。

參數:
  • XQ – 輸入查詢;形狀 = (B, 1, H_Q, D),其中 B = 批次大小,H_Q = 查詢磁頭數量,D = 磁頭尺寸(固定為 128)

  • cache_K – K 快取;形狀 = (B, MAX_T, H_KV, D),其中 MAX_T = 最大上下文長度(固定為 16384),而 H_KV = KV 快取磁頭數量(固定為 1)

  • cache_V – V 快取;形狀 = (B, MAX_T, H_KV, D)

  • seq_positions – 序列位置(包含每個 token 的實際長度);形狀 = (B)

  • qk_scale – 在 QK^T 之後套用的縮放比例

  • num_split_ks – Split-K 的數量(控制上下文長度維度 (MAX_T) 中的平行處理量)

  • kv_cache_quant_num_groups – 每個 KV token 的群組式 INT4 和 FP8 量化的群組數量(每個群組對量化使用相同的縮放比例和偏差)。FP8 目前僅支援單一群組。

  • use_tensor_cores – 是否使用 tensor core wmma 指令以實現快速實作

  • cache_logical_dtype_int – 指定 kv_cache 的量化資料類型:{BF16:0 , FP8:1, INT4:2}

回傳值:

包含組合後的 split-K 輸出、非組合 split-K 輸出和 split-K metadata 的 tuple (包含最大 QK^T 和 softmax(QK^T) 磁頭總和)

文件資料

存取 PyTorch 的全面開發者文件資料

查看文件

教學文件

取得針對初學者和進階開發者的深入教學文件

查看教學

資源

尋找開發資源並獲得問題解答

查看資源