attention_kernels.cu 43 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003
  1. /*
  2. * Adapted from
  3. * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
  4. * Copyright (c) 2023, The PygmalionAI team.
  5. * Copyright (c) 2023, The vLLM team.
  6. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
  7. *
  8. * Licensed under the Apache License, Version 2.0 (the "License");
  9. * you may not use this file except in compliance with the License.
  10. * You may obtain a copy of the License at
  11. *
  12. * http://www.apache.org/licenses/LICENSE-2.0
  13. *
  14. * Unless required by applicable law or agreed to in writing, software
  15. * distributed under the License is distributed on an "AS IS" BASIS,
  16. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. * See the License for the specific language governing permissions and
  18. * limitations under the License.
  19. */
  20. #include <torch/all.h>
  21. #include <ATen/cuda/CUDAContext.h>
  22. #include <c10/cuda/CUDAGuard.h>
  23. #include <algorithm>
  24. #include "attention_dtypes.h"
  25. #include "attention_utils.cuh"
  26. #ifdef USE_ROCM
  27. #include <hip/hip_bf16.h>
  28. #include "../quantization/fp8/amd/quant_utils.cuh"
  29. typedef __hip_bfloat16 __nv_bfloat16;
  30. #else
  31. #include "../quantization/fp8/nvidia/quant_utils.cuh"
  32. #endif
  33. #ifndef USE_ROCM
  34. #define WARP_SIZE 32
  35. #else
  36. #define WARP_SIZE warpSize
  37. #endif
  38. #define MAX(a, b) ((a) > (b) ? (a) : (b))
  39. #define MIN(a, b) ((a) < (b) ? (a) : (b))
  40. #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
  41. namespace aphrodite {
  42. // Utility function for attention softmax.
  43. template <int NUM_WARPS>
  44. inline __device__ float block_sum(float* red_smem, float sum) {
  45. // Decompose the thread index into warp / lane.
  46. int warp = threadIdx.x / WARP_SIZE;
  47. int lane = threadIdx.x % WARP_SIZE;
  48. // Compute the sum per warp.
  49. #pragma unroll
  50. for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
  51. sum += APHRODITE_SHFL_XOR_SYNC(sum, mask);
  52. }
  53. // Warp leaders store the data to shared memory.
  54. if (lane == 0) {
  55. red_smem[warp] = sum;
  56. }
  57. // Make sure the data is in shared memory.
  58. __syncthreads();
  59. // The warps compute the final sums.
  60. if (lane < NUM_WARPS) {
  61. sum = red_smem[lane];
  62. }
  63. // Parallel reduction inside the warp.
  64. #pragma unroll
  65. for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
  66. sum += APHRODITE_SHFL_XOR_SYNC(sum, mask);
  67. }
  68. // Broadcast to other threads.
  69. return APHRODITE_SHFL_SYNC(sum, 0);
  70. }
  71. // TODO: Merge the last two dimensions of the grid.
  72. // Grid: (num_heads, num_seqs, max_num_partitions).
  73. template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
  74. int NUM_THREADS, aphrodite::Fp8KVCacheDataType KV_DTYPE,
  75. bool IS_BLOCK_SPARSE,
  76. int PARTITION_SIZE = 0> // Zero means no partitioning.
  77. __device__ void paged_attention_kernel(
  78. float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
  79. float* __restrict__ max_logits, // [num_seqs, num_heads,
  80. // max_num_partitions]
  81. scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
  82. // head_size]
  83. const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
  84. const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
  85. // head_size/x, block_size, x]
  86. const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
  87. // head_size, block_size]
  88. const int num_kv_heads, // [num_heads]
  89. const float scale,
  90. const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
  91. const int* __restrict__ seq_lens, // [num_seqs]
  92. const int max_num_blocks_per_seq,
  93. const float* __restrict__ alibi_slopes, // [num_heads]
  94. const int q_stride, const int kv_block_stride, const int kv_head_stride,
  95. const float k_scale, const float v_scale, const int tp_rank,
  96. const int blocksparse_local_blocks, const int blocksparse_vert_stride,
  97. const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
  98. const int seq_idx = blockIdx.y;
  99. const int partition_idx = blockIdx.z;
  100. const int max_num_partitions = gridDim.z;
  101. constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
  102. const int seq_len = seq_lens[seq_idx];
  103. if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) {
  104. // No work to do. Terminate the thread block.
  105. return;
  106. }
  107. const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
  108. const int num_blocks_per_partition =
  109. USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
  110. // [start_block_idx, end_block_idx) is the range of blocks to process.
  111. const int start_block_idx =
  112. USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
  113. const int end_block_idx =
  114. MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);
  115. const int num_blocks = end_block_idx - start_block_idx;
  116. // [start_token_idx, end_token_idx) is the range of tokens to process.
  117. const int start_token_idx = start_block_idx * BLOCK_SIZE;
  118. const int end_token_idx =
  119. MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
  120. const int num_tokens = end_token_idx - start_token_idx;
  121. constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
  122. constexpr int NUM_THREAD_GROUPS =
  123. NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE
  124. // divides NUM_THREADS
  125. assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
  126. constexpr int NUM_TOKENS_PER_THREAD_GROUP =
  127. DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
  128. constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  129. const int thread_idx = threadIdx.x;
  130. const int warp_idx = thread_idx / WARP_SIZE;
  131. const int lane = thread_idx % WARP_SIZE;
  132. const int head_idx = blockIdx.x;
  133. const int num_heads = gridDim.x;
  134. const int num_queries_per_kv = num_heads / num_kv_heads;
  135. const int kv_head_idx = head_idx / num_queries_per_kv;
  136. const float alibi_slope =
  137. alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
  138. // A vector type to store a part of a key or a query.
  139. // The vector size is configured in such a way that the threads in a thread
  140. // group fetch or compute 16 bytes at a time. For example, if the size of a
  141. // thread group is 4 and the data type is half, then the vector size is 16 /
  142. // (4 * sizeof(half)) == 2.
  143. constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
  144. using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
  145. using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
  146. using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
  147. constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
  148. constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
  149. const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
  150. const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
  151. // Load the query to registers.
  152. // Each thread in a thread group has a different part of the query.
  153. // For example, if the the thread group size is 4, then the first thread in
  154. // the group has 0, 4, 8, ... th vectors of the query, and the second thread
  155. // has 1, 5, 9, ... th vectors of the query, and so on. NOTE: Because
  156. // q is split from a qkv tensor, it may not be contiguous.
  157. const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
  158. __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
  159. #pragma unroll
  160. for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
  161. i += NUM_THREAD_GROUPS) {
  162. const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
  163. q_vecs[thread_group_offset][i] =
  164. *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
  165. }
  166. __syncthreads(); // TODO: possible speedup if this is replaced with a
  167. // memory wall right before we use q_vecs
  168. // Memory planning.
  169. extern __shared__ char shared_mem[];
  170. // NOTE: We use FP32 for the softmax logits for better accuracy.
  171. float* logits = reinterpret_cast<float*>(shared_mem);
  172. // Workspace for reduction.
  173. __shared__ float red_smem[2 * NUM_WARPS];
  174. // x == THREAD_GROUP_SIZE * VEC_SIZE
  175. // Each thread group fetches x elements from the key at a time.
  176. constexpr int x = 16 / sizeof(cache_t);
  177. float qk_max = -FLT_MAX;
  178. // Iterate over the key blocks.
  179. // Each warp fetches a block of keys for each iteration.
  180. // Each thread group in a warp fetches a key from the block, and computes
  181. // dot product with the query.
  182. const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
  183. // blocksparse specific vars
  184. [[maybe_unused]] int bs_block_offset;
  185. [[maybe_unused]] int q_bs_block_id;
  186. if constexpr (IS_BLOCK_SPARSE) {
  187. // const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
  188. // blocksparse_block_size);
  189. q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
  190. if (blocksparse_head_sliding_step >= 0)
  191. // sliding on q heads
  192. bs_block_offset =
  193. (tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1;
  194. else
  195. // sliding on kv heads
  196. bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
  197. (-blocksparse_head_sliding_step) +
  198. 1;
  199. }
  200. for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
  201. block_idx += NUM_WARPS) {
  202. // NOTE: The block number is stored in int32. However, we cast it to
  203. // int64 because int32 can lead to overflow when this variable is multiplied
  204. // by large numbers (e.g., kv_block_stride).
  205. // For blocksparse attention: skip computation on blocks that are not
  206. // attended
  207. if constexpr (IS_BLOCK_SPARSE) {
  208. const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
  209. const bool is_remote =
  210. ((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0);
  211. const bool is_local =
  212. (k_bs_block_id > q_bs_block_id - blocksparse_local_blocks);
  213. if (!is_remote && !is_local) {
  214. for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
  215. const int physical_block_offset =
  216. (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
  217. const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
  218. if (thread_group_offset == 0) {
  219. // NOTE: assign very large number to skipped tokens to
  220. // avoid contribution to the sumexp softmax normalizer. This will
  221. // not be used at computing sum(softmax*v) as the blocks will be
  222. // skipped.
  223. logits[token_idx - start_token_idx] = -FLT_MAX;
  224. }
  225. }
  226. continue;
  227. }
  228. }
  229. const int64_t physical_block_number =
  230. static_cast<int64_t>(block_table[block_idx]);
  231. // Load a key to registers.
  232. // Each thread in a thread group has a different part of the key.
  233. // For example, if the the thread group size is 4, then the first thread in
  234. // the group has 0, 4, 8, ... th vectors of the key, and the second thread
  235. // has 1, 5, 9, ... th vectors of the key, and so on.
  236. for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
  237. const int physical_block_offset =
  238. (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
  239. const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
  240. K_vec k_vecs[NUM_VECS_PER_THREAD];
  241. #pragma unroll
  242. for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
  243. const cache_t* k_ptr =
  244. k_cache + physical_block_number * kv_block_stride +
  245. kv_head_idx * kv_head_stride + physical_block_offset * x;
  246. const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
  247. const int offset1 = (vec_idx * VEC_SIZE) / x;
  248. const int offset2 = (vec_idx * VEC_SIZE) % x;
  249. if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
  250. k_vecs[j] = *reinterpret_cast<const K_vec*>(
  251. k_ptr + offset1 * BLOCK_SIZE * x + offset2);
  252. } else {
  253. // Vector conversion from Quant_vec to K_vec.
  254. Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
  255. k_ptr + offset1 * BLOCK_SIZE * x + offset2);
  256. k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
  257. k_vec_quant, k_scale);
  258. }
  259. }
  260. // Compute dot product.
  261. // This includes a reduction across the threads in the same thread group.
  262. float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(
  263. q_vecs[thread_group_offset], k_vecs);
  264. // Add the ALiBi bias if slopes are given.
  265. qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
  266. if (thread_group_offset == 0) {
  267. // Store the partial reductions to shared memory.
  268. // NOTE: It is required to zero out the masked logits.
  269. const bool mask = token_idx >= seq_len;
  270. logits[token_idx - start_token_idx] = mask ? 0.f : qk;
  271. // Update the max value.
  272. qk_max = mask ? qk_max : fmaxf(qk_max, qk);
  273. }
  274. }
  275. }
  276. // Perform reduction across the threads in the same warp to get the
  277. // max qk value for each "warp" (not across the thread block yet).
  278. // The 0-th thread of each thread group already has its max qk value.
  279. #pragma unroll
  280. for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
  281. qk_max = fmaxf(qk_max, APHRODITE_SHFL_XOR_SYNC(qk_max, mask));
  282. }
  283. if (lane == 0) {
  284. red_smem[warp_idx] = qk_max;
  285. }
  286. __syncthreads();
  287. // TODO: Refactor this part.
  288. // Get the max qk value for the sequence.
  289. qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
  290. #pragma unroll
  291. for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
  292. qk_max = fmaxf(qk_max, APHRODITE_SHFL_XOR_SYNC(qk_max, mask));
  293. }
  294. // Broadcast the max qk value to all threads.
  295. qk_max = APHRODITE_SHFL_SYNC(qk_max, 0);
  296. // Get the sum of the exp values.
  297. float exp_sum = 0.f;
  298. for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
  299. float val = __expf(logits[i] - qk_max);
  300. logits[i] = val;
  301. exp_sum += val;
  302. }
  303. exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
  304. // Compute softmax.
  305. const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
  306. for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
  307. logits[i] *= inv_sum;
  308. }
  309. __syncthreads();
  310. // If partitioning is enabled, store the max logit and exp_sum.
  311. if (USE_PARTITIONING && thread_idx == 0) {
  312. float* max_logits_ptr = max_logits +
  313. seq_idx * num_heads * max_num_partitions +
  314. head_idx * max_num_partitions + partition_idx;
  315. *max_logits_ptr = qk_max;
  316. float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions +
  317. head_idx * max_num_partitions + partition_idx;
  318. *exp_sums_ptr = exp_sum;
  319. }
  320. // Each thread will fetch 16 bytes from the value cache at a time.
  321. constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
  322. using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
  323. using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
  324. using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
  325. using Float_L_vec = typename FloatVec<L_vec>::Type;
  326. constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
  327. constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
  328. constexpr int NUM_ROWS_PER_THREAD =
  329. DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
  330. // NOTE: We use FP32 for the accumulator for better accuracy.
  331. float accs[NUM_ROWS_PER_THREAD];
  332. #pragma unroll
  333. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  334. accs[i] = 0.f;
  335. }
  336. scalar_t zero_value;
  337. zero(zero_value);
  338. for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
  339. block_idx += NUM_WARPS) {
  340. // NOTE: The block number is stored in int32. However, we cast it to
  341. // int64 because int32 can lead to overflow when this variable is multiplied
  342. // by large numbers (e.g., kv_block_stride).
  343. // For blocksparse attention: skip computation on blocks that are not
  344. // attended
  345. if constexpr (IS_BLOCK_SPARSE) {
  346. int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
  347. if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) &&
  348. !((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) {
  349. continue;
  350. }
  351. }
  352. const int64_t physical_block_number =
  353. static_cast<int64_t>(block_table[block_idx]);
  354. const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
  355. const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
  356. L_vec logits_vec;
  357. from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx -
  358. start_token_idx));
  359. const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
  360. kv_head_idx * kv_head_stride;
  361. #pragma unroll
  362. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  363. const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
  364. if (row_idx < HEAD_SIZE) {
  365. const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
  366. V_vec v_vec;
  367. if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
  368. v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
  369. } else {
  370. V_quant_vec v_quant_vec =
  371. *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
  372. // Vector conversion from V_quant_vec to V_vec.
  373. v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
  374. v_scale);
  375. }
  376. if (block_idx == num_seq_blocks - 1) {
  377. // NOTE: When v_vec contains the tokens that are out of the
  378. // context, we should explicitly zero out the values since they may
  379. // contain NaNs.
  380. scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
  381. #pragma unroll
  382. for (int j = 0; j < V_VEC_SIZE; j++) {
  383. v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
  384. }
  385. }
  386. accs[i] += dot(logits_vec, v_vec);
  387. }
  388. }
  389. }
  390. // Perform reduction within each warp.
  391. #pragma unroll
  392. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  393. float acc = accs[i];
  394. #pragma unroll
  395. for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
  396. acc += APHRODITE_SHFL_XOR_SYNC(acc, mask);
  397. }
  398. accs[i] = acc;
  399. }
  400. // NOTE: A barrier is required because the shared memory space for
  401. // logits is reused for the output.
  402. __syncthreads();
  403. // Perform reduction across warps.
  404. float* out_smem = reinterpret_cast<float*>(shared_mem);
  405. #pragma unroll
  406. for (int i = NUM_WARPS; i > 1; i /= 2) {
  407. int mid = i / 2;
  408. // Upper warps write to shared memory.
  409. if (warp_idx >= mid && warp_idx < i) {
  410. float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
  411. #pragma unroll
  412. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  413. const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
  414. if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
  415. dst[row_idx] = accs[i];
  416. }
  417. }
  418. }
  419. __syncthreads();
  420. // Lower warps update the output.
  421. if (warp_idx < mid) {
  422. const float* src = &out_smem[warp_idx * HEAD_SIZE];
  423. #pragma unroll
  424. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  425. const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
  426. if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
  427. accs[i] += src[row_idx];
  428. }
  429. }
  430. }
  431. __syncthreads();
  432. }
  433. // Write the final output.
  434. if (warp_idx == 0) {
  435. scalar_t* out_ptr =
  436. out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
  437. head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
  438. #pragma unroll
  439. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  440. const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
  441. if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
  442. from_float(*(out_ptr + row_idx), accs[i]);
  443. }
  444. }
  445. }
  446. }
  447. // Grid: (num_heads, num_seqs, 1).
  448. template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
  449. int NUM_THREADS, aphrodite::Fp8KVCacheDataType KV_DTYPE,
  450. bool IS_BLOCK_SPARSE>
  451. __global__ void paged_attention_v1_kernel(
  452. scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
  453. const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
  454. const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
  455. // head_size/x, block_size, x]
  456. const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
  457. // head_size, block_size]
  458. const int num_kv_heads, // [num_heads]
  459. const float scale,
  460. const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
  461. const int* __restrict__ seq_lens, // [num_seqs]
  462. const int max_num_blocks_per_seq,
  463. const float* __restrict__ alibi_slopes, // [num_heads]
  464. const int q_stride, const int kv_block_stride, const int kv_head_stride,
  465. const float k_scale, const float v_scale, const int tp_rank,
  466. const int blocksparse_local_blocks, const int blocksparse_vert_stride,
  467. const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
  468. paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
  469. KV_DTYPE, IS_BLOCK_SPARSE>(
  470. /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
  471. v_cache, num_kv_heads, scale, block_tables, seq_lens,
  472. max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
  473. kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
  474. blocksparse_vert_stride, blocksparse_block_size,
  475. blocksparse_head_sliding_step);
  476. }
  477. // Grid: (num_heads, num_seqs, max_num_partitions).
  478. template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
  479. int NUM_THREADS, aphrodite::Fp8KVCacheDataType KV_DTYPE,
  480. bool IS_BLOCK_SPARSE,
  481. int PARTITION_SIZE>
  482. __global__ void paged_attention_v2_kernel(
  483. float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
  484. float* __restrict__ max_logits, // [num_seqs, num_heads,
  485. // max_num_partitions]
  486. scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
  487. // max_num_partitions, head_size]
  488. const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
  489. const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
  490. // head_size/x, block_size, x]
  491. const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
  492. // head_size, block_size]
  493. const int num_kv_heads, // [num_heads]
  494. const float scale,
  495. const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
  496. const int* __restrict__ seq_lens, // [num_seqs]
  497. const int max_num_blocks_per_seq,
  498. const float* __restrict__ alibi_slopes, // [num_heads]
  499. const int q_stride, const int kv_block_stride, const int kv_head_stride,
  500. const float k_scale, const float v_scale, const int tp_rank,
  501. const int blocksparse_local_blocks, const int blocksparse_vert_stride,
  502. const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
  503. paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
  504. KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>(
  505. exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
  506. block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
  507. kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,
  508. blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
  509. blocksparse_head_sliding_step);
  510. }
  511. // Grid: (num_heads, num_seqs).
  512. template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
  513. int PARTITION_SIZE>
  514. __global__ void paged_attention_v2_reduce_kernel(
  515. scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
  516. const float* __restrict__ exp_sums, // [num_seqs, num_heads,
  517. // max_num_partitions]
  518. const float* __restrict__ max_logits, // [num_seqs, num_heads,
  519. // max_num_partitions]
  520. const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
  521. // max_num_partitions, head_size]
  522. const int* __restrict__ seq_lens, // [num_seqs]
  523. const int max_num_partitions) {
  524. const int num_heads = gridDim.x;
  525. const int head_idx = blockIdx.x;
  526. const int seq_idx = blockIdx.y;
  527. const int seq_len = seq_lens[seq_idx];
  528. const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
  529. if (num_partitions == 1) {
  530. // No need to reduce. Only copy tmp_out to out.
  531. scalar_t* out_ptr =
  532. out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
  533. const scalar_t* tmp_out_ptr =
  534. tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
  535. head_idx * max_num_partitions * HEAD_SIZE;
  536. for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
  537. out_ptr[i] = tmp_out_ptr[i];
  538. }
  539. // Terminate the thread block.
  540. return;
  541. }
  542. constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  543. const int warp_idx = threadIdx.x / WARP_SIZE;
  544. const int lane = threadIdx.x % WARP_SIZE;
  545. // Size: 2 * num_partitions.
  546. extern __shared__ char shared_mem[];
  547. // Workspace for reduction.
  548. __shared__ float red_smem[2 * NUM_WARPS];
  549. // Load max logits to shared memory.
  550. float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
  551. const float* max_logits_ptr = max_logits +
  552. seq_idx * num_heads * max_num_partitions +
  553. head_idx * max_num_partitions;
  554. float max_logit = -FLT_MAX;
  555. for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
  556. const float l = max_logits_ptr[i];
  557. shared_max_logits[i] = l;
  558. max_logit = fmaxf(max_logit, l);
  559. }
  560. __syncthreads();
  561. // Get the global max logit.
  562. // Reduce within the warp.
  563. #pragma unroll
  564. for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
  565. max_logit = fmaxf(max_logit, APHRODITE_SHFL_XOR_SYNC(max_logit, mask));
  566. }
  567. if (lane == 0) {
  568. red_smem[warp_idx] = max_logit;
  569. }
  570. __syncthreads();
  571. // Reduce across warps.
  572. max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
  573. #pragma unroll
  574. for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
  575. max_logit = fmaxf(max_logit, APHRODITE_SHFL_XOR_SYNC(max_logit, mask));
  576. }
  577. // Broadcast the max value to all threads.
  578. max_logit = APHRODITE_SHFL_SYNC(max_logit, 0);
  579. // Load rescaled exp sums to shared memory.
  580. float* shared_exp_sums =
  581. reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
  582. const float* exp_sums_ptr = exp_sums +
  583. seq_idx * num_heads * max_num_partitions +
  584. head_idx * max_num_partitions;
  585. float global_exp_sum = 0.0f;
  586. for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
  587. float l = shared_max_logits[i];
  588. float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
  589. global_exp_sum += rescaled_exp_sum;
  590. shared_exp_sums[i] = rescaled_exp_sum;
  591. }
  592. __syncthreads();
  593. global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
  594. const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
  595. // Aggregate tmp_out to out.
  596. const scalar_t* tmp_out_ptr =
  597. tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
  598. head_idx * max_num_partitions * HEAD_SIZE;
  599. scalar_t* out_ptr =
  600. out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
  601. #pragma unroll
  602. for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
  603. float acc = 0.0f;
  604. for (int j = 0; j < num_partitions; ++j) {
  605. acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] *
  606. inv_global_exp_sum;
  607. }
  608. from_float(out_ptr[i], acc);
  609. }
  610. }
  611. } // namespace aphrodite
  612. #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
  613. APHRODITE_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
  614. ((void*)aphrodite::paged_attention_v1_kernel< \
  615. T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE, \
  616. IS_BLOCK_SPARSE>), \
  617. shared_mem_size); \
  618. aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
  619. NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE> \
  620. <<<grid, block, shared_mem_size, stream>>>( \
  621. out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
  622. scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
  623. alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
  624. k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
  625. blocksparse_vert_stride, blocksparse_block_size, \
  626. blocksparse_head_sliding_step);
  627. // TODO(woosuk): Tune NUM_THREADS.
  628. template <typename T, typename CACHE_T, int BLOCK_SIZE,
  629. aphrodite::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
  630. int NUM_THREADS = 128>
  631. void paged_attention_v1_launcher(
  632. torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
  633. torch::Tensor& value_cache, int num_kv_heads, float scale,
  634. torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
  635. const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
  636. float v_scale, const int tp_rank, const int blocksparse_local_blocks,
  637. const int blocksparse_vert_stride, const int blocksparse_block_size,
  638. const int blocksparse_head_sliding_step) {
  639. int num_seqs = query.size(0);
  640. int num_heads = query.size(1);
  641. int head_size = query.size(2);
  642. int max_num_blocks_per_seq = block_tables.size(1);
  643. int q_stride = query.stride(0);
  644. int kv_block_stride = key_cache.stride(0);
  645. int kv_head_stride = key_cache.stride(1);
  646. [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
  647. assert(head_size % thread_group_size == 0);
  648. // NOTE: alibi_slopes is optional.
  649. const float* alibi_slopes_ptr =
  650. alibi_slopes
  651. ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
  652. : nullptr;
  653. T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
  654. T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
  655. CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
  656. CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
  657. int* block_tables_ptr = block_tables.data_ptr<int>();
  658. int* seq_lens_ptr = seq_lens.data_ptr<int>();
  659. constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  660. int padded_max_seq_len =
  661. DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
  662. int logits_size = padded_max_seq_len * sizeof(float);
  663. int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
  664. // Python-side check in
  665. // aphrodite.worker.worker._check_if_can_support_max_seq_len Keep that
  666. // in sync with the logic here!
  667. int shared_mem_size = std::max(logits_size, outputs_size);
  668. dim3 grid(num_heads, num_seqs, 1);
  669. dim3 block(NUM_THREADS);
  670. const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
  671. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  672. switch (head_size) {
  673. // NOTE: To reduce the compilation time, we only compile for the
  674. // head sizes that we use in the model. However, we can easily extend this
  675. // to support any head size which is a multiple of 16.
  676. case 64:
  677. LAUNCH_PAGED_ATTENTION_V1(64);
  678. break;
  679. case 80:
  680. LAUNCH_PAGED_ATTENTION_V1(80);
  681. break;
  682. case 96:
  683. LAUNCH_PAGED_ATTENTION_V1(96);
  684. break;
  685. case 112:
  686. LAUNCH_PAGED_ATTENTION_V1(112);
  687. break;
  688. case 120:
  689. LAUNCH_PAGED_ATTENTION_V1(120);
  690. break;
  691. case 128:
  692. LAUNCH_PAGED_ATTENTION_V1(128);
  693. break;
  694. case 192:
  695. LAUNCH_PAGED_ATTENTION_V1(192);
  696. break;
  697. case 256:
  698. LAUNCH_PAGED_ATTENTION_V1(256);
  699. break;
  700. default:
  701. TORCH_CHECK(false, "Unsupported head size: ", head_size);
  702. break;
  703. }
  704. }
  705. #define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
  706. paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
  707. IS_BLOCK_SPARSE>( \
  708. out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
  709. seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
  710. blocksparse_local_blocks, blocksparse_vert_stride, \
  711. blocksparse_block_size, blocksparse_head_sliding_step);
  712. #define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
  713. switch (is_block_sparse) { \
  714. case true: \
  715. CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
  716. break; \
  717. case false: \
  718. CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
  719. break; \
  720. }
  721. // NOTE: To reduce the compilation time, we omitted block sizes
  722. // 1, 2, 4, 64, 128, 256.
  723. #define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
  724. switch (block_size) { \
  725. case 8: \
  726. CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
  727. break; \
  728. case 16: \
  729. CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
  730. break; \
  731. case 32: \
  732. CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
  733. break; \
  734. default: \
  735. TORCH_CHECK(false, "Unsupported block size: ", block_size); \
  736. break; \
  737. }
  738. void paged_attention_v1(
  739. torch::Tensor& out, // [num_seqs, num_heads, head_size]
  740. torch::Tensor& query, // [num_seqs, num_heads, head_size]
  741. torch::Tensor&
  742. key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  743. torch::Tensor&
  744. value_cache, // [num_blocks, num_heads, head_size, block_size]
  745. int64_t num_kv_heads, // [num_heads]
  746. double scale,
  747. torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
  748. torch::Tensor& seq_lens, // [num_seqs]
  749. int64_t block_size, int64_t max_seq_len,
  750. const c10::optional<torch::Tensor>& alibi_slopes,
  751. const std::string& kv_cache_dtype, double k_scale, double v_scale,
  752. const int64_t tp_rank, const int64_t blocksparse_local_blocks,
  753. const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
  754. const int64_t blocksparse_head_sliding_step) {
  755. const bool is_block_sparse = (blocksparse_vert_stride > 1);
  756. DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
  757. CALL_V1_LAUNCHER_BLOCK_SIZE)
  758. }
  759. #define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
  760. aphrodite::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
  761. NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE, \
  762. PARTITION_SIZE> \
  763. <<<grid, block, shared_mem_size, stream>>>( \
  764. exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
  765. value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
  766. seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
  767. kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
  768. blocksparse_local_blocks, blocksparse_vert_stride, \
  769. blocksparse_block_size, blocksparse_head_sliding_step); \
  770. aphrodite::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
  771. PARTITION_SIZE> \
  772. <<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
  773. out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
  774. max_num_partitions);
  775. template <typename T, typename CACHE_T, int BLOCK_SIZE,
  776. aphrodite::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
  777. int NUM_THREADS = 128, int PARTITION_SIZE = 512>
  778. void paged_attention_v2_launcher(
  779. torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
  780. torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
  781. torch::Tensor& value_cache, int num_kv_heads, float scale,
  782. torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
  783. const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
  784. float v_scale, const int tp_rank, const int blocksparse_local_blocks,
  785. const int blocksparse_vert_stride, const int blocksparse_block_size,
  786. const int blocksparse_head_sliding_step) {
  787. int num_seqs = query.size(0);
  788. int num_heads = query.size(1);
  789. int head_size = query.size(2);
  790. int max_num_blocks_per_seq = block_tables.size(1);
  791. int q_stride = query.stride(0);
  792. int kv_block_stride = key_cache.stride(0);
  793. int kv_head_stride = key_cache.stride(1);
  794. [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
  795. assert(head_size % thread_group_size == 0);
  796. // NOTE: alibi_slopes is optional.
  797. const float* alibi_slopes_ptr =
  798. alibi_slopes
  799. ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
  800. : nullptr;
  801. T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
  802. float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
  803. float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
  804. T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
  805. T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
  806. CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
  807. CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
  808. int* block_tables_ptr = block_tables.data_ptr<int>();
  809. int* seq_lens_ptr = seq_lens.data_ptr<int>();
  810. constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  811. int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
  812. int logits_size = PARTITION_SIZE * sizeof(float);
  813. int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
  814. // For paged attention v2 kernel.
  815. dim3 grid(num_heads, num_seqs, max_num_partitions);
  816. int shared_mem_size = std::max(logits_size, outputs_size);
  817. // For paged attention v2 reduce kernel.
  818. dim3 reduce_grid(num_heads, num_seqs);
  819. int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
  820. dim3 block(NUM_THREADS);
  821. const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
  822. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  823. switch (head_size) {
  824. // NOTE: To reduce the compilation time, we only compile for the
  825. // head sizes that we use in the model. However, we can easily extend this
  826. // to support any head size which is a multiple of 16.
  827. case 64:
  828. LAUNCH_PAGED_ATTENTION_V2(64);
  829. break;
  830. case 80:
  831. LAUNCH_PAGED_ATTENTION_V2(80);
  832. break;
  833. case 96:
  834. LAUNCH_PAGED_ATTENTION_V2(96);
  835. break;
  836. case 112:
  837. LAUNCH_PAGED_ATTENTION_V2(112);
  838. break;
  839. case 120:
  840. LAUNCH_PAGED_ATTENTION_V2(120);
  841. break;
  842. case 128:
  843. LAUNCH_PAGED_ATTENTION_V2(128);
  844. break;
  845. case 192:
  846. LAUNCH_PAGED_ATTENTION_V2(192);
  847. break;
  848. case 256:
  849. LAUNCH_PAGED_ATTENTION_V2(256);
  850. break;
  851. default:
  852. TORCH_CHECK(false, "Unsupported head size: ", head_size);
  853. break;
  854. }
  855. }
  856. #define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
  857. paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
  858. IS_BLOCK_SPARSE>( \
  859. out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
  860. num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
  861. k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
  862. blocksparse_vert_stride, blocksparse_block_size, \
  863. blocksparse_head_sliding_step);
  864. #define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
  865. switch (is_block_sparse) { \
  866. case true: \
  867. CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
  868. break; \
  869. case false: \
  870. CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
  871. break; \
  872. }
  873. // NOTE: To reduce the compilation time, we omitted block sizes
  874. // 1, 2, 4, 64, 128, 256.
  875. #define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
  876. switch (block_size) { \
  877. case 8: \
  878. CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
  879. break; \
  880. case 16: \
  881. CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
  882. break; \
  883. case 32: \
  884. CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
  885. break; \
  886. default: \
  887. TORCH_CHECK(false, "Unsupported block size: ", block_size); \
  888. break; \
  889. }
  890. void paged_attention_v2(
  891. torch::Tensor& out, // [num_seqs, num_heads, head_size]
  892. torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
  893. torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
  894. torch::Tensor&
  895. tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
  896. torch::Tensor& query, // [num_seqs, num_heads, head_size]
  897. torch::Tensor&
  898. key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  899. torch::Tensor&
  900. value_cache, // [num_blocks, num_heads, head_size, block_size]
  901. int64_t num_kv_heads, // [num_heads]
  902. double scale,
  903. torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
  904. torch::Tensor& seq_lens, // [num_seqs]
  905. int64_t block_size, int64_t max_seq_len,
  906. const c10::optional<torch::Tensor>& alibi_slopes,
  907. const std::string& kv_cache_dtype, double k_scale, double v_scale,
  908. const int64_t tp_rank, const int64_t blocksparse_local_blocks,
  909. const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
  910. const int64_t blocksparse_head_sliding_step) {
  911. const bool is_block_sparse = (blocksparse_vert_stride > 1);
  912. DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
  913. CALL_V2_LAUNCHER_BLOCK_SIZE)
  914. }
  915. #undef WARP_SIZE
  916. #undef MAX
  917. #undef MIN
  918. #undef DIVIDE_ROUND_UP