attention_kernels.cu 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876
  1. /*
  2. * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
  3. * Copyright (c) 2023, The PygmalionAI team.
  4. * Copyright (c) 2023, The vLLM team.
  5. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
  6. *
  7. * Licensed under the Apache License, Version 2.0 (the "License");
  8. * you may not use this file except in compliance with the License.
  9. * You may obtain a copy of the License at
  10. *
  11. * http://www.apache.org/licenses/LICENSE-2.0
  12. *
  13. * Unless required by applicable law or agreed to in writing, software
  14. * distributed under the License is distributed on an "AS IS" BASIS,
  15. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. * See the License for the specific language governing permissions and
  17. * limitations under the License.
  18. */
  19. #ifdef USE_ROCM
  20. #include <hip/hip_runtime.h>
  21. #endif
  22. #include <torch/extension.h>
  23. #include <ATen/cuda/CUDAContext.h>
  24. #include "attention_dtypes.h"
  25. #include "attention_utils.cuh"
  26. #include <algorithm>
  27. #ifndef USE_ROCM
  28. #define WARP_SIZE 32
  29. #else
  30. #define WARP_SIZE warpSize
  31. #endif
  32. #define MAX(a, b) ((a) > (b) ? (a) : (b))
  33. #define MIN(a, b) ((a) < (b) ? (a) : (b))
  34. #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
  35. namespace aphrodite {
  36. // Utility function for attention softmax.
  37. template<int NUM_WARPS>
  38. inline __device__ float block_sum(float* red_smem, float sum) {
  39. // Decompose the thread index into warp / lane.
  40. int warp = threadIdx.x / WARP_SIZE;
  41. int lane = threadIdx.x % WARP_SIZE;
  42. // Compute the sum per warp.
  43. #pragma unroll
  44. for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
  45. sum += APHRODITE_SHFL_XOR_SYNC(sum, mask);
  46. }
  47. // Warp leaders store the data to shared memory.
  48. if (lane == 0) {
  49. red_smem[warp] = sum;
  50. }
  51. // Make sure the data is in shared memory.
  52. __syncthreads();
  53. // The warps compute the final sums.
  54. if (lane < NUM_WARPS) {
  55. sum = red_smem[lane];
  56. }
  57. // Parallel reduction inside the warp.
  58. #pragma unroll
  59. for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
  60. sum += APHRODITE_SHFL_XOR_SYNC(sum, mask);
  61. }
  62. // Broadcast to other threads.
  63. return APHRODITE_SHFL_SYNC(sum, 0);
  64. }
  65. // TODO: Merge the last two dimensions of the grid.
  66. // Grid: (num_heads, num_seqs, max_num_partitions).
  67. template<
  68. typename scalar_t,
  69. int HEAD_SIZE,
  70. int BLOCK_SIZE,
  71. int NUM_THREADS,
  72. int PARTITION_SIZE = 0> // Zero means no partitioning.
  73. __device__ void paged_attention_kernel(
  74. float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
  75. float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
  76. scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
  77. const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
  78. const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
  79. const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
  80. const int* __restrict__ head_mapping, // [num_heads]
  81. const float scale,
  82. const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
  83. const int* __restrict__ context_lens, // [num_seqs]
  84. const int max_num_blocks_per_seq,
  85. const float* __restrict__ alibi_slopes, // [num_heads]
  86. const int q_stride,
  87. const int kv_block_stride,
  88. const int kv_head_stride) {
  89. const int seq_idx = blockIdx.y;
  90. const int partition_idx = blockIdx.z;
  91. const int max_num_partitions = gridDim.z;
  92. constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
  93. const int context_len = context_lens[seq_idx];
  94. if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) {
  95. // No work to do. Terminate the thread block.
  96. return;
  97. }
  98. const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
  99. const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks;
  100. // [start_block_idx, end_block_idx) is the range of blocks to process.
  101. const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
  102. const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks);
  103. const int num_blocks = end_block_idx - start_block_idx;
  104. // [start_token_idx, end_token_idx) is the range of tokens to process.
  105. const int start_token_idx = start_block_idx * BLOCK_SIZE;
  106. const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len);
  107. const int num_tokens = end_token_idx - start_token_idx;
  108. constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
  109. constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
  110. assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
  111. constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
  112. constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  113. const int thread_idx = threadIdx.x;
  114. const int warp_idx = thread_idx / WARP_SIZE;
  115. const int lane = thread_idx % WARP_SIZE;
  116. const int head_idx = blockIdx.x;
  117. const int num_heads = gridDim.x;
  118. const int kv_head_idx = head_mapping[head_idx];
  119. const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
  120. // A vector type to store a part of a key or a query.
  121. // The vector size is configured in such a way that the threads in a thread group
  122. // fetch or compute 16 bytes at a time.
  123. // For example, if the size of a thread group is 4 and the data type is half,
  124. // then the vector size is 16 / (4 * sizeof(half)) == 2.
  125. constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
  126. using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
  127. using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
  128. constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
  129. constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
  130. const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
  131. const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
  132. // Load the query to registers.
  133. // Each thread in a thread group has a different part of the query.
  134. // For example, if the the thread group size is 4, then the first thread in the group
  135. // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
  136. // th vectors of the query, and so on.
  137. // NOTE: Because q is split from a qkv tensor, it may not be contiguous.
  138. const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
  139. __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
  140. #pragma unroll
  141. for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
  142. const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
  143. q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
  144. }
  145. __syncthreads(); // TODO: possible speedup if this is replaced with a memory wall right before we use q_vecs
  146. // Memory planning.
  147. extern __shared__ char shared_mem[];
  148. // NOTE: We use FP32 for the softmax logits for better accuracy.
  149. float* logits = reinterpret_cast<float*>(shared_mem);
  150. // Workspace for reduction.
  151. __shared__ float red_smem[2 * NUM_WARPS];
  152. // x == THREAD_GROUP_SIZE * VEC_SIZE
  153. // Each thread group fetches x elements from the key at a time.
  154. constexpr int x = 16 / sizeof(scalar_t);
  155. float qk_max = -FLT_MAX;
  156. // Iterate over the key blocks.
  157. // Each warp fetches a block of keys for each iteration.
  158. // Each thread group in a warp fetches a key from the block, and computes
  159. // dot product with the query.
  160. const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
  161. for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
  162. const int physical_block_number = block_table[block_idx];
  163. // Load a key to registers.
  164. // Each thread in a thread group has a different part of the key.
  165. // For example, if the the thread group size is 4, then the first thread in the group
  166. // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
  167. // vectors of the key, and so on.
  168. for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
  169. const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
  170. const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
  171. K_vec k_vecs[NUM_VECS_PER_THREAD];
  172. #pragma unroll
  173. for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
  174. const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride
  175. + kv_head_idx * kv_head_stride
  176. + physical_block_offset * x;
  177. const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
  178. const int offset1 = (vec_idx * VEC_SIZE) / x;
  179. const int offset2 = (vec_idx * VEC_SIZE) % x;
  180. k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
  181. }
  182. // Compute dot product.
  183. // This includes a reduction across the threads in the same thread group.
  184. float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
  185. // Add the ALiBi bias if slopes are given.
  186. qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
  187. if (thread_group_offset == 0) {
  188. // Store the partial reductions to shared memory.
  189. // NOTE: It is required to zero out the masked logits.
  190. const bool mask = token_idx >= context_len;
  191. logits[token_idx - start_token_idx] = mask ? 0.f : qk;
  192. // Update the max value.
  193. qk_max = mask ? qk_max : fmaxf(qk_max, qk);
  194. }
  195. }
  196. }
  197. // Perform reduction across the threads in the same warp to get the
  198. // max qk value for each "warp" (not across the thread block yet).
  199. // The 0-th thread of each thread group already has its max qk value.
  200. #pragma unroll
  201. for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
  202. qk_max = fmaxf(qk_max, APHRODITE_SHFL_XOR_SYNC(qk_max, mask));
  203. }
  204. if (lane == 0) {
  205. red_smem[warp_idx] = qk_max;
  206. }
  207. __syncthreads();
  208. // TODO: Refactor this part.
  209. // Get the max qk value for the sequence.
  210. qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
  211. #pragma unroll
  212. for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
  213. qk_max = fmaxf(qk_max, APHRODITE_SHFL_XOR_SYNC(qk_max, mask));
  214. }
  215. // Broadcast the max qk value to all threads.
  216. qk_max = APHRODITE_SHFL_SYNC(qk_max, 0);
  217. // Get the sum of the exp values.
  218. float exp_sum = 0.f;
  219. for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
  220. float val = __expf(logits[i] - qk_max);
  221. logits[i] = val;
  222. exp_sum += val;
  223. }
  224. exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
  225. // Compute softmax.
  226. const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
  227. for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
  228. logits[i] *= inv_sum;
  229. }
  230. __syncthreads();
  231. // If partitioning is enabled, store the max logit and exp_sum.
  232. if (USE_PARTITIONING && thread_idx == 0) {
  233. float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
  234. + head_idx * max_num_partitions
  235. + partition_idx;
  236. *max_logits_ptr = qk_max;
  237. float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
  238. + head_idx * max_num_partitions
  239. + partition_idx;
  240. *exp_sums_ptr = exp_sum;
  241. }
  242. // Each thread will fetch 16 bytes from the value cache at a time.
  243. constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
  244. using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
  245. using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
  246. using Float_L_vec = typename FloatVec<L_vec>::Type;
  247. constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
  248. constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
  249. constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
  250. // NOTE: We use FP32 for the accumulator for better accuracy.
  251. float accs[NUM_ROWS_PER_THREAD];
  252. #pragma unroll
  253. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  254. accs[i] = 0.f;
  255. }
  256. scalar_t zero_value;
  257. zero(zero_value);
  258. for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
  259. const int physical_block_number = block_table[block_idx];
  260. const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
  261. const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
  262. L_vec logits_vec;
  263. from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx));
  264. const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride
  265. + kv_head_idx * kv_head_stride;
  266. #pragma unroll
  267. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  268. const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
  269. if (row_idx < HEAD_SIZE) {
  270. const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
  271. V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
  272. if (block_idx == num_context_blocks - 1) {
  273. // NOTE: When v_vec contains the tokens that are out of the context,
  274. // we should explicitly zero out the values since they may contain NaNs.
  275. scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
  276. #pragma unroll
  277. for (int j = 0; j < V_VEC_SIZE; j++) {
  278. v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
  279. }
  280. }
  281. accs[i] += dot(logits_vec, v_vec);
  282. }
  283. }
  284. }
  285. // Perform reduction within each warp.
  286. #pragma unroll
  287. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  288. float acc = accs[i];
  289. #pragma unroll
  290. for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
  291. acc += APHRODITE_SHFL_XOR_SYNC(acc, mask);
  292. }
  293. accs[i] = acc;
  294. }
  295. // NOTE: A barrier is required because the shared memory space for logits
  296. // is reused for the output.
  297. __syncthreads();
  298. // Perform reduction across warps.
  299. float* out_smem = reinterpret_cast<float*>(shared_mem);
  300. #pragma unroll
  301. for (int i = NUM_WARPS; i > 1; i /= 2) {
  302. int mid = i / 2;
  303. // Upper warps write to shared memory.
  304. if (warp_idx >= mid && warp_idx < i) {
  305. float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
  306. #pragma unroll
  307. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  308. const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
  309. if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
  310. dst[row_idx] = accs[i];
  311. }
  312. }
  313. }
  314. __syncthreads();
  315. // Lower warps update the output.
  316. if (warp_idx < mid) {
  317. const float* src = &out_smem[warp_idx * HEAD_SIZE];
  318. #pragma unroll
  319. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  320. const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
  321. if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
  322. accs[i] += src[row_idx];
  323. }
  324. }
  325. }
  326. __syncthreads();
  327. }
  328. // Write the final output.
  329. if (warp_idx == 0) {
  330. scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
  331. + head_idx * max_num_partitions * HEAD_SIZE
  332. + partition_idx * HEAD_SIZE;
  333. #pragma unroll
  334. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  335. const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
  336. if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
  337. from_float(*(out_ptr + row_idx), accs[i]);
  338. }
  339. }
  340. }
  341. }
  342. // Grid: (num_heads, num_seqs, 1).
  343. template<
  344. typename scalar_t,
  345. int HEAD_SIZE,
  346. int BLOCK_SIZE,
  347. int NUM_THREADS>
  348. __global__ void paged_attention_v1_kernel(
  349. scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
  350. const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
  351. const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
  352. const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
  353. const int* __restrict__ head_mapping, // [num_heads]
  354. const float scale,
  355. const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
  356. const int* __restrict__ context_lens, // [num_seqs]
  357. const int max_num_blocks_per_seq,
  358. const float* __restrict__ alibi_slopes, // [num_heads]
  359. const int q_stride,
  360. const int kv_block_stride,
  361. const int kv_head_stride) {
  362. paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>(
  363. /* exp_sums */ nullptr, /* max_logits */ nullptr,
  364. out, q, k_cache, v_cache, head_mapping, scale, block_tables, context_lens,
  365. max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
  366. }
  367. // Grid: (num_heads, num_seqs, max_num_partitions).
  368. template<
  369. typename scalar_t,
  370. int HEAD_SIZE,
  371. int BLOCK_SIZE,
  372. int NUM_THREADS,
  373. int PARTITION_SIZE>
  374. __global__ void paged_attention_v2_kernel(
  375. float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
  376. float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
  377. scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
  378. const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
  379. const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
  380. const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
  381. const int* __restrict__ head_mapping, // [num_heads]
  382. const float scale,
  383. const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
  384. const int* __restrict__ context_lens, // [num_seqs]
  385. const int max_num_blocks_per_seq,
  386. const float* __restrict__ alibi_slopes, // [num_heads]
  387. const int q_stride,
  388. const int kv_block_stride,
  389. const int kv_head_stride) {
  390. paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE>(
  391. exp_sums, max_logits, tmp_out, q, k_cache, v_cache, head_mapping, scale,
  392. block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
  393. q_stride, kv_block_stride, kv_head_stride);
  394. }
  395. // Grid: (num_heads, num_seqs).
  396. template<
  397. typename scalar_t,
  398. int HEAD_SIZE,
  399. int NUM_THREADS,
  400. int PARTITION_SIZE>
  401. __global__ void paged_attention_v2_reduce_kernel(
  402. scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
  403. const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
  404. const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
  405. const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
  406. const int* __restrict__ context_lens, // [num_seqs]
  407. const int max_num_partitions) {
  408. const int num_heads = gridDim.x;
  409. const int head_idx = blockIdx.x;
  410. const int seq_idx = blockIdx.y;
  411. const int context_len = context_lens[seq_idx];
  412. const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
  413. if (num_partitions == 1) {
  414. // No need to reduce. Only copy tmp_out to out.
  415. scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
  416. const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
  417. + head_idx * max_num_partitions * HEAD_SIZE;
  418. for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
  419. out_ptr[i] = tmp_out_ptr[i];
  420. }
  421. // Terminate the thread block.
  422. return;
  423. }
  424. constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  425. const int warp_idx = threadIdx.x / WARP_SIZE;
  426. const int lane = threadIdx.x % WARP_SIZE;
  427. // Size: 2 * num_partitions.
  428. extern __shared__ char shared_mem[];
  429. // Workspace for reduction.
  430. __shared__ float red_smem[2 * NUM_WARPS];
  431. // Load max logits to shared memory.
  432. float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
  433. const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
  434. + head_idx * max_num_partitions;
  435. float max_logit = -FLT_MAX;
  436. for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
  437. const float l = max_logits_ptr[i];
  438. shared_max_logits[i] = l;
  439. max_logit = fmaxf(max_logit, l);
  440. }
  441. __syncthreads();
  442. // Get the global max logit.
  443. // Reduce within the warp.
  444. #pragma unroll
  445. for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
  446. max_logit = fmaxf(max_logit, APHRODITE_SHFL_XOR_SYNC(max_logit, mask));
  447. }
  448. if (lane == 0) {
  449. red_smem[warp_idx] = max_logit;
  450. }
  451. __syncthreads();
  452. // Reduce across warps.
  453. max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
  454. #pragma unroll
  455. for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
  456. max_logit = fmaxf(max_logit, APHRODITE_SHFL_XOR_SYNC(max_logit, mask));
  457. }
  458. // Broadcast the max value to all threads.
  459. max_logit = APHRODITE_SHFL_SYNC(max_logit, 0);
  460. // Load rescaled exp sums to shared memory.
  461. float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
  462. const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
  463. + head_idx * max_num_partitions;
  464. float global_exp_sum = 0.0f;
  465. for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
  466. float l = shared_max_logits[i];
  467. float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
  468. global_exp_sum += rescaled_exp_sum;
  469. shared_exp_sums[i] = rescaled_exp_sum;
  470. }
  471. __syncthreads();
  472. global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
  473. const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
  474. // Aggregate tmp_out to out.
  475. const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
  476. + head_idx * max_num_partitions * HEAD_SIZE;
  477. scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
  478. #pragma unroll
  479. for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
  480. float acc = 0.0f;
  481. for (int j = 0; j < num_partitions; ++j) {
  482. acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum;
  483. }
  484. from_float(out_ptr[i], acc);
  485. }
  486. }
  487. } // namespace aphrodite
  488. #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
  489. APHRODITE_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
  490. ((void*)aphrodite::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>), \
  491. shared_mem_size); \
  492. aphrodite::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
  493. <<<grid, block, shared_mem_size, stream>>>( \
  494. out_ptr, \
  495. query_ptr, \
  496. key_cache_ptr, \
  497. value_cache_ptr, \
  498. head_mapping_ptr, \
  499. scale, \
  500. block_tables_ptr, \
  501. context_lens_ptr, \
  502. max_num_blocks_per_seq, \
  503. alibi_slopes_ptr, \
  504. q_stride, \
  505. kv_block_stride, \
  506. kv_head_stride);
  507. // TODO: Tune NUM_THREADS.
  508. template<
  509. typename T,
  510. int BLOCK_SIZE,
  511. int NUM_THREADS = 128>
  512. void paged_attention_v1_launcher(
  513. torch::Tensor& out,
  514. torch::Tensor& query,
  515. torch::Tensor& key_cache,
  516. torch::Tensor& value_cache,
  517. torch::Tensor& head_mapping,
  518. float scale,
  519. torch::Tensor& block_tables,
  520. torch::Tensor& context_lens,
  521. int max_context_len,
  522. const c10::optional<torch::Tensor>& alibi_slopes) {
  523. int num_seqs = query.size(0);
  524. int num_heads = query.size(1);
  525. int head_size = query.size(2);
  526. int max_num_blocks_per_seq = block_tables.size(1);
  527. int q_stride = query.stride(0);
  528. int kv_block_stride = key_cache.stride(0);
  529. int kv_head_stride = key_cache.stride(1);
  530. int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
  531. assert(head_size % thread_group_size == 0);
  532. // NOTE: alibi_slopes is optional.
  533. const float* alibi_slopes_ptr = alibi_slopes ?
  534. reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
  535. : nullptr;
  536. T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
  537. T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
  538. T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
  539. T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
  540. int* head_mapping_ptr = reinterpret_cast<int*>(head_mapping.data_ptr());
  541. int* block_tables_ptr = block_tables.data_ptr<int>();
  542. int* context_lens_ptr = context_lens.data_ptr<int>();
  543. constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  544. int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE;
  545. int logits_size = padded_max_context_len * sizeof(float);
  546. int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
  547. // Python-side check in aphrodite.task_handler.worker._check_if_can_support_max_seq_len
  548. // Keep that in sync with the logic here!
  549. int shared_mem_size = std::max(logits_size, outputs_size);
  550. dim3 grid(num_heads, num_seqs, 1);
  551. dim3 block(NUM_THREADS);
  552. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  553. switch (head_size) {
  554. // NOTE: To reduce the compilation time, we only compile for the
  555. // head sizes that we use in the model. However, we can easily extend this
  556. // to support any head size which is a multiple of 16.
  557. case 64:
  558. LAUNCH_PAGED_ATTENTION_V1(64);
  559. break;
  560. case 80:
  561. LAUNCH_PAGED_ATTENTION_V1(80);
  562. break;
  563. case 96:
  564. LAUNCH_PAGED_ATTENTION_V1(96);
  565. break;
  566. case 112:
  567. LAUNCH_PAGED_ATTENTION_V1(112);
  568. break;
  569. case 128:
  570. LAUNCH_PAGED_ATTENTION_V1(128);
  571. break;
  572. case 256:
  573. LAUNCH_PAGED_ATTENTION_V1(256);
  574. break;
  575. default:
  576. TORCH_CHECK(false, "Unsupported head size: ", head_size);
  577. break;
  578. }
  579. }
  580. #define CALL_V1_LAUNCHER(T, BLOCK_SIZE) \
  581. paged_attention_v1_launcher<T, BLOCK_SIZE>( \
  582. out, \
  583. query, \
  584. key_cache, \
  585. value_cache, \
  586. head_mapping, \
  587. scale, \
  588. block_tables, \
  589. context_lens, \
  590. max_context_len, \
  591. alibi_slopes);
  592. // NOTE: To reduce the compilation time, we omitted block sizes
  593. // 1, 2, 4, 64, 128, 256.
  594. #define CALL_V1_LAUNCHER_BLOCK_SIZE(T) \
  595. switch (block_size) { \
  596. case 8: \
  597. CALL_V1_LAUNCHER(T, 8); \
  598. break; \
  599. case 16: \
  600. CALL_V1_LAUNCHER(T, 16); \
  601. break; \
  602. case 32: \
  603. CALL_V1_LAUNCHER(T, 32); \
  604. break; \
  605. default: \
  606. TORCH_CHECK(false, "Unsupported block size: ", block_size); \
  607. break; \
  608. }
  609. void paged_attention_v1(
  610. torch::Tensor& out, // [num_seqs, num_heads, head_size]
  611. torch::Tensor& query, // [num_seqs, num_heads, head_size]
  612. torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  613. torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
  614. torch::Tensor& head_mapping, // [num_heads]
  615. float scale,
  616. torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
  617. torch::Tensor& context_lens, // [num_seqs]
  618. int block_size,
  619. int max_context_len,
  620. const c10::optional<torch::Tensor>& alibi_slopes) {
  621. if (query.dtype() == at::ScalarType::Float) {
  622. CALL_V1_LAUNCHER_BLOCK_SIZE(float);
  623. } else if (query.dtype() == at::ScalarType::Half) {
  624. CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t);
  625. } else if (query.dtype() == at::ScalarType::BFloat16) {
  626. CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
  627. } else {
  628. TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
  629. }
  630. }
  631. #define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
  632. aphrodite::paged_attention_v2_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE> \
  633. <<<grid, block, shared_mem_size, stream>>>( \
  634. exp_sums_ptr, \
  635. max_logits_ptr, \
  636. tmp_out_ptr, \
  637. query_ptr, \
  638. key_cache_ptr, \
  639. value_cache_ptr, \
  640. head_mapping_ptr, \
  641. scale, \
  642. block_tables_ptr, \
  643. context_lens_ptr, \
  644. max_num_blocks_per_seq, \
  645. alibi_slopes_ptr, \
  646. q_stride, \
  647. kv_block_stride, \
  648. kv_head_stride); \
  649. aphrodite::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE> \
  650. <<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
  651. out_ptr, \
  652. exp_sums_ptr, \
  653. max_logits_ptr, \
  654. tmp_out_ptr, \
  655. context_lens_ptr, \
  656. max_num_partitions);
  657. template<
  658. typename T,
  659. int BLOCK_SIZE,
  660. int NUM_THREADS = 128,
  661. int PARTITION_SIZE = 512>
  662. void paged_attention_v2_launcher(
  663. torch::Tensor& out,
  664. torch::Tensor& exp_sums,
  665. torch::Tensor& max_logits,
  666. torch::Tensor& tmp_out,
  667. torch::Tensor& query,
  668. torch::Tensor& key_cache,
  669. torch::Tensor& value_cache,
  670. torch::Tensor& head_mapping,
  671. float scale,
  672. torch::Tensor& block_tables,
  673. torch::Tensor& context_lens,
  674. int max_context_len,
  675. const c10::optional<torch::Tensor>& alibi_slopes) {
  676. int num_seqs = query.size(0);
  677. int num_heads = query.size(1);
  678. int head_size = query.size(2);
  679. int max_num_blocks_per_seq = block_tables.size(1);
  680. int q_stride = query.stride(0);
  681. int kv_block_stride = key_cache.stride(0);
  682. int kv_head_stride = key_cache.stride(1);
  683. int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
  684. assert(head_size % thread_group_size == 0);
  685. // NOTE: alibi_slopes is optional.
  686. const float* alibi_slopes_ptr = alibi_slopes ?
  687. reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
  688. : nullptr;
  689. T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
  690. float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
  691. float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
  692. T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
  693. T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
  694. T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
  695. T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
  696. int* head_mapping_ptr = reinterpret_cast<int*>(head_mapping.data_ptr());
  697. int* block_tables_ptr = block_tables.data_ptr<int>();
  698. int* context_lens_ptr = context_lens.data_ptr<int>();
  699. constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  700. int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
  701. int logits_size = PARTITION_SIZE * sizeof(float);
  702. int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
  703. // For paged attention v2 kernel.
  704. dim3 grid(num_heads, num_seqs, max_num_partitions);
  705. int shared_mem_size = std::max(logits_size, outputs_size);
  706. // For paged attention v2 reduce kernel.
  707. dim3 reduce_grid(num_heads, num_seqs);
  708. int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
  709. dim3 block(NUM_THREADS);
  710. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  711. switch (head_size) {
  712. // NOTE: To reduce the compilation time, we only compile for the
  713. // head sizes that we use in the model. However, we can easily extend this
  714. // to support any head size which is a multiple of 16.
  715. case 64:
  716. LAUNCH_PAGED_ATTENTION_V2(64);
  717. break;
  718. case 80:
  719. LAUNCH_PAGED_ATTENTION_V2(80);
  720. break;
  721. case 96:
  722. LAUNCH_PAGED_ATTENTION_V2(96);
  723. break;
  724. case 112:
  725. LAUNCH_PAGED_ATTENTION_V2(112);
  726. break;
  727. case 128:
  728. LAUNCH_PAGED_ATTENTION_V2(128);
  729. break;
  730. case 256:
  731. LAUNCH_PAGED_ATTENTION_V2(256);
  732. break;
  733. default:
  734. TORCH_CHECK(false, "Unsupported head size: ", head_size);
  735. break;
  736. }
  737. }
  738. #define CALL_V2_LAUNCHER(T, BLOCK_SIZE) \
  739. paged_attention_v2_launcher<T, BLOCK_SIZE>( \
  740. out, \
  741. exp_sums, \
  742. max_logits, \
  743. tmp_out, \
  744. query, \
  745. key_cache, \
  746. value_cache, \
  747. head_mapping, \
  748. scale, \
  749. block_tables, \
  750. context_lens, \
  751. max_context_len, \
  752. alibi_slopes);
  753. // NOTE: To reduce the compilation time, we omitted block sizes
  754. // 1, 2, 4, 64, 128, 256.
  755. #define CALL_V2_LAUNCHER_BLOCK_SIZE(T) \
  756. switch (block_size) { \
  757. case 8: \
  758. CALL_V2_LAUNCHER(T, 8); \
  759. break; \
  760. case 16: \
  761. CALL_V2_LAUNCHER(T, 16); \
  762. break; \
  763. case 32: \
  764. CALL_V2_LAUNCHER(T, 32); \
  765. break; \
  766. default: \
  767. TORCH_CHECK(false, "Unsupported block size: ", block_size); \
  768. break; \
  769. }
  770. void paged_attention_v2(
  771. torch::Tensor& out, // [num_seqs, num_heads, head_size]
  772. torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
  773. torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
  774. torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
  775. torch::Tensor& query, // [num_seqs, num_heads, head_size]
  776. torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  777. torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
  778. torch::Tensor& head_mapping, // [num_heads]
  779. float scale,
  780. torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
  781. torch::Tensor& context_lens, // [num_seqs]
  782. int block_size,
  783. int max_context_len,
  784. const c10::optional<torch::Tensor>& alibi_slopes) {
  785. if (query.dtype() == at::ScalarType::Float) {
  786. CALL_V2_LAUNCHER_BLOCK_SIZE(float);
  787. } else if (query.dtype() == at::ScalarType::Half) {
  788. CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t);
  789. } else if (query.dtype() == at::ScalarType::BFloat16) {
  790. CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
  791. } else {
  792. TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
  793. }
  794. }
  795. #undef WARP_SIZE
  796. #undef MAX
  797. #undef MIN
  798. #undef DIVIDE_ROUND_UP