1
0

attention_kernels.cu 44 KB


  1. /*
  2. * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
  3. * Copyright (c) 2023, The PygmalionAI team.
  4. * Copyright (c) 2023, The vLLM team.
  5. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
  6. *
  7. * Licensed under the Apache License, Version 2.0 (the "License");
  8. * you may not use this file except in compliance with the License.
  9. * You may obtain a copy of the License at
  10. *
  11. * http://www.apache.org/licenses/LICENSE-2.0
  12. *
  13. * Unless required by applicable law or agreed to in writing, software
  14. * distributed under the License is distributed on an "AS IS" BASIS,
  15. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. * See the License for the specific language governing permissions and
  17. * limitations under the License.
  18. */
  19. #include <torch/extension.h>
  20. #include <ATen/cuda/CUDAContext.h>
  21. #include <c10/cuda/CUDAGuard.h>
  22. #include "attention_dtypes.h"
  23. #include "attention_utils.cuh"
  24. #if defined(ENABLE_FP8_E5M2)
  25. #include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
  26. #elif defined(ENABLE_FP8_E4M3)
  27. #include "../quantization/fp8/amd_detail/quant_utils.cuh"
  28. #endif
  29. #include <algorithm>
  30. #ifdef USE_ROCM
  31. #include <hip/hip_bf16.h>
  32. typedef __hip_bfloat16 __nv_bfloat16;
  33. #endif
  34. #ifndef USE_ROCM
  35. #define WARP_SIZE 32
  36. #else
  37. #define WARP_SIZE warpSize
  38. #endif
  39. #define MAX(a, b) ((a) > (b) ? (a) : (b))
  40. #define MIN(a, b) ((a) < (b) ? (a) : (b))
  41. #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
  42. namespace aphrodite {
  43. // Utility function for attention softmax.
  44. template<int NUM_WARPS>
  45. inline __device__ float block_sum(float* red_smem, float sum) {
  46. // Decompose the thread index into warp / lane.
  47. int warp = threadIdx.x / WARP_SIZE;
  48. int lane = threadIdx.x % WARP_SIZE;
  49. // Compute the sum per warp.
  50. #pragma unroll
  51. for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
  52. sum += APHRODITE_SHFL_XOR_SYNC(sum, mask);
  53. }
  54. // Warp leaders store the data to shared memory.
  55. if (lane == 0) {
  56. red_smem[warp] = sum;
  57. }
  58. // Make sure the data is in shared memory.
  59. __syncthreads();
  60. // The warps compute the final sums.
  61. if (lane < NUM_WARPS) {
  62. sum = red_smem[lane];
  63. }
  64. // Parallel reduction inside the warp.
  65. #pragma unroll
  66. for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
  67. sum += APHRODITE_SHFL_XOR_SYNC(sum, mask);
  68. }
  69. // Broadcast to other threads.
  70. return APHRODITE_SHFL_SYNC(sum, 0);
  71. }
  72. // TODO: Merge the last two dimensions of the grid.
  73. // Grid: (num_heads, num_seqs, max_num_partitions).
  74. template<
  75. typename scalar_t,
  76. typename cache_t,
  77. int HEAD_SIZE,
  78. int BLOCK_SIZE,
  79. int NUM_THREADS,
  80. bool IS_FP8_KV_CACHE,
  81. int PARTITION_SIZE = 0> // Zero means no partitioning.
  82. __device__ void paged_attention_kernel(
  83. float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
  84. float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
  85. scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
  86. const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
  87. const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
  88. const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
  89. const int num_kv_heads, // [num_heads]
  90. const float scale,
  91. const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
  92. const int* __restrict__ context_lens, // [num_seqs]
  93. const int max_num_blocks_per_seq,
  94. const float* __restrict__ alibi_slopes, // [num_heads]
  95. const int q_stride,
  96. const int kv_block_stride,
  97. const int kv_head_stride,
  98. const float kv_scale) {
  99. const int seq_idx = blockIdx.y;
  100. const int partition_idx = blockIdx.z;
  101. const int max_num_partitions = gridDim.z;
  102. constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
  103. const int context_len = context_lens[seq_idx];
  104. if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) {
  105. // No work to do. Terminate the thread block.
  106. return;
  107. }
  108. const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
  109. const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks;
  110. // [start_block_idx, end_block_idx) is the range of blocks to process.
  111. const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
  112. const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks);
  113. const int num_blocks = end_block_idx - start_block_idx;
  114. // [start_token_idx, end_token_idx) is the range of tokens to process.
  115. const int start_token_idx = start_block_idx * BLOCK_SIZE;
  116. const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len);
  117. const int num_tokens = end_token_idx - start_token_idx;
  118. constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
  119. constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
  120. assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
  121. constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
  122. constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  123. const int thread_idx = threadIdx.x;
  124. const int warp_idx = thread_idx / WARP_SIZE;
  125. const int lane = thread_idx % WARP_SIZE;
  126. const int head_idx = blockIdx.x;
  127. const int num_heads = gridDim.x;
  128. const int num_queries_per_kv = num_heads / num_kv_heads;
  129. const int kv_head_idx = head_idx / num_queries_per_kv;
  130. const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
  131. // A vector type to store a part of a key or a query.
  132. // The vector size is configured in such a way that the threads in a thread group
  133. // fetch or compute 16 bytes at a time.
  134. // For example, if the size of a thread group is 4 and the data type is half,
  135. // then the vector size is 16 / (4 * sizeof(half)) == 2.
  136. constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
  137. using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
  138. using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
  139. #if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3)
  140. using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
  141. #endif
  142. constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
  143. constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
  144. const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
  145. const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
  146. // Load the query to registers.
  147. // Each thread in a thread group has a different part of the query.
  148. // For example, if the the thread group size is 4, then the first thread in the group
  149. // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
  150. // th vectors of the query, and so on.
  151. // NOTE: Because q is split from a qkv tensor, it may not be contiguous.
  152. const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
  153. __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
  154. #pragma unroll
  155. for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
  156. const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
  157. q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
  158. }
  159. __syncthreads(); // TODO: possible speedup if this is replaced with a memory wall right before we use q_vecs
  160. // Memory planning.
  161. extern __shared__ char shared_mem[];
  162. // NOTE: We use FP32 for the softmax logits for better accuracy.
  163. float* logits = reinterpret_cast<float*>(shared_mem);
  164. // Workspace for reduction.
  165. __shared__ float red_smem[2 * NUM_WARPS];
  166. // x == THREAD_GROUP_SIZE * VEC_SIZE
  167. // Each thread group fetches x elements from the key at a time.
  168. constexpr int x = 16 / sizeof(cache_t);
  169. float qk_max = -FLT_MAX;
  170. // Iterate over the key blocks.
  171. // Each warp fetches a block of keys for each iteration.
  172. // Each thread group in a warp fetches a key from the block, and computes
  173. // dot product with the query.
  174. const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
  175. for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
  176. // NOTE: The block number is stored in int32. However, we cast it to int64
  177. // because int32 can lead to overflow when this variable is multiplied by large numbers
  178. // (e.g., kv_block_stride).
  179. const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
  180. // Load a key to registers.
  181. // Each thread in a thread group has a different part of the key.
  182. // For example, if the the thread group size is 4, then the first thread in the group
  183. // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
  184. // vectors of the key, and so on.
  185. for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
  186. const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
  187. const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
  188. K_vec k_vecs[NUM_VECS_PER_THREAD];
  189. #pragma unroll
  190. for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
  191. const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
  192. + kv_head_idx * kv_head_stride
  193. + physical_block_offset * x;
  194. const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
  195. const int offset1 = (vec_idx * VEC_SIZE) / x;
  196. const int offset2 = (vec_idx * VEC_SIZE) % x;
  197. if constexpr (IS_FP8_KV_CACHE) {
  198. #if defined(ENABLE_FP8_E5M2)
  199. Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
  200. // Vector conversion from Quant_vec to K_vec.
  201. k_vecs[j] = fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
  202. #elif defined(ENABLE_FP8_E4M3)
  203. Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
  204. // Vector conversion from Quant_vec to K_vec. Use scaled_vec_conversion to convert FP8_E4M3 quantized k
  205. // cache vec to k vec in higher precision (FP16, BFloat16, etc.)
  206. k_vecs[j] = fp8_e4m3::scaled_vec_conversion<K_vec, Quant_vec>(k_vec_quant, kv_scale);
  207. #else
  208. assert(false);
  209. #endif
  210. } else {
  211. k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
  212. }
  213. }
  214. // Compute dot product.
  215. // This includes a reduction across the threads in the same thread group.
  216. float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
  217. // Add the ALiBi bias if slopes are given.
  218. qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
  219. if (thread_group_offset == 0) {
  220. // Store the partial reductions to shared memory.
  221. // NOTE: It is required to zero out the masked logits.
  222. const bool mask = token_idx >= context_len;
  223. logits[token_idx - start_token_idx] = mask ? 0.f : qk;
  224. // Update the max value.
  225. qk_max = mask ? qk_max : fmaxf(qk_max, qk);
  226. }
  227. }
  228. }
  229. // Perform reduction across the threads in the same warp to get the
  230. // max qk value for each "warp" (not across the thread block yet).
  231. // The 0-th thread of each thread group already has its max qk value.
  232. #pragma unroll
  233. for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
  234. qk_max = fmaxf(qk_max, APHRODITE_SHFL_XOR_SYNC(qk_max, mask));
  235. }
  236. if (lane == 0) {
  237. red_smem[warp_idx] = qk_max;
  238. }
  239. __syncthreads();
  240. // TODO: Refactor this part.
  241. // Get the max qk value for the sequence.
  242. qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
  243. #pragma unroll
  244. for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
  245. qk_max = fmaxf(qk_max, APHRODITE_SHFL_XOR_SYNC(qk_max, mask));
  246. }
  247. // Broadcast the max qk value to all threads.
  248. qk_max = APHRODITE_SHFL_SYNC(qk_max, 0);
  249. // Get the sum of the exp values.
  250. float exp_sum = 0.f;
  251. for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
  252. float val = __expf(logits[i] - qk_max);
  253. logits[i] = val;
  254. exp_sum += val;
  255. }
  256. exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
  257. // Compute softmax.
  258. const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
  259. for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
  260. logits[i] *= inv_sum;
  261. }
  262. __syncthreads();
  263. // If partitioning is enabled, store the max logit and exp_sum.
  264. if (USE_PARTITIONING && thread_idx == 0) {
  265. float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
  266. + head_idx * max_num_partitions
  267. + partition_idx;
  268. *max_logits_ptr = qk_max;
  269. float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
  270. + head_idx * max_num_partitions
  271. + partition_idx;
  272. *exp_sums_ptr = exp_sum;
  273. }
  274. // Each thread will fetch 16 bytes from the value cache at a time.
  275. constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
  276. using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
  277. using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
  278. #if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3)
  279. using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
  280. #endif
  281. using Float_L_vec = typename FloatVec<L_vec>::Type;
  282. constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
  283. constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
  284. constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
  285. // NOTE: We use FP32 for the accumulator for better accuracy.
  286. float accs[NUM_ROWS_PER_THREAD];
  287. #pragma unroll
  288. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  289. accs[i] = 0.f;
  290. }
  291. scalar_t zero_value;
  292. zero(zero_value);
  293. for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
  294. // NOTE: The block number is stored in int32. However, we cast it to int64
  295. // because int32 can lead to overflow when this variable is multiplied by large numbers
  296. // (e.g., kv_block_stride).
  297. const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
  298. const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
  299. const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
  300. L_vec logits_vec;
  301. from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx));
  302. const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
  303. + kv_head_idx * kv_head_stride;
  304. #pragma unroll
  305. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  306. const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
  307. if (row_idx < HEAD_SIZE) {
  308. const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
  309. V_vec v_vec;
  310. if constexpr (IS_FP8_KV_CACHE) {
  311. #if defined(ENABLE_FP8_E5M2)
  312. V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
  313. // Vector conversion from V_quant_vec to V_vec.
  314. v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(v_quant_vec);
  315. #elif defined(ENABLE_FP8_E4M3)
  316. V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
  317. // Vector conversion from V_quant_vec to V_vec. Use scaled_vec_conversion to convert
  318. // FP8_E4M3 quantized v cache vec to v vec in higher precision (FP16, BFloat16, etc.)
  319. v_vec = fp8_e4m3::scaled_vec_conversion<V_vec, V_quant_vec>(v_quant_vec, kv_scale);
  320. #else
  321. assert(false);
  322. #endif
  323. } else {
  324. v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
  325. }
  326. if (block_idx == num_context_blocks - 1) {
  327. // NOTE: When v_vec contains the tokens that are out of the context,
  328. // we should explicitly zero out the values since they may contain NaNs.
  329. // See https://github.com/aphrodite-project/aphrodite/issues/641#issuecomment-1682544472
  330. scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
  331. #pragma unroll
  332. for (int j = 0; j < V_VEC_SIZE; j++) {
  333. v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
  334. }
  335. }
  336. accs[i] += dot(logits_vec, v_vec);
  337. }
  338. }
  339. }
  340. // Perform reduction within each warp.
  341. #pragma unroll
  342. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  343. float acc = accs[i];
  344. #pragma unroll
  345. for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
  346. acc += APHRODITE_SHFL_XOR_SYNC(acc, mask);
  347. }
  348. accs[i] = acc;
  349. }
  350. // NOTE: A barrier is required because the shared memory space for logits
  351. // is reused for the output.
  352. __syncthreads();
  353. // Perform reduction across warps.
  354. float* out_smem = reinterpret_cast<float*>(shared_mem);
  355. #pragma unroll
  356. for (int i = NUM_WARPS; i > 1; i /= 2) {
  357. int mid = i / 2;
  358. // Upper warps write to shared memory.
  359. if (warp_idx >= mid && warp_idx < i) {
  360. float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
  361. #pragma unroll
  362. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  363. const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
  364. if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
  365. dst[row_idx] = accs[i];
  366. }
  367. }
  368. }
  369. __syncthreads();
  370. // Lower warps update the output.
  371. if (warp_idx < mid) {
  372. const float* src = &out_smem[warp_idx * HEAD_SIZE];
  373. #pragma unroll
  374. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  375. const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
  376. if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
  377. accs[i] += src[row_idx];
  378. }
  379. }
  380. }
  381. __syncthreads();
  382. }
  383. // Write the final output.
  384. if (warp_idx == 0) {
  385. scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
  386. + head_idx * max_num_partitions * HEAD_SIZE
  387. + partition_idx * HEAD_SIZE;
  388. #pragma unroll
  389. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  390. const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
  391. if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
  392. from_float(*(out_ptr + row_idx), accs[i]);
  393. }
  394. }
  395. }
  396. }
  397. // Grid: (num_heads, num_seqs, 1).
  398. template<
  399. typename scalar_t,
  400. typename cache_t,
  401. int HEAD_SIZE,
  402. int BLOCK_SIZE,
  403. int NUM_THREADS,
  404. bool IS_FP8_KV_CACHE>
  405. __global__ void paged_attention_v1_kernel(
  406. scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
  407. const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
  408. const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
  409. const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
  410. const int num_kv_heads, // [num_heads]
  411. const float scale,
  412. const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
  413. const int* __restrict__ context_lens, // [num_seqs]
  414. const int max_num_blocks_per_seq,
  415. const float* __restrict__ alibi_slopes, // [num_heads]
  416. const int q_stride,
  417. const int kv_block_stride,
  418. const int kv_head_stride,
  419. const float kv_scale) {
  420. paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_KV_CACHE>(
  421. /* exp_sums */ nullptr, /* max_logits */ nullptr,
  422. out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
  423. max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale);
  424. }
  425. // Grid: (num_heads, num_seqs, max_num_partitions).
  426. template<
  427. typename scalar_t,
  428. typename cache_t,
  429. int HEAD_SIZE,
  430. int BLOCK_SIZE,
  431. int NUM_THREADS,
  432. bool IS_FP8_KV_CACHE,
  433. int PARTITION_SIZE>
  434. __global__ void paged_attention_v2_kernel(
  435. float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
  436. float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
  437. scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
  438. const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
  439. const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
  440. const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
  441. const int num_kv_heads, // [num_heads]
  442. const float scale,
  443. const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
  444. const int* __restrict__ context_lens, // [num_seqs]
  445. const int max_num_blocks_per_seq,
  446. const float* __restrict__ alibi_slopes, // [num_heads]
  447. const int q_stride,
  448. const int kv_block_stride,
  449. const int kv_head_stride,
  450. const float kv_scale) {
  451. paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_KV_CACHE, PARTITION_SIZE>(
  452. exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
  453. block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
  454. q_stride, kv_block_stride, kv_head_stride, kv_scale);
  455. }
  456. // Grid: (num_heads, num_seqs).
  457. template<
  458. typename scalar_t,
  459. int HEAD_SIZE,
  460. int NUM_THREADS,
  461. int PARTITION_SIZE>
  462. __global__ void paged_attention_v2_reduce_kernel(
  463. scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
  464. const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
  465. const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
  466. const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
  467. const int* __restrict__ context_lens, // [num_seqs]
  468. const int max_num_partitions) {
  469. const int num_heads = gridDim.x;
  470. const int head_idx = blockIdx.x;
  471. const int seq_idx = blockIdx.y;
  472. const int context_len = context_lens[seq_idx];
  473. const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
  474. if (num_partitions == 1) {
  475. // No need to reduce. Only copy tmp_out to out.
  476. scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
  477. const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
  478. + head_idx * max_num_partitions * HEAD_SIZE;
  479. for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
  480. out_ptr[i] = tmp_out_ptr[i];
  481. }
  482. // Terminate the thread block.
  483. return;
  484. }
  485. constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  486. const int warp_idx = threadIdx.x / WARP_SIZE;
  487. const int lane = threadIdx.x % WARP_SIZE;
  488. // Size: 2 * num_partitions.
  489. extern __shared__ char shared_mem[];
  490. // Workspace for reduction.
  491. __shared__ float red_smem[2 * NUM_WARPS];
  492. // Load max logits to shared memory.
  493. float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
  494. const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
  495. + head_idx * max_num_partitions;
  496. float max_logit = -FLT_MAX;
  497. for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
  498. const float l = max_logits_ptr[i];
  499. shared_max_logits[i] = l;
  500. max_logit = fmaxf(max_logit, l);
  501. }
  502. __syncthreads();
  503. // Get the global max logit.
  504. // Reduce within the warp.
  505. #pragma unroll
  506. for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
  507. max_logit = fmaxf(max_logit, APHRODITE_SHFL_XOR_SYNC(max_logit, mask));
  508. }
  509. if (lane == 0) {
  510. red_smem[warp_idx] = max_logit;
  511. }
  512. __syncthreads();
  513. // Reduce across warps.
  514. max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
  515. #pragma unroll
  516. for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
  517. max_logit = fmaxf(max_logit, APHRODITE_SHFL_XOR_SYNC(max_logit, mask));
  518. }
  519. // Broadcast the max value to all threads.
  520. max_logit = APHRODITE_SHFL_SYNC(max_logit, 0);
  521. // Load rescaled exp sums to shared memory.
  522. float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
  523. const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
  524. + head_idx * max_num_partitions;
  525. float global_exp_sum = 0.0f;
  526. for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
  527. float l = shared_max_logits[i];
  528. float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
  529. global_exp_sum += rescaled_exp_sum;
  530. shared_exp_sums[i] = rescaled_exp_sum;
  531. }
  532. __syncthreads();
  533. global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
  534. const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
  535. // Aggregate tmp_out to out.
  536. const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
  537. + head_idx * max_num_partitions * HEAD_SIZE;
  538. scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
  539. #pragma unroll
  540. for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
  541. float acc = 0.0f;
  542. for (int j = 0; j < num_partitions; ++j) {
  543. acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum;
  544. }
  545. from_float(out_ptr[i], acc);
  546. }
  547. }
  548. } // namespace aphrodite
  549. #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
  550. APHRODITE_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
  551. ((void*)aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
  552. IS_FP8_KV_CACHE>), shared_mem_size); \
  553. aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
  554. IS_FP8_KV_CACHE><<<grid, block, shared_mem_size, stream>>>( \
  555. out_ptr, \
  556. query_ptr, \
  557. key_cache_ptr, \
  558. value_cache_ptr, \
  559. num_kv_heads, \
  560. scale, \
  561. block_tables_ptr, \
  562. context_lens_ptr, \
  563. max_num_blocks_per_seq, \
  564. alibi_slopes_ptr, \
  565. q_stride, \
  566. kv_block_stride, \
  567. kv_head_stride, \
  568. kv_scale);
  569. // TODO: Tune NUM_THREADS.
  570. template<
  571. typename T,
  572. typename CACHE_T,
  573. int BLOCK_SIZE,
  574. bool IS_FP8_KV_CACHE,
  575. int NUM_THREADS = 128>
  576. void paged_attention_v1_launcher(
  577. torch::Tensor& out,
  578. torch::Tensor& query,
  579. torch::Tensor& key_cache,
  580. torch::Tensor& value_cache,
  581. int num_kv_heads,
  582. float scale,
  583. torch::Tensor& block_tables,
  584. torch::Tensor& context_lens,
  585. int max_context_len,
  586. const c10::optional<torch::Tensor>& alibi_slopes,
  587. float kv_scale) {
  588. int num_seqs = query.size(0);
  589. int num_heads = query.size(1);
  590. int head_size = query.size(2);
  591. int max_num_blocks_per_seq = block_tables.size(1);
  592. int q_stride = query.stride(0);
  593. int kv_block_stride = key_cache.stride(0);
  594. int kv_head_stride = key_cache.stride(1);
  595. int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
  596. assert(head_size % thread_group_size == 0);
  597. // NOTE: alibi_slopes is optional.
  598. const float* alibi_slopes_ptr = alibi_slopes ?
  599. reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
  600. : nullptr;
  601. T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
  602. T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
  603. CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
  604. CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
  605. int* block_tables_ptr = block_tables.data_ptr<int>();
  606. int* context_lens_ptr = context_lens.data_ptr<int>();
  607. constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  608. int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE;
  609. int logits_size = padded_max_context_len * sizeof(float);
  610. int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
  611. // Python-side check in aphrodite.task_handler.worker._check_if_can_support_max_seq_len
  612. // Keep that in sync with the logic here!
  613. int shared_mem_size = std::max(logits_size, outputs_size);
  614. dim3 grid(num_heads, num_seqs, 1);
  615. dim3 block(NUM_THREADS);
  616. const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
  617. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  618. switch (head_size) {
  619. // NOTE: To reduce the compilation time, we only compile for the
  620. // head sizes that we use in the model. However, we can easily extend this
  621. // to support any head size which is a multiple of 16.
  622. case 64:
  623. LAUNCH_PAGED_ATTENTION_V1(64);
  624. break;
  625. case 80:
  626. LAUNCH_PAGED_ATTENTION_V1(80);
  627. break;
  628. case 96:
  629. LAUNCH_PAGED_ATTENTION_V1(96);
  630. break;
  631. case 112:
  632. LAUNCH_PAGED_ATTENTION_V1(112);
  633. break;
  634. case 128:
  635. LAUNCH_PAGED_ATTENTION_V1(128);
  636. break;
  637. case 256:
  638. LAUNCH_PAGED_ATTENTION_V1(256);
  639. break;
  640. default:
  641. TORCH_CHECK(false, "Unsupported head size: ", head_size);
  642. break;
  643. }
  644. }
  645. #define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
  646. paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE>( \
  647. out, \
  648. query, \
  649. key_cache, \
  650. value_cache, \
  651. num_kv_heads, \
  652. scale, \
  653. block_tables, \
  654. context_lens, \
  655. max_context_len, \
  656. alibi_slopes, \
  657. kv_scale);
  658. // NOTE: To reduce the compilation time, we omitted block sizes
  659. // 1, 2, 4, 64, 128, 256.
  660. #define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \
  661. switch (block_size) { \
  662. case 8: \
  663. CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \
  664. break; \
  665. case 16: \
  666. CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \
  667. break; \
  668. case 32: \
  669. CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \
  670. break; \
  671. default: \
  672. TORCH_CHECK(false, "Unsupported block size: ", block_size); \
  673. break; \
  674. }
  675. void paged_attention_v1(
  676. torch::Tensor& out, // [num_seqs, num_heads, head_size]
  677. torch::Tensor& query, // [num_seqs, num_heads, head_size]
  678. torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  679. torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
  680. int num_kv_heads, // [num_heads]
  681. float scale,
  682. torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
  683. torch::Tensor& context_lens, // [num_seqs]
  684. int block_size,
  685. int max_context_len,
  686. const c10::optional<torch::Tensor>& alibi_slopes,
  687. const std::string& kv_cache_dtype,
  688. float kv_scale) {
  689. if (kv_cache_dtype == "auto") {
  690. if (query.dtype() == at::ScalarType::Float) {
  691. CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false);
  692. } else if (query.dtype() == at::ScalarType::Half) {
  693. CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
  694. } else if (query.dtype() == at::ScalarType::BFloat16) {
  695. CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
  696. } else {
  697. TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
  698. }
  699. } else if (kv_cache_dtype == "fp8") {
  700. if (query.dtype() == at::ScalarType::Float) {
  701. CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
  702. } else if (query.dtype() == at::ScalarType::Half) {
  703. CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
  704. } else if (query.dtype() == at::ScalarType::BFloat16) {
  705. CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
  706. } else {
  707. TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
  708. }
  709. } else {
  710. TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
  711. }
  712. }
  713. #define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
  714. aphrodite::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
  715. IS_FP8_KV_CACHE, PARTITION_SIZE> \
  716. <<<grid, block, shared_mem_size, stream>>>( \
  717. exp_sums_ptr, \
  718. max_logits_ptr, \
  719. tmp_out_ptr, \
  720. query_ptr, \
  721. key_cache_ptr, \
  722. value_cache_ptr, \
  723. num_kv_heads, \
  724. scale, \
  725. block_tables_ptr, \
  726. context_lens_ptr, \
  727. max_num_blocks_per_seq, \
  728. alibi_slopes_ptr, \
  729. q_stride, \
  730. kv_block_stride, \
  731. kv_head_stride, \
  732. kv_scale); \
  733. aphrodite::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE> \
  734. <<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
  735. out_ptr, \
  736. exp_sums_ptr, \
  737. max_logits_ptr, \
  738. tmp_out_ptr, \
  739. context_lens_ptr, \
  740. max_num_partitions);
  741. template<
  742. typename T,
  743. typename CACHE_T,
  744. int BLOCK_SIZE,
  745. bool IS_FP8_KV_CACHE,
  746. int NUM_THREADS = 128,
  747. int PARTITION_SIZE = 512>
  748. void paged_attention_v2_launcher(
  749. torch::Tensor& out,
  750. torch::Tensor& exp_sums,
  751. torch::Tensor& max_logits,
  752. torch::Tensor& tmp_out,
  753. torch::Tensor& query,
  754. torch::Tensor& key_cache,
  755. torch::Tensor& value_cache,
  756. int num_kv_heads,
  757. float scale,
  758. torch::Tensor& block_tables,
  759. torch::Tensor& context_lens,
  760. int max_context_len,
  761. const c10::optional<torch::Tensor>& alibi_slopes,
  762. float kv_scale) {
  763. int num_seqs = query.size(0);
  764. int num_heads = query.size(1);
  765. int head_size = query.size(2);
  766. int max_num_blocks_per_seq = block_tables.size(1);
  767. int q_stride = query.stride(0);
  768. int kv_block_stride = key_cache.stride(0);
  769. int kv_head_stride = key_cache.stride(1);
  770. int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
  771. assert(head_size % thread_group_size == 0);
  772. // NOTE: alibi_slopes is optional.
  773. const float* alibi_slopes_ptr = alibi_slopes ?
  774. reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
  775. : nullptr;
  776. T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
  777. float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
  778. float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
  779. T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
  780. T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
  781. CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
  782. CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
  783. int* block_tables_ptr = block_tables.data_ptr<int>();
  784. int* context_lens_ptr = context_lens.data_ptr<int>();
  785. constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  786. int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
  787. int logits_size = PARTITION_SIZE * sizeof(float);
  788. int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
  789. // For paged attention v2 kernel.
  790. dim3 grid(num_heads, num_seqs, max_num_partitions);
  791. int shared_mem_size = std::max(logits_size, outputs_size);
  792. // For paged attention v2 reduce kernel.
  793. dim3 reduce_grid(num_heads, num_seqs);
  794. int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
  795. dim3 block(NUM_THREADS);
  796. const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
  797. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  798. switch (head_size) {
  799. // NOTE: To reduce the compilation time, we only compile for the
  800. // head sizes that we use in the model. However, we can easily extend this
  801. // to support any head size which is a multiple of 16.
  802. case 64:
  803. LAUNCH_PAGED_ATTENTION_V2(64);
  804. break;
  805. case 80:
  806. LAUNCH_PAGED_ATTENTION_V2(80);
  807. break;
  808. case 96:
  809. LAUNCH_PAGED_ATTENTION_V2(96);
  810. break;
  811. case 112:
  812. LAUNCH_PAGED_ATTENTION_V2(112);
  813. break;
  814. case 128:
  815. LAUNCH_PAGED_ATTENTION_V2(128);
  816. break;
  817. case 256:
  818. LAUNCH_PAGED_ATTENTION_V2(256);
  819. break;
  820. default:
  821. TORCH_CHECK(false, "Unsupported head size: ", head_size);
  822. break;
  823. }
  824. }
  825. #define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
  826. paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE>( \
  827. out, \
  828. exp_sums, \
  829. max_logits, \
  830. tmp_out, \
  831. query, \
  832. key_cache, \
  833. value_cache, \
  834. num_kv_heads, \
  835. scale, \
  836. block_tables, \
  837. context_lens, \
  838. max_context_len, \
  839. alibi_slopes, \
  840. kv_scale);
  841. // NOTE: To reduce the compilation time, we omitted block sizes
  842. // 1, 2, 4, 64, 128, 256.
  843. #define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \
  844. switch (block_size) { \
  845. case 8: \
  846. CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \
  847. break; \
  848. case 16: \
  849. CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \
  850. break; \
  851. case 32: \
  852. CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \
  853. break; \
  854. default: \
  855. TORCH_CHECK(false, "Unsupported block size: ", block_size); \
  856. break; \
  857. }
  858. void paged_attention_v2(
  859. torch::Tensor& out, // [num_seqs, num_heads, head_size]
  860. torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
  861. torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
  862. torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
  863. torch::Tensor& query, // [num_seqs, num_heads, head_size]
  864. torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  865. torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
  866. int num_kv_heads, // [num_heads]
  867. float scale,
  868. torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
  869. torch::Tensor& context_lens, // [num_seqs]
  870. int block_size,
  871. int max_context_len,
  872. const c10::optional<torch::Tensor>& alibi_slopes,
  873. const std::string& kv_cache_dtype,
  874. float kv_scale) {
  875. if (kv_cache_dtype == "auto") {
  876. if (query.dtype() == at::ScalarType::Float) {
  877. CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false);
  878. } else if (query.dtype() == at::ScalarType::Half) {
  879. CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
  880. } else if (query.dtype() == at::ScalarType::BFloat16) {
  881. CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
  882. } else {
  883. TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
  884. }
  885. } else if (kv_cache_dtype == "fp8") {
  886. if (query.dtype() == at::ScalarType::Float) {
  887. CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
  888. } else if (query.dtype() == at::ScalarType::Half) {
  889. CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
  890. } else if (query.dtype() == at::ScalarType::BFloat16) {
  891. CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
  892. } else {
  893. TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
  894. }
  895. } else {
  896. TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
  897. }
  898. }
  899. #undef WARP_SIZE
  900. #undef MAX
  901. #undef MIN
  902. #undef DIVIDE_ROUND_UP