1
0

attention_kernels.cu 46 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032
  1. /*
  2. * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
  3. * Copyright (c) 2023, The PygmalionAI team.
  4. * Copyright (c) 2023, The vLLM team.
  5. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
  6. *
  7. * Licensed under the Apache License, Version 2.0 (the "License");
  8. * you may not use this file except in compliance with the License.
  9. * You may obtain a copy of the License at
  10. *
  11. * http://www.apache.org/licenses/LICENSE-2.0
  12. *
  13. * Unless required by applicable law or agreed to in writing, software
  14. * distributed under the License is distributed on an "AS IS" BASIS,
  15. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. * See the License for the specific language governing permissions and
  17. * limitations under the License.
  18. */
  19. #ifdef USE_ROCM
  20. #include <hip/hip_runtime.h>
  21. #endif
  22. #include <torch/extension.h>
  23. #include <ATen/cuda/CUDAContext.h>
  24. #include <c10/cuda/CUDAGuard.h>
  25. #include "attention_dtypes.h"
  26. #include "attention_utils.cuh"
  27. #include "../quantization/int8_kvcache/quant_utils.cuh"
  28. #ifdef ENABLE_FP8_E5M2
  29. #include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
  30. #endif
  31. #include <algorithm>
  32. #ifndef USE_ROCM
  33. #define WARP_SIZE 32
  34. #else
  35. #define WARP_SIZE warpSize
  36. #endif
  37. #define MAX(a, b) ((a) > (b) ? (a) : (b))
  38. #define MIN(a, b) ((a) < (b) ? (a) : (b))
  39. #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
  40. enum kv_cache_dtype {
  41. AUTO,
  42. #ifdef ENABLE_FP8_E5M2
  43. FP8_E5M2,
  44. #endif
  45. INT8};
  46. namespace aphrodite {
  47. // Utility function for attention softmax.
  48. template<int NUM_WARPS>
  49. inline __device__ float block_sum(float* red_smem, float sum) {
  50. // Decompose the thread index into warp / lane.
  51. int warp = threadIdx.x / WARP_SIZE;
  52. int lane = threadIdx.x % WARP_SIZE;
  53. // Compute the sum per warp.
  54. #pragma unroll
  55. for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
  56. sum += APHRODITE_SHFL_XOR_SYNC(sum, mask);
  57. }
  58. // Warp leaders store the data to shared memory.
  59. if (lane == 0) {
  60. red_smem[warp] = sum;
  61. }
  62. // Make sure the data is in shared memory.
  63. __syncthreads();
  64. // The warps compute the final sums.
  65. if (lane < NUM_WARPS) {
  66. sum = red_smem[lane];
  67. }
  68. // Parallel reduction inside the warp.
  69. #pragma unroll
  70. for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
  71. sum += APHRODITE_SHFL_XOR_SYNC(sum, mask);
  72. }
  73. // Broadcast to other threads.
  74. return APHRODITE_SHFL_SYNC(sum, 0);
  75. }
  76. // TODO: Merge the last two dimensions of the grid.
  77. // Grid: (num_heads, num_seqs, max_num_partitions).
  78. template<
  79. typename scalar_t,
  80. typename cache_t,
  81. int HEAD_SIZE,
  82. int BLOCK_SIZE,
  83. int NUM_THREADS,
  84. kv_cache_dtype KV_CACHE_DTYPE,
  85. int PARTITION_SIZE = 0> // Zero means no partitioning.
  86. __device__ void paged_attention_kernel(
  87. float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
  88. float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
  89. scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
  90. const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
  91. const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
  92. const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
  93. const int num_kv_heads, // [num_heads]
  94. const float scale,
  95. const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
  96. const int* __restrict__ context_lens, // [num_seqs]
  97. const int max_num_blocks_per_seq,
  98. const float* __restrict__ alibi_slopes, // [num_heads]
  99. const int q_stride,
  100. const int kv_block_stride,
  101. const int kv_head_stride,
  102. const float k_scale = 1.0f,
  103. const float k_zp = 0.0f,
  104. const float v_scale = 1.0f,
  105. const float v_zp = 0.0f) {
  106. const int seq_idx = blockIdx.y;
  107. const int partition_idx = blockIdx.z;
  108. const int max_num_partitions = gridDim.z;
  109. constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
  110. const int context_len = context_lens[seq_idx];
  111. if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) {
  112. // No work to do. Terminate the thread block.
  113. return;
  114. }
  115. const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
  116. const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks;
  117. // [start_block_idx, end_block_idx) is the range of blocks to process.
  118. const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
  119. const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks);
  120. const int num_blocks = end_block_idx - start_block_idx;
  121. // [start_token_idx, end_token_idx) is the range of tokens to process.
  122. const int start_token_idx = start_block_idx * BLOCK_SIZE;
  123. const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len);
  124. const int num_tokens = end_token_idx - start_token_idx;
  125. constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
  126. constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
  127. assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
  128. constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
  129. constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  130. const int thread_idx = threadIdx.x;
  131. const int warp_idx = thread_idx / WARP_SIZE;
  132. const int lane = thread_idx % WARP_SIZE;
  133. const int head_idx = blockIdx.x;
  134. const int num_heads = gridDim.x;
  135. const int num_queries_per_kv = num_heads / num_kv_heads;
  136. const int kv_head_idx = head_idx / num_queries_per_kv;
  137. const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
  138. // A vector type to store a part of a key or a query.
  139. // The vector size is configured in such a way that the threads in a thread group
  140. // fetch or compute 16 bytes at a time.
  141. // For example, if the size of a thread group is 4 and the data type is half,
  142. // then the vector size is 16 / (4 * sizeof(half)) == 2.
  143. constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
  144. using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
  145. using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
  146. using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
  147. constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
  148. constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
  149. const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
  150. const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
  151. // Load the query to registers.
  152. // Each thread in a thread group has a different part of the query.
  153. // For example, if the the thread group size is 4, then the first thread in the group
  154. // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
  155. // th vectors of the query, and so on.
  156. // NOTE: Because q is split from a qkv tensor, it may not be contiguous.
  157. const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
  158. __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
  159. #pragma unroll
  160. for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
  161. const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
  162. q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
  163. }
  164. __syncthreads(); // TODO: possible speedup if this is replaced with a memory wall right before we use q_vecs
  165. // Memory planning.
  166. extern __shared__ char shared_mem[];
  167. // NOTE: We use FP32 for the softmax logits for better accuracy.
  168. float* logits = reinterpret_cast<float*>(shared_mem);
  169. // Workspace for reduction.
  170. __shared__ float red_smem[2 * NUM_WARPS];
  171. // x == THREAD_GROUP_SIZE * VEC_SIZE
  172. // Each thread group fetches x elements from the key at a time.
  173. constexpr int x = 16 / sizeof(cache_t);
  174. float qk_max = -FLT_MAX;
  175. // Iterate over the key blocks.
  176. // Each warp fetches a block of keys for each iteration.
  177. // Each thread group in a warp fetches a key from the block, and computes
  178. // dot product with the query.
  179. const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
  180. for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
  181. // NOTE: The block number is stored in int32. However, we cast it to int64
  182. // because int32 can lead to overflow when this variable is multiplied by large numbers
  183. // (e.g., kv_block_stride).
  184. const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
  185. // Load a key to registers.
  186. // Each thread in a thread group has a different part of the key.
  187. // For example, if the the thread group size is 4, then the first thread in the group
  188. // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
  189. // vectors of the key, and so on.
  190. for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
  191. const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
  192. const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
  193. K_vec k_vecs[NUM_VECS_PER_THREAD];
  194. #pragma unroll
  195. for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
  196. const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
  197. + kv_head_idx * kv_head_stride
  198. + physical_block_offset * x;
  199. const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
  200. const int offset1 = (vec_idx * VEC_SIZE) / x;
  201. const int offset2 = (vec_idx * VEC_SIZE) % x;
  202. if constexpr (KV_CACHE_DTYPE == INT8) {
  203. Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
  204. using Dequant_vec = typename FloatVec<Quant_vec>::Type;
  205. Dequant_vec k_vec_dequant = int8::dequant(k_vec_quant, k_scale, k_zp);
  206. k_vecs[j] = int8::vec_conversion<K_vec, Dequant_vec>(k_vec_dequant);
  207. #ifdef ENABLE_FP8_E5M2
  208. } else if constexpr (KV_CACHE_DTYPE == FP8_E5M2) {
  209. Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
  210. // Vector conversion from Quant_vec to K_vec.
  211. k_vecs[j] = fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
  212. #endif
  213. } else {
  214. k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
  215. }
  216. }
  217. // Compute dot product.
  218. // This includes a reduction across the threads in the same thread group.
  219. float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
  220. // Add the ALiBi bias if slopes are given.
  221. qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
  222. if (thread_group_offset == 0) {
  223. // Store the partial reductions to shared memory.
  224. // NOTE: It is required to zero out the masked logits.
  225. const bool mask = token_idx >= context_len;
  226. logits[token_idx - start_token_idx] = mask ? 0.f : qk;
  227. // Update the max value.
  228. qk_max = mask ? qk_max : fmaxf(qk_max, qk);
  229. }
  230. }
  231. }
  232. // Perform reduction across the threads in the same warp to get the
  233. // max qk value for each "warp" (not across the thread block yet).
  234. // The 0-th thread of each thread group already has its max qk value.
  235. #pragma unroll
  236. for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
  237. qk_max = fmaxf(qk_max, APHRODITE_SHFL_XOR_SYNC(qk_max, mask));
  238. }
  239. if (lane == 0) {
  240. red_smem[warp_idx] = qk_max;
  241. }
  242. __syncthreads();
  243. // TODO: Refactor this part.
  244. // Get the max qk value for the sequence.
  245. qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
  246. #pragma unroll
  247. for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
  248. qk_max = fmaxf(qk_max, APHRODITE_SHFL_XOR_SYNC(qk_max, mask));
  249. }
  250. // Broadcast the max qk value to all threads.
  251. qk_max = APHRODITE_SHFL_SYNC(qk_max, 0);
  252. // Get the sum of the exp values.
  253. float exp_sum = 0.f;
  254. for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
  255. float val = __expf(logits[i] - qk_max);
  256. logits[i] = val;
  257. exp_sum += val;
  258. }
  259. exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
  260. // Compute softmax.
  261. const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
  262. for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
  263. logits[i] *= inv_sum;
  264. }
  265. __syncthreads();
  266. // If partitioning is enabled, store the max logit and exp_sum.
  267. if (USE_PARTITIONING && thread_idx == 0) {
  268. float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
  269. + head_idx * max_num_partitions
  270. + partition_idx;
  271. *max_logits_ptr = qk_max;
  272. float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
  273. + head_idx * max_num_partitions
  274. + partition_idx;
  275. *exp_sums_ptr = exp_sum;
  276. }
  277. // Each thread will fetch 16 bytes from the value cache at a time.
  278. constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
  279. using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
  280. using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
  281. using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
  282. using Float_L_vec = typename FloatVec<L_vec>::Type;
  283. constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
  284. constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
  285. constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
  286. // NOTE: We use FP32 for the accumulator for better accuracy.
  287. float accs[NUM_ROWS_PER_THREAD];
  288. #pragma unroll
  289. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  290. accs[i] = 0.f;
  291. }
  292. scalar_t zero_value;
  293. zero(zero_value);
  294. for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
  295. // NOTE: The block number is stored in int32. However, we cast it to int64
  296. // because int32 can lead to overflow when this variable is multiplied by large numbers
  297. // (e.g., kv_block_stride).
  298. const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
  299. const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
  300. const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
  301. L_vec logits_vec;
  302. from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx));
  303. const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
  304. + kv_head_idx * kv_head_stride;
  305. #pragma unroll
  306. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  307. const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
  308. if (row_idx < HEAD_SIZE) {
  309. const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
  310. V_vec v_vec;
  311. if constexpr (KV_CACHE_DTYPE == INT8) {
  312. // dequant and conversion
  313. V_quant_vec v_vec_quant = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
  314. using V_dequant_vec = typename FloatVec<V_quant_vec>::Type;
  315. V_dequant_vec v_vec_dequant = int8::dequant(v_vec_quant, v_scale, v_zp);
  316. v_vec = int8::vec_conversion<V_vec, V_dequant_vec>(v_vec_dequant);
  317. #ifdef ENABLE_FP8_E5M2
  318. } else if constexpr (KV_CACHE_DTYPE == FP8_E5M2) {
  319. V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
  320. // Vector conversion from V_quant_vec to V_vec.
  321. v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(v_quant_vec);
  322. #endif
  323. } else {
  324. v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
  325. }
  326. if (block_idx == num_context_blocks - 1) {
  327. // NOTE: When v_vec contains the tokens that are out of the context,
  328. // we should explicitly zero out the values since they may contain NaNs.
  329. scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
  330. #pragma unroll
  331. for (int j = 0; j < V_VEC_SIZE; j++) {
  332. v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
  333. }
  334. }
  335. accs[i] += dot(logits_vec, v_vec);
  336. }
  337. }
  338. }
  339. // Perform reduction within each warp.
  340. #pragma unroll
  341. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  342. float acc = accs[i];
  343. #pragma unroll
  344. for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
  345. acc += APHRODITE_SHFL_XOR_SYNC(acc, mask);
  346. }
  347. accs[i] = acc;
  348. }
  349. // NOTE: A barrier is required because the shared memory space for logits
  350. // is reused for the output.
  351. __syncthreads();
  352. // Perform reduction across warps.
  353. float* out_smem = reinterpret_cast<float*>(shared_mem);
  354. #pragma unroll
  355. for (int i = NUM_WARPS; i > 1; i /= 2) {
  356. int mid = i / 2;
  357. // Upper warps write to shared memory.
  358. if (warp_idx >= mid && warp_idx < i) {
  359. float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
  360. #pragma unroll
  361. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  362. const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
  363. if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
  364. dst[row_idx] = accs[i];
  365. }
  366. }
  367. }
  368. __syncthreads();
  369. // Lower warps update the output.
  370. if (warp_idx < mid) {
  371. const float* src = &out_smem[warp_idx * HEAD_SIZE];
  372. #pragma unroll
  373. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  374. const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
  375. if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
  376. accs[i] += src[row_idx];
  377. }
  378. }
  379. }
  380. __syncthreads();
  381. }
  382. // Write the final output.
  383. if (warp_idx == 0) {
  384. scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
  385. + head_idx * max_num_partitions * HEAD_SIZE
  386. + partition_idx * HEAD_SIZE;
  387. #pragma unroll
  388. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  389. const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
  390. if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
  391. from_float(*(out_ptr + row_idx), accs[i]);
  392. }
  393. }
  394. }
  395. }
  396. // Grid: (num_heads, num_seqs, 1).
  397. template<
  398. typename scalar_t,
  399. typename cache_t,
  400. int HEAD_SIZE,
  401. int BLOCK_SIZE,
  402. int NUM_THREADS,
  403. kv_cache_dtype KV_CACHE_DTYPE>
  404. __global__ void paged_attention_v1_kernel(
  405. scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
  406. const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
  407. const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
  408. const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
  409. const int num_kv_heads, // [num_heads]
  410. const float scale,
  411. const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
  412. const int* __restrict__ context_lens, // [num_seqs]
  413. const int max_num_blocks_per_seq,
  414. const float* __restrict__ alibi_slopes, // [num_heads]
  415. const int q_stride,
  416. const int kv_block_stride,
  417. const int kv_head_stride,
  418. const float k_scale,
  419. const float k_zp,
  420. const float v_scale,
  421. const float v_zp) {
  422. paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_CACHE_DTYPE>(
  423. /* exp_sums */ nullptr, /* max_logits */ nullptr,
  424. out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
  425. max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, k_scale, k_zp, v_scale, v_zp);
  426. }
  427. // Grid: (num_heads, num_seqs, max_num_partitions).
  428. template<
  429. typename scalar_t,
  430. typename cache_t,
  431. int HEAD_SIZE,
  432. int BLOCK_SIZE,
  433. int NUM_THREADS,
  434. kv_cache_dtype KV_CACHE_DTYPE,
  435. int PARTITION_SIZE>
  436. __global__ void paged_attention_v2_kernel(
  437. float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
  438. float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
  439. scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
  440. const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
  441. const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
  442. const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
  443. const int num_kv_heads, // [num_heads]
  444. const float scale,
  445. const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
  446. const int* __restrict__ context_lens, // [num_seqs]
  447. const int max_num_blocks_per_seq,
  448. const float* __restrict__ alibi_slopes, // [num_heads]
  449. const int q_stride,
  450. const int kv_block_stride,
  451. const int kv_head_stride,
  452. const float k_scale,
  453. const float k_zp,
  454. const float v_scale,
  455. const float v_zp) {
  456. paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_CACHE_DTYPE, PARTITION_SIZE>(
  457. exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
  458. block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
  459. q_stride, kv_block_stride, kv_head_stride, k_scale, k_zp, v_scale, v_zp);
  460. }
  461. // Grid: (num_heads, num_seqs).
  462. template<
  463. typename scalar_t,
  464. int HEAD_SIZE,
  465. int NUM_THREADS,
  466. int PARTITION_SIZE>
  467. __global__ void paged_attention_v2_reduce_kernel(
  468. scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
  469. const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
  470. const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
  471. const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
  472. const int* __restrict__ context_lens, // [num_seqs]
  473. const int max_num_partitions) {
  474. const int num_heads = gridDim.x;
  475. const int head_idx = blockIdx.x;
  476. const int seq_idx = blockIdx.y;
  477. const int context_len = context_lens[seq_idx];
  478. const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
  479. if (num_partitions == 1) {
  480. // No need to reduce. Only copy tmp_out to out.
  481. scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
  482. const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
  483. + head_idx * max_num_partitions * HEAD_SIZE;
  484. for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
  485. out_ptr[i] = tmp_out_ptr[i];
  486. }
  487. // Terminate the thread block.
  488. return;
  489. }
  490. constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  491. const int warp_idx = threadIdx.x / WARP_SIZE;
  492. const int lane = threadIdx.x % WARP_SIZE;
  493. // Size: 2 * num_partitions.
  494. extern __shared__ char shared_mem[];
  495. // Workspace for reduction.
  496. __shared__ float red_smem[2 * NUM_WARPS];
  497. // Load max logits to shared memory.
  498. float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
  499. const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
  500. + head_idx * max_num_partitions;
  501. float max_logit = -FLT_MAX;
  502. for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
  503. const float l = max_logits_ptr[i];
  504. shared_max_logits[i] = l;
  505. max_logit = fmaxf(max_logit, l);
  506. }
  507. __syncthreads();
  508. // Get the global max logit.
  509. // Reduce within the warp.
  510. #pragma unroll
  511. for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
  512. max_logit = fmaxf(max_logit, APHRODITE_SHFL_XOR_SYNC(max_logit, mask));
  513. }
  514. if (lane == 0) {
  515. red_smem[warp_idx] = max_logit;
  516. }
  517. __syncthreads();
  518. // Reduce across warps.
  519. max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
  520. #pragma unroll
  521. for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
  522. max_logit = fmaxf(max_logit, APHRODITE_SHFL_XOR_SYNC(max_logit, mask));
  523. }
  524. // Broadcast the max value to all threads.
  525. max_logit = APHRODITE_SHFL_SYNC(max_logit, 0);
  526. // Load rescaled exp sums to shared memory.
  527. float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
  528. const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
  529. + head_idx * max_num_partitions;
  530. float global_exp_sum = 0.0f;
  531. for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
  532. float l = shared_max_logits[i];
  533. float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
  534. global_exp_sum += rescaled_exp_sum;
  535. shared_exp_sums[i] = rescaled_exp_sum;
  536. }
  537. __syncthreads();
  538. global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
  539. const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
  540. // Aggregate tmp_out to out.
  541. const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
  542. + head_idx * max_num_partitions * HEAD_SIZE;
  543. scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
  544. #pragma unroll
  545. for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
  546. float acc = 0.0f;
  547. for (int j = 0; j < num_partitions; ++j) {
  548. acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum;
  549. }
  550. from_float(out_ptr[i], acc);
  551. }
  552. }
  553. } // namespace aphrodite
  554. #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
  555. APHRODITE_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
  556. ((void*)aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
  557. KV_CACHE_DTYPE>), shared_mem_size); \
  558. aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
  559. KV_CACHE_DTYPE><<<grid, block, shared_mem_size, stream>>>( \
  560. out_ptr, \
  561. query_ptr, \
  562. key_cache_ptr, \
  563. value_cache_ptr, \
  564. num_kv_heads, \
  565. scale, \
  566. block_tables_ptr, \
  567. context_lens_ptr, \
  568. max_num_blocks_per_seq, \
  569. alibi_slopes_ptr, \
  570. q_stride, \
  571. kv_block_stride, \
  572. kv_head_stride, \
  573. k_scale, \
  574. k_zp, \
  575. v_scale, \
  576. v_zp);
  577. // TODO: Tune NUM_THREADS.
  578. template<
  579. typename T,
  580. typename CACHE_T,
  581. int BLOCK_SIZE,
  582. kv_cache_dtype KV_CACHE_DTYPE,
  583. int NUM_THREADS = 128>
  584. void paged_attention_v1_launcher(
  585. torch::Tensor& out,
  586. torch::Tensor& query,
  587. torch::Tensor& key_cache,
  588. torch::Tensor& value_cache,
  589. int num_kv_heads,
  590. float scale,
  591. torch::Tensor& block_tables,
  592. torch::Tensor& context_lens,
  593. int max_context_len,
  594. const c10::optional<torch::Tensor>& alibi_slopes,
  595. const float k_scale,
  596. const float k_zp,
  597. const float v_scale,
  598. const float v_zp) {
  599. int num_seqs = query.size(0);
  600. int num_heads = query.size(1);
  601. int head_size = query.size(2);
  602. int max_num_blocks_per_seq = block_tables.size(1);
  603. int q_stride = query.stride(0);
  604. int kv_block_stride = key_cache.stride(0);
  605. int kv_head_stride = key_cache.stride(1);
  606. int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
  607. assert(head_size % thread_group_size == 0);
  608. // NOTE: alibi_slopes is optional.
  609. const float* alibi_slopes_ptr = alibi_slopes ?
  610. reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
  611. : nullptr;
  612. T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
  613. T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
  614. CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
  615. CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
  616. int* block_tables_ptr = block_tables.data_ptr<int>();
  617. int* context_lens_ptr = context_lens.data_ptr<int>();
  618. constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  619. int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE;
  620. int logits_size = padded_max_context_len * sizeof(float);
  621. int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
  622. // Python-side check in aphrodite.task_handler.worker._check_if_can_support_max_seq_len
  623. // Keep that in sync with the logic here!
  624. int shared_mem_size = std::max(logits_size, outputs_size);
  625. dim3 grid(num_heads, num_seqs, 1);
  626. dim3 block(NUM_THREADS);
  627. const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
  628. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  629. switch (head_size) {
  630. // NOTE: To reduce the compilation time, we only compile for the
  631. // head sizes that we use in the model. However, we can easily extend this
  632. // to support any head size which is a multiple of 16.
  633. case 64:
  634. LAUNCH_PAGED_ATTENTION_V1(64);
  635. break;
  636. case 80:
  637. LAUNCH_PAGED_ATTENTION_V1(80);
  638. break;
  639. case 96:
  640. LAUNCH_PAGED_ATTENTION_V1(96);
  641. break;
  642. case 112:
  643. LAUNCH_PAGED_ATTENTION_V1(112);
  644. break;
  645. case 128:
  646. LAUNCH_PAGED_ATTENTION_V1(128);
  647. break;
  648. case 256:
  649. LAUNCH_PAGED_ATTENTION_V1(256);
  650. break;
  651. default:
  652. TORCH_CHECK(false, "Unsupported head size: ", head_size);
  653. break;
  654. }
  655. }
  656. #define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE) \
  657. paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE>( \
  658. out, \
  659. query, \
  660. key_cache, \
  661. value_cache, \
  662. num_kv_heads, \
  663. scale, \
  664. block_tables, \
  665. context_lens, \
  666. max_context_len, \
  667. alibi_slopes, \
  668. k_scale, \
  669. k_zp, \
  670. v_scale, \
  671. v_zp);
  672. // NOTE: To reduce the compilation time, we omitted block sizes
  673. // 1, 2, 4, 64, 128, 256.
  674. #define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_CACHE_DTYPE) \
  675. switch (block_size) { \
  676. case 8: \
  677. CALL_V1_LAUNCHER(T, CACHE_T, 8, KV_CACHE_DTYPE); \
  678. break; \
  679. case 16: \
  680. CALL_V1_LAUNCHER(T, CACHE_T, 16, KV_CACHE_DTYPE); \
  681. break; \
  682. case 32: \
  683. CALL_V1_LAUNCHER(T, CACHE_T, 32, KV_CACHE_DTYPE); \
  684. break; \
  685. default: \
  686. TORCH_CHECK(false, "Unsupported block size: ", block_size); \
  687. break; \
  688. }
  689. void paged_attention_v1(
  690. torch::Tensor& out, // [num_seqs, num_heads, head_size]
  691. torch::Tensor& query, // [num_seqs, num_heads, head_size]
  692. torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  693. torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
  694. int num_kv_heads, // [num_heads]
  695. float scale,
  696. torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
  697. torch::Tensor& context_lens, // [num_seqs]
  698. int block_size,
  699. int max_context_len,
  700. const c10::optional<torch::Tensor>& alibi_slopes,
  701. const std::string& kv_cache_dtype,
  702. const float k_scale = 1.0f,
  703. const float k_zp = 0.0f,
  704. const float v_scale = 1.0f,
  705. const float v_zp = 0.0f) {
  706. if (kv_cache_dtype == "auto") {
  707. if (query.dtype() == at::ScalarType::Float) {
  708. CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, AUTO);
  709. } else if (query.dtype() == at::ScalarType::Half) {
  710. CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, AUTO);
  711. } else if (query.dtype() == at::ScalarType::BFloat16) {
  712. CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, AUTO);
  713. } else {
  714. TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
  715. }
  716. #ifdef ENABLE_FP8_E5M2
  717. } else if (kv_cache_dtype == "fp8_e5m2") {
  718. if (query.dtype() == at::ScalarType::Float) {
  719. CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, FP8_E5M2);
  720. } else if (query.dtype() == at::ScalarType::Half) {
  721. CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, FP8_E5M2);
  722. } else if (query.dtype() == at::ScalarType::BFloat16) {
  723. CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, FP8_E5M2);
  724. } else {
  725. TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
  726. }
  727. #endif
  728. } else if (kv_cache_dtype == "int8") {
  729. if (query.dtype() == at::ScalarType::Float) {
  730. CALL_V1_LAUNCHER_BLOCK_SIZE(float, int8_t, INT8);
  731. } else if (query.dtype() == at::ScalarType::Half) {
  732. CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, int8_t, INT8);
  733. } else if (query.dtype() == at::ScalarType::BFloat16) {
  734. CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, int8_t, INT8);
  735. } else {
  736. TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
  737. }
  738. } else {
  739. TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
  740. }
  741. }
  742. #define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
  743. aphrodite::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
  744. KV_CACHE_DTYPE, PARTITION_SIZE> \
  745. <<<grid, block, shared_mem_size, stream>>>( \
  746. exp_sums_ptr, \
  747. max_logits_ptr, \
  748. tmp_out_ptr, \
  749. query_ptr, \
  750. key_cache_ptr, \
  751. value_cache_ptr, \
  752. num_kv_heads, \
  753. scale, \
  754. block_tables_ptr, \
  755. context_lens_ptr, \
  756. max_num_blocks_per_seq, \
  757. alibi_slopes_ptr, \
  758. q_stride, \
  759. kv_block_stride, \
  760. kv_head_stride, \
  761. k_scale, \
  762. k_zp, \
  763. v_scale, \
  764. v_zp); \
  765. aphrodite::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE> \
  766. <<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
  767. out_ptr, \
  768. exp_sums_ptr, \
  769. max_logits_ptr, \
  770. tmp_out_ptr, \
  771. context_lens_ptr, \
  772. max_num_partitions);
  773. template<
  774. typename T,
  775. typename CACHE_T,
  776. int BLOCK_SIZE,
  777. kv_cache_dtype KV_CACHE_DTYPE,
  778. int NUM_THREADS = 128,
  779. int PARTITION_SIZE = 512>
  780. void paged_attention_v2_launcher(
  781. torch::Tensor& out,
  782. torch::Tensor& exp_sums,
  783. torch::Tensor& max_logits,
  784. torch::Tensor& tmp_out,
  785. torch::Tensor& query,
  786. torch::Tensor& key_cache,
  787. torch::Tensor& value_cache,
  788. int num_kv_heads,
  789. float scale,
  790. torch::Tensor& block_tables,
  791. torch::Tensor& context_lens,
  792. int max_context_len,
  793. const c10::optional<torch::Tensor>& alibi_slopes,
  794. const float k_scale,
  795. const float k_zp,
  796. const float v_scale,
  797. const float v_zp) {
  798. int num_seqs = query.size(0);
  799. int num_heads = query.size(1);
  800. int head_size = query.size(2);
  801. int max_num_blocks_per_seq = block_tables.size(1);
  802. int q_stride = query.stride(0);
  803. int kv_block_stride = key_cache.stride(0);
  804. int kv_head_stride = key_cache.stride(1);
  805. int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
  806. assert(head_size % thread_group_size == 0);
  807. // NOTE: alibi_slopes is optional.
  808. const float* alibi_slopes_ptr = alibi_slopes ?
  809. reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
  810. : nullptr;
  811. T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
  812. float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
  813. float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
  814. T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
  815. T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
  816. CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
  817. CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
  818. int* block_tables_ptr = block_tables.data_ptr<int>();
  819. int* context_lens_ptr = context_lens.data_ptr<int>();
  820. constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  821. int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
  822. int logits_size = PARTITION_SIZE * sizeof(float);
  823. int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
  824. // For paged attention v2 kernel.
  825. dim3 grid(num_heads, num_seqs, max_num_partitions);
  826. int shared_mem_size = std::max(logits_size, outputs_size);
  827. // For paged attention v2 reduce kernel.
  828. dim3 reduce_grid(num_heads, num_seqs);
  829. int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
  830. dim3 block(NUM_THREADS);
  831. const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
  832. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  833. switch (head_size) {
  834. // NOTE: To reduce the compilation time, we only compile for the
  835. // head sizes that we use in the model. However, we can easily extend this
  836. // to support any head size which is a multiple of 16.
  837. case 64:
  838. LAUNCH_PAGED_ATTENTION_V2(64);
  839. break;
  840. case 80:
  841. LAUNCH_PAGED_ATTENTION_V2(80);
  842. break;
  843. case 96:
  844. LAUNCH_PAGED_ATTENTION_V2(96);
  845. break;
  846. case 112:
  847. LAUNCH_PAGED_ATTENTION_V2(112);
  848. break;
  849. case 128:
  850. LAUNCH_PAGED_ATTENTION_V2(128);
  851. break;
  852. case 256:
  853. LAUNCH_PAGED_ATTENTION_V2(256);
  854. break;
  855. default:
  856. TORCH_CHECK(false, "Unsupported head size: ", head_size);
  857. break;
  858. }
  859. }
  860. #define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE) \
  861. paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE>( \
  862. out, \
  863. exp_sums, \
  864. max_logits, \
  865. tmp_out, \
  866. query, \
  867. key_cache, \
  868. value_cache, \
  869. num_kv_heads, \
  870. scale, \
  871. block_tables, \
  872. context_lens, \
  873. max_context_len, \
  874. alibi_slopes, \
  875. k_scale, \
  876. k_zp, \
  877. v_scale, \
  878. v_zp);
  879. // NOTE: To reduce the compilation time, we omitted block sizes
  880. // 1, 2, 4, 64, 128, 256.
  881. #define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_CACHE_DTYPE) \
  882. switch (block_size) { \
  883. case 8: \
  884. CALL_V2_LAUNCHER(T, CACHE_T, 8, KV_CACHE_DTYPE); \
  885. break; \
  886. case 16: \
  887. CALL_V2_LAUNCHER(T, CACHE_T, 16, KV_CACHE_DTYPE); \
  888. break; \
  889. case 32: \
  890. CALL_V2_LAUNCHER(T, CACHE_T, 32, KV_CACHE_DTYPE); \
  891. break; \
  892. default: \
  893. TORCH_CHECK(false, "Unsupported block size: ", block_size); \
  894. break; \
  895. }
  896. void paged_attention_v2(
  897. torch::Tensor& out, // [num_seqs, num_heads, head_size]
  898. torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
  899. torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
  900. torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
  901. torch::Tensor& query, // [num_seqs, num_heads, head_size]
  902. torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  903. torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
  904. int num_kv_heads, // [num_heads]
  905. float scale,
  906. torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
  907. torch::Tensor& context_lens, // [num_seqs]
  908. int block_size,
  909. int max_context_len,
  910. const c10::optional<torch::Tensor>& alibi_slopes,
  911. const std::string& kv_cache_dtype,
  912. const float k_scale = 1.0f,
  913. const float k_zp = 0.0f,
  914. const float v_scale = 1.0f,
  915. const float v_zp = 0.0f) {
  916. if (kv_cache_dtype == "auto") {
  917. if (query.dtype() == at::ScalarType::Float) {
  918. CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, AUTO);
  919. } else if (query.dtype() == at::ScalarType::Half) {
  920. CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, AUTO);
  921. } else if (query.dtype() == at::ScalarType::BFloat16) {
  922. CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, AUTO);
  923. } else {
  924. TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
  925. }
  926. #ifdef ENABLE_FP8_E5M2
  927. } else if (kv_cache_dtype == "fp8_e5m2") {
  928. if (query.dtype() == at::ScalarType::Float) {
  929. CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, FP8_E5M2);
  930. } else if (query.dtype() == at::ScalarType::Half) {
  931. CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, FP8_E5M2);
  932. } else if (query.dtype() == at::ScalarType::BFloat16) {
  933. CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, FP8_E5M2);
  934. } else {
  935. TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
  936. }
  937. #endif
  938. } else if (kv_cache_dtype == "int8") {
  939. if (query.dtype() == at::ScalarType::Float) {
  940. CALL_V2_LAUNCHER_BLOCK_SIZE(float, int8_t, INT8);
  941. } else if (query.dtype() == at::ScalarType::Half) {
  942. CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, int8_t, INT8);
  943. } else if (query.dtype() == at::ScalarType::BFloat16) {
  944. CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, int8_t, INT8);
  945. } else {
  946. TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
  947. }
  948. } else {
  949. TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
  950. }
  951. }
  952. #undef WARP_SIZE
  953. #undef MAX
  954. #undef MIN
  955. #undef DIVIDE_ROUND_UP