attention.cpp 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758
  1. #include "cpu_types.hpp"
  2. namespace {
  3. template <typename scalar_t>
  4. struct KernelVecType {
  5. using q_load_vec_type = void;
  6. using q_vec_type = void;
  7. using k_load_vec_type = void;
  8. using k_vec_type = void;
  9. using qk_acc_vec_type = void;
  10. using v_load_vec_type = void;
  11. };
  12. template <>
  13. struct KernelVecType<float> {
  14. using q_load_vec_type = vec_op::FP32Vec4;
  15. using q_vec_type = vec_op::FP32Vec16;
  16. using k_load_vec_type = vec_op::FP32Vec16;
  17. using k_vec_type = vec_op::FP32Vec16;
  18. using qk_acc_vec_type = vec_op::FP32Vec16;
  19. using v_load_vec_type = vec_op::FP32Vec16;
  20. };
  21. #ifdef __AVX512BF16__
  22. template <>
  23. struct KernelVecType<c10::BFloat16> {
  24. using q_load_vec_type = vec_op::BF16Vec8;
  25. using q_vec_type = vec_op::BF16Vec32;
  26. using k_load_vec_type = vec_op::BF16Vec32;
  27. using k_vec_type = vec_op::BF16Vec32;
  28. using qk_acc_vec_type = vec_op::FP32Vec16;
  29. using v_load_vec_type = vec_op::BF16Vec16;
  30. };
  31. #else
  32. template <>
  33. struct KernelVecType<c10::BFloat16> {
  34. using q_load_vec_type = vec_op::BF16Vec8;
  35. using q_vec_type = vec_op::FP32Vec16;
  36. using k_load_vec_type = vec_op::BF16Vec16;
  37. using k_vec_type = vec_op::FP32Vec16;
  38. using qk_acc_vec_type = vec_op::FP32Vec16;
  39. using v_load_vec_type = vec_op::BF16Vec16;
  40. };
  41. #endif
  42. template <typename T>
  43. FORCE_INLINE std::pair<T, T> reduceSoftmax(T* data, const int size,
  44. const int capacity) {
  45. T max = data[0];
  46. for (int i = 1; i < size; ++i) {
  47. max = max >= data[i] ? max : data[i];
  48. }
  49. T sum = 0;
  50. for (int i = 0; i < size; ++i) {
  51. data[i] = std::exp(data[i] - max);
  52. sum += data[i];
  53. }
  54. int i = 0;
  55. for (; i < size; ++i) {
  56. data[i] /= sum;
  57. }
  58. for (; i < capacity; ++i) {
  59. data[i] = 0;
  60. }
  61. return {max, sum};
  62. }
  63. template <typename T>
  64. FORCE_INLINE std::pair<T, T> reduceSoftmaxAlibi(T* data, const int size,
  65. const int capacity,
  66. const float alibi_slope,
  67. const int start_index,
  68. const int seq_len) {
  69. data[0] += alibi_slope * (start_index - seq_len + 1);
  70. T max = data[0];
  71. for (int i = 1; i < size; ++i) {
  72. T qk = data[i] + alibi_slope * (start_index + i - seq_len + 1);
  73. data[i] = qk;
  74. max = max >= qk ? max : qk;
  75. }
  76. T sum = 0;
  77. for (int i = 0; i < size; ++i) {
  78. data[i] = std::exp(data[i] - max);
  79. sum += data[i];
  80. }
  81. int i = 0;
  82. for (; i < size; ++i) {
  83. data[i] /= sum;
  84. }
  85. for (; i < capacity; ++i) {
  86. data[i] = 0;
  87. }
  88. return {max, sum};
  89. }
  90. template <typename T>
  91. FORCE_INLINE void reducePartitonSoftmax(const T* max_data, T* sum_data,
  92. const int size) {
  93. T max = max_data[0];
  94. for (int i = 1; i < size; ++i) {
  95. max = max >= max_data[i] ? max : max_data[i];
  96. }
  97. T rescaled_sum = 0;
  98. for (int i = 0; i < size; ++i) {
  99. T rescale_factor = std::exp(max_data[i] - max);
  100. rescaled_sum += rescale_factor * sum_data[i];
  101. sum_data[i] *= rescale_factor;
  102. }
  103. for (int i = 0; i < size; ++i) {
  104. sum_data[i] /= rescaled_sum + 1e-8;
  105. }
  106. }
  107. template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int x>
  108. struct reduceQKBlockKernel {
  109. using q_load_vec_type = typename KernelVecType<scalar_t>::q_load_vec_type;
  110. using q_vec_type = typename KernelVecType<scalar_t>::q_vec_type;
  111. using k_load_vec_type = typename KernelVecType<scalar_t>::k_load_vec_type;
  112. using k_vec_type = typename KernelVecType<scalar_t>::k_vec_type;
  113. using qk_acc_vec_type = typename KernelVecType<scalar_t>::qk_acc_vec_type;
  114. constexpr static int TOKEN_PER_GROUP = k_load_vec_type::get_elem_num() / x;
  115. constexpr static int MAX_GROUP_NUM = 16 / TOKEN_PER_GROUP;
  116. constexpr static int UNROLL_GROUP_NUM = MAX_GROUP_NUM / 4;
  117. static_assert(MAX_GROUP_NUM == 8 || MAX_GROUP_NUM == 4);
  118. static_assert(k_load_vec_type::get_elem_num() % x == 0);
  119. static_assert(q_load_vec_type::get_elem_num() * sizeof(scalar_t) == 16);
  120. FORCE_INLINE static void call(const scalar_t* __restrict__ q,
  121. const scalar_t* __restrict__ k_block,
  122. float* __restrict__ logits, float scale,
  123. const int token_num) {
  124. const int group_num = (token_num + TOKEN_PER_GROUP - 1) / TOKEN_PER_GROUP;
  125. qk_acc_vec_type group_accums[MAX_GROUP_NUM];
  126. if (token_num == BLOCK_SIZE) {
  127. for (int q_offset = 0; q_offset < HEAD_SIZE;
  128. q_offset += x, k_block += x * BLOCK_SIZE) {
  129. q_load_vec_type q_load_group_vec(q + q_offset);
  130. q_vec_type q_group_vec(q_load_group_vec);
  131. vec_op::unroll_loop<int, MAX_GROUP_NUM>(
  132. [k_block, &q_group_vec, &group_accums](int token_group_idx) {
  133. k_load_vec_type k_load_group_vec(k_block + token_group_idx * x *
  134. TOKEN_PER_GROUP);
  135. k_vec_type k_group_vec(k_load_group_vec);
  136. vec_op::fma(group_accums[token_group_idx], q_group_vec,
  137. k_group_vec);
  138. vec_op::prefetch(k_block + x * BLOCK_SIZE +
  139. token_group_idx * x * TOKEN_PER_GROUP);
  140. });
  141. }
  142. } else {
  143. for (int q_offset = 0; q_offset < HEAD_SIZE;
  144. q_offset += x, k_block += x * BLOCK_SIZE) {
  145. q_load_vec_type q_load_group_vec(q + q_offset);
  146. q_vec_type q_group_vec(q_load_group_vec);
  147. for (int token_group_start = 0; token_group_start < group_num;
  148. token_group_start += UNROLL_GROUP_NUM) {
  149. vec_op::unroll_loop<int, UNROLL_GROUP_NUM>(
  150. [token_group_start, k_block, &q_group_vec,
  151. &group_accums](int token_group_idx) {
  152. token_group_idx += token_group_start;
  153. k_load_vec_type k_load_group_vec(k_block + token_group_idx * x *
  154. TOKEN_PER_GROUP);
  155. k_vec_type k_group_vec(k_load_group_vec);
  156. vec_op::fma(group_accums[token_group_idx], q_group_vec,
  157. k_group_vec);
  158. vec_op::prefetch(k_block + x * BLOCK_SIZE +
  159. token_group_idx * x * TOKEN_PER_GROUP);
  160. });
  161. }
  162. }
  163. }
  164. for (int token_group_idx = 0; token_group_idx < group_num;
  165. ++token_group_idx) {
  166. vec_op::unroll_loop<int, TOKEN_PER_GROUP>(
  167. [&group_accums, logits, scale, token_group_idx](int token_idx) {
  168. float dot_v =
  169. group_accums[token_group_idx]
  170. .template reduce_sub_sum<qk_acc_vec_type::get_elem_num() /
  171. TOKEN_PER_GROUP>(token_idx);
  172. logits[token_group_idx * TOKEN_PER_GROUP + token_idx] =
  173. dot_v * scale;
  174. });
  175. }
  176. }
  177. };
  178. template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE,
  179. int HEAD_PARTITION_SIZE, typename acc_t>
  180. FORCE_INLINE void reduceValueBlock(const float* prob, const scalar_t* v_block,
  181. acc_t&& acc) {
  182. using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
  183. constexpr int ELEM_NUM = v_load_vec_type::get_elem_num();
  184. static_assert(BLOCK_SIZE == ELEM_NUM);
  185. vec_op::FP32Vec16 prob_vec(prob);
  186. vec_op::unroll_loop<int, HEAD_PARTITION_SIZE>([&](int head_elem_idx) {
  187. v_load_vec_type v_vec(v_block + BLOCK_SIZE * head_elem_idx);
  188. vec_op::FP32Vec16 fp32_v_vec(v_vec);
  189. acc[head_elem_idx] = acc[head_elem_idx] + prob_vec * fp32_v_vec;
  190. });
  191. }
  192. }; // namespace
  193. // Paged attention v1
  194. namespace {
  195. template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE>
  196. struct paged_attention_v1_impl {
  197. static void call(
  198. scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
  199. const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
  200. const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
  201. // head_size/x, block_size, x]
  202. const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
  203. // head_size, block_size]
  204. const int num_kv_heads, const float scale,
  205. const int* __restrict__ block_tables, // [num_seqs,
  206. // max_num_blocks_per_seq]
  207. const int* __restrict__ seq_lens, // [num_seqs]
  208. const int max_num_blocks_per_seq,
  209. const float* __restrict__ alibi_slopes, // [num_heads]
  210. const int q_stride, const int kv_block_stride, const int kv_head_stride,
  211. const int num_seqs, const int num_heads) {
  212. constexpr int x = 16 / sizeof(scalar_t);
  213. const int num_queries_per_kv = num_heads / num_kv_heads;
  214. static_assert(BLOCK_SIZE == 16);
  215. int max_seq_len = max_num_blocks_per_seq * BLOCK_SIZE;
  216. int max_seq_len_padded = (max_seq_len + 15) & 0xFFFFFFF0;
  217. TORCH_CHECK((max_seq_len_padded * sizeof(float)) % 64 == 0);
  218. const int parallel_work_item_num = omp_get_max_threads();
  219. size_t logits_bytes =
  220. parallel_work_item_num * max_seq_len_padded * sizeof(float);
  221. float* logits = (float*)std::aligned_alloc(
  222. 64, logits_bytes); // Cacheline alignment for each context token.
  223. // [parallel_work_item_num, max_seq_len_padded]
  224. #pragma omp parallel for collapse(2) schedule(dynamic, 1)
  225. for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
  226. for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
  227. int seq_len = seq_lens[seq_idx];
  228. const int* seq_block_table =
  229. block_tables + max_num_blocks_per_seq * seq_idx;
  230. const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
  231. const int64_t kv_head_idx = head_idx / num_queries_per_kv;
  232. const scalar_t* __restrict__ q_vec_ptr =
  233. q + seq_idx * q_stride + head_idx * HEAD_SIZE;
  234. const int last_block_token_num = seq_len - (block_num - 1) * BLOCK_SIZE;
  235. float* __restrict__ thread_block_logits =
  236. logits + omp_get_thread_num() * max_seq_len_padded;
  237. // Compute logits
  238. for (int block_idx = 0; block_idx < block_num; ++block_idx) {
  239. const int64_t physical_block_idx = seq_block_table[block_idx];
  240. const scalar_t* __restrict__ k_block_cache_ptr =
  241. k_cache + physical_block_idx * kv_block_stride +
  242. kv_head_idx * kv_head_stride;
  243. float* __restrict__ head_block_logits =
  244. thread_block_logits + block_idx * BLOCK_SIZE;
  245. reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call(
  246. q_vec_ptr, k_block_cache_ptr, head_block_logits, scale,
  247. block_idx == block_num - 1 ? last_block_token_num : BLOCK_SIZE);
  248. }
  249. // Compute softmax
  250. if (alibi_slopes) {
  251. reduceSoftmaxAlibi(thread_block_logits, seq_len,
  252. block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0,
  253. seq_len);
  254. } else {
  255. reduceSoftmax(thread_block_logits, seq_len, block_num * BLOCK_SIZE);
  256. }
  257. // Compute value
  258. constexpr int head_elem_num_per_partition = 16;
  259. constexpr int head_partition_num =
  260. HEAD_SIZE / head_elem_num_per_partition;
  261. for (int head_part_idx = 0; head_part_idx < head_partition_num;
  262. ++head_part_idx) {
  263. vec_op::FP32Vec16 accums[head_elem_num_per_partition];
  264. scalar_t* __restrict__ out_ptr =
  265. out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE +
  266. head_part_idx * head_elem_num_per_partition;
  267. for (int block_idx = 0; block_idx < block_num; ++block_idx) {
  268. const int64_t physical_block_idx = seq_block_table[block_idx];
  269. const float* __restrict__ prob_vec_ptr =
  270. thread_block_logits + block_idx * BLOCK_SIZE;
  271. const scalar_t* __restrict__ v_block_cache_ptr =
  272. v_cache + physical_block_idx * kv_block_stride +
  273. kv_head_idx * kv_head_stride +
  274. BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
  275. reduceValueBlock<scalar_t, HEAD_SIZE, BLOCK_SIZE,
  276. head_elem_num_per_partition>(
  277. prob_vec_ptr, v_block_cache_ptr, accums);
  278. if (block_idx != block_num - 1) {
  279. const int64_t next_physical_block_idx =
  280. seq_block_table[block_idx + 1];
  281. const scalar_t* __restrict__ next_v_block_cache_ptr =
  282. v_cache + next_physical_block_idx * kv_block_stride +
  283. kv_head_idx * kv_head_stride +
  284. BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
  285. vec_op::unroll_loop<int, head_elem_num_per_partition>(
  286. [&](int head_elem_idx) {
  287. if (head_elem_idx % 2 == 0) {
  288. vec_op::prefetch(next_v_block_cache_ptr +
  289. BLOCK_SIZE * head_elem_idx);
  290. }
  291. });
  292. }
  293. }
  294. vec_op::unroll_loop<int, head_elem_num_per_partition>(
  295. [&](int head_elem_idx) {
  296. float value = accums[head_elem_idx].reduce_sum();
  297. vec_op::storeFP32(value, out_ptr + head_elem_idx);
  298. });
  299. }
  300. }
  301. }
  302. std::free(logits);
  303. }
  304. };
  305. #define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
  306. paged_attention_v1_impl<T, HEAD_SIZE, BLOCK_SIZE>::call( \
  307. out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
  308. block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
  309. alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \
  310. num_heads);
  311. template <typename T, int BLOCK_SIZE>
  312. void paged_attention_v1_impl_launcher(
  313. torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
  314. torch::Tensor& value_cache, int num_kv_heads, float scale,
  315. torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
  316. const c10::optional<torch::Tensor>& alibi_slopes) {
  317. int num_seqs = query.size(0);
  318. int num_heads = query.size(1);
  319. int head_size = query.size(2);
  320. int max_num_blocks_per_seq = block_tables.size(1);
  321. int q_stride = query.stride(0);
  322. int kv_block_stride = key_cache.stride(0);
  323. int kv_head_stride = key_cache.stride(1);
  324. // NOTE: alibi_slopes is optional.
  325. const float* alibi_slopes_ptr =
  326. alibi_slopes
  327. ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
  328. : nullptr;
  329. T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
  330. T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
  331. T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
  332. T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
  333. int* block_tables_ptr = block_tables.data_ptr<int>();
  334. int* seq_lens_ptr = seq_lens.data_ptr<int>();
  335. switch (head_size) {
  336. case 64:
  337. LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
  338. break;
  339. case 80:
  340. LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
  341. break;
  342. case 96:
  343. LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
  344. break;
  345. case 112:
  346. LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
  347. break;
  348. case 128:
  349. LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
  350. break;
  351. case 192:
  352. LAUNCH_V1_ATTENTION_KERNEL(T, 192, BLOCK_SIZE);
  353. break;
  354. case 256:
  355. LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
  356. break;
  357. default:
  358. TORCH_CHECK(false, "Unsupported head size: ", head_size);
  359. break;
  360. }
  361. }
  362. #define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
  363. paged_attention_v1_impl_launcher<T, BLOCK_SIZE>( \
  364. out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
  365. seq_lens, max_seq_len, alibi_slopes);
  366. #define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
  367. switch (block_size) { \
  368. case 16: \
  369. CALL_V1_KERNEL_LAUNCHER(T, 16); \
  370. break; \
  371. default: \
  372. TORCH_CHECK(false, "Unsupported block size: ", block_size); \
  373. break; \
  374. }
  375. } // namespace
  376. void paged_attention_v1(
  377. torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
  378. torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
  379. torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
  380. int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
  381. const std::string& kv_cache_dtype, double k_scale, double v_scale,
  382. const int64_t tp_rank, const int64_t blocksparse_local_blocks,
  383. const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
  384. const int64_t blocksparse_head_sliding_step) {
  385. TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f);
  386. TORCH_CHECK(blocksparse_vert_stride <= 1,
  387. "CPU backend does not support blocksparse attention yet.");
  388. APHRODITE_DISPATCH_FLOATING_TYPES(
  389. query.scalar_type(), "paged_attention_v1_impl", [&] {
  390. CPU_KERNEL_GUARD_IN(paged_attention_v1_impl)
  391. CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t);
  392. CPU_KERNEL_GUARD_OUT(paged_attention_v1_impl)
  393. });
  394. }
  395. // Paged attention v2
  396. namespace {
  397. template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int PARTITION_SIZE>
  398. struct paged_attention_v2_impl {
  399. static void call(
  400. scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
  401. float* __restrict__ exp_sums, // [num_seqs, num_heads,
  402. // max_num_partitions]
  403. float* __restrict__ max_logits, // [num_seqs, num_heads,
  404. // max_num_partitions]
  405. scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
  406. // max_num_partitions, head_size]
  407. const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
  408. const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
  409. // head_size/x, block_size, x]
  410. const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
  411. // head_size, block_size]
  412. const int num_kv_heads, const float scale,
  413. const int* __restrict__ block_tables, // [num_seqs,
  414. // max_num_blocks_per_seq]
  415. const int* __restrict__ seq_lens, // [num_seqs]
  416. const int max_num_blocks_per_seq,
  417. const float* __restrict__ alibi_slopes, // [num_heads]
  418. const int q_stride, const int kv_block_stride, const int kv_head_stride,
  419. const int num_seqs, const int num_heads, const int max_num_partitions) {
  420. constexpr int x = 16 / sizeof(scalar_t);
  421. const int num_queries_per_kv = num_heads / num_kv_heads;
  422. static_assert(BLOCK_SIZE == 16);
  423. static_assert(PARTITION_SIZE * sizeof(float) % 64 == 0);
  424. static_assert(PARTITION_SIZE % BLOCK_SIZE == 0);
  425. #pragma omp parallel for collapse(3) schedule(static, 1)
  426. for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
  427. for (int partition_idx = 0; partition_idx < max_num_partitions;
  428. ++partition_idx) {
  429. for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
  430. const int seq_len = seq_lens[seq_idx];
  431. const int start_token_idx = partition_idx * PARTITION_SIZE;
  432. if (start_token_idx >= seq_len) continue;
  433. const int partition_num =
  434. (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
  435. const bool no_reduce = (partition_num == 1);
  436. const int token_num =
  437. (std::min(seq_len, start_token_idx + PARTITION_SIZE) -
  438. start_token_idx);
  439. const int block_num = (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
  440. const int last_block_token_num =
  441. token_num - (block_num - 1) * BLOCK_SIZE;
  442. const int* seq_block_table = block_tables +
  443. max_num_blocks_per_seq * seq_idx +
  444. start_token_idx / BLOCK_SIZE;
  445. const int64_t kv_head_idx = head_idx / num_queries_per_kv;
  446. const scalar_t* __restrict__ q_vec_ptr =
  447. q + seq_idx * q_stride + head_idx * HEAD_SIZE;
  448. float logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0};
  449. // Compute logits
  450. for (int block_idx = 0; block_idx < block_num; ++block_idx) {
  451. const int64_t physical_block_idx = seq_block_table[block_idx];
  452. const scalar_t* __restrict__ k_block_cache_ptr =
  453. k_cache + physical_block_idx * kv_block_stride +
  454. kv_head_idx * kv_head_stride;
  455. float* __restrict__ head_block_logits =
  456. logits + block_idx * BLOCK_SIZE;
  457. reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call(
  458. q_vec_ptr, k_block_cache_ptr, head_block_logits, scale,
  459. block_idx == block_num - 1 ? last_block_token_num : BLOCK_SIZE);
  460. }
  461. std::pair<float, float> max_and_sum;
  462. if (alibi_slopes) {
  463. max_and_sum = reduceSoftmaxAlibi(
  464. logits, token_num, block_num * BLOCK_SIZE,
  465. alibi_slopes[head_idx], start_token_idx, seq_len);
  466. } else {
  467. max_and_sum =
  468. reduceSoftmax(logits, token_num, block_num * BLOCK_SIZE);
  469. }
  470. auto&& [max_logit, exp_sum] = max_and_sum;
  471. scalar_t* __restrict__ output_buffer = nullptr;
  472. if (!no_reduce) {
  473. auto idx = seq_idx * num_heads * max_num_partitions +
  474. head_idx * max_num_partitions + partition_idx;
  475. max_logits[idx] = max_logit;
  476. exp_sums[idx] = exp_sum;
  477. output_buffer =
  478. tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
  479. head_idx * max_num_partitions * HEAD_SIZE +
  480. partition_idx * HEAD_SIZE;
  481. } else {
  482. output_buffer =
  483. out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
  484. }
  485. // Compute value
  486. constexpr int head_elem_num_per_partition = 16;
  487. constexpr int head_partition_num =
  488. HEAD_SIZE / head_elem_num_per_partition;
  489. for (int head_part_idx = 0; head_part_idx < head_partition_num;
  490. ++head_part_idx) {
  491. vec_op::FP32Vec16 accums[head_elem_num_per_partition];
  492. scalar_t* __restrict__ out_ptr =
  493. output_buffer + head_part_idx * head_elem_num_per_partition;
  494. for (int block_idx = 0; block_idx < block_num; ++block_idx) {
  495. const int64_t physical_block_idx = seq_block_table[block_idx];
  496. const float* __restrict__ prob_vec_ptr =
  497. logits + block_idx * BLOCK_SIZE;
  498. const scalar_t* __restrict__ v_block_cache_ptr =
  499. v_cache + physical_block_idx * kv_block_stride +
  500. kv_head_idx * kv_head_stride +
  501. BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
  502. reduceValueBlock<scalar_t, HEAD_SIZE, BLOCK_SIZE,
  503. head_elem_num_per_partition>(
  504. prob_vec_ptr, v_block_cache_ptr, accums);
  505. if (block_idx != block_num - 1) {
  506. const int64_t next_physical_block_idx =
  507. seq_block_table[block_idx + 1];
  508. const scalar_t* __restrict__ next_v_block_cache_ptr =
  509. v_cache + next_physical_block_idx * kv_block_stride +
  510. kv_head_idx * kv_head_stride +
  511. BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
  512. vec_op::unroll_loop<int, head_elem_num_per_partition>(
  513. [&](int head_elem_idx) {
  514. if (head_elem_idx % 2 == 0) {
  515. vec_op::prefetch(next_v_block_cache_ptr +
  516. BLOCK_SIZE * head_elem_idx);
  517. }
  518. });
  519. }
  520. }
  521. vec_op::unroll_loop<int, head_elem_num_per_partition>(
  522. [&](int head_elem_idx) {
  523. float value = accums[head_elem_idx].reduce_sum();
  524. vec_op::storeFP32(value, out_ptr + head_elem_idx);
  525. });
  526. }
  527. }
  528. }
  529. }
  530. // Rescale partition softmax and store the factors to exp_sums
  531. #pragma omp parallel for collapse(2) schedule(static, 1)
  532. for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
  533. for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
  534. const int seq_len = seq_lens[seq_idx];
  535. const int partition_num =
  536. (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
  537. if (partition_num == 1) continue;
  538. reducePartitonSoftmax(
  539. max_logits + seq_idx * num_heads * max_num_partitions +
  540. head_idx * max_num_partitions,
  541. exp_sums + seq_idx * num_heads * max_num_partitions +
  542. head_idx * max_num_partitions,
  543. partition_num);
  544. }
  545. }
  546. // Reduce values
  547. using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
  548. static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE);
  549. constexpr int head_elem_num_per_group =
  550. 16; // Note: didn't align with the cacheline size, due to some
  551. // HEAD_SIZE didn't align with 64 bytes
  552. static_assert(HEAD_SIZE % head_elem_num_per_group == 0);
  553. constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group;
  554. const float* __restrict__ rescale_factors = exp_sums;
  555. #pragma omp parallel for collapse(3) schedule(static, 1)
  556. for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
  557. for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
  558. for (int group_idx = 0; group_idx < head_group_num; ++group_idx) {
  559. const int seq_len = seq_lens[seq_idx];
  560. const int partition_num =
  561. (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
  562. if (partition_num == 1) continue;
  563. const float* __restrict__ seq_head_rescale_factors =
  564. rescale_factors + seq_idx * num_heads * max_num_partitions +
  565. head_idx * max_num_partitions;
  566. const scalar_t* __restrict__ seq_head_tmp_out =
  567. tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
  568. head_idx * max_num_partitions * HEAD_SIZE +
  569. group_idx * head_elem_num_per_group;
  570. scalar_t* __restrict__ seq_head_output =
  571. out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE +
  572. group_idx * head_elem_num_per_group;
  573. vec_op::FP32Vec16 acc;
  574. for (int i = 0; i < partition_num; ++i) {
  575. vec_op::FP32Vec16 rescale_factor(seq_head_rescale_factors[i]);
  576. v_load_vec_type value(seq_head_tmp_out + i * HEAD_SIZE);
  577. vec_op::FP32Vec16 fp32_value(value);
  578. acc = acc + fp32_value * rescale_factor;
  579. }
  580. v_load_vec_type cast_acc(acc);
  581. cast_acc.save(seq_head_output);
  582. }
  583. }
  584. }
  585. }
  586. };
  587. #define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
  588. paged_attention_v2_impl<T, HEAD_SIZE, BLOCK_SIZE, PARTITION_SIZE>::call( \
  589. out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \
  590. key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
  591. seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
  592. kv_block_stride, kv_head_stride, num_seqs, num_heads, \
  593. max_num_partitions);
  594. template <typename T, int BLOCK_SIZE, int PARTITION_SIZE = 512>
  595. void paged_attention_v2_impl_launcher(
  596. torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
  597. torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
  598. torch::Tensor& value_cache, int num_kv_heads, float scale,
  599. torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
  600. int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes) {
  601. int num_seqs = query.size(0);
  602. int num_heads = query.size(1);
  603. int head_size = query.size(2);
  604. int max_num_blocks_per_seq = block_tables.size(1);
  605. int q_stride = query.stride(0);
  606. int kv_block_stride = key_cache.stride(0);
  607. int kv_head_stride = key_cache.stride(1);
  608. int max_num_partitions = exp_sums.size(-1);
  609. // NOTE: alibi_slopes is optional.
  610. const float* alibi_slopes_ptr =
  611. alibi_slopes
  612. ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
  613. : nullptr;
  614. T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
  615. float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
  616. float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
  617. T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
  618. T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
  619. T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
  620. T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
  621. int* block_tables_ptr = block_tables.data_ptr<int>();
  622. int* seq_lens_ptr = seq_lens.data_ptr<int>();
  623. switch (head_size) {
  624. case 64:
  625. LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
  626. break;
  627. case 80:
  628. LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
  629. break;
  630. case 96:
  631. LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
  632. break;
  633. case 112:
  634. LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
  635. break;
  636. case 128:
  637. LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
  638. break;
  639. case 192:
  640. LAUNCH_V1_ATTENTION_KERNEL(T, 192, BLOCK_SIZE);
  641. break;
  642. case 256:
  643. LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
  644. break;
  645. default:
  646. TORCH_CHECK(false, "Unsupported head size: ", head_size);
  647. break;
  648. }
  649. }
  650. #define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
  651. paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \
  652. out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
  653. num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, \
  654. alibi_slopes);
  655. #define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
  656. switch (block_size) { \
  657. case 16: \
  658. CALL_V2_KERNEL_LAUNCHER(T, 16); \
  659. break; \
  660. default: \
  661. TORCH_CHECK(false, "Unsupported block size: ", block_size); \
  662. break; \
  663. }
  664. } // namespace
  665. void paged_attention_v2(
  666. torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
  667. torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
  668. torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
  669. torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
  670. int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
  671. const std::string& kv_cache_dtype, double k_scale, double v_scale,
  672. const int64_t tp_rank, const int64_t blocksparse_local_blocks,
  673. const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
  674. const int64_t blocksparse_head_sliding_step) {
  675. TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f);
  676. TORCH_CHECK(blocksparse_vert_stride <= 1,
  677. "CPU backend does not support blocksparse attention yet.");
  678. APHRODITE_DISPATCH_FLOATING_TYPES(
  679. query.scalar_type(), "paged_attention_v2_impl", [&] {
  680. CPU_KERNEL_GUARD_IN(paged_attention_v2_impl)
  681. CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t);
  682. CPU_KERNEL_GUARD_OUT(paged_attention_v2_impl)
  683. });
  684. }