attention_kernels.cu 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983
  1. /*
  2. * Adapted from
  3. * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
  4. * Copyright (c) 2023, The PygmalionAI team.
  5. * Copyright (c) 2023, The vLLM team.
  6. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
  7. *
  8. * Licensed under the Apache License, Version 2.0 (the "License");
  9. * you may not use this file except in compliance with the License.
  10. * You may obtain a copy of the License at
  11. *
  12. * http://www.apache.org/licenses/LICENSE-2.0
  13. *
  14. * Unless required by applicable law or agreed to in writing, software
  15. * distributed under the License is distributed on an "AS IS" BASIS,
  16. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. * See the License for the specific language governing permissions and
  18. * limitations under the License.
  19. */
  20. #ifdef USE_ROCM
  21. #include <hip/hip_runtime.h>
  22. #endif
  23. #include <torch/extension.h>
  24. #include <ATen/cuda/CUDAContext.h>
  25. #include <c10/cuda/CUDAGuard.h>
  26. #include "attention_dtypes.h"
  27. #include "attention_utils.cuh"
  28. #include "../quantization/int8_kvcache/quant_utils.cuh"
  29. #ifdef ENABLE_FP8_E5M2
  30. #include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
  31. #endif
  32. #include <algorithm>
  33. #ifndef USE_ROCM
  34. #define WARP_SIZE 32
  35. #else
  36. #define WARP_SIZE warpSize
  37. #endif
  38. #define MAX(a, b) ((a) > (b) ? (a) : (b))
  39. #define MIN(a, b) ((a) < (b) ? (a) : (b))
  40. #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
  41. enum kv_cache_dtype {
  42. AUTO,
  43. #ifdef ENABLE_FP8_E5M2
  44. FP8_E5M2,
  45. #endif
  46. INT8
  47. };
  48. namespace aphrodite {
  49. // Utility function for attention softmax.
  50. template <int NUM_WARPS>
  51. inline __device__ float block_sum(float* red_smem, float sum) {
  52. // Decompose the thread index into warp / lane.
  53. int warp = threadIdx.x / WARP_SIZE;
  54. int lane = threadIdx.x % WARP_SIZE;
  55. // Compute the sum per warp.
  56. #pragma unroll
  57. for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
  58. sum += APHRODITE_SHFL_XOR_SYNC(sum, mask);
  59. }
  60. // Warp leaders store the data to shared memory.
  61. if (lane == 0) {
  62. red_smem[warp] = sum;
  63. }
  64. // Make sure the data is in shared memory.
  65. __syncthreads();
  66. // The warps compute the final sums.
  67. if (lane < NUM_WARPS) {
  68. sum = red_smem[lane];
  69. }
  70. // Parallel reduction inside the warp.
  71. #pragma unroll
  72. for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
  73. sum += APHRODITE_SHFL_XOR_SYNC(sum, mask);
  74. }
  75. // Broadcast to other threads.
  76. return APHRODITE_SHFL_SYNC(sum, 0);
  77. }
  78. // TODO: Merge the last two dimensions of the grid.
  79. // Grid: (num_heads, num_seqs, max_num_partitions).
  80. template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
  81. int NUM_THREADS, kv_cache_dtype KV_CACHE_DTYPE,
  82. int PARTITION_SIZE = 0> // Zero means no partitioning.
  83. __device__ void paged_attention_kernel(
  84. float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
  85. float* __restrict__ max_logits, // [num_seqs, num_heads,
  86. // max_num_partitions]
  87. scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
  88. // head_size]
  89. const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
  90. const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
  91. // head_size/x, block_size, x]
  92. const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
  93. // head_size, block_size]
  94. const int num_kv_heads, // [num_heads]
  95. const float scale,
  96. const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
  97. const int* __restrict__ context_lens, // [num_seqs]
  98. const int max_num_blocks_per_seq,
  99. const float* __restrict__ alibi_slopes, // [num_heads]
  100. const int q_stride, const int kv_block_stride, const int kv_head_stride,
  101. const float k_scale = 1.0f, const float k_zp = 0.0f,
  102. const float v_scale = 1.0f, const float v_zp = 0.0f) {
  103. const int seq_idx = blockIdx.y;
  104. const int partition_idx = blockIdx.z;
  105. const int max_num_partitions = gridDim.z;
  106. constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
  107. const int context_len = context_lens[seq_idx];
  108. if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) {
  109. // No work to do. Terminate the thread block.
  110. return;
  111. }
  112. const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
  113. const int num_blocks_per_partition =
  114. USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks;
  115. // [start_block_idx, end_block_idx) is the range of blocks to process.
  116. const int start_block_idx =
  117. USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
  118. const int end_block_idx =
  119. MIN(start_block_idx + num_blocks_per_partition, num_context_blocks);
  120. const int num_blocks = end_block_idx - start_block_idx;
  121. // [start_token_idx, end_token_idx) is the range of tokens to process.
  122. const int start_token_idx = start_block_idx * BLOCK_SIZE;
  123. const int end_token_idx =
  124. MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len);
  125. const int num_tokens = end_token_idx - start_token_idx;
  126. constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
  127. constexpr int NUM_THREAD_GROUPS =
  128. NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE
  129. // divides NUM_THREADS
  130. assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
  131. constexpr int NUM_TOKENS_PER_THREAD_GROUP =
  132. DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
  133. constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  134. const int thread_idx = threadIdx.x;
  135. const int warp_idx = thread_idx / WARP_SIZE;
  136. const int lane = thread_idx % WARP_SIZE;
  137. const int head_idx = blockIdx.x;
  138. const int num_heads = gridDim.x;
  139. const int num_queries_per_kv = num_heads / num_kv_heads;
  140. const int kv_head_idx = head_idx / num_queries_per_kv;
  141. const float alibi_slope =
  142. alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
  143. // A vector type to store a part of a key or a query.
  144. // The vector size is configured in such a way that the threads in a thread
  145. // group fetch or compute 16 bytes at a time. For example, if the size of a
  146. // thread group is 4 and the data type is half, then the vector size is 16 /
  147. // (4 * sizeof(half)) == 2.
  148. constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
  149. using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
  150. using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
  151. using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
  152. constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
  153. constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
  154. const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
  155. const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
  156. // Load the query to registers.
  157. // Each thread in a thread group has a different part of the query.
  158. // For example, if the the thread group size is 4, then the first thread in
  159. // the group has 0, 4, 8, ... th vectors of the query, and the second thread
  160. // has 1, 5, 9, ... th vectors of the query, and so on. NOTE: Because q is
  161. // split from a qkv tensor, it may not be contiguous.
  162. const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
  163. __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
  164. #pragma unroll
  165. for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
  166. i += NUM_THREAD_GROUPS) {
  167. const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
  168. q_vecs[thread_group_offset][i] =
  169. *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
  170. }
  171. __syncthreads(); // TODO: possible speedup if this is replaced with a memory
  172. // wall right before we use q_vecs
  173. // Memory planning.
  174. extern __shared__ char shared_mem[];
  175. // NOTE: We use FP32 for the softmax logits for better accuracy.
  176. float* logits = reinterpret_cast<float*>(shared_mem);
  177. // Workspace for reduction.
  178. __shared__ float red_smem[2 * NUM_WARPS];
  179. // x == THREAD_GROUP_SIZE * VEC_SIZE
  180. // Each thread group fetches x elements from the key at a time.
  181. constexpr int x = 16 / sizeof(cache_t);
  182. float qk_max = -FLT_MAX;
  183. // Iterate over the key blocks.
  184. // Each warp fetches a block of keys for each iteration.
  185. // Each thread group in a warp fetches a key from the block, and computes
  186. // dot product with the query.
  187. const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
  188. for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
  189. block_idx += NUM_WARPS) {
  190. // NOTE: The block number is stored in int32. However, we cast it to int64
  191. // because int32 can lead to overflow when this variable is multiplied by
  192. // large numbers (e.g., kv_block_stride).
  193. const int64_t physical_block_number =
  194. static_cast<int64_t>(block_table[block_idx]);
  195. // Load a key to registers.
  196. // Each thread in a thread group has a different part of the key.
  197. // For example, if the the thread group size is 4, then the first thread in
  198. // the group has 0, 4, 8, ... th vectors of the key, and the second thread
  199. // has 1, 5, 9, ... th vectors of the key, and so on.
  200. for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
  201. const int physical_block_offset =
  202. (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
  203. const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
  204. K_vec k_vecs[NUM_VECS_PER_THREAD];
  205. #pragma unroll
  206. for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
  207. const cache_t* k_ptr =
  208. k_cache + physical_block_number * kv_block_stride +
  209. kv_head_idx * kv_head_stride + physical_block_offset * x;
  210. const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
  211. const int offset1 = (vec_idx * VEC_SIZE) / x;
  212. const int offset2 = (vec_idx * VEC_SIZE) % x;
  213. if constexpr (KV_CACHE_DTYPE == INT8) {
  214. Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
  215. k_ptr + offset1 * BLOCK_SIZE * x + offset2);
  216. using Dequant_vec = typename FloatVec<Quant_vec>::Type;
  217. Dequant_vec k_vec_dequant = int8::dequant(k_vec_quant, k_scale, k_zp);
  218. k_vecs[j] = int8::vec_conversion<K_vec, Dequant_vec>(k_vec_dequant);
  219. #ifdef ENABLE_FP8_E5M2
  220. } else if constexpr (KV_CACHE_DTYPE == FP8_E5M2) {
  221. Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
  222. k_ptr + offset1 * BLOCK_SIZE * x + offset2);
  223. // Vector conversion from Quant_vec to K_vec.
  224. k_vecs[j] =
  225. fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
  226. #endif
  227. } else {
  228. k_vecs[j] = *reinterpret_cast<const K_vec*>(
  229. k_ptr + offset1 * BLOCK_SIZE * x + offset2);
  230. }
  231. }
  232. // Compute dot product.
  233. // This includes a reduction across the threads in the same thread group.
  234. float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(
  235. q_vecs[thread_group_offset], k_vecs);
  236. // Add the ALiBi bias if slopes are given.
  237. qk +=
  238. (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
  239. if (thread_group_offset == 0) {
  240. // Store the partial reductions to shared memory.
  241. // NOTE: It is required to zero out the masked logits.
  242. const bool mask = token_idx >= context_len;
  243. logits[token_idx - start_token_idx] = mask ? 0.f : qk;
  244. // Update the max value.
  245. qk_max = mask ? qk_max : fmaxf(qk_max, qk);
  246. }
  247. }
  248. }
  249. // Perform reduction across the threads in the same warp to get the
  250. // max qk value for each "warp" (not across the thread block yet).
  251. // The 0-th thread of each thread group already has its max qk value.
  252. #pragma unroll
  253. for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
  254. qk_max = fmaxf(qk_max, APHRODITE_SHFL_XOR_SYNC(qk_max, mask));
  255. }
  256. if (lane == 0) {
  257. red_smem[warp_idx] = qk_max;
  258. }
  259. __syncthreads();
  260. // TODO: Refactor this part.
  261. // Get the max qk value for the sequence.
  262. qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
  263. #pragma unroll
  264. for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
  265. qk_max = fmaxf(qk_max, APHRODITE_SHFL_XOR_SYNC(qk_max, mask));
  266. }
  267. // Broadcast the max qk value to all threads.
  268. qk_max = APHRODITE_SHFL_SYNC(qk_max, 0);
  269. // Get the sum of the exp values.
  270. float exp_sum = 0.f;
  271. for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
  272. float val = __expf(logits[i] - qk_max);
  273. logits[i] = val;
  274. exp_sum += val;
  275. }
  276. exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
  277. // Compute softmax.
  278. const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
  279. for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
  280. logits[i] *= inv_sum;
  281. }
  282. __syncthreads();
  283. // If partitioning is enabled, store the max logit and exp_sum.
  284. if (USE_PARTITIONING && thread_idx == 0) {
  285. float* max_logits_ptr = max_logits +
  286. seq_idx * num_heads * max_num_partitions +
  287. head_idx * max_num_partitions + partition_idx;
  288. *max_logits_ptr = qk_max;
  289. float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions +
  290. head_idx * max_num_partitions + partition_idx;
  291. *exp_sums_ptr = exp_sum;
  292. }
  293. // Each thread will fetch 16 bytes from the value cache at a time.
  294. constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
  295. using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
  296. using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
  297. using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
  298. using Float_L_vec = typename FloatVec<L_vec>::Type;
  299. constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
  300. constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
  301. constexpr int NUM_ROWS_PER_THREAD =
  302. DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
  303. // NOTE: We use FP32 for the accumulator for better accuracy.
  304. float accs[NUM_ROWS_PER_THREAD];
  305. #pragma unroll
  306. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  307. accs[i] = 0.f;
  308. }
  309. scalar_t zero_value;
  310. zero(zero_value);
  311. for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
  312. block_idx += NUM_WARPS) {
  313. // NOTE: The block number is stored in int32. However, we cast it to int64
  314. // because int32 can lead to overflow when this variable is multiplied by
  315. // large numbers (e.g., kv_block_stride).
  316. const int64_t physical_block_number =
  317. static_cast<int64_t>(block_table[block_idx]);
  318. const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
  319. const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
  320. L_vec logits_vec;
  321. from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx -
  322. start_token_idx));
  323. const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
  324. kv_head_idx * kv_head_stride;
  325. #pragma unroll
  326. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  327. const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
  328. if (row_idx < HEAD_SIZE) {
  329. const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
  330. V_vec v_vec;
  331. if constexpr (KV_CACHE_DTYPE == INT8) {
  332. // dequant and conversion
  333. V_quant_vec v_vec_quant =
  334. *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
  335. using V_dequant_vec = typename FloatVec<V_quant_vec>::Type;
  336. V_dequant_vec v_vec_dequant =
  337. int8::dequant(v_vec_quant, v_scale, v_zp);
  338. v_vec = int8::vec_conversion<V_vec, V_dequant_vec>(v_vec_dequant);
  339. #ifdef ENABLE_FP8_E5M2
  340. } else if constexpr (KV_CACHE_DTYPE == FP8_E5M2) {
  341. V_quant_vec v_quant_vec =
  342. *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
  343. // Vector conversion from V_quant_vec to V_vec.
  344. v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(
  345. v_quant_vec);
  346. #endif
  347. } else {
  348. v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
  349. }
  350. if (block_idx == num_context_blocks - 1) {
  351. // NOTE: When v_vec contains the tokens that are out of the context,
  352. // we should explicitly zero out the values since they may contain
  353. // NaNs.
  354. scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
  355. #pragma unroll
  356. for (int j = 0; j < V_VEC_SIZE; j++) {
  357. v_vec_ptr[j] =
  358. token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
  359. }
  360. }
  361. accs[i] += dot(logits_vec, v_vec);
  362. }
  363. }
  364. }
  365. // Perform reduction within each warp.
  366. #pragma unroll
  367. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  368. float acc = accs[i];
  369. #pragma unroll
  370. for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
  371. acc += APHRODITE_SHFL_XOR_SYNC(acc, mask);
  372. }
  373. accs[i] = acc;
  374. }
  375. // NOTE: A barrier is required because the shared memory space for logits
  376. // is reused for the output.
  377. __syncthreads();
  378. // Perform reduction across warps.
  379. float* out_smem = reinterpret_cast<float*>(shared_mem);
  380. #pragma unroll
  381. for (int i = NUM_WARPS; i > 1; i /= 2) {
  382. int mid = i / 2;
  383. // Upper warps write to shared memory.
  384. if (warp_idx >= mid && warp_idx < i) {
  385. float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
  386. #pragma unroll
  387. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  388. const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
  389. if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
  390. dst[row_idx] = accs[i];
  391. }
  392. }
  393. }
  394. __syncthreads();
  395. // Lower warps update the output.
  396. if (warp_idx < mid) {
  397. const float* src = &out_smem[warp_idx * HEAD_SIZE];
  398. #pragma unroll
  399. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  400. const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
  401. if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
  402. accs[i] += src[row_idx];
  403. }
  404. }
  405. }
  406. __syncthreads();
  407. }
  408. // Write the final output.
  409. if (warp_idx == 0) {
  410. scalar_t* out_ptr =
  411. out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
  412. head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
  413. #pragma unroll
  414. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  415. const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
  416. if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
  417. from_float(*(out_ptr + row_idx), accs[i]);
  418. }
  419. }
  420. }
  421. }
  422. // Grid: (num_heads, num_seqs, 1).
  423. template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
  424. int NUM_THREADS,
  425. kv_cache_dtype KV_CACHE_DTYPE>
  426. __global__ void paged_attention_v1_kernel(
  427. scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
  428. const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
  429. const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
  430. // head_size/x, block_size, x]
  431. const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
  432. // head_size, block_size]
  433. const int num_kv_heads, // [num_heads]
  434. const float scale,
  435. const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
  436. const int* __restrict__ context_lens, // [num_seqs]
  437. const int max_num_blocks_per_seq,
  438. const float* __restrict__ alibi_slopes, // [num_heads]
  439. const int q_stride, const int kv_block_stride, const int kv_head_stride,
  440. const float k_scale, const float k_zp, const float v_scale,
  441. const float v_zp) {
  442. paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
  443. KV_CACHE_DTYPE>(
  444. /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
  445. v_cache, num_kv_heads, scale, block_tables, context_lens,
  446. max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
  447. kv_head_stride, k_scale, k_zp, v_scale, v_zp);
  448. }
  449. // Grid: (num_heads, num_seqs, max_num_partitions).
  450. template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
  451. int NUM_THREADS, kv_cache_dtype KV_CACHE_DTYPE,
  452. int PARTITION_SIZE>
  453. __global__ void paged_attention_v2_kernel(
  454. float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
  455. float* __restrict__ max_logits, // [num_seqs, num_heads,
  456. // max_num_partitions]
  457. scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
  458. // max_num_partitions, head_size]
  459. const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
  460. const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
  461. // head_size/x, block_size, x]
  462. const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
  463. // head_size, block_size]
  464. const int num_kv_heads, // [num_heads]
  465. const float scale,
  466. const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
  467. const int* __restrict__ context_lens, // [num_seqs]
  468. const int max_num_blocks_per_seq,
  469. const float* __restrict__ alibi_slopes, // [num_heads]
  470. const int q_stride, const int kv_block_stride, const int kv_head_stride,
  471. const float k_scale, const float k_zp, const float v_scale,
  472. const float v_zp) {
  473. paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
  474. KV_CACHE_DTYPE, PARTITION_SIZE>(
  475. exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
  476. block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
  477. q_stride, kv_block_stride, kv_head_stride, k_scale, k_zp, v_scale, v_zp);
  478. }
  479. // Grid: (num_heads, num_seqs).
  480. template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
  481. int PARTITION_SIZE>
  482. __global__ void paged_attention_v2_reduce_kernel(
  483. scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
  484. const float* __restrict__ exp_sums, // [num_seqs, num_heads,
  485. // max_num_partitions]
  486. const float* __restrict__ max_logits, // [num_seqs, num_heads,
  487. // max_num_partitions]
  488. const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
  489. // max_num_partitions, head_size]
  490. const int* __restrict__ context_lens, // [num_seqs]
  491. const int max_num_partitions) {
  492. const int num_heads = gridDim.x;
  493. const int head_idx = blockIdx.x;
  494. const int seq_idx = blockIdx.y;
  495. const int context_len = context_lens[seq_idx];
  496. const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
  497. if (num_partitions == 1) {
  498. // No need to reduce. Only copy tmp_out to out.
  499. scalar_t* out_ptr =
  500. out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
  501. const scalar_t* tmp_out_ptr =
  502. tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
  503. head_idx * max_num_partitions * HEAD_SIZE;
  504. for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
  505. out_ptr[i] = tmp_out_ptr[i];
  506. }
  507. // Terminate the thread block.
  508. return;
  509. }
  510. constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  511. const int warp_idx = threadIdx.x / WARP_SIZE;
  512. const int lane = threadIdx.x % WARP_SIZE;
  513. // Size: 2 * num_partitions.
  514. extern __shared__ char shared_mem[];
  515. // Workspace for reduction.
  516. __shared__ float red_smem[2 * NUM_WARPS];
  517. // Load max logits to shared memory.
  518. float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
  519. const float* max_logits_ptr = max_logits +
  520. seq_idx * num_heads * max_num_partitions +
  521. head_idx * max_num_partitions;
  522. float max_logit = -FLT_MAX;
  523. for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
  524. const float l = max_logits_ptr[i];
  525. shared_max_logits[i] = l;
  526. max_logit = fmaxf(max_logit, l);
  527. }
  528. __syncthreads();
  529. // Get the global max logit.
  530. // Reduce within the warp.
  531. #pragma unroll
  532. for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
  533. max_logit = fmaxf(max_logit, APHRODITE_SHFL_XOR_SYNC(max_logit, mask));
  534. }
  535. if (lane == 0) {
  536. red_smem[warp_idx] = max_logit;
  537. }
  538. __syncthreads();
  539. // Reduce across warps.
  540. max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
  541. #pragma unroll
  542. for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
  543. max_logit = fmaxf(max_logit, APHRODITE_SHFL_XOR_SYNC(max_logit, mask));
  544. }
  545. // Broadcast the max value to all threads.
  546. max_logit = APHRODITE_SHFL_SYNC(max_logit, 0);
  547. // Load rescaled exp sums to shared memory.
  548. float* shared_exp_sums =
  549. reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
  550. const float* exp_sums_ptr = exp_sums +
  551. seq_idx * num_heads * max_num_partitions +
  552. head_idx * max_num_partitions;
  553. float global_exp_sum = 0.0f;
  554. for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
  555. float l = shared_max_logits[i];
  556. float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
  557. global_exp_sum += rescaled_exp_sum;
  558. shared_exp_sums[i] = rescaled_exp_sum;
  559. }
  560. __syncthreads();
  561. global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
  562. const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
  563. // Aggregate tmp_out to out.
  564. const scalar_t* tmp_out_ptr =
  565. tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
  566. head_idx * max_num_partitions * HEAD_SIZE;
  567. scalar_t* out_ptr =
  568. out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
  569. #pragma unroll
  570. for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
  571. float acc = 0.0f;
  572. for (int j = 0; j < num_partitions; ++j) {
  573. acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] *
  574. inv_global_exp_sum;
  575. }
  576. from_float(out_ptr[i], acc);
  577. }
  578. }
  579. } // namespace aphrodite
  580. #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
  581. APHRODITE_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
  582. ((void*)aphrodite::paged_attention_v1_kernel< \
  583. T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_CACHE_DTYPE>), \
  584. shared_mem_size); \
  585. aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
  586. NUM_THREADS, KV_CACHE_DTYPE> \
  587. <<<grid, block, shared_mem_size, stream>>>( \
  588. out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
  589. scale, block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \
  590. alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
  591. k_scale, k_zp, v_scale, v_zp);
  592. // TODO: Tune NUM_THREADS.
  593. template <typename T, typename CACHE_T, int BLOCK_SIZE,
  594. kv_cache_dtype KV_CACHE_DTYPE, int NUM_THREADS = 128>
  595. void paged_attention_v1_launcher(
  596. torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
  597. torch::Tensor& value_cache, int num_kv_heads, float scale,
  598. torch::Tensor& block_tables, torch::Tensor& context_lens,
  599. int max_context_len, const c10::optional<torch::Tensor>& alibi_slopes,
  600. const float k_scale, const float k_zp, const float v_scale,
  601. const float v_zp) {
  602. int num_seqs = query.size(0);
  603. int num_heads = query.size(1);
  604. int head_size = query.size(2);
  605. int max_num_blocks_per_seq = block_tables.size(1);
  606. int q_stride = query.stride(0);
  607. int kv_block_stride = key_cache.stride(0);
  608. int kv_head_stride = key_cache.stride(1);
  609. int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
  610. assert(head_size % thread_group_size == 0);
  611. // NOTE: alibi_slopes is optional.
  612. const float* alibi_slopes_ptr =
  613. alibi_slopes
  614. ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
  615. : nullptr;
  616. T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
  617. T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
  618. CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
  619. CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
  620. int* block_tables_ptr = block_tables.data_ptr<int>();
  621. int* context_lens_ptr = context_lens.data_ptr<int>();
  622. constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  623. int padded_max_context_len =
  624. DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE;
  625. int logits_size = padded_max_context_len * sizeof(float);
  626. int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
  627. // Python-side check in
  628. // aphrodite.worker.worker._check_if_can_support_max_seq_len Keep that in sync
  629. // with the logic here!
  630. int shared_mem_size = std::max(logits_size, outputs_size);
  631. dim3 grid(num_heads, num_seqs, 1);
  632. dim3 block(NUM_THREADS);
  633. const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
  634. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  635. switch (head_size) {
  636. // NOTE: To reduce the compilation time, we only compile for the
  637. // head sizes that we use in the model. However, we can easily extend this
  638. // to support any head size which is a multiple of 16.
  639. case 64:
  640. LAUNCH_PAGED_ATTENTION_V1(64);
  641. break;
  642. case 80:
  643. LAUNCH_PAGED_ATTENTION_V1(80);
  644. break;
  645. case 96:
  646. LAUNCH_PAGED_ATTENTION_V1(96);
  647. break;
  648. case 112:
  649. LAUNCH_PAGED_ATTENTION_V1(112);
  650. break;
  651. case 128:
  652. LAUNCH_PAGED_ATTENTION_V1(128);
  653. break;
  654. case 256:
  655. LAUNCH_PAGED_ATTENTION_V1(256);
  656. break;
  657. default:
  658. TORCH_CHECK(false, "Unsupported head size: ", head_size);
  659. break;
  660. }
  661. }
  662. #define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE) \
  663. paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE>( \
  664. out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
  665. context_lens, max_context_len, alibi_slopes, k_scale, k_zp, v_scale, \
  666. v_zp);
  667. // NOTE: To reduce the compilation time, we omitted block sizes
  668. // 1, 2, 4, 64, 128, 256.
  669. #define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_CACHE_DTYPE) \
  670. switch (block_size) { \
  671. case 8: \
  672. CALL_V1_LAUNCHER(T, CACHE_T, 8, KV_CACHE_DTYPE); \
  673. break; \
  674. case 16: \
  675. CALL_V1_LAUNCHER(T, CACHE_T, 16, KV_CACHE_DTYPE); \
  676. break; \
  677. case 32: \
  678. CALL_V1_LAUNCHER(T, CACHE_T, 32, KV_CACHE_DTYPE); \
  679. break; \
  680. default: \
  681. TORCH_CHECK(false, "Unsupported block size: ", block_size); \
  682. break; \
  683. }
  684. void paged_attention_v1(
  685. torch::Tensor& out, // [num_seqs, num_heads, head_size]
  686. torch::Tensor& query, // [num_seqs, num_heads, head_size]
  687. torch::Tensor&
  688. key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  689. torch::Tensor&
  690. value_cache, // [num_blocks, num_heads, head_size, block_size]
  691. int num_kv_heads, // [num_heads]
  692. float scale,
  693. torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
  694. torch::Tensor& context_lens, // [num_seqs]
  695. int block_size, int max_context_len,
  696. const c10::optional<torch::Tensor>& alibi_slopes,
  697. const std::string& kv_cache_dtype, const float k_scale = 1.0f,
  698. const float k_zp = 0.0f, const float v_scale = 1.0f,
  699. const float v_zp = 0.0f) {
  700. if (kv_cache_dtype == "auto") {
  701. if (query.dtype() == at::ScalarType::Float) {
  702. CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, AUTO);
  703. } else if (query.dtype() == at::ScalarType::Half) {
  704. CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, AUTO);
  705. } else if (query.dtype() == at::ScalarType::BFloat16) {
  706. CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, AUTO);
  707. } else {
  708. TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
  709. }
  710. #ifdef ENABLE_FP8_E5M2
  711. } else if (kv_cache_dtype == "fp8_e5m2") {
  712. if (query.dtype() == at::ScalarType::Float) {
  713. CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, FP8_E5M2);
  714. } else if (query.dtype() == at::ScalarType::Half) {
  715. CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, FP8_E5M2);
  716. } else if (query.dtype() == at::ScalarType::BFloat16) {
  717. CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, FP8_E5M2);
  718. } else {
  719. TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
  720. }
  721. #endif
  722. } else if (kv_cache_dtype == "int8") {
  723. if (query.dtype() == at::ScalarType::Float) {
  724. CALL_V1_LAUNCHER_BLOCK_SIZE(float, int8_t, INT8);
  725. } else if (query.dtype() == at::ScalarType::Half) {
  726. CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, int8_t, INT8);
  727. } else if (query.dtype() == at::ScalarType::BFloat16) {
  728. CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, int8_t, INT8);
  729. } else {
  730. TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
  731. }
  732. } else {
  733. TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
  734. }
  735. }
  736. #define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
  737. aphrodite::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
  738. NUM_THREADS, KV_CACHE_DTYPE, \
  739. PARTITION_SIZE> \
  740. <<<grid, block, shared_mem_size, stream>>>( \
  741. exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
  742. value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
  743. context_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, \
  744. q_stride, kv_block_stride, kv_head_stride, k_scale, k_zp, v_scale, \
  745. v_zp); \
  746. aphrodite::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
  747. PARTITION_SIZE> \
  748. <<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
  749. out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \
  750. context_lens_ptr, max_num_partitions);
  751. template <typename T, typename CACHE_T, int BLOCK_SIZE,
  752. kv_cache_dtype KV_CACHE_DTYPE, int NUM_THREADS = 128,
  753. int PARTITION_SIZE = 512>
  754. void paged_attention_v2_launcher(
  755. torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
  756. torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
  757. torch::Tensor& value_cache, int num_kv_heads, float scale,
  758. torch::Tensor& block_tables, torch::Tensor& context_lens,
  759. int max_context_len, const c10::optional<torch::Tensor>& alibi_slopes,
  760. const float k_scale, const float k_zp, const float v_scale,
  761. const float v_zp) {
  762. int num_seqs = query.size(0);
  763. int num_heads = query.size(1);
  764. int head_size = query.size(2);
  765. int max_num_blocks_per_seq = block_tables.size(1);
  766. int q_stride = query.stride(0);
  767. int kv_block_stride = key_cache.stride(0);
  768. int kv_head_stride = key_cache.stride(1);
  769. int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
  770. assert(head_size % thread_group_size == 0);
  771. // NOTE: alibi_slopes is optional.
  772. const float* alibi_slopes_ptr =
  773. alibi_slopes
  774. ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
  775. : nullptr;
  776. T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
  777. float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
  778. float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
  779. T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
  780. T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
  781. CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
  782. CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
  783. int* block_tables_ptr = block_tables.data_ptr<int>();
  784. int* context_lens_ptr = context_lens.data_ptr<int>();
  785. constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  786. int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
  787. int logits_size = PARTITION_SIZE * sizeof(float);
  788. int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
  789. // For paged attention v2 kernel.
  790. dim3 grid(num_heads, num_seqs, max_num_partitions);
  791. int shared_mem_size = std::max(logits_size, outputs_size);
  792. // For paged attention v2 reduce kernel.
  793. dim3 reduce_grid(num_heads, num_seqs);
  794. int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
  795. dim3 block(NUM_THREADS);
  796. const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
  797. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  798. switch (head_size) {
  799. // NOTE: To reduce the compilation time, we only compile for the
  800. // head sizes that we use in the model. However, we can easily extend this
  801. // to support any head size which is a multiple of 16.
  802. case 64:
  803. LAUNCH_PAGED_ATTENTION_V2(64);
  804. break;
  805. case 80:
  806. LAUNCH_PAGED_ATTENTION_V2(80);
  807. break;
  808. case 96:
  809. LAUNCH_PAGED_ATTENTION_V2(96);
  810. break;
  811. case 112:
  812. LAUNCH_PAGED_ATTENTION_V2(112);
  813. break;
  814. case 128:
  815. LAUNCH_PAGED_ATTENTION_V2(128);
  816. break;
  817. case 256:
  818. LAUNCH_PAGED_ATTENTION_V2(256);
  819. break;
  820. default:
  821. TORCH_CHECK(false, "Unsupported head size: ", head_size);
  822. break;
  823. }
  824. }
  825. #define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE) \
  826. paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE>( \
  827. out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
  828. num_kv_heads, scale, block_tables, context_lens, max_context_len, \
  829. alibi_slopes, k_scale, k_zp, v_scale, v_zp);
  830. // NOTE: To reduce the compilation time, we omitted block sizes
  831. // 1, 2, 4, 64, 128, 256.
  832. #define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_CACHE_DTYPE) \
  833. switch (block_size) { \
  834. case 8: \
  835. CALL_V2_LAUNCHER(T, CACHE_T, 8, KV_CACHE_DTYPE); \
  836. break; \
  837. case 16: \
  838. CALL_V2_LAUNCHER(T, CACHE_T, 16, KV_CACHE_DTYPE); \
  839. break; \
  840. case 32: \
  841. CALL_V2_LAUNCHER(T, CACHE_T, 32, KV_CACHE_DTYPE); \
  842. break; \
  843. default: \
  844. TORCH_CHECK(false, "Unsupported block size: ", block_size); \
  845. break; \
  846. }
  847. void paged_attention_v2(
  848. torch::Tensor& out, // [num_seqs, num_heads, head_size]
  849. torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
  850. torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
  851. torch::Tensor&
  852. tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
  853. torch::Tensor& query, // [num_seqs, num_heads, head_size]
  854. torch::Tensor&
  855. key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  856. torch::Tensor&
  857. value_cache, // [num_blocks, num_heads, head_size, block_size]
  858. int num_kv_heads, // [num_heads]
  859. float scale,
  860. torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
  861. torch::Tensor& context_lens, // [num_seqs]
  862. int block_size, int max_context_len,
  863. const c10::optional<torch::Tensor>& alibi_slopes,
  864. const std::string& kv_cache_dtype, const float k_scale = 1.0f,
  865. const float k_zp = 0.0f, const float v_scale = 1.0f,
  866. const float v_zp = 0.0f) {
  867. if (kv_cache_dtype == "auto") {
  868. if (query.dtype() == at::ScalarType::Float) {
  869. CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, AUTO);
  870. } else if (query.dtype() == at::ScalarType::Half) {
  871. CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, AUTO);
  872. } else if (query.dtype() == at::ScalarType::BFloat16) {
  873. CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, AUTO);
  874. } else {
  875. TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
  876. }
  877. #ifdef ENABLE_FP8_E5M2
  878. } else if (kv_cache_dtype == "fp8_e5m2") {
  879. if (query.dtype() == at::ScalarType::Float) {
  880. CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, FP8_E5M2);
  881. } else if (query.dtype() == at::ScalarType::Half) {
  882. CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, FP8_E5M2);
  883. } else if (query.dtype() == at::ScalarType::BFloat16) {
  884. CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, FP8_E5M2);
  885. } else {
  886. TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
  887. }
  888. #endif
  889. } else if (kv_cache_dtype == "int8") {
  890. if (query.dtype() == at::ScalarType::Float) {
  891. CALL_V2_LAUNCHER_BLOCK_SIZE(float, int8_t, INT8);
  892. } else if (query.dtype() == at::ScalarType::Half) {
  893. CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, int8_t, INT8);
  894. } else if (query.dtype() == at::ScalarType::BFloat16) {
  895. CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, int8_t, INT8);
  896. } else {
  897. TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
  898. }
  899. } else {
  900. TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
  901. }
  902. }
  903. #undef WARP_SIZE
  904. #undef MAX
  905. #undef MIN
  906. #undef DIVIDE_ROUND_UP