attention_kernels.cu 42 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953
  1. /*
  2. * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
  3. * Copyright (c) 2023, The PygmalionAI team.
  4. * Copyright (c) 2023, The vLLM team.
  5. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
  6. *
  7. * Licensed under the Apache License, Version 2.0 (the "License");
  8. * you may not use this file except in compliance with the License.
  9. * You may obtain a copy of the License at
  10. *
  11. * http://www.apache.org/licenses/LICENSE-2.0
  12. *
  13. * Unless required by applicable law or agreed to in writing, software
  14. * distributed under the License is distributed on an "AS IS" BASIS,
  15. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. * See the License for the specific language governing permissions and
  17. * limitations under the License.
  18. */
  19. #ifdef USE_ROCM
  20. #include <hip/hip_runtime.h>
  21. #endif
  22. #include <torch/extension.h>
  23. #include <ATen/cuda/CUDAContext.h>
  24. #include <c10/cuda/CUDAGuard.h>
  25. #include "attention_dtypes.h"
  26. #include "attention_utils.cuh"
  27. #ifdef ENABLE_FP8_E5M2
  28. #include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
  29. #endif
  30. #include <algorithm>
  31. #ifndef USE_ROCM
  32. #define WARP_SIZE 32
  33. #else
  34. #define WARP_SIZE warpSize
  35. #endif
  36. #define MAX(a, b) ((a) > (b) ? (a) : (b))
  37. #define MIN(a, b) ((a) < (b) ? (a) : (b))
  38. #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
  39. namespace aphrodite {
  40. // Utility function for attention softmax.
  41. template<int NUM_WARPS>
  42. inline __device__ float block_sum(float* red_smem, float sum) {
  43. // Decompose the thread index into warp / lane.
  44. int warp = threadIdx.x / WARP_SIZE;
  45. int lane = threadIdx.x % WARP_SIZE;
  46. // Compute the sum per warp.
  47. #pragma unroll
  48. for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
  49. sum += APHRODITE_SHFL_XOR_SYNC(sum, mask);
  50. }
  51. // Warp leaders store the data to shared memory.
  52. if (lane == 0) {
  53. red_smem[warp] = sum;
  54. }
  55. // Make sure the data is in shared memory.
  56. __syncthreads();
  57. // The warps compute the final sums.
  58. if (lane < NUM_WARPS) {
  59. sum = red_smem[lane];
  60. }
  61. // Parallel reduction inside the warp.
  62. #pragma unroll
  63. for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
  64. sum += APHRODITE_SHFL_XOR_SYNC(sum, mask);
  65. }
  66. // Broadcast to other threads.
  67. return APHRODITE_SHFL_SYNC(sum, 0);
  68. }
  69. // TODO: Merge the last two dimensions of the grid.
  70. // Grid: (num_heads, num_seqs, max_num_partitions).
  71. template<
  72. typename scalar_t,
  73. typename cache_t,
  74. int HEAD_SIZE,
  75. int BLOCK_SIZE,
  76. int NUM_THREADS,
  77. bool IS_FP8_E5M2_KV_CACHE,
  78. int PARTITION_SIZE = 0> // Zero means no partitioning.
  79. __device__ void paged_attention_kernel(
  80. float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
  81. float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
  82. scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
  83. const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
  84. const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
  85. const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
  86. const int num_kv_heads, // [num_heads]
  87. const float scale,
  88. const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
  89. const int* __restrict__ context_lens, // [num_seqs]
  90. const int max_num_blocks_per_seq,
  91. const float* __restrict__ alibi_slopes, // [num_heads]
  92. const int q_stride,
  93. const int kv_block_stride,
  94. const int kv_head_stride) {
  95. const int seq_idx = blockIdx.y;
  96. const int partition_idx = blockIdx.z;
  97. const int max_num_partitions = gridDim.z;
  98. constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
  99. const int context_len = context_lens[seq_idx];
  100. if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) {
  101. // No work to do. Terminate the thread block.
  102. return;
  103. }
  104. const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
  105. const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks;
  106. // [start_block_idx, end_block_idx) is the range of blocks to process.
  107. const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
  108. const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks);
  109. const int num_blocks = end_block_idx - start_block_idx;
  110. // [start_token_idx, end_token_idx) is the range of tokens to process.
  111. const int start_token_idx = start_block_idx * BLOCK_SIZE;
  112. const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len);
  113. const int num_tokens = end_token_idx - start_token_idx;
  114. constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
  115. constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
  116. assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
  117. constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
  118. constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  119. const int thread_idx = threadIdx.x;
  120. const int warp_idx = thread_idx / WARP_SIZE;
  121. const int lane = thread_idx % WARP_SIZE;
  122. const int head_idx = blockIdx.x;
  123. const int num_heads = gridDim.x;
  124. const int num_queries_per_kv = num_heads / num_kv_heads;
  125. const int kv_head_idx = head_idx / num_queries_per_kv;
  126. const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
  127. // A vector type to store a part of a key or a query.
  128. // The vector size is configured in such a way that the threads in a thread group
  129. // fetch or compute 16 bytes at a time.
  130. // For example, if the size of a thread group is 4 and the data type is half,
  131. // then the vector size is 16 / (4 * sizeof(half)) == 2.
  132. constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
  133. using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
  134. using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
  135. #ifdef ENABLE_FP8_E5M2
  136. using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
  137. #endif
  138. constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
  139. constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
  140. const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
  141. const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
  142. // Load the query to registers.
  143. // Each thread in a thread group has a different part of the query.
  144. // For example, if the the thread group size is 4, then the first thread in the group
  145. // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
  146. // th vectors of the query, and so on.
  147. // NOTE: Because q is split from a qkv tensor, it may not be contiguous.
  148. const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
  149. __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
  150. #pragma unroll
  151. for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
  152. const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
  153. q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
  154. }
  155. __syncthreads(); // TODO: possible speedup if this is replaced with a memory wall right before we use q_vecs
  156. // Memory planning.
  157. extern __shared__ char shared_mem[];
  158. // NOTE: We use FP32 for the softmax logits for better accuracy.
  159. float* logits = reinterpret_cast<float*>(shared_mem);
  160. // Workspace for reduction.
  161. __shared__ float red_smem[2 * NUM_WARPS];
  162. // x == THREAD_GROUP_SIZE * VEC_SIZE
  163. // Each thread group fetches x elements from the key at a time.
  164. constexpr int x = 16 / sizeof(cache_t);
  165. float qk_max = -FLT_MAX;
  166. // Iterate over the key blocks.
  167. // Each warp fetches a block of keys for each iteration.
  168. // Each thread group in a warp fetches a key from the block, and computes
  169. // dot product with the query.
  170. const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
  171. for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
  172. // NOTE: The block number is stored in int32. However, we cast it to int64
  173. // because int32 can lead to overflow when this variable is multiplied by large numbers
  174. // (e.g., kv_block_stride).
  175. const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
  176. // Load a key to registers.
  177. // Each thread in a thread group has a different part of the key.
  178. // For example, if the the thread group size is 4, then the first thread in the group
  179. // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
  180. // vectors of the key, and so on.
  181. for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
  182. const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
  183. const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
  184. K_vec k_vecs[NUM_VECS_PER_THREAD];
  185. #pragma unroll
  186. for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
  187. const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
  188. + kv_head_idx * kv_head_stride
  189. + physical_block_offset * x;
  190. const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
  191. const int offset1 = (vec_idx * VEC_SIZE) / x;
  192. const int offset2 = (vec_idx * VEC_SIZE) % x;
  193. if constexpr (IS_FP8_E5M2_KV_CACHE) {
  194. #ifdef ENABLE_FP8_E5M2
  195. Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
  196. // Vector conversion from Quant_vec to K_vec.
  197. k_vecs[j] = fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
  198. #else
  199. assert(false);
  200. #endif
  201. } else {
  202. k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
  203. }
  204. }
  205. // Compute dot product.
  206. // This includes a reduction across the threads in the same thread group.
  207. float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
  208. // Add the ALiBi bias if slopes are given.
  209. qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
  210. if (thread_group_offset == 0) {
  211. // Store the partial reductions to shared memory.
  212. // NOTE: It is required to zero out the masked logits.
  213. const bool mask = token_idx >= context_len;
  214. logits[token_idx - start_token_idx] = mask ? 0.f : qk;
  215. // Update the max value.
  216. qk_max = mask ? qk_max : fmaxf(qk_max, qk);
  217. }
  218. }
  219. }
  220. // Perform reduction across the threads in the same warp to get the
  221. // max qk value for each "warp" (not across the thread block yet).
  222. // The 0-th thread of each thread group already has its max qk value.
  223. #pragma unroll
  224. for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
  225. qk_max = fmaxf(qk_max, APHRODITE_SHFL_XOR_SYNC(qk_max, mask));
  226. }
  227. if (lane == 0) {
  228. red_smem[warp_idx] = qk_max;
  229. }
  230. __syncthreads();
  231. // TODO: Refactor this part.
  232. // Get the max qk value for the sequence.
  233. qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
  234. #pragma unroll
  235. for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
  236. qk_max = fmaxf(qk_max, APHRODITE_SHFL_XOR_SYNC(qk_max, mask));
  237. }
  238. // Broadcast the max qk value to all threads.
  239. qk_max = APHRODITE_SHFL_SYNC(qk_max, 0);
  240. // Get the sum of the exp values.
  241. float exp_sum = 0.f;
  242. for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
  243. float val = __expf(logits[i] - qk_max);
  244. logits[i] = val;
  245. exp_sum += val;
  246. }
  247. exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
  248. // Compute softmax.
  249. const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
  250. for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
  251. logits[i] *= inv_sum;
  252. }
  253. __syncthreads();
  254. // If partitioning is enabled, store the max logit and exp_sum.
  255. if (USE_PARTITIONING && thread_idx == 0) {
  256. float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
  257. + head_idx * max_num_partitions
  258. + partition_idx;
  259. *max_logits_ptr = qk_max;
  260. float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
  261. + head_idx * max_num_partitions
  262. + partition_idx;
  263. *exp_sums_ptr = exp_sum;
  264. }
  265. // Each thread will fetch 16 bytes from the value cache at a time.
  266. constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
  267. using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
  268. using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
  269. #ifdef ENABLE_FP8_E5M2
  270. using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
  271. #endif
  272. using Float_L_vec = typename FloatVec<L_vec>::Type;
  273. constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
  274. constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
  275. constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
  276. // NOTE: We use FP32 for the accumulator for better accuracy.
  277. float accs[NUM_ROWS_PER_THREAD];
  278. #pragma unroll
  279. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  280. accs[i] = 0.f;
  281. }
  282. scalar_t zero_value;
  283. zero(zero_value);
  284. for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
  285. // NOTE: The block number is stored in int32. However, we cast it to int64
  286. // because int32 can lead to overflow when this variable is multiplied by large numbers
  287. // (e.g., kv_block_stride).
  288. const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
  289. const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
  290. const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
  291. L_vec logits_vec;
  292. from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx));
  293. const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
  294. + kv_head_idx * kv_head_stride;
  295. #pragma unroll
  296. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  297. const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
  298. if (row_idx < HEAD_SIZE) {
  299. const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
  300. V_vec v_vec;
  301. if constexpr (IS_FP8_E5M2_KV_CACHE) {
  302. #ifdef ENABLE_FP8_E5M2
  303. V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
  304. // Vector conversion from V_quant_vec to V_vec.
  305. v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(v_quant_vec);
  306. #else
  307. assert(false);
  308. #endif
  309. } else {
  310. v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
  311. }
  312. if (block_idx == num_context_blocks - 1) {
  313. // NOTE: When v_vec contains the tokens that are out of the context,
  314. // we should explicitly zero out the values since they may contain NaNs.
  315. scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
  316. #pragma unroll
  317. for (int j = 0; j < V_VEC_SIZE; j++) {
  318. v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
  319. }
  320. }
  321. accs[i] += dot(logits_vec, v_vec);
  322. }
  323. }
  324. }
  325. // Perform reduction within each warp.
  326. #pragma unroll
  327. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  328. float acc = accs[i];
  329. #pragma unroll
  330. for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
  331. acc += APHRODITE_SHFL_XOR_SYNC(acc, mask);
  332. }
  333. accs[i] = acc;
  334. }
  335. // NOTE: A barrier is required because the shared memory space for logits
  336. // is reused for the output.
  337. __syncthreads();
  338. // Perform reduction across warps.
  339. float* out_smem = reinterpret_cast<float*>(shared_mem);
  340. #pragma unroll
  341. for (int i = NUM_WARPS; i > 1; i /= 2) {
  342. int mid = i / 2;
  343. // Upper warps write to shared memory.
  344. if (warp_idx >= mid && warp_idx < i) {
  345. float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
  346. #pragma unroll
  347. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  348. const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
  349. if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
  350. dst[row_idx] = accs[i];
  351. }
  352. }
  353. }
  354. __syncthreads();
  355. // Lower warps update the output.
  356. if (warp_idx < mid) {
  357. const float* src = &out_smem[warp_idx * HEAD_SIZE];
  358. #pragma unroll
  359. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  360. const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
  361. if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
  362. accs[i] += src[row_idx];
  363. }
  364. }
  365. }
  366. __syncthreads();
  367. }
  368. // Write the final output.
  369. if (warp_idx == 0) {
  370. scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
  371. + head_idx * max_num_partitions * HEAD_SIZE
  372. + partition_idx * HEAD_SIZE;
  373. #pragma unroll
  374. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  375. const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
  376. if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
  377. from_float(*(out_ptr + row_idx), accs[i]);
  378. }
  379. }
  380. }
  381. }
  382. // Grid: (num_heads, num_seqs, 1).
  383. template<
  384. typename scalar_t,
  385. typename cache_t,
  386. int HEAD_SIZE,
  387. int BLOCK_SIZE,
  388. int NUM_THREADS,
  389. bool IS_FP8_E5M2_KV_CACHE>
  390. __global__ void paged_attention_v1_kernel(
  391. scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
  392. const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
  393. const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
  394. const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
  395. const int num_kv_heads, // [num_heads]
  396. const float scale,
  397. const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
  398. const int* __restrict__ context_lens, // [num_seqs]
  399. const int max_num_blocks_per_seq,
  400. const float* __restrict__ alibi_slopes, // [num_heads]
  401. const int q_stride,
  402. const int kv_block_stride,
  403. const int kv_head_stride) {
  404. paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE>(
  405. /* exp_sums */ nullptr, /* max_logits */ nullptr,
  406. out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
  407. max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
  408. }
  409. // Grid: (num_heads, num_seqs, max_num_partitions).
  410. template<
  411. typename scalar_t,
  412. typename cache_t,
  413. int HEAD_SIZE,
  414. int BLOCK_SIZE,
  415. int NUM_THREADS,
  416. bool IS_FP8_E5M2_KV_CACHE,
  417. int PARTITION_SIZE>
  418. __global__ void paged_attention_v2_kernel(
  419. float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
  420. float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
  421. scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
  422. const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
  423. const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
  424. const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
  425. const int num_kv_heads, // [num_heads]
  426. const float scale,
  427. const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
  428. const int* __restrict__ context_lens, // [num_seqs]
  429. const int max_num_blocks_per_seq,
  430. const float* __restrict__ alibi_slopes, // [num_heads]
  431. const int q_stride,
  432. const int kv_block_stride,
  433. const int kv_head_stride) {
  434. paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE>(
  435. exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
  436. block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
  437. q_stride, kv_block_stride, kv_head_stride);
  438. }
  439. // Grid: (num_heads, num_seqs).
  440. template<
  441. typename scalar_t,
  442. int HEAD_SIZE,
  443. int NUM_THREADS,
  444. int PARTITION_SIZE>
  445. __global__ void paged_attention_v2_reduce_kernel(
  446. scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
  447. const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
  448. const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
  449. const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
  450. const int* __restrict__ context_lens, // [num_seqs]
  451. const int max_num_partitions) {
  452. const int num_heads = gridDim.x;
  453. const int head_idx = blockIdx.x;
  454. const int seq_idx = blockIdx.y;
  455. const int context_len = context_lens[seq_idx];
  456. const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
  457. if (num_partitions == 1) {
  458. // No need to reduce. Only copy tmp_out to out.
  459. scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
  460. const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
  461. + head_idx * max_num_partitions * HEAD_SIZE;
  462. for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
  463. out_ptr[i] = tmp_out_ptr[i];
  464. }
  465. // Terminate the thread block.
  466. return;
  467. }
  468. constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  469. const int warp_idx = threadIdx.x / WARP_SIZE;
  470. const int lane = threadIdx.x % WARP_SIZE;
  471. // Size: 2 * num_partitions.
  472. extern __shared__ char shared_mem[];
  473. // Workspace for reduction.
  474. __shared__ float red_smem[2 * NUM_WARPS];
  475. // Load max logits to shared memory.
  476. float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
  477. const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
  478. + head_idx * max_num_partitions;
  479. float max_logit = -FLT_MAX;
  480. for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
  481. const float l = max_logits_ptr[i];
  482. shared_max_logits[i] = l;
  483. max_logit = fmaxf(max_logit, l);
  484. }
  485. __syncthreads();
  486. // Get the global max logit.
  487. // Reduce within the warp.
  488. #pragma unroll
  489. for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
  490. max_logit = fmaxf(max_logit, APHRODITE_SHFL_XOR_SYNC(max_logit, mask));
  491. }
  492. if (lane == 0) {
  493. red_smem[warp_idx] = max_logit;
  494. }
  495. __syncthreads();
  496. // Reduce across warps.
  497. max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
  498. #pragma unroll
  499. for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
  500. max_logit = fmaxf(max_logit, APHRODITE_SHFL_XOR_SYNC(max_logit, mask));
  501. }
  502. // Broadcast the max value to all threads.
  503. max_logit = APHRODITE_SHFL_SYNC(max_logit, 0);
  504. // Load rescaled exp sums to shared memory.
  505. float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
  506. const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
  507. + head_idx * max_num_partitions;
  508. float global_exp_sum = 0.0f;
  509. for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
  510. float l = shared_max_logits[i];
  511. float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
  512. global_exp_sum += rescaled_exp_sum;
  513. shared_exp_sums[i] = rescaled_exp_sum;
  514. }
  515. __syncthreads();
  516. global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
  517. const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
  518. // Aggregate tmp_out to out.
  519. const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
  520. + head_idx * max_num_partitions * HEAD_SIZE;
  521. scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
  522. #pragma unroll
  523. for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
  524. float acc = 0.0f;
  525. for (int j = 0; j < num_partitions; ++j) {
  526. acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum;
  527. }
  528. from_float(out_ptr[i], acc);
  529. }
  530. }
  531. } // namespace aphrodite
  532. #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
  533. APHRODITE_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
  534. ((void*)aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
  535. IS_FP8_E5M2_KV_CACHE>), shared_mem_size); \
  536. aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
  537. IS_FP8_E5M2_KV_CACHE><<<grid, block, shared_mem_size, stream>>>( \
  538. out_ptr, \
  539. query_ptr, \
  540. key_cache_ptr, \
  541. value_cache_ptr, \
  542. num_kv_heads, \
  543. scale, \
  544. block_tables_ptr, \
  545. context_lens_ptr, \
  546. max_num_blocks_per_seq, \
  547. alibi_slopes_ptr, \
  548. q_stride, \
  549. kv_block_stride, \
  550. kv_head_stride);
  551. // TODO: Tune NUM_THREADS.
  552. template<
  553. typename T,
  554. typename CACHE_T,
  555. int BLOCK_SIZE,
  556. bool IS_FP8_E5M2_KV_CACHE,
  557. int NUM_THREADS = 128>
  558. void paged_attention_v1_launcher(
  559. torch::Tensor& out,
  560. torch::Tensor& query,
  561. torch::Tensor& key_cache,
  562. torch::Tensor& value_cache,
  563. int num_kv_heads,
  564. float scale,
  565. torch::Tensor& block_tables,
  566. torch::Tensor& context_lens,
  567. int max_context_len,
  568. const c10::optional<torch::Tensor>& alibi_slopes) {
  569. int num_seqs = query.size(0);
  570. int num_heads = query.size(1);
  571. int head_size = query.size(2);
  572. int max_num_blocks_per_seq = block_tables.size(1);
  573. int q_stride = query.stride(0);
  574. int kv_block_stride = key_cache.stride(0);
  575. int kv_head_stride = key_cache.stride(1);
  576. int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
  577. assert(head_size % thread_group_size == 0);
  578. // NOTE: alibi_slopes is optional.
  579. const float* alibi_slopes_ptr = alibi_slopes ?
  580. reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
  581. : nullptr;
  582. T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
  583. T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
  584. CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
  585. CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
  586. int* block_tables_ptr = block_tables.data_ptr<int>();
  587. int* context_lens_ptr = context_lens.data_ptr<int>();
  588. constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  589. int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE;
  590. int logits_size = padded_max_context_len * sizeof(float);
  591. int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
  592. // Python-side check in aphrodite.task_handler.worker._check_if_can_support_max_seq_len
  593. // Keep that in sync with the logic here!
  594. int shared_mem_size = std::max(logits_size, outputs_size);
  595. dim3 grid(num_heads, num_seqs, 1);
  596. dim3 block(NUM_THREADS);
  597. const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
  598. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  599. switch (head_size) {
  600. // NOTE: To reduce the compilation time, we only compile for the
  601. // head sizes that we use in the model. However, we can easily extend this
  602. // to support any head size which is a multiple of 16.
  603. case 64:
  604. LAUNCH_PAGED_ATTENTION_V1(64);
  605. break;
  606. case 80:
  607. LAUNCH_PAGED_ATTENTION_V1(80);
  608. break;
  609. case 96:
  610. LAUNCH_PAGED_ATTENTION_V1(96);
  611. break;
  612. case 112:
  613. LAUNCH_PAGED_ATTENTION_V1(112);
  614. break;
  615. case 128:
  616. LAUNCH_PAGED_ATTENTION_V1(128);
  617. break;
  618. case 256:
  619. LAUNCH_PAGED_ATTENTION_V1(256);
  620. break;
  621. default:
  622. TORCH_CHECK(false, "Unsupported head size: ", head_size);
  623. break;
  624. }
  625. }
  626. #define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \
  627. paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>( \
  628. out, \
  629. query, \
  630. key_cache, \
  631. value_cache, \
  632. num_kv_heads, \
  633. scale, \
  634. block_tables, \
  635. context_lens, \
  636. max_context_len, \
  637. alibi_slopes);
  638. // NOTE: To reduce the compilation time, we omitted block sizes
  639. // 1, 2, 4, 64, 128, 256.
  640. #define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
  641. switch (block_size) { \
  642. case 8: \
  643. CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \
  644. break; \
  645. case 16: \
  646. CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \
  647. break; \
  648. case 32: \
  649. CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \
  650. break; \
  651. default: \
  652. TORCH_CHECK(false, "Unsupported block size: ", block_size); \
  653. break; \
  654. }
  655. void paged_attention_v1(
  656. torch::Tensor& out, // [num_seqs, num_heads, head_size]
  657. torch::Tensor& query, // [num_seqs, num_heads, head_size]
  658. torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  659. torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
  660. int num_kv_heads, // [num_heads]
  661. float scale,
  662. torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
  663. torch::Tensor& context_lens, // [num_seqs]
  664. int block_size,
  665. int max_context_len,
  666. const c10::optional<torch::Tensor>& alibi_slopes,
  667. const std::string& kv_cache_dtype) {
  668. if (kv_cache_dtype == "auto") {
  669. if (query.dtype() == at::ScalarType::Float) {
  670. CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false);
  671. } else if (query.dtype() == at::ScalarType::Half) {
  672. CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
  673. } else if (query.dtype() == at::ScalarType::BFloat16) {
  674. CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
  675. } else {
  676. TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
  677. }
  678. } else if (kv_cache_dtype == "fp8_e5m2") {
  679. if (query.dtype() == at::ScalarType::Float) {
  680. CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
  681. } else if (query.dtype() == at::ScalarType::Half) {
  682. CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
  683. } else if (query.dtype() == at::ScalarType::BFloat16) {
  684. CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
  685. } else {
  686. TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
  687. }
  688. } else {
  689. TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
  690. }
  691. }
  692. #define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
  693. aphrodite::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
  694. IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE> \
  695. <<<grid, block, shared_mem_size, stream>>>( \
  696. exp_sums_ptr, \
  697. max_logits_ptr, \
  698. tmp_out_ptr, \
  699. query_ptr, \
  700. key_cache_ptr, \
  701. value_cache_ptr, \
  702. num_kv_heads, \
  703. scale, \
  704. block_tables_ptr, \
  705. context_lens_ptr, \
  706. max_num_blocks_per_seq, \
  707. alibi_slopes_ptr, \
  708. q_stride, \
  709. kv_block_stride, \
  710. kv_head_stride); \
  711. aphrodite::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE> \
  712. <<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
  713. out_ptr, \
  714. exp_sums_ptr, \
  715. max_logits_ptr, \
  716. tmp_out_ptr, \
  717. context_lens_ptr, \
  718. max_num_partitions);
  719. template<
  720. typename T,
  721. typename CACHE_T,
  722. int BLOCK_SIZE,
  723. bool IS_FP8_E5M2_KV_CACHE,
  724. int NUM_THREADS = 128,
  725. int PARTITION_SIZE = 512>
  726. void paged_attention_v2_launcher(
  727. torch::Tensor& out,
  728. torch::Tensor& exp_sums,
  729. torch::Tensor& max_logits,
  730. torch::Tensor& tmp_out,
  731. torch::Tensor& query,
  732. torch::Tensor& key_cache,
  733. torch::Tensor& value_cache,
  734. int num_kv_heads,
  735. float scale,
  736. torch::Tensor& block_tables,
  737. torch::Tensor& context_lens,
  738. int max_context_len,
  739. const c10::optional<torch::Tensor>& alibi_slopes) {
  740. int num_seqs = query.size(0);
  741. int num_heads = query.size(1);
  742. int head_size = query.size(2);
  743. int max_num_blocks_per_seq = block_tables.size(1);
  744. int q_stride = query.stride(0);
  745. int kv_block_stride = key_cache.stride(0);
  746. int kv_head_stride = key_cache.stride(1);
  747. int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
  748. assert(head_size % thread_group_size == 0);
  749. // NOTE: alibi_slopes is optional.
  750. const float* alibi_slopes_ptr = alibi_slopes ?
  751. reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
  752. : nullptr;
  753. T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
  754. float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
  755. float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
  756. T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
  757. T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
  758. CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
  759. CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
  760. int* block_tables_ptr = block_tables.data_ptr<int>();
  761. int* context_lens_ptr = context_lens.data_ptr<int>();
  762. constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  763. int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
  764. int logits_size = PARTITION_SIZE * sizeof(float);
  765. int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
  766. // For paged attention v2 kernel.
  767. dim3 grid(num_heads, num_seqs, max_num_partitions);
  768. int shared_mem_size = std::max(logits_size, outputs_size);
  769. // For paged attention v2 reduce kernel.
  770. dim3 reduce_grid(num_heads, num_seqs);
  771. int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
  772. dim3 block(NUM_THREADS);
  773. const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
  774. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  775. switch (head_size) {
  776. // NOTE: To reduce the compilation time, we only compile for the
  777. // head sizes that we use in the model. However, we can easily extend this
  778. // to support any head size which is a multiple of 16.
  779. case 64:
  780. LAUNCH_PAGED_ATTENTION_V2(64);
  781. break;
  782. case 80:
  783. LAUNCH_PAGED_ATTENTION_V2(80);
  784. break;
  785. case 96:
  786. LAUNCH_PAGED_ATTENTION_V2(96);
  787. break;
  788. case 112:
  789. LAUNCH_PAGED_ATTENTION_V2(112);
  790. break;
  791. case 128:
  792. LAUNCH_PAGED_ATTENTION_V2(128);
  793. break;
  794. case 256:
  795. LAUNCH_PAGED_ATTENTION_V2(256);
  796. break;
  797. default:
  798. TORCH_CHECK(false, "Unsupported head size: ", head_size);
  799. break;
  800. }
  801. }
  802. #define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \
  803. paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>( \
  804. out, \
  805. exp_sums, \
  806. max_logits, \
  807. tmp_out, \
  808. query, \
  809. key_cache, \
  810. value_cache, \
  811. num_kv_heads, \
  812. scale, \
  813. block_tables, \
  814. context_lens, \
  815. max_context_len, \
  816. alibi_slopes);
  817. // NOTE: To reduce the compilation time, we omitted block sizes
  818. // 1, 2, 4, 64, 128, 256.
  819. #define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
  820. switch (block_size) { \
  821. case 8: \
  822. CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \
  823. break; \
  824. case 16: \
  825. CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \
  826. break; \
  827. case 32: \
  828. CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \
  829. break; \
  830. default: \
  831. TORCH_CHECK(false, "Unsupported block size: ", block_size); \
  832. break; \
  833. }
  834. void paged_attention_v2(
  835. torch::Tensor& out, // [num_seqs, num_heads, head_size]
  836. torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
  837. torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
  838. torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
  839. torch::Tensor& query, // [num_seqs, num_heads, head_size]
  840. torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  841. torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
  842. int num_kv_heads, // [num_heads]
  843. float scale,
  844. torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
  845. torch::Tensor& context_lens, // [num_seqs]
  846. int block_size,
  847. int max_context_len,
  848. const c10::optional<torch::Tensor>& alibi_slopes,
  849. const std::string& kv_cache_dtype) {
  850. if (kv_cache_dtype == "auto") {
  851. if (query.dtype() == at::ScalarType::Float) {
  852. CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false);
  853. } else if (query.dtype() == at::ScalarType::Half) {
  854. CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
  855. } else if (query.dtype() == at::ScalarType::BFloat16) {
  856. CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
  857. } else {
  858. TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
  859. }
  860. } else if (kv_cache_dtype == "fp8_e5m2") {
  861. if (query.dtype() == at::ScalarType::Float) {
  862. CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
  863. } else if (query.dtype() == at::ScalarType::Half) {
  864. CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
  865. } else if (query.dtype() == at::ScalarType::BFloat16) {
  866. CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
  867. } else {
  868. TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
  869. }
  870. } else {
  871. TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
  872. }
  873. }
  874. #undef WARP_SIZE
  875. #undef MAX
  876. #undef MIN
  877. #undef DIVIDE_ROUND_UP