attention.cpp 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746
  1. #include "cpu_types.hpp"
  2. namespace {
  3. template <typename scalar_t> struct KernelVecType {
  4. using q_load_vec_type = void;
  5. using q_vec_type = void;
  6. using k_load_vec_type = void;
  7. using k_vec_type = void;
  8. using qk_acc_vec_type = void;
  9. using v_load_vec_type = void;
  10. };
  11. template <> struct KernelVecType<float> {
  12. using q_load_vec_type = vec_op::FP32Vec4;
  13. using q_vec_type = vec_op::FP32Vec16;
  14. using k_load_vec_type = vec_op::FP32Vec16;
  15. using k_vec_type = vec_op::FP32Vec16;
  16. using qk_acc_vec_type = vec_op::FP32Vec16;
  17. using v_load_vec_type = vec_op::FP32Vec16;
  18. };
  19. #ifdef __AVX512BF16__
  20. template <> struct KernelVecType<c10::BFloat16> {
  21. using q_load_vec_type = vec_op::BF16Vec8;
  22. using q_vec_type = vec_op::BF16Vec32;
  23. using k_load_vec_type = vec_op::BF16Vec32;
  24. using k_vec_type = vec_op::BF16Vec32;
  25. using qk_acc_vec_type = vec_op::FP32Vec16;
  26. using v_load_vec_type = vec_op::BF16Vec16;
  27. };
  28. #else
  29. template <> struct KernelVecType<c10::BFloat16> {
  30. using q_load_vec_type = vec_op::BF16Vec8;
  31. using q_vec_type = vec_op::FP32Vec16;
  32. using k_load_vec_type = vec_op::BF16Vec16;
  33. using k_vec_type = vec_op::FP32Vec16;
  34. using qk_acc_vec_type = vec_op::FP32Vec16;
  35. using v_load_vec_type = vec_op::BF16Vec16;
  36. };
  37. #endif
  38. template <typename T>
  39. FORCE_INLINE std::pair<T, T> reduceSoftmax(T *data, const int size,
  40. const int capacity) {
  41. T max = data[0];
  42. for (int i = 1; i < size; ++i) {
  43. max = max >= data[i] ? max : data[i];
  44. }
  45. T sum = 0;
  46. for (int i = 0; i < size; ++i) {
  47. data[i] = std::exp(data[i] - max);
  48. sum += data[i];
  49. }
  50. int i = 0;
  51. for (; i < size; ++i) {
  52. data[i] /= sum;
  53. }
  54. for (; i < capacity; ++i) {
  55. data[i] = 0;
  56. }
  57. return {max, sum};
  58. }
  59. template <typename T>
  60. FORCE_INLINE std::pair<T, T>
  61. reduceSoftmaxAlibi(T *data, const int size, const int capacity,
  62. const float alibi_slope, const int start_index,
  63. const int context_len) {
  64. data[0] += alibi_slope * (start_index - context_len + 1);
  65. T max = data[0];
  66. for (int i = 1; i < size; ++i) {
  67. T qk = data[i] + alibi_slope * (start_index + i - context_len + 1);
  68. data[i] = qk;
  69. max = max >= qk ? max : qk;
  70. }
  71. T sum = 0;
  72. for (int i = 0; i < size; ++i) {
  73. data[i] = std::exp(data[i] - max);
  74. sum += data[i];
  75. }
  76. int i = 0;
  77. for (; i < size; ++i) {
  78. data[i] /= sum;
  79. }
  80. for (; i < capacity; ++i) {
  81. data[i] = 0;
  82. }
  83. return {max, sum};
  84. }
  85. template <typename T>
  86. FORCE_INLINE void reducePartitonSoftmax(const T *max_data, T *sum_data,
  87. const int size) {
  88. T max = max_data[0];
  89. for (int i = 1; i < size; ++i) {
  90. max = max >= max_data[i] ? max : max_data[i];
  91. }
  92. T rescaled_sum = 0;
  93. for (int i = 0; i < size; ++i) {
  94. T rescale_factor = std::exp(max_data[i] - max);
  95. rescaled_sum += rescale_factor * sum_data[i];
  96. sum_data[i] *= rescale_factor;
  97. }
  98. for (int i = 0; i < size; ++i) {
  99. sum_data[i] /= rescaled_sum + 1e-8;
  100. }
  101. }
  102. template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int x>
  103. struct reduceQKBlockKernel {
  104. using q_load_vec_type = typename KernelVecType<scalar_t>::q_load_vec_type;
  105. using q_vec_type = typename KernelVecType<scalar_t>::q_vec_type;
  106. using k_load_vec_type = typename KernelVecType<scalar_t>::k_load_vec_type;
  107. using k_vec_type = typename KernelVecType<scalar_t>::k_vec_type;
  108. using qk_acc_vec_type = typename KernelVecType<scalar_t>::qk_acc_vec_type;
  109. constexpr static int TOKEN_PER_GROUP = k_load_vec_type::get_elem_num() / x;
  110. constexpr static int MAX_GROUP_NUM = 16 / TOKEN_PER_GROUP;
  111. constexpr static int UNROLL_GROUP_NUM = MAX_GROUP_NUM / 4;
  112. static_assert(MAX_GROUP_NUM == 8 || MAX_GROUP_NUM == 4);
  113. static_assert(k_load_vec_type::get_elem_num() % x == 0);
  114. static_assert(q_load_vec_type::get_elem_num() * sizeof(scalar_t) == 16);
  115. FORCE_INLINE static void call(const scalar_t *__restrict__ q,
  116. const scalar_t *__restrict__ k_block,
  117. float *__restrict__ logits, float scale,
  118. const int token_num) {
  119. const int group_num = (token_num + TOKEN_PER_GROUP - 1) / TOKEN_PER_GROUP;
  120. qk_acc_vec_type group_accums[MAX_GROUP_NUM];
  121. if (token_num == BLOCK_SIZE) {
  122. for (int q_offset = 0; q_offset < HEAD_SIZE;
  123. q_offset += x, k_block += x * BLOCK_SIZE) {
  124. q_load_vec_type q_load_group_vec(q + q_offset);
  125. q_vec_type q_group_vec(q_load_group_vec);
  126. vec_op::unroll_loop<int, MAX_GROUP_NUM>(
  127. [k_block, &q_group_vec, &group_accums](int token_group_idx) {
  128. k_load_vec_type k_load_group_vec(k_block + token_group_idx * x *
  129. TOKEN_PER_GROUP);
  130. k_vec_type k_group_vec(k_load_group_vec);
  131. vec_op::fma(group_accums[token_group_idx], q_group_vec,
  132. k_group_vec);
  133. vec_op::prefetch(k_block + x * BLOCK_SIZE +
  134. token_group_idx * x * TOKEN_PER_GROUP);
  135. });
  136. }
  137. } else {
  138. for (int q_offset = 0; q_offset < HEAD_SIZE;
  139. q_offset += x, k_block += x * BLOCK_SIZE) {
  140. q_load_vec_type q_load_group_vec(q + q_offset);
  141. q_vec_type q_group_vec(q_load_group_vec);
  142. for (int token_group_start = 0; token_group_start < group_num;
  143. token_group_start += UNROLL_GROUP_NUM) {
  144. vec_op::unroll_loop<int, UNROLL_GROUP_NUM>(
  145. [token_group_start, k_block, &q_group_vec,
  146. &group_accums](int token_group_idx) {
  147. token_group_idx += token_group_start;
  148. k_load_vec_type k_load_group_vec(k_block + token_group_idx * x *
  149. TOKEN_PER_GROUP);
  150. k_vec_type k_group_vec(k_load_group_vec);
  151. vec_op::fma(group_accums[token_group_idx], q_group_vec,
  152. k_group_vec);
  153. vec_op::prefetch(k_block + x * BLOCK_SIZE +
  154. token_group_idx * x * TOKEN_PER_GROUP);
  155. });
  156. }
  157. }
  158. }
  159. for (int token_group_idx = 0; token_group_idx < group_num;
  160. ++token_group_idx) {
  161. vec_op::unroll_loop<int, TOKEN_PER_GROUP>(
  162. [&group_accums, logits, scale, token_group_idx](int token_idx) {
  163. float dot_v =
  164. group_accums[token_group_idx]
  165. .template reduce_sub_sum<qk_acc_vec_type::get_elem_num() /
  166. TOKEN_PER_GROUP>(token_idx);
  167. logits[token_group_idx * TOKEN_PER_GROUP + token_idx] =
  168. dot_v * scale;
  169. });
  170. }
  171. }
  172. };
  173. template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE,
  174. int HEAD_PARTITION_SIZE, typename acc_t>
  175. FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block,
  176. acc_t &&acc) {
  177. using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
  178. constexpr int ELEM_NUM = v_load_vec_type::get_elem_num();
  179. static_assert(BLOCK_SIZE == ELEM_NUM);
  180. vec_op::FP32Vec16 prob_vec(prob);
  181. vec_op::unroll_loop<int, HEAD_PARTITION_SIZE>([&](int head_elem_idx) {
  182. v_load_vec_type v_vec(v_block + BLOCK_SIZE * head_elem_idx);
  183. vec_op::FP32Vec16 fp32_v_vec(v_vec);
  184. acc[head_elem_idx] = acc[head_elem_idx] + prob_vec * fp32_v_vec;
  185. });
  186. }
  187. }; // namespace
  188. // Paged attention v1
  189. namespace {
  190. template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE>
  191. struct paged_attention_v1_impl {
  192. static void
  193. call(scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size]
  194. const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size]
  195. const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads,
  196. // head_size/x, block_size, x]
  197. const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads,
  198. // head_size, block_size]
  199. const int num_kv_heads, const float scale,
  200. const int
  201. *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
  202. const int *__restrict__ context_lens, // [num_seqs]
  203. const int max_num_blocks_per_seq,
  204. const float *__restrict__ alibi_slopes, // [num_heads]
  205. const int q_stride, const int kv_block_stride, const int kv_head_stride,
  206. const int num_seqs, const int num_heads) {
  207. constexpr int x = 16 / sizeof(scalar_t);
  208. const int num_queries_per_kv = num_heads / num_kv_heads;
  209. static_assert(BLOCK_SIZE == 16);
  210. int max_context_len = max_num_blocks_per_seq * BLOCK_SIZE;
  211. int max_context_len_padded = (max_context_len + 15) & 0xFFFFFFF0;
  212. TORCH_CHECK((max_context_len_padded * sizeof(float)) % 64 == 0);
  213. const int parallel_work_item_num = omp_get_max_threads();
  214. size_t logits_bytes =
  215. parallel_work_item_num * max_context_len_padded * sizeof(float);
  216. float *logits = (float *)std::aligned_alloc(
  217. 64, logits_bytes); // Cacheline alignment for each context token.
  218. // [parallel_work_item_num, max_context_len_padded]
  219. #pragma omp parallel for collapse(2) schedule(dynamic, 1)
  220. for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
  221. for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
  222. int context_len = context_lens[seq_idx];
  223. const int *seq_block_table =
  224. block_tables + max_num_blocks_per_seq * seq_idx;
  225. const int block_num = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
  226. const int64_t kv_head_idx = head_idx / num_queries_per_kv;
  227. const scalar_t *__restrict__ q_vec_ptr =
  228. q + seq_idx * q_stride + head_idx * HEAD_SIZE;
  229. const int last_block_token_num =
  230. context_len - (block_num - 1) * BLOCK_SIZE;
  231. float *__restrict__ thread_block_logits =
  232. logits + omp_get_thread_num() * max_context_len_padded;
  233. // Compute logits
  234. for (int block_idx = 0; block_idx < block_num; ++block_idx) {
  235. const int64_t physical_block_idx = seq_block_table[block_idx];
  236. const scalar_t *__restrict__ k_block_cache_ptr =
  237. k_cache + physical_block_idx * kv_block_stride +
  238. kv_head_idx * kv_head_stride;
  239. float *__restrict__ head_block_logits =
  240. thread_block_logits + block_idx * BLOCK_SIZE;
  241. reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call(
  242. q_vec_ptr, k_block_cache_ptr, head_block_logits, scale,
  243. block_idx == block_num - 1 ? last_block_token_num : BLOCK_SIZE);
  244. }
  245. // Compute softmax
  246. if (alibi_slopes) {
  247. reduceSoftmaxAlibi(thread_block_logits, context_len,
  248. block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0,
  249. context_len);
  250. } else {
  251. reduceSoftmax(thread_block_logits, context_len,
  252. block_num * BLOCK_SIZE);
  253. }
  254. // Compute value
  255. constexpr int head_elem_num_per_partition = 16;
  256. constexpr int head_partition_num =
  257. HEAD_SIZE / head_elem_num_per_partition;
  258. for (int head_part_idx = 0; head_part_idx < head_partition_num;
  259. ++head_part_idx) {
  260. vec_op::FP32Vec16 accums[head_elem_num_per_partition];
  261. scalar_t *__restrict__ out_ptr =
  262. out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE +
  263. head_part_idx * head_elem_num_per_partition;
  264. for (int block_idx = 0; block_idx < block_num; ++block_idx) {
  265. const int64_t physical_block_idx = seq_block_table[block_idx];
  266. const float *__restrict__ prob_vec_ptr =
  267. thread_block_logits + block_idx * BLOCK_SIZE;
  268. const scalar_t *__restrict__ v_block_cache_ptr =
  269. v_cache + physical_block_idx * kv_block_stride +
  270. kv_head_idx * kv_head_stride +
  271. BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
  272. reduceValueBlock<scalar_t, HEAD_SIZE, BLOCK_SIZE,
  273. head_elem_num_per_partition>(
  274. prob_vec_ptr, v_block_cache_ptr, accums);
  275. if (block_idx != block_num - 1) {
  276. const int64_t next_physical_block_idx =
  277. seq_block_table[block_idx + 1];
  278. const scalar_t *__restrict__ next_v_block_cache_ptr =
  279. v_cache + next_physical_block_idx * kv_block_stride +
  280. kv_head_idx * kv_head_stride +
  281. BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
  282. vec_op::unroll_loop<int, head_elem_num_per_partition>(
  283. [&](int head_elem_idx) {
  284. if (head_elem_idx % 2 == 0) {
  285. vec_op::prefetch(next_v_block_cache_ptr +
  286. BLOCK_SIZE * head_elem_idx);
  287. }
  288. });
  289. }
  290. }
  291. vec_op::unroll_loop<int, head_elem_num_per_partition>(
  292. [&](int head_elem_idx) {
  293. float value = accums[head_elem_idx].reduce_sum();
  294. vec_op::storeFP32(value, out_ptr + head_elem_idx);
  295. });
  296. }
  297. }
  298. }
  299. std::free(logits);
  300. }
  301. };
  302. #define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
  303. paged_attention_v1_impl<T, HEAD_SIZE, BLOCK_SIZE>::call( \
  304. out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
  305. block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \
  306. alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \
  307. num_heads);
  308. template <typename T, int BLOCK_SIZE>
  309. void paged_attention_v1_impl_launcher(
  310. torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache,
  311. torch::Tensor &value_cache, int num_kv_heads, float scale,
  312. torch::Tensor &block_tables, torch::Tensor &context_lens,
  313. int max_context_len, const c10::optional<torch::Tensor> &alibi_slopes) {
  314. int num_seqs = query.size(0);
  315. int num_heads = query.size(1);
  316. int head_size = query.size(2);
  317. int max_num_blocks_per_seq = block_tables.size(1);
  318. int q_stride = query.stride(0);
  319. int kv_block_stride = key_cache.stride(0);
  320. int kv_head_stride = key_cache.stride(1);
  321. // NOTE: alibi_slopes is optional.
  322. const float *alibi_slopes_ptr =
  323. alibi_slopes
  324. ? reinterpret_cast<const float *>(alibi_slopes.value().data_ptr())
  325. : nullptr;
  326. T *out_ptr = reinterpret_cast<T *>(out.data_ptr());
  327. T *query_ptr = reinterpret_cast<T *>(query.data_ptr());
  328. T *key_cache_ptr = reinterpret_cast<T *>(key_cache.data_ptr());
  329. T *value_cache_ptr = reinterpret_cast<T *>(value_cache.data_ptr());
  330. int *block_tables_ptr = block_tables.data_ptr<int>();
  331. int *context_lens_ptr = context_lens.data_ptr<int>();
  332. switch (head_size) {
  333. case 64:
  334. LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
  335. break;
  336. case 80:
  337. LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
  338. break;
  339. case 96:
  340. LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
  341. break;
  342. case 112:
  343. LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
  344. break;
  345. case 128:
  346. LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
  347. break;
  348. case 256:
  349. LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
  350. break;
  351. default:
  352. TORCH_CHECK(false, "Unsupported head size: ", head_size);
  353. break;
  354. }
  355. }
  356. #define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
  357. paged_attention_v1_impl_launcher<T, BLOCK_SIZE>( \
  358. out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
  359. context_lens, max_context_len, alibi_slopes);
  360. #define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
  361. switch (block_size) { \
  362. case 16: \
  363. CALL_V1_KERNEL_LAUNCHER(T, 16); \
  364. break; \
  365. default: \
  366. TORCH_CHECK(false, "Unsupported block size: ", block_size); \
  367. break; \
  368. }
  369. } // namespace
  370. void paged_attention_v1(torch::Tensor &out, torch::Tensor &query,
  371. torch::Tensor &key_cache, torch::Tensor &value_cache,
  372. int num_kv_heads, float scale,
  373. torch::Tensor &block_tables,
  374. torch::Tensor &context_lens, int block_size,
  375. int max_context_len,
  376. const c10::optional<torch::Tensor> &alibi_slopes,
  377. const std::string &kv_cache_dtype, float kv_scale) {
  378. TORCH_CHECK(kv_scale == 1.0f);
  379. APHRODITE_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
  380. [&] {
  381. CPU_KERNEL_GUARD_IN(paged_attention_v1_impl)
  382. CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t);
  383. CPU_KERNEL_GUARD_OUT(paged_attention_v1_impl)
  384. });
  385. }
  386. // Paged attention v2
  387. namespace {
  388. template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int PARTITION_SIZE>
  389. struct paged_attention_v2_impl {
  390. static void call(
  391. scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size]
  392. float *__restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
  393. float
  394. *__restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
  395. scalar_t *__restrict__ tmp_out, // [num_seqs, num_heads,
  396. // max_num_partitions, head_size]
  397. const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size]
  398. const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads,
  399. // head_size/x, block_size, x]
  400. const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads,
  401. // head_size, block_size]
  402. const int num_kv_heads, const float scale,
  403. const int
  404. *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
  405. const int *__restrict__ context_lens, // [num_seqs]
  406. const int max_num_blocks_per_seq,
  407. const float *__restrict__ alibi_slopes, // [num_heads]
  408. const int q_stride, const int kv_block_stride, const int kv_head_stride,
  409. const int num_seqs, const int num_heads, const int max_num_partitions) {
  410. constexpr int x = 16 / sizeof(scalar_t);
  411. const int num_queries_per_kv = num_heads / num_kv_heads;
  412. static_assert(BLOCK_SIZE == 16);
  413. static_assert(PARTITION_SIZE * sizeof(float) % 64 == 0);
  414. static_assert(PARTITION_SIZE % BLOCK_SIZE == 0);
  415. #pragma omp parallel for collapse(3) schedule(static, 1)
  416. for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
  417. for (int partition_idx = 0; partition_idx < max_num_partitions;
  418. ++partition_idx) {
  419. for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
  420. const int context_len = context_lens[seq_idx];
  421. const int start_token_idx = partition_idx * PARTITION_SIZE;
  422. if (start_token_idx >= context_len)
  423. continue;
  424. const int partition_num =
  425. (context_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
  426. const bool no_reduce = (partition_num == 1);
  427. const int context_token_num =
  428. (std::min(context_len, start_token_idx + PARTITION_SIZE) -
  429. start_token_idx);
  430. const int block_num =
  431. (context_token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
  432. const int last_block_token_num =
  433. context_token_num - (block_num - 1) * BLOCK_SIZE;
  434. const int *seq_block_table = block_tables +
  435. max_num_blocks_per_seq * seq_idx +
  436. start_token_idx / BLOCK_SIZE;
  437. const int64_t kv_head_idx = head_idx / num_queries_per_kv;
  438. const scalar_t *__restrict__ q_vec_ptr =
  439. q + seq_idx * q_stride + head_idx * HEAD_SIZE;
  440. float logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0};
  441. // Compute logits
  442. for (int block_idx = 0; block_idx < block_num; ++block_idx) {
  443. const int64_t physical_block_idx = seq_block_table[block_idx];
  444. const scalar_t *__restrict__ k_block_cache_ptr =
  445. k_cache + physical_block_idx * kv_block_stride +
  446. kv_head_idx * kv_head_stride;
  447. float *__restrict__ head_block_logits =
  448. logits + block_idx * BLOCK_SIZE;
  449. reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call(
  450. q_vec_ptr, k_block_cache_ptr, head_block_logits, scale,
  451. block_idx == block_num - 1 ? last_block_token_num : BLOCK_SIZE);
  452. }
  453. std::pair<float, float> max_and_sum;
  454. if (alibi_slopes) {
  455. max_and_sum = reduceSoftmaxAlibi(
  456. logits, context_token_num, block_num * BLOCK_SIZE,
  457. alibi_slopes[head_idx], start_token_idx, context_len);
  458. } else {
  459. max_and_sum = reduceSoftmax(logits, context_token_num,
  460. block_num * BLOCK_SIZE);
  461. }
  462. auto &&[max_logit, exp_sum] = max_and_sum;
  463. scalar_t *__restrict__ output_buffer = nullptr;
  464. if (!no_reduce) {
  465. auto idx = seq_idx * num_heads * max_num_partitions +
  466. head_idx * max_num_partitions + partition_idx;
  467. max_logits[idx] = max_logit;
  468. exp_sums[idx] = exp_sum;
  469. output_buffer =
  470. tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
  471. head_idx * max_num_partitions * HEAD_SIZE +
  472. partition_idx * HEAD_SIZE;
  473. } else {
  474. output_buffer =
  475. out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
  476. }
  477. // Compute value
  478. constexpr int head_elem_num_per_partition = 16;
  479. constexpr int head_partition_num =
  480. HEAD_SIZE / head_elem_num_per_partition;
  481. for (int head_part_idx = 0; head_part_idx < head_partition_num;
  482. ++head_part_idx) {
  483. vec_op::FP32Vec16 accums[head_elem_num_per_partition];
  484. scalar_t *__restrict__ out_ptr =
  485. output_buffer + head_part_idx * head_elem_num_per_partition;
  486. for (int block_idx = 0; block_idx < block_num; ++block_idx) {
  487. const int64_t physical_block_idx = seq_block_table[block_idx];
  488. const float *__restrict__ prob_vec_ptr =
  489. logits + block_idx * BLOCK_SIZE;
  490. const scalar_t *__restrict__ v_block_cache_ptr =
  491. v_cache + physical_block_idx * kv_block_stride +
  492. kv_head_idx * kv_head_stride +
  493. BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
  494. reduceValueBlock<scalar_t, HEAD_SIZE, BLOCK_SIZE,
  495. head_elem_num_per_partition>(
  496. prob_vec_ptr, v_block_cache_ptr, accums);
  497. if (block_idx != block_num - 1) {
  498. const int64_t next_physical_block_idx =
  499. seq_block_table[block_idx + 1];
  500. const scalar_t *__restrict__ next_v_block_cache_ptr =
  501. v_cache + next_physical_block_idx * kv_block_stride +
  502. kv_head_idx * kv_head_stride +
  503. BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
  504. vec_op::unroll_loop<int, head_elem_num_per_partition>(
  505. [&](int head_elem_idx) {
  506. if (head_elem_idx % 2 == 0) {
  507. vec_op::prefetch(next_v_block_cache_ptr +
  508. BLOCK_SIZE * head_elem_idx);
  509. }
  510. });
  511. }
  512. }
  513. vec_op::unroll_loop<int, head_elem_num_per_partition>(
  514. [&](int head_elem_idx) {
  515. float value = accums[head_elem_idx].reduce_sum();
  516. vec_op::storeFP32(value, out_ptr + head_elem_idx);
  517. });
  518. }
  519. }
  520. }
  521. }
  522. // Rescale partition softmax and store the factors to exp_sums
  523. #pragma omp parallel for collapse(2) schedule(static, 1)
  524. for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
  525. for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
  526. const int context_len = context_lens[seq_idx];
  527. const int partition_num =
  528. (context_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
  529. if (partition_num == 1)
  530. continue;
  531. reducePartitonSoftmax(
  532. max_logits + seq_idx * num_heads * max_num_partitions +
  533. head_idx * max_num_partitions,
  534. exp_sums + seq_idx * num_heads * max_num_partitions +
  535. head_idx * max_num_partitions,
  536. partition_num);
  537. }
  538. }
  539. // Reduce values
  540. using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
  541. static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE);
  542. constexpr int head_elem_num_per_group =
  543. 16; // Note: didn't align with the cacheline size, due to some HEAD_SIZE
  544. // didn't align with 64 bytes
  545. static_assert(HEAD_SIZE % head_elem_num_per_group == 0);
  546. constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group;
  547. const float *__restrict__ rescale_factors = exp_sums;
  548. #pragma omp parallel for collapse(3) schedule(static, 1)
  549. for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
  550. for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
  551. for (int group_idx = 0; group_idx < head_group_num; ++group_idx) {
  552. const int context_len = context_lens[seq_idx];
  553. const int partition_num =
  554. (context_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
  555. if (partition_num == 1)
  556. continue;
  557. const float *__restrict__ seq_head_rescale_factors =
  558. rescale_factors + seq_idx * num_heads * max_num_partitions +
  559. head_idx * max_num_partitions;
  560. const scalar_t *__restrict__ seq_head_tmp_out =
  561. tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
  562. head_idx * max_num_partitions * HEAD_SIZE +
  563. group_idx * head_elem_num_per_group;
  564. scalar_t *__restrict__ seq_head_output =
  565. out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE +
  566. group_idx * head_elem_num_per_group;
  567. vec_op::FP32Vec16 acc;
  568. for (int i = 0; i < partition_num; ++i) {
  569. vec_op::FP32Vec16 rescale_factor(seq_head_rescale_factors[i]);
  570. v_load_vec_type value(seq_head_tmp_out + i * HEAD_SIZE);
  571. vec_op::FP32Vec16 fp32_value(value);
  572. acc = acc + fp32_value * rescale_factor;
  573. }
  574. v_load_vec_type cast_acc(acc);
  575. cast_acc.save(seq_head_output);
  576. }
  577. }
  578. }
  579. }
  580. };
  581. #define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
  582. paged_attention_v2_impl<T, HEAD_SIZE, BLOCK_SIZE, PARTITION_SIZE>::call( \
  583. out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \
  584. key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
  585. context_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
  586. kv_block_stride, kv_head_stride, num_seqs, num_heads, \
  587. max_num_partitions);
  588. template <typename T, int BLOCK_SIZE, int PARTITION_SIZE = 512>
  589. void paged_attention_v2_impl_launcher(
  590. torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits,
  591. torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache,
  592. torch::Tensor &value_cache, int num_kv_heads, float scale,
  593. torch::Tensor &block_tables, torch::Tensor &context_lens, int block_size,
  594. int max_context_len, const c10::optional<torch::Tensor> &alibi_slopes) {
  595. int num_seqs = query.size(0);
  596. int num_heads = query.size(1);
  597. int head_size = query.size(2);
  598. int max_num_blocks_per_seq = block_tables.size(1);
  599. int q_stride = query.stride(0);
  600. int kv_block_stride = key_cache.stride(0);
  601. int kv_head_stride = key_cache.stride(1);
  602. int max_num_partitions = exp_sums.size(-1);
  603. // NOTE: alibi_slopes is optional.
  604. const float *alibi_slopes_ptr =
  605. alibi_slopes
  606. ? reinterpret_cast<const float *>(alibi_slopes.value().data_ptr())
  607. : nullptr;
  608. T *out_ptr = reinterpret_cast<T *>(out.data_ptr());
  609. float *exp_sums_ptr = reinterpret_cast<float *>(exp_sums.data_ptr());
  610. float *max_logits_ptr = reinterpret_cast<float *>(max_logits.data_ptr());
  611. T *tmp_out_ptr = reinterpret_cast<T *>(tmp_out.data_ptr());
  612. T *query_ptr = reinterpret_cast<T *>(query.data_ptr());
  613. T *key_cache_ptr = reinterpret_cast<T *>(key_cache.data_ptr());
  614. T *value_cache_ptr = reinterpret_cast<T *>(value_cache.data_ptr());
  615. int *block_tables_ptr = block_tables.data_ptr<int>();
  616. int *context_lens_ptr = context_lens.data_ptr<int>();
  617. switch (head_size) {
  618. case 64:
  619. LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
  620. break;
  621. case 80:
  622. LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
  623. break;
  624. case 96:
  625. LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
  626. break;
  627. case 112:
  628. LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
  629. break;
  630. case 128:
  631. LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
  632. break;
  633. case 256:
  634. LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
  635. break;
  636. default:
  637. TORCH_CHECK(false, "Unsupported head size: ", head_size);
  638. break;
  639. }
  640. }
  641. #define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
  642. paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \
  643. out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
  644. num_kv_heads, scale, block_tables, context_lens, block_size, \
  645. max_context_len, alibi_slopes);
  646. #define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
  647. switch (block_size) { \
  648. case 16: \
  649. CALL_V2_KERNEL_LAUNCHER(T, 16); \
  650. break; \
  651. default: \
  652. TORCH_CHECK(false, "Unsupported block size: ", block_size); \
  653. break; \
  654. }
  655. } // namespace
  656. void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums,
  657. torch::Tensor &max_logits, torch::Tensor &tmp_out,
  658. torch::Tensor &query, torch::Tensor &key_cache,
  659. torch::Tensor &value_cache, int num_kv_heads,
  660. float scale, torch::Tensor &block_tables,
  661. torch::Tensor &context_lens, int block_size,
  662. int max_context_len,
  663. const c10::optional<torch::Tensor> &alibi_slopes,
  664. const std::string &kv_cache_dtype, float kv_scale) {
  665. TORCH_CHECK(kv_scale == 1.0f);
  666. APHRODITE_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
  667. [&] {
  668. CPU_KERNEL_GUARD_IN(paged_attention_v2_impl)
  669. CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t);
  670. CPU_KERNEL_GUARD_OUT(paged_attention_v2_impl)
  671. });
  672. }