1
0

attention_kernels.cu 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935
  1. /*
  2. * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
  3. * Copyright (c) 2023, The PygmalionAI team.
  4. * Copyright (c) 2023, The vLLM team.
  5. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
  6. *
  7. * Licensed under the Apache License, Version 2.0 (the "License");
  8. * you may not use this file except in compliance with the License.
  9. * You may obtain a copy of the License at
  10. *
  11. * http://www.apache.org/licenses/LICENSE-2.0
  12. *
  13. * Unless required by applicable law or agreed to in writing, software
  14. * distributed under the License is distributed on an "AS IS" BASIS,
  15. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. * See the License for the specific language governing permissions and
  17. * limitations under the License.
  18. */
  19. #ifdef USE_ROCM
  20. #include <hip/hip_runtime.h>
  21. #endif
  22. #include <torch/extension.h>
  23. #include <ATen/cuda/CUDAContext.h>
  24. #include <c10/cuda/CUDAGuard.h>
  25. #include "attention_dtypes.h"
  26. #include "attention_utils.cuh"
  27. #include "../quantization/kvcache/quant_utils.cuh"
  28. #include <algorithm>
  29. #ifndef USE_ROCM
  30. #define WARP_SIZE 32
  31. #else
  32. #define WARP_SIZE warpSize
  33. #endif
  34. #define MAX(a, b) ((a) > (b) ? (a) : (b))
  35. #define MIN(a, b) ((a) < (b) ? (a) : (b))
  36. #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
  37. namespace aphrodite {
  38. // Utility function for attention softmax.
  39. template<int NUM_WARPS>
  40. inline __device__ float block_sum(float* red_smem, float sum) {
  41. // Decompose the thread index into warp / lane.
  42. int warp = threadIdx.x / WARP_SIZE;
  43. int lane = threadIdx.x % WARP_SIZE;
  44. // Compute the sum per warp.
  45. #pragma unroll
  46. for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
  47. sum += APHRODITE_SHFL_XOR_SYNC(sum, mask);
  48. }
  49. // Warp leaders store the data to shared memory.
  50. if (lane == 0) {
  51. red_smem[warp] = sum;
  52. }
  53. // Make sure the data is in shared memory.
  54. __syncthreads();
  55. // The warps compute the final sums.
  56. if (lane < NUM_WARPS) {
  57. sum = red_smem[lane];
  58. }
  59. // Parallel reduction inside the warp.
  60. #pragma unroll
  61. for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
  62. sum += APHRODITE_SHFL_XOR_SYNC(sum, mask);
  63. }
  64. // Broadcast to other threads.
  65. return APHRODITE_SHFL_SYNC(sum, 0);
  66. }
  67. // TODO: Merge the last two dimensions of the grid.
  68. // Grid: (num_heads, num_seqs, max_num_partitions).
  69. template<
  70. typename scalar_t,
  71. typename cache_t,
  72. int HEAD_SIZE,
  73. int BLOCK_SIZE,
  74. int NUM_THREADS,
  75. bool ENABLE_FP8_KV_CACHE,
  76. int PARTITION_SIZE = 0> // Zero means no partitioning.
  77. __device__ void paged_attention_kernel(
  78. float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
  79. float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
  80. scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
  81. const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
  82. const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
  83. const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
  84. const int num_kv_heads, // [num_heads]
  85. const float scale,
  86. const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
  87. const int* __restrict__ context_lens, // [num_seqs]
  88. const int max_num_blocks_per_seq,
  89. const float* __restrict__ alibi_slopes, // [num_heads]
  90. const int q_stride,
  91. const int kv_block_stride,
  92. const int kv_head_stride) {
  93. const int seq_idx = blockIdx.y;
  94. const int partition_idx = blockIdx.z;
  95. const int max_num_partitions = gridDim.z;
  96. constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
  97. const int context_len = context_lens[seq_idx];
  98. if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) {
  99. // No work to do. Terminate the thread block.
  100. return;
  101. }
  102. const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
  103. const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks;
  104. // [start_block_idx, end_block_idx) is the range of blocks to process.
  105. const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
  106. const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks);
  107. const int num_blocks = end_block_idx - start_block_idx;
  108. // [start_token_idx, end_token_idx) is the range of tokens to process.
  109. const int start_token_idx = start_block_idx * BLOCK_SIZE;
  110. const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len);
  111. const int num_tokens = end_token_idx - start_token_idx;
  112. constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
  113. constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
  114. assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
  115. constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
  116. constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  117. const int thread_idx = threadIdx.x;
  118. const int warp_idx = thread_idx / WARP_SIZE;
  119. const int lane = thread_idx % WARP_SIZE;
  120. const int head_idx = blockIdx.x;
  121. const int num_heads = gridDim.x;
  122. const int num_queries_per_kv = num_heads / num_kv_heads;
  123. const int kv_head_idx = head_idx / num_queries_per_kv;
  124. const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
  125. // A vector type to store a part of a key or a query.
  126. // The vector size is configured in such a way that the threads in a thread group
  127. // fetch or compute 16 bytes at a time.
  128. // For example, if the size of a thread group is 4 and the data type is half,
  129. // then the vector size is 16 / (4 * sizeof(half)) == 2.
  130. constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
  131. using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
  132. using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
  133. using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
  134. constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
  135. constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
  136. const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
  137. const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
  138. // Load the query to registers.
  139. // Each thread in a thread group has a different part of the query.
  140. // For example, if the the thread group size is 4, then the first thread in the group
  141. // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
  142. // th vectors of the query, and so on.
  143. // NOTE: Because q is split from a qkv tensor, it may not be contiguous.
  144. const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
  145. __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
  146. #pragma unroll
  147. for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
  148. const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
  149. q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
  150. }
  151. __syncthreads(); // TODO: possible speedup if this is replaced with a memory wall right before we use q_vecs
  152. // Memory planning.
  153. extern __shared__ char shared_mem[];
  154. // NOTE: We use FP32 for the softmax logits for better accuracy.
  155. float* logits = reinterpret_cast<float*>(shared_mem);
  156. // Workspace for reduction.
  157. __shared__ float red_smem[2 * NUM_WARPS];
  158. // x == THREAD_GROUP_SIZE * VEC_SIZE
  159. // Each thread group fetches x elements from the key at a time.
  160. constexpr int x = 16 / sizeof(cache_t);
  161. float qk_max = -FLT_MAX;
  162. // Iterate over the key blocks.
  163. // Each warp fetches a block of keys for each iteration.
  164. // Each thread group in a warp fetches a key from the block, and computes
  165. // dot product with the query.
  166. const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
  167. for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
  168. // NOTE: The block number is stored in int32. However, we cast it to int64
  169. // because int32 can lead to overflow when this variable is multiplied by large numbers
  170. // (e.g., kv_block_stride).
  171. const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
  172. // Load a key to registers.
  173. // Each thread in a thread group has a different part of the key.
  174. // For example, if the the thread group size is 4, then the first thread in the group
  175. // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
  176. // vectors of the key, and so on.
  177. for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
  178. const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
  179. const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
  180. K_vec k_vecs[NUM_VECS_PER_THREAD];
  181. #pragma unroll
  182. for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
  183. const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
  184. + kv_head_idx * kv_head_stride
  185. + physical_block_offset * x;
  186. const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
  187. const int offset1 = (vec_idx * VEC_SIZE) / x;
  188. const int offset2 = (vec_idx * VEC_SIZE) % x;
  189. if constexpr (ENABLE_FP8_KV_CACHE) {
  190. Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
  191. // Vector conversion from Quant_vec to K_vec.
  192. k_vecs[j] = vec_conversion<K_vec, Quant_vec>(k_vec_quant);
  193. } else {
  194. k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
  195. }
  196. }
  197. // Compute dot product.
  198. // This includes a reduction across the threads in the same thread group.
  199. float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
  200. // Add the ALiBi bias if slopes are given.
  201. qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
  202. if (thread_group_offset == 0) {
  203. // Store the partial reductions to shared memory.
  204. // NOTE: It is required to zero out the masked logits.
  205. const bool mask = token_idx >= context_len;
  206. logits[token_idx - start_token_idx] = mask ? 0.f : qk;
  207. // Update the max value.
  208. qk_max = mask ? qk_max : fmaxf(qk_max, qk);
  209. }
  210. }
  211. }
  212. // Perform reduction across the threads in the same warp to get the
  213. // max qk value for each "warp" (not across the thread block yet).
  214. // The 0-th thread of each thread group already has its max qk value.
  215. #pragma unroll
  216. for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
  217. qk_max = fmaxf(qk_max, APHRODITE_SHFL_XOR_SYNC(qk_max, mask));
  218. }
  219. if (lane == 0) {
  220. red_smem[warp_idx] = qk_max;
  221. }
  222. __syncthreads();
  223. // TODO: Refactor this part.
  224. // Get the max qk value for the sequence.
  225. qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
  226. #pragma unroll
  227. for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
  228. qk_max = fmaxf(qk_max, APHRODITE_SHFL_XOR_SYNC(qk_max, mask));
  229. }
  230. // Broadcast the max qk value to all threads.
  231. qk_max = APHRODITE_SHFL_SYNC(qk_max, 0);
  232. // Get the sum of the exp values.
  233. float exp_sum = 0.f;
  234. for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
  235. float val = __expf(logits[i] - qk_max);
  236. logits[i] = val;
  237. exp_sum += val;
  238. }
  239. exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
  240. // Compute softmax.
  241. const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
  242. for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
  243. logits[i] *= inv_sum;
  244. }
  245. __syncthreads();
  246. // If partitioning is enabled, store the max logit and exp_sum.
  247. if (USE_PARTITIONING && thread_idx == 0) {
  248. float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
  249. + head_idx * max_num_partitions
  250. + partition_idx;
  251. *max_logits_ptr = qk_max;
  252. float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
  253. + head_idx * max_num_partitions
  254. + partition_idx;
  255. *exp_sums_ptr = exp_sum;
  256. }
  257. // Each thread will fetch 16 bytes from the value cache at a time.
  258. constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
  259. using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
  260. using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
  261. using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
  262. using Float_L_vec = typename FloatVec<L_vec>::Type;
  263. constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
  264. constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
  265. constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
  266. // NOTE: We use FP32 for the accumulator for better accuracy.
  267. float accs[NUM_ROWS_PER_THREAD];
  268. #pragma unroll
  269. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  270. accs[i] = 0.f;
  271. }
  272. scalar_t zero_value;
  273. zero(zero_value);
  274. for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
  275. // NOTE: The block number is stored in int32. However, we cast it to int64
  276. // because int32 can lead to overflow when this variable is multiplied by large numbers
  277. // (e.g., kv_block_stride).
  278. const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
  279. const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
  280. const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
  281. L_vec logits_vec;
  282. from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx));
  283. const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
  284. + kv_head_idx * kv_head_stride;
  285. #pragma unroll
  286. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  287. const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
  288. if (row_idx < HEAD_SIZE) {
  289. const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
  290. V_vec v_vec;
  291. if constexpr (ENABLE_FP8_KV_CACHE) {
  292. V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
  293. // Vector conversion from V_quant_vec to V_vec.
  294. v_vec = vec_conversion<V_vec, V_quant_vec>(v_quant_vec);
  295. } else {
  296. v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
  297. }
  298. if (block_idx == num_context_blocks - 1) {
  299. // NOTE: When v_vec contains the tokens that are out of the context,
  300. // we should explicitly zero out the values since they may contain NaNs.
  301. scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
  302. #pragma unroll
  303. for (int j = 0; j < V_VEC_SIZE; j++) {
  304. v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
  305. }
  306. }
  307. accs[i] += dot(logits_vec, v_vec);
  308. }
  309. }
  310. }
  311. // Perform reduction within each warp.
  312. #pragma unroll
  313. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  314. float acc = accs[i];
  315. #pragma unroll
  316. for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
  317. acc += APHRODITE_SHFL_XOR_SYNC(acc, mask);
  318. }
  319. accs[i] = acc;
  320. }
  321. // NOTE: A barrier is required because the shared memory space for logits
  322. // is reused for the output.
  323. __syncthreads();
  324. // Perform reduction across warps.
  325. float* out_smem = reinterpret_cast<float*>(shared_mem);
  326. #pragma unroll
  327. for (int i = NUM_WARPS; i > 1; i /= 2) {
  328. int mid = i / 2;
  329. // Upper warps write to shared memory.
  330. if (warp_idx >= mid && warp_idx < i) {
  331. float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
  332. #pragma unroll
  333. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  334. const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
  335. if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
  336. dst[row_idx] = accs[i];
  337. }
  338. }
  339. }
  340. __syncthreads();
  341. // Lower warps update the output.
  342. if (warp_idx < mid) {
  343. const float* src = &out_smem[warp_idx * HEAD_SIZE];
  344. #pragma unroll
  345. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  346. const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
  347. if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
  348. accs[i] += src[row_idx];
  349. }
  350. }
  351. }
  352. __syncthreads();
  353. }
  354. // Write the final output.
  355. if (warp_idx == 0) {
  356. scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
  357. + head_idx * max_num_partitions * HEAD_SIZE
  358. + partition_idx * HEAD_SIZE;
  359. #pragma unroll
  360. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  361. const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
  362. if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
  363. from_float(*(out_ptr + row_idx), accs[i]);
  364. }
  365. }
  366. }
  367. }
  368. // Grid: (num_heads, num_seqs, 1).
  369. template<
  370. typename scalar_t,
  371. typename cache_t,
  372. int HEAD_SIZE,
  373. int BLOCK_SIZE,
  374. int NUM_THREADS,
  375. bool ENABLE_FP8_KV_CACHE>
  376. __global__ void paged_attention_v1_kernel(
  377. scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
  378. const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
  379. const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
  380. const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
  381. const int num_kv_heads, // [num_heads]
  382. const float scale,
  383. const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
  384. const int* __restrict__ context_lens, // [num_seqs]
  385. const int max_num_blocks_per_seq,
  386. const float* __restrict__ alibi_slopes, // [num_heads]
  387. const int q_stride,
  388. const int kv_block_stride,
  389. const int kv_head_stride) {
  390. paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, ENABLE_FP8_KV_CACHE>(
  391. /* exp_sums */ nullptr, /* max_logits */ nullptr,
  392. out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
  393. max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
  394. }
  395. // Grid: (num_heads, num_seqs, max_num_partitions).
  396. template<
  397. typename scalar_t,
  398. typename cache_t,
  399. int HEAD_SIZE,
  400. int BLOCK_SIZE,
  401. int NUM_THREADS,
  402. bool ENABLE_FP8_KV_CACHE,
  403. int PARTITION_SIZE>
  404. __global__ void paged_attention_v2_kernel(
  405. float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
  406. float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
  407. scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
  408. const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
  409. const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
  410. const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
  411. const int num_kv_heads, // [num_heads]
  412. const float scale,
  413. const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
  414. const int* __restrict__ context_lens, // [num_seqs]
  415. const int max_num_blocks_per_seq,
  416. const float* __restrict__ alibi_slopes, // [num_heads]
  417. const int q_stride,
  418. const int kv_block_stride,
  419. const int kv_head_stride) {
  420. paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, ENABLE_FP8_KV_CACHE, PARTITION_SIZE>(
  421. exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
  422. block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
  423. q_stride, kv_block_stride, kv_head_stride);
  424. }
  425. // Grid: (num_heads, num_seqs).
  426. template<
  427. typename scalar_t,
  428. int HEAD_SIZE,
  429. int NUM_THREADS,
  430. int PARTITION_SIZE>
  431. __global__ void paged_attention_v2_reduce_kernel(
  432. scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
  433. const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
  434. const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
  435. const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
  436. const int* __restrict__ context_lens, // [num_seqs]
  437. const int max_num_partitions) {
  438. const int num_heads = gridDim.x;
  439. const int head_idx = blockIdx.x;
  440. const int seq_idx = blockIdx.y;
  441. const int context_len = context_lens[seq_idx];
  442. const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
  443. if (num_partitions == 1) {
  444. // No need to reduce. Only copy tmp_out to out.
  445. scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
  446. const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
  447. + head_idx * max_num_partitions * HEAD_SIZE;
  448. for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
  449. out_ptr[i] = tmp_out_ptr[i];
  450. }
  451. // Terminate the thread block.
  452. return;
  453. }
  454. constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  455. const int warp_idx = threadIdx.x / WARP_SIZE;
  456. const int lane = threadIdx.x % WARP_SIZE;
  457. // Size: 2 * num_partitions.
  458. extern __shared__ char shared_mem[];
  459. // Workspace for reduction.
  460. __shared__ float red_smem[2 * NUM_WARPS];
  461. // Load max logits to shared memory.
  462. float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
  463. const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
  464. + head_idx * max_num_partitions;
  465. float max_logit = -FLT_MAX;
  466. for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
  467. const float l = max_logits_ptr[i];
  468. shared_max_logits[i] = l;
  469. max_logit = fmaxf(max_logit, l);
  470. }
  471. __syncthreads();
  472. // Get the global max logit.
  473. // Reduce within the warp.
  474. #pragma unroll
  475. for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
  476. max_logit = fmaxf(max_logit, APHRODITE_SHFL_XOR_SYNC(max_logit, mask));
  477. }
  478. if (lane == 0) {
  479. red_smem[warp_idx] = max_logit;
  480. }
  481. __syncthreads();
  482. // Reduce across warps.
  483. max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
  484. #pragma unroll
  485. for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
  486. max_logit = fmaxf(max_logit, APHRODITE_SHFL_XOR_SYNC(max_logit, mask));
  487. }
  488. // Broadcast the max value to all threads.
  489. max_logit = APHRODITE_SHFL_SYNC(max_logit, 0);
  490. // Load rescaled exp sums to shared memory.
  491. float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
  492. const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
  493. + head_idx * max_num_partitions;
  494. float global_exp_sum = 0.0f;
  495. for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
  496. float l = shared_max_logits[i];
  497. float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
  498. global_exp_sum += rescaled_exp_sum;
  499. shared_exp_sums[i] = rescaled_exp_sum;
  500. }
  501. __syncthreads();
  502. global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
  503. const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
  504. // Aggregate tmp_out to out.
  505. const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
  506. + head_idx * max_num_partitions * HEAD_SIZE;
  507. scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
  508. #pragma unroll
  509. for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
  510. float acc = 0.0f;
  511. for (int j = 0; j < num_partitions; ++j) {
  512. acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum;
  513. }
  514. from_float(out_ptr[i], acc);
  515. }
  516. }
  517. } // namespace aphrodite
  518. #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
  519. APHRODITE_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
  520. ((void*)aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
  521. ENABLE_FP8_KV_CACHE>), shared_mem_size); \
  522. aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
  523. ENABLE_FP8_KV_CACHE><<<grid, block, shared_mem_size, stream>>>( \
  524. out_ptr, \
  525. query_ptr, \
  526. key_cache_ptr, \
  527. value_cache_ptr, \
  528. num_kv_heads, \
  529. scale, \
  530. block_tables_ptr, \
  531. context_lens_ptr, \
  532. max_num_blocks_per_seq, \
  533. alibi_slopes_ptr, \
  534. q_stride, \
  535. kv_block_stride, \
  536. kv_head_stride);
  537. // TODO: Tune NUM_THREADS.
  538. template<
  539. typename T,
  540. typename CACHE_T,
  541. int BLOCK_SIZE,
  542. bool ENABLE_FP8_KV_CACHE,
  543. int NUM_THREADS = 128>
  544. void paged_attention_v1_launcher(
  545. torch::Tensor& out,
  546. torch::Tensor& query,
  547. torch::Tensor& key_cache,
  548. torch::Tensor& value_cache,
  549. int num_kv_heads,
  550. float scale,
  551. torch::Tensor& block_tables,
  552. torch::Tensor& context_lens,
  553. int max_context_len,
  554. const c10::optional<torch::Tensor>& alibi_slopes) {
  555. int num_seqs = query.size(0);
  556. int num_heads = query.size(1);
  557. int head_size = query.size(2);
  558. int max_num_blocks_per_seq = block_tables.size(1);
  559. int q_stride = query.stride(0);
  560. int kv_block_stride = key_cache.stride(0);
  561. int kv_head_stride = key_cache.stride(1);
  562. int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
  563. assert(head_size % thread_group_size == 0);
  564. // NOTE: alibi_slopes is optional.
  565. const float* alibi_slopes_ptr = alibi_slopes ?
  566. reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
  567. : nullptr;
  568. T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
  569. T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
  570. CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
  571. CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
  572. int* block_tables_ptr = block_tables.data_ptr<int>();
  573. int* context_lens_ptr = context_lens.data_ptr<int>();
  574. constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  575. int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE;
  576. int logits_size = padded_max_context_len * sizeof(float);
  577. int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
  578. // Python-side check in aphrodite.task_handler.worker._check_if_can_support_max_seq_len
  579. // Keep that in sync with the logic here!
  580. int shared_mem_size = std::max(logits_size, outputs_size);
  581. dim3 grid(num_heads, num_seqs, 1);
  582. dim3 block(NUM_THREADS);
  583. const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
  584. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  585. switch (head_size) {
  586. // NOTE: To reduce the compilation time, we only compile for the
  587. // head sizes that we use in the model. However, we can easily extend this
  588. // to support any head size which is a multiple of 16.
  589. case 64:
  590. LAUNCH_PAGED_ATTENTION_V1(64);
  591. break;
  592. case 80:
  593. LAUNCH_PAGED_ATTENTION_V1(80);
  594. break;
  595. case 96:
  596. LAUNCH_PAGED_ATTENTION_V1(96);
  597. break;
  598. case 112:
  599. LAUNCH_PAGED_ATTENTION_V1(112);
  600. break;
  601. case 128:
  602. LAUNCH_PAGED_ATTENTION_V1(128);
  603. break;
  604. case 256:
  605. LAUNCH_PAGED_ATTENTION_V1(256);
  606. break;
  607. default:
  608. TORCH_CHECK(false, "Unsupported head size: ", head_size);
  609. break;
  610. }
  611. }
  612. #define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, ENABLE_FP8_KV_CACHE) \
  613. paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, ENABLE_FP8_KV_CACHE>( \
  614. out, \
  615. query, \
  616. key_cache, \
  617. value_cache, \
  618. num_kv_heads, \
  619. scale, \
  620. block_tables, \
  621. context_lens, \
  622. max_context_len, \
  623. alibi_slopes);
  624. // NOTE: To reduce the compilation time, we omitted block sizes
  625. // 1, 2, 4, 64, 128, 256.
  626. #define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, ENABLE_FP8_KV_CACHE)\
  627. switch (block_size) { \
  628. case 8: \
  629. CALL_V1_LAUNCHER(T, CACHE_T, 8, ENABLE_FP8_KV_CACHE); \
  630. break; \
  631. case 16: \
  632. CALL_V1_LAUNCHER(T, CACHE_T, 16, ENABLE_FP8_KV_CACHE); \
  633. break; \
  634. case 32: \
  635. CALL_V1_LAUNCHER(T, CACHE_T, 32, ENABLE_FP8_KV_CACHE); \
  636. break; \
  637. default: \
  638. TORCH_CHECK(false, "Unsupported block size: ", block_size); \
  639. break; \
  640. }
  641. void paged_attention_v1(
  642. torch::Tensor& out, // [num_seqs, num_heads, head_size]
  643. torch::Tensor& query, // [num_seqs, num_heads, head_size]
  644. torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  645. torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
  646. int num_kv_heads, // [num_heads]
  647. float scale,
  648. torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
  649. torch::Tensor& context_lens, // [num_seqs]
  650. int block_size,
  651. int max_context_len,
  652. const c10::optional<torch::Tensor>& alibi_slopes,
  653. const bool enable_fp8_kv_cache) {
  654. if (enable_fp8_kv_cache) {
  655. if (query.dtype() == at::ScalarType::Float) {
  656. CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
  657. } else if (query.dtype() == at::ScalarType::Half) {
  658. CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
  659. } else if (query.dtype() == at::ScalarType::BFloat16) {
  660. CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
  661. } else {
  662. TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
  663. }
  664. } else {
  665. if (query.dtype() == at::ScalarType::Float) {
  666. CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false);
  667. } else if (query.dtype() == at::ScalarType::Half) {
  668. CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
  669. } else if (query.dtype() == at::ScalarType::BFloat16) {
  670. CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
  671. } else {
  672. TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
  673. }
  674. }
  675. }
  676. #define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
  677. aphrodite::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
  678. ENABLE_FP8_KV_CACHE, PARTITION_SIZE> \
  679. <<<grid, block, shared_mem_size, stream>>>( \
  680. exp_sums_ptr, \
  681. max_logits_ptr, \
  682. tmp_out_ptr, \
  683. query_ptr, \
  684. key_cache_ptr, \
  685. value_cache_ptr, \
  686. num_kv_heads, \
  687. scale, \
  688. block_tables_ptr, \
  689. context_lens_ptr, \
  690. max_num_blocks_per_seq, \
  691. alibi_slopes_ptr, \
  692. q_stride, \
  693. kv_block_stride, \
  694. kv_head_stride); \
  695. aphrodite::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE> \
  696. <<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
  697. out_ptr, \
  698. exp_sums_ptr, \
  699. max_logits_ptr, \
  700. tmp_out_ptr, \
  701. context_lens_ptr, \
  702. max_num_partitions);
  703. template<
  704. typename T,
  705. typename CACHE_T,
  706. int BLOCK_SIZE,
  707. bool ENABLE_FP8_KV_CACHE,
  708. int NUM_THREADS = 128,
  709. int PARTITION_SIZE = 512>
  710. void paged_attention_v2_launcher(
  711. torch::Tensor& out,
  712. torch::Tensor& exp_sums,
  713. torch::Tensor& max_logits,
  714. torch::Tensor& tmp_out,
  715. torch::Tensor& query,
  716. torch::Tensor& key_cache,
  717. torch::Tensor& value_cache,
  718. int num_kv_heads,
  719. float scale,
  720. torch::Tensor& block_tables,
  721. torch::Tensor& context_lens,
  722. int max_context_len,
  723. const c10::optional<torch::Tensor>& alibi_slopes) {
  724. int num_seqs = query.size(0);
  725. int num_heads = query.size(1);
  726. int head_size = query.size(2);
  727. int max_num_blocks_per_seq = block_tables.size(1);
  728. int q_stride = query.stride(0);
  729. int kv_block_stride = key_cache.stride(0);
  730. int kv_head_stride = key_cache.stride(1);
  731. int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
  732. assert(head_size % thread_group_size == 0);
  733. // NOTE: alibi_slopes is optional.
  734. const float* alibi_slopes_ptr = alibi_slopes ?
  735. reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
  736. : nullptr;
  737. T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
  738. float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
  739. float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
  740. T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
  741. T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
  742. CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
  743. CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
  744. int* block_tables_ptr = block_tables.data_ptr<int>();
  745. int* context_lens_ptr = context_lens.data_ptr<int>();
  746. constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  747. int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
  748. int logits_size = PARTITION_SIZE * sizeof(float);
  749. int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
  750. // For paged attention v2 kernel.
  751. dim3 grid(num_heads, num_seqs, max_num_partitions);
  752. int shared_mem_size = std::max(logits_size, outputs_size);
  753. // For paged attention v2 reduce kernel.
  754. dim3 reduce_grid(num_heads, num_seqs);
  755. int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
  756. dim3 block(NUM_THREADS);
  757. const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
  758. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  759. switch (head_size) {
  760. // NOTE: To reduce the compilation time, we only compile for the
  761. // head sizes that we use in the model. However, we can easily extend this
  762. // to support any head size which is a multiple of 16.
  763. case 64:
  764. LAUNCH_PAGED_ATTENTION_V2(64);
  765. break;
  766. case 80:
  767. LAUNCH_PAGED_ATTENTION_V2(80);
  768. break;
  769. case 96:
  770. LAUNCH_PAGED_ATTENTION_V2(96);
  771. break;
  772. case 112:
  773. LAUNCH_PAGED_ATTENTION_V2(112);
  774. break;
  775. case 128:
  776. LAUNCH_PAGED_ATTENTION_V2(128);
  777. break;
  778. case 256:
  779. LAUNCH_PAGED_ATTENTION_V2(256);
  780. break;
  781. default:
  782. TORCH_CHECK(false, "Unsupported head size: ", head_size);
  783. break;
  784. }
  785. }
  786. #define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, ENABLE_FP8_KV_CACHE) \
  787. paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, ENABLE_FP8_KV_CACHE>( \
  788. out, \
  789. exp_sums, \
  790. max_logits, \
  791. tmp_out, \
  792. query, \
  793. key_cache, \
  794. value_cache, \
  795. num_kv_heads, \
  796. scale, \
  797. block_tables, \
  798. context_lens, \
  799. max_context_len, \
  800. alibi_slopes);
  801. // NOTE: To reduce the compilation time, we omitted block sizes
  802. // 1, 2, 4, 64, 128, 256.
  803. #define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, ENABLE_FP8_KV_CACHE) \
  804. switch (block_size) { \
  805. case 8: \
  806. CALL_V2_LAUNCHER(T, CACHE_T, 8, ENABLE_FP8_KV_CACHE); \
  807. break; \
  808. case 16: \
  809. CALL_V2_LAUNCHER(T, CACHE_T, 16, ENABLE_FP8_KV_CACHE); \
  810. break; \
  811. case 32: \
  812. CALL_V2_LAUNCHER(T, CACHE_T, 32, ENABLE_FP8_KV_CACHE); \
  813. break; \
  814. default: \
  815. TORCH_CHECK(false, "Unsupported block size: ", block_size); \
  816. break; \
  817. }
  818. void paged_attention_v2(
  819. torch::Tensor& out, // [num_seqs, num_heads, head_size]
  820. torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
  821. torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
  822. torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
  823. torch::Tensor& query, // [num_seqs, num_heads, head_size]
  824. torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  825. torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
  826. int num_kv_heads, // [num_heads]
  827. float scale,
  828. torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
  829. torch::Tensor& context_lens, // [num_seqs]
  830. int block_size,
  831. int max_context_len,
  832. const c10::optional<torch::Tensor>& alibi_slopes,
  833. const bool enable_fp8_kv_cache) {
  834. if (enable_fp8_kv_cache) {
  835. if (query.dtype() == at::ScalarType::Float) {
  836. CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
  837. } else if (query.dtype() == at::ScalarType::Half) {
  838. CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
  839. } else if (query.dtype() == at::ScalarType::BFloat16) {
  840. CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
  841. } else {
  842. TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
  843. }
  844. } else {
  845. if (query.dtype() == at::ScalarType::Float) {
  846. CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false);
  847. } else if (query.dtype() == at::ScalarType::Half) {
  848. CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
  849. } else if (query.dtype() == at::ScalarType::BFloat16) {
  850. CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
  851. } else {
  852. TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
  853. }
  854. }
  855. }
  856. #undef WARP_SIZE
  857. #undef MAX
  858. #undef MIN
  859. #undef DIVIDE_ROUND_UP