attention_kernels.cu 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866
  1. /*
  2. * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
  3. * Copyright (c) 2023, The PygmalionAI team.
  4. * Copyright (c) 2023, The vLLM team.
  5. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
  6. *
  7. * Licensed under the Apache License, Version 2.0 (the "License");
  8. * you may not use this file except in compliance with the License.
  9. * You may obtain a copy of the License at
  10. *
  11. * http://www.apache.org/licenses/LICENSE-2.0
  12. *
  13. * Unless required by applicable law or agreed to in writing, software
  14. * distributed under the License is distributed on an "AS IS" BASIS,
  15. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. * See the License for the specific language governing permissions and
  17. * limitations under the License.
  18. */
  19. #include <torch/extension.h>
  20. #include <ATen/cuda/CUDAContext.h>
  21. #include "attention_dtypes.h"
  22. #include "attention_utils.cuh"
  23. #include <algorithm>
  24. #define WARP_SIZE 32
  25. #define MAX(a, b) ((a) > (b) ? (a) : (b))
  26. #define MIN(a, b) ((a) < (b) ? (a) : (b))
  27. #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
  28. namespace aphrodite {
  29. // Utility function for attention softmax.
  30. template<int NUM_WARPS>
  31. inline __device__ float block_sum(float* red_smem, float sum) {
  32. // Decompose the thread index into warp / lane.
  33. int warp = threadIdx.x / WARP_SIZE;
  34. int lane = threadIdx.x % WARP_SIZE;
  35. // Compute the sum per warp.
  36. #pragma unroll
  37. for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
  38. sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
  39. }
  40. // Warp leaders store the data to shared memory.
  41. if (lane == 0) {
  42. red_smem[warp] = sum;
  43. }
  44. // Make sure the data is in shared memory.
  45. __syncthreads();
  46. // The warps compute the final sums.
  47. if (lane < NUM_WARPS) {
  48. sum = red_smem[lane];
  49. }
  50. // Parallel reduction inside the warp.
  51. #pragma unroll
  52. for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
  53. sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
  54. }
  55. // Broadcast to other threads.
  56. return __shfl_sync(uint32_t(-1), sum, 0);
  57. }
  58. // TODO: Merge the last two dimensions of the grid.
  59. // Grid: (num_heads, num_seqs, max_num_partitions).
  60. template<
  61. typename scalar_t,
  62. int HEAD_SIZE,
  63. int BLOCK_SIZE,
  64. int NUM_THREADS,
  65. int PARTITION_SIZE = 0> // Zero means no partitioning.
  66. __device__ void paged_attention_kernel(
  67. float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
  68. float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
  69. scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
  70. const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
  71. const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
  72. const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
  73. const int* __restrict__ head_mapping, // [num_heads]
  74. const float scale,
  75. const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
  76. const int* __restrict__ context_lens, // [num_seqs]
  77. const int max_num_blocks_per_seq,
  78. const float* __restrict__ alibi_slopes, // [num_heads]
  79. const int q_stride,
  80. const int kv_block_stride,
  81. const int kv_head_stride) {
  82. const int seq_idx = blockIdx.y;
  83. const int partition_idx = blockIdx.z;
  84. const int max_num_partitions = gridDim.z;
  85. constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
  86. const int context_len = context_lens[seq_idx];
  87. if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) {
  88. // No work to do. Terminate the thread block.
  89. return;
  90. }
  91. const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
  92. const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks;
  93. // [start_block_idx, end_block_idx) is the range of blocks to process.
  94. const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
  95. const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks);
  96. const int num_blocks = end_block_idx - start_block_idx;
  97. // [start_token_idx, end_token_idx) is the range of tokens to process.
  98. const int start_token_idx = start_block_idx * BLOCK_SIZE;
  99. const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len);
  100. const int num_tokens = end_token_idx - start_token_idx;
  101. constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
  102. constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
  103. assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
  104. constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
  105. constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  106. const int thread_idx = threadIdx.x;
  107. const int warp_idx = thread_idx / WARP_SIZE;
  108. const int lane = thread_idx % WARP_SIZE;
  109. const int head_idx = blockIdx.x;
  110. const int num_heads = gridDim.x;
  111. const int kv_head_idx = head_mapping[head_idx];
  112. const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
  113. // A vector type to store a part of a key or a query.
  114. // The vector size is configured in such a way that the threads in a thread group
  115. // fetch or compute 16 bytes at a time.
  116. // For example, if the size of a thread group is 4 and the data type is half,
  117. // then the vector size is 16 / (4 * sizeof(half)) == 2.
  118. constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
  119. using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
  120. using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
  121. constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
  122. constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
  123. const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
  124. const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
  125. // Load the query to registers.
  126. // Each thread in a thread group has a different part of the query.
  127. // For example, if the the thread group size is 4, then the first thread in the group
  128. // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
  129. // th vectors of the query, and so on.
  130. // NOTE: Because q is split from a qkv tensor, it may not be contiguous.
  131. const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
  132. __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
  133. #pragma unroll
  134. for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
  135. const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
  136. q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
  137. }
  138. __syncthreads(); // TODO: possible speedup if this is replaced with a memory wall right before we use q_vecs
  139. // Memory planning.
  140. extern __shared__ char shared_mem[];
  141. // NOTE: We use FP32 for the softmax logits for better accuracy.
  142. float* logits = reinterpret_cast<float*>(shared_mem);
  143. // Workspace for reduction.
  144. __shared__ float red_smem[2 * NUM_WARPS];
  145. // x == THREAD_GROUP_SIZE * VEC_SIZE
  146. // Each thread group fetches x elements from the key at a time.
  147. constexpr int x = 16 / sizeof(scalar_t);
  148. float qk_max = -FLT_MAX;
  149. // Iterate over the key blocks.
  150. // Each warp fetches a block of keys for each iteration.
  151. // Each thread group in a warp fetches a key from the block, and computes
  152. // dot product with the query.
  153. const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
  154. for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
  155. const int physical_block_number = block_table[block_idx];
  156. // Load a key to registers.
  157. // Each thread in a thread group has a different part of the key.
  158. // For example, if the the thread group size is 4, then the first thread in the group
  159. // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
  160. // vectors of the key, and so on.
  161. for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
  162. const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
  163. const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
  164. K_vec k_vecs[NUM_VECS_PER_THREAD];
  165. #pragma unroll
  166. for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
  167. const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride
  168. + kv_head_idx * kv_head_stride
  169. + physical_block_offset * x;
  170. const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
  171. const int offset1 = (vec_idx * VEC_SIZE) / x;
  172. const int offset2 = (vec_idx * VEC_SIZE) % x;
  173. k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
  174. }
  175. // Compute dot product.
  176. // This includes a reduction across the threads in the same thread group.
  177. float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
  178. // Add the ALiBi bias if slopes are given.
  179. qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
  180. if (thread_group_offset == 0) {
  181. // Store the partial reductions to shared memory.
  182. // NOTE: It is required to zero out the masked logits.
  183. const bool mask = token_idx >= context_len;
  184. logits[token_idx - start_token_idx] = mask ? 0.f : qk;
  185. // Update the max value.
  186. qk_max = mask ? qk_max : fmaxf(qk_max, qk);
  187. }
  188. }
  189. }
  190. // Perform reduction across the threads in the same warp to get the
  191. // max qk value for each "warp" (not across the thread block yet).
  192. // The 0-th thread of each thread group already has its max qk value.
  193. #pragma unroll
  194. for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
  195. qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
  196. }
  197. if (lane == 0) {
  198. red_smem[warp_idx] = qk_max;
  199. }
  200. __syncthreads();
  201. // TODO: Refactor this part.
  202. // Get the max qk value for the sequence.
  203. qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
  204. #pragma unroll
  205. for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
  206. qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
  207. }
  208. // Broadcast the max qk value to all threads.
  209. qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
  210. // Get the sum of the exp values.
  211. float exp_sum = 0.f;
  212. for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
  213. float val = __expf(logits[i] - qk_max);
  214. logits[i] = val;
  215. exp_sum += val;
  216. }
  217. exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
  218. // Compute softmax.
  219. const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
  220. for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
  221. logits[i] *= inv_sum;
  222. }
  223. __syncthreads();
  224. // If partitioning is enabled, store the max logit and exp_sum.
  225. if (USE_PARTITIONING && thread_idx == 0) {
  226. float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
  227. + head_idx * max_num_partitions
  228. + partition_idx;
  229. *max_logits_ptr = qk_max;
  230. float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
  231. + head_idx * max_num_partitions
  232. + partition_idx;
  233. *exp_sums_ptr = exp_sum;
  234. }
  235. // Each thread will fetch 16 bytes from the value cache at a time.
  236. constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
  237. using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
  238. using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
  239. using Float_L_vec = typename FloatVec<L_vec>::Type;
  240. constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
  241. constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
  242. constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
  243. // NOTE: We use FP32 for the accumulator for better accuracy.
  244. float accs[NUM_ROWS_PER_THREAD];
  245. #pragma unroll
  246. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  247. accs[i] = 0.f;
  248. }
  249. scalar_t zero_value;
  250. zero(zero_value);
  251. for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
  252. const int physical_block_number = block_table[block_idx];
  253. const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
  254. const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
  255. L_vec logits_vec;
  256. from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx));
  257. const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride
  258. + kv_head_idx * kv_head_stride;
  259. #pragma unroll
  260. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  261. const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
  262. if (row_idx < HEAD_SIZE) {
  263. const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
  264. V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
  265. if (block_idx == num_context_blocks - 1) {
  266. // NOTE: When v_vec contains the tokens that are out of the context,
  267. // we should explicitly zero out the values since they may contain NaNs.
  268. scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
  269. #pragma unroll
  270. for (int j = 0; j < V_VEC_SIZE; j++) {
  271. v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
  272. }
  273. }
  274. accs[i] += dot(logits_vec, v_vec);
  275. }
  276. }
  277. }
  278. // Perform reduction within each warp.
  279. #pragma unroll
  280. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  281. float acc = accs[i];
  282. #pragma unroll
  283. for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
  284. acc += __shfl_xor_sync(uint32_t(-1), acc, mask);
  285. }
  286. accs[i] = acc;
  287. }
  288. // NOTE: A barrier is required because the shared memory space for logits
  289. // is reused for the output.
  290. __syncthreads();
  291. // Perform reduction across warps.
  292. float* out_smem = reinterpret_cast<float*>(shared_mem);
  293. #pragma unroll
  294. for (int i = NUM_WARPS; i > 1; i /= 2) {
  295. int mid = i / 2;
  296. // Upper warps write to shared memory.
  297. if (warp_idx >= mid && warp_idx < i) {
  298. float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
  299. #pragma unroll
  300. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  301. const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
  302. if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
  303. dst[row_idx] = accs[i];
  304. }
  305. }
  306. }
  307. __syncthreads();
  308. // Lower warps update the output.
  309. if (warp_idx < mid) {
  310. const float* src = &out_smem[warp_idx * HEAD_SIZE];
  311. #pragma unroll
  312. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  313. const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
  314. if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
  315. accs[i] += src[row_idx];
  316. }
  317. }
  318. }
  319. __syncthreads();
  320. }
  321. // Write the final output.
  322. if (warp_idx == 0) {
  323. scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
  324. + head_idx * max_num_partitions * HEAD_SIZE
  325. + partition_idx * HEAD_SIZE;
  326. #pragma unroll
  327. for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  328. const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
  329. if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
  330. from_float(*(out_ptr + row_idx), accs[i]);
  331. }
  332. }
  333. }
  334. }
  335. // Grid: (num_heads, num_seqs, 1).
  336. template<
  337. typename scalar_t,
  338. int HEAD_SIZE,
  339. int BLOCK_SIZE,
  340. int NUM_THREADS>
  341. __global__ void paged_attention_v1_kernel(
  342. scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
  343. const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
  344. const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
  345. const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
  346. const int* __restrict__ head_mapping, // [num_heads]
  347. const float scale,
  348. const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
  349. const int* __restrict__ context_lens, // [num_seqs]
  350. const int max_num_blocks_per_seq,
  351. const float* __restrict__ alibi_slopes, // [num_heads]
  352. const int q_stride,
  353. const int kv_block_stride,
  354. const int kv_head_stride) {
  355. paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>(
  356. /* exp_sums */ nullptr, /* max_logits */ nullptr,
  357. out, q, k_cache, v_cache, head_mapping, scale, block_tables, context_lens,
  358. max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
  359. }
  360. // Grid: (num_heads, num_seqs, max_num_partitions).
  361. template<
  362. typename scalar_t,
  363. int HEAD_SIZE,
  364. int BLOCK_SIZE,
  365. int NUM_THREADS,
  366. int PARTITION_SIZE>
  367. __global__ void paged_attention_v2_kernel(
  368. float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
  369. float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
  370. scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
  371. const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
  372. const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
  373. const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
  374. const int* __restrict__ head_mapping, // [num_heads]
  375. const float scale,
  376. const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
  377. const int* __restrict__ context_lens, // [num_seqs]
  378. const int max_num_blocks_per_seq,
  379. const float* __restrict__ alibi_slopes, // [num_heads]
  380. const int q_stride,
  381. const int kv_block_stride,
  382. const int kv_head_stride) {
  383. paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE>(
  384. exp_sums, max_logits, tmp_out, q, k_cache, v_cache, head_mapping, scale,
  385. block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
  386. q_stride, kv_block_stride, kv_head_stride);
  387. }
  388. // Grid: (num_heads, num_seqs).
  389. template<
  390. typename scalar_t,
  391. int HEAD_SIZE,
  392. int NUM_THREADS,
  393. int PARTITION_SIZE>
  394. __global__ void paged_attention_v2_reduce_kernel(
  395. scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
  396. const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
  397. const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
  398. const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
  399. const int* __restrict__ context_lens, // [num_seqs]
  400. const int max_num_partitions) {
  401. const int num_heads = gridDim.x;
  402. const int head_idx = blockIdx.x;
  403. const int seq_idx = blockIdx.y;
  404. const int context_len = context_lens[seq_idx];
  405. const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
  406. if (num_partitions == 1) {
  407. // No need to reduce. Only copy tmp_out to out.
  408. scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
  409. const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
  410. + head_idx * max_num_partitions * HEAD_SIZE;
  411. for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
  412. out_ptr[i] = tmp_out_ptr[i];
  413. }
  414. // Terminate the thread block.
  415. return;
  416. }
  417. constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  418. const int warp_idx = threadIdx.x / WARP_SIZE;
  419. const int lane = threadIdx.x % WARP_SIZE;
  420. // Size: 2 * num_partitions.
  421. extern __shared__ char shared_mem[];
  422. // Workspace for reduction.
  423. __shared__ float red_smem[2 * NUM_WARPS];
  424. // Load max logits to shared memory.
  425. float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
  426. const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
  427. + head_idx * max_num_partitions;
  428. float max_logit = -FLT_MAX;
  429. for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
  430. const float l = max_logits_ptr[i];
  431. shared_max_logits[i] = l;
  432. max_logit = fmaxf(max_logit, l);
  433. }
  434. __syncthreads();
  435. // Get the global max logit.
  436. // Reduce within the warp.
  437. #pragma unroll
  438. for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
  439. max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask));
  440. }
  441. if (lane == 0) {
  442. red_smem[warp_idx] = max_logit;
  443. }
  444. __syncthreads();
  445. // Reduce across warps.
  446. max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
  447. #pragma unroll
  448. for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
  449. max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask));
  450. }
  451. // Broadcast the max value to all threads.
  452. max_logit = __shfl_sync(uint32_t(-1), max_logit, 0);
  453. // Load rescaled exp sums to shared memory.
  454. float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
  455. const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
  456. + head_idx * max_num_partitions;
  457. float global_exp_sum = 0.0f;
  458. for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
  459. float l = shared_max_logits[i];
  460. float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
  461. global_exp_sum += rescaled_exp_sum;
  462. shared_exp_sums[i] = rescaled_exp_sum;
  463. }
  464. __syncthreads();
  465. global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
  466. const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
  467. // Aggregate tmp_out to out.
  468. const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
  469. + head_idx * max_num_partitions * HEAD_SIZE;
  470. scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
  471. #pragma unroll
  472. for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
  473. float acc = 0.0f;
  474. for (int j = 0; j < num_partitions; ++j) {
  475. acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum;
  476. }
  477. from_float(out_ptr[i], acc);
  478. }
  479. }
  480. } // namespace aphrodite
  481. #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
  482. cudaFuncSetAttribute( \
  483. aphrodite::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>, \
  484. cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \
  485. aphrodite::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
  486. <<<grid, block, shared_mem_size, stream>>>( \
  487. out_ptr, \
  488. query_ptr, \
  489. key_cache_ptr, \
  490. value_cache_ptr, \
  491. head_mapping_ptr, \
  492. scale, \
  493. block_tables_ptr, \
  494. context_lens_ptr, \
  495. max_num_blocks_per_seq, \
  496. alibi_slopes_ptr, \
  497. q_stride, \
  498. kv_block_stride, \
  499. kv_head_stride);
  500. // TODO: Tune NUM_THREADS.
  501. template<
  502. typename T,
  503. int BLOCK_SIZE,
  504. int NUM_THREADS = 128>
  505. void paged_attention_v1_launcher(
  506. torch::Tensor& out,
  507. torch::Tensor& query,
  508. torch::Tensor& key_cache,
  509. torch::Tensor& value_cache,
  510. torch::Tensor& head_mapping,
  511. float scale,
  512. torch::Tensor& block_tables,
  513. torch::Tensor& context_lens,
  514. int max_context_len,
  515. const c10::optional<torch::Tensor>& alibi_slopes) {
  516. int num_seqs = query.size(0);
  517. int num_heads = query.size(1);
  518. int head_size = query.size(2);
  519. int max_num_blocks_per_seq = block_tables.size(1);
  520. int q_stride = query.stride(0);
  521. int kv_block_stride = key_cache.stride(0);
  522. int kv_head_stride = key_cache.stride(1);
  523. int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
  524. assert(head_size % thread_group_size == 0);
  525. // NOTE: alibi_slopes is optional.
  526. const float* alibi_slopes_ptr = alibi_slopes ?
  527. reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
  528. : nullptr;
  529. T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
  530. T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
  531. T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
  532. T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
  533. int* head_mapping_ptr = reinterpret_cast<int*>(head_mapping.data_ptr());
  534. int* block_tables_ptr = block_tables.data_ptr<int>();
  535. int* context_lens_ptr = context_lens.data_ptr<int>();
  536. constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  537. int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE;
  538. int logits_size = padded_max_context_len * sizeof(float);
  539. int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
  540. // Python-side check in aphrodite.task_handler.worker._check_if_can_support_max_seq_len
  541. // Keep that in sync with the logic here!
  542. int shared_mem_size = std::max(logits_size, outputs_size);
  543. dim3 grid(num_heads, num_seqs, 1);
  544. dim3 block(NUM_THREADS);
  545. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  546. switch (head_size) {
  547. // NOTE: To reduce the compilation time, we only compile for the
  548. // head sizes that we use in the model. However, we can easily extend this
  549. // to support any head size which is a multiple of 16.
  550. case 64:
  551. LAUNCH_PAGED_ATTENTION_V1(64);
  552. break;
  553. case 80:
  554. LAUNCH_PAGED_ATTENTION_V1(80);
  555. break;
  556. case 96:
  557. LAUNCH_PAGED_ATTENTION_V1(96);
  558. break;
  559. case 112:
  560. LAUNCH_PAGED_ATTENTION_V1(112);
  561. break;
  562. case 128:
  563. LAUNCH_PAGED_ATTENTION_V1(128);
  564. break;
  565. case 256:
  566. LAUNCH_PAGED_ATTENTION_V1(256);
  567. break;
  568. default:
  569. TORCH_CHECK(false, "Unsupported head size: ", head_size);
  570. break;
  571. }
  572. }
  573. #define CALL_V1_LAUNCHER(T, BLOCK_SIZE) \
  574. paged_attention_v1_launcher<T, BLOCK_SIZE>( \
  575. out, \
  576. query, \
  577. key_cache, \
  578. value_cache, \
  579. head_mapping, \
  580. scale, \
  581. block_tables, \
  582. context_lens, \
  583. max_context_len, \
  584. alibi_slopes);
  585. // NOTE: To reduce the compilation time, we omitted block sizes
  586. // 1, 2, 4, 64, 128, 256.
  587. #define CALL_V1_LAUNCHER_BLOCK_SIZE(T) \
  588. switch (block_size) { \
  589. case 8: \
  590. CALL_V1_LAUNCHER(T, 8); \
  591. break; \
  592. case 16: \
  593. CALL_V1_LAUNCHER(T, 16); \
  594. break; \
  595. case 32: \
  596. CALL_V1_LAUNCHER(T, 32); \
  597. break; \
  598. default: \
  599. TORCH_CHECK(false, "Unsupported block size: ", block_size); \
  600. break; \
  601. }
  602. void paged_attention_v1(
  603. torch::Tensor& out, // [num_seqs, num_heads, head_size]
  604. torch::Tensor& query, // [num_seqs, num_heads, head_size]
  605. torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  606. torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
  607. torch::Tensor& head_mapping, // [num_heads]
  608. float scale,
  609. torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
  610. torch::Tensor& context_lens, // [num_seqs]
  611. int block_size,
  612. int max_context_len,
  613. const c10::optional<torch::Tensor>& alibi_slopes) {
  614. if (query.dtype() == at::ScalarType::Float) {
  615. CALL_V1_LAUNCHER_BLOCK_SIZE(float);
  616. } else if (query.dtype() == at::ScalarType::Half) {
  617. CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t);
  618. } else if (query.dtype() == at::ScalarType::BFloat16) {
  619. CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
  620. } else {
  621. TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
  622. }
  623. }
  624. #define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
  625. aphrodite::paged_attention_v2_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE> \
  626. <<<grid, block, shared_mem_size, stream>>>( \
  627. exp_sums_ptr, \
  628. max_logits_ptr, \
  629. tmp_out_ptr, \
  630. query_ptr, \
  631. key_cache_ptr, \
  632. value_cache_ptr, \
  633. head_mapping_ptr, \
  634. scale, \
  635. block_tables_ptr, \
  636. context_lens_ptr, \
  637. max_num_blocks_per_seq, \
  638. alibi_slopes_ptr, \
  639. q_stride, \
  640. kv_block_stride, \
  641. kv_head_stride); \
  642. aphrodite::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE> \
  643. <<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
  644. out_ptr, \
  645. exp_sums_ptr, \
  646. max_logits_ptr, \
  647. tmp_out_ptr, \
  648. context_lens_ptr, \
  649. max_num_partitions);
  650. template<
  651. typename T,
  652. int BLOCK_SIZE,
  653. int NUM_THREADS = 128,
  654. int PARTITION_SIZE = 512>
  655. void paged_attention_v2_launcher(
  656. torch::Tensor& out,
  657. torch::Tensor& exp_sums,
  658. torch::Tensor& max_logits,
  659. torch::Tensor& tmp_out,
  660. torch::Tensor& query,
  661. torch::Tensor& key_cache,
  662. torch::Tensor& value_cache,
  663. torch::Tensor& head_mapping,
  664. float scale,
  665. torch::Tensor& block_tables,
  666. torch::Tensor& context_lens,
  667. int max_context_len,
  668. const c10::optional<torch::Tensor>& alibi_slopes) {
  669. int num_seqs = query.size(0);
  670. int num_heads = query.size(1);
  671. int head_size = query.size(2);
  672. int max_num_blocks_per_seq = block_tables.size(1);
  673. int q_stride = query.stride(0);
  674. int kv_block_stride = key_cache.stride(0);
  675. int kv_head_stride = key_cache.stride(1);
  676. int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
  677. assert(head_size % thread_group_size == 0);
  678. // NOTE: alibi_slopes is optional.
  679. const float* alibi_slopes_ptr = alibi_slopes ?
  680. reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
  681. : nullptr;
  682. T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
  683. float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
  684. float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
  685. T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
  686. T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
  687. T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
  688. T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
  689. int* head_mapping_ptr = reinterpret_cast<int*>(head_mapping.data_ptr());
  690. int* block_tables_ptr = block_tables.data_ptr<int>();
  691. int* context_lens_ptr = context_lens.data_ptr<int>();
  692. constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  693. int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
  694. int logits_size = PARTITION_SIZE * sizeof(float);
  695. int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
  696. // For paged attention v2 kernel.
  697. dim3 grid(num_heads, num_seqs, max_num_partitions);
  698. int shared_mem_size = std::max(logits_size, outputs_size);
  699. // For paged attention v2 reduce kernel.
  700. dim3 reduce_grid(num_heads, num_seqs);
  701. int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
  702. dim3 block(NUM_THREADS);
  703. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  704. switch (head_size) {
  705. // NOTE: To reduce the compilation time, we only compile for the
  706. // head sizes that we use in the model. However, we can easily extend this
  707. // to support any head size which is a multiple of 16.
  708. case 64:
  709. LAUNCH_PAGED_ATTENTION_V2(64);
  710. break;
  711. case 80:
  712. LAUNCH_PAGED_ATTENTION_V2(80);
  713. break;
  714. case 96:
  715. LAUNCH_PAGED_ATTENTION_V2(96);
  716. break;
  717. case 112:
  718. LAUNCH_PAGED_ATTENTION_V2(112);
  719. break;
  720. case 128:
  721. LAUNCH_PAGED_ATTENTION_V2(128);
  722. break;
  723. case 256:
  724. LAUNCH_PAGED_ATTENTION_V2(256);
  725. break;
  726. default:
  727. TORCH_CHECK(false, "Unsupported head size: ", head_size);
  728. break;
  729. }
  730. }
  731. #define CALL_V2_LAUNCHER(T, BLOCK_SIZE) \
  732. paged_attention_v2_launcher<T, BLOCK_SIZE>( \
  733. out, \
  734. exp_sums, \
  735. max_logits, \
  736. tmp_out, \
  737. query, \
  738. key_cache, \
  739. value_cache, \
  740. head_mapping, \
  741. scale, \
  742. block_tables, \
  743. context_lens, \
  744. max_context_len, \
  745. alibi_slopes);
  746. // NOTE: To reduce the compilation time, we omitted block sizes
  747. // 1, 2, 4, 64, 128, 256.
  748. #define CALL_V2_LAUNCHER_BLOCK_SIZE(T) \
  749. switch (block_size) { \
  750. case 8: \
  751. CALL_V2_LAUNCHER(T, 8); \
  752. break; \
  753. case 16: \
  754. CALL_V2_LAUNCHER(T, 16); \
  755. break; \
  756. case 32: \
  757. CALL_V2_LAUNCHER(T, 32); \
  758. break; \
  759. default: \
  760. TORCH_CHECK(false, "Unsupported block size: ", block_size); \
  761. break; \
  762. }
  763. void paged_attention_v2(
  764. torch::Tensor& out, // [num_seqs, num_heads, head_size]
  765. torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
  766. torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
  767. torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
  768. torch::Tensor& query, // [num_seqs, num_heads, head_size]
  769. torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  770. torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
  771. torch::Tensor& head_mapping, // [num_heads]
  772. float scale,
  773. torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
  774. torch::Tensor& context_lens, // [num_seqs]
  775. int block_size,
  776. int max_context_len,
  777. const c10::optional<torch::Tensor>& alibi_slopes) {
  778. if (query.dtype() == at::ScalarType::Float) {
  779. CALL_V2_LAUNCHER_BLOCK_SIZE(float);
  780. } else if (query.dtype() == at::ScalarType::Half) {
  781. CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t);
  782. } else if (query.dtype() == at::ScalarType::BFloat16) {
  783. CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
  784. } else {
  785. TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
  786. }
  787. }
  788. #undef WARP_SIZE
  789. #undef MAX
  790. #undef MIN
  791. #undef DIVIDE_ROUND_UP