block.h 5.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. namespace flash {
  6. template <class SeqlenInfo_t, int kBlockM, int kBlockN, bool Is_causal, bool Is_local, bool PackGQA=false, bool Split=false>
  7. struct BlockMN {
  8. static
  9. CUTLASS_DEVICE
  10. cute::tuple<int, int> get_n_block_min_max(
  11. SeqlenInfo_t const& seqlen_info,
  12. int const m_block, int const bidb, int const split_idx, int const num_splits,
  13. int const window_size_left, int const window_size_right,
  14. cutlass::FastDivmod const& qhead_per_khead_divmod) {
  15. int const seqlen_k = seqlen_info.seqlen_k;
  16. int const seqlen_q = seqlen_info.seqlen_q;
  17. int n_block_max = cute::ceil_div(seqlen_k, kBlockN);
  18. if constexpr (Is_causal || Is_local) {
  19. int m_idx_max = (m_block + 1) * kBlockM;
  20. // TODO: check off-by-1 error
  21. if (PackGQA) { m_idx_max = qhead_per_khead_divmod.divide(m_idx_max - 1) + 1 ; }
  22. n_block_max = std::min(n_block_max,
  23. cute::ceil_div(m_idx_max + seqlen_k - seqlen_q + window_size_right, kBlockN));
  24. }
  25. int n_block_min = 0;
  26. if constexpr (Is_local) {
  27. int m_idx_min = m_block * kBlockM;
  28. if (PackGQA) { m_idx_min = qhead_per_khead_divmod.divide(m_idx_min); }
  29. n_block_min = std::max(int(0), (m_idx_min + seqlen_k - seqlen_q - window_size_left) / kBlockN);
  30. }
  31. // if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); }
  32. if constexpr (Split) {
  33. uint32_t num_splits_dynamic_u = reinterpret_cast<uint32_t const&>(split_idx) >> 16; // first 16 bits are for num_splits
  34. int num_splits_dynamic = reinterpret_cast<int&>(num_splits_dynamic_u);
  35. int split_idx_actual = split_idx & 0x0000FFFF;
  36. int num_splits_actual = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits;
  37. int num_n_blocks_per_split = n_block_max <= n_block_min ? 0 : cute::ceil_div(n_block_max - n_block_min, num_splits_actual);
  38. n_block_min = n_block_min + split_idx_actual * num_n_blocks_per_split;
  39. n_block_max = std::min(n_block_min + num_n_blocks_per_split, n_block_max);
  40. // if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, num_splits_dynamic = %d, num_splits_actual = %d, num_n_blocks_per_split = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, num_splits_dynamic, num_splits_actual, num_n_blocks_per_split, n_block_min, n_block_max); }
  41. }
  42. // if (threadIdx.x == 128) { printf("After split, inside, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); }
  43. return {n_block_min, n_block_max};
  44. }
  45. static
  46. CUTLASS_DEVICE
  47. cute::tuple<int, int> get_n_block_k_new_min_max(
  48. SeqlenInfo_t const& seqlen_info,
  49. int const m_block, int const bidb, int const split_idx, int const num_splits,
  50. int const window_size_left, int const window_size_right,
  51. cutlass::FastDivmod const& qhead_per_khead_divmod) {
  52. auto [n_block_min, n_block_max] = get_n_block_min_max(
  53. seqlen_info, m_block, bidb, split_idx, num_splits,
  54. window_size_left, window_size_right, qhead_per_khead_divmod);
  55. int const idx_k_new_min = std::max(n_block_min * kBlockN - seqlen_info.seqlen_k_og, 0);
  56. int const idx_k_new_max = std::min(n_block_max * kBlockN - seqlen_info.seqlen_k_og, seqlen_info.seqlen_k_new);
  57. int const n_block_new_min = idx_k_new_min / kBlockN;
  58. int const n_block_new_max = idx_k_new_max > idx_k_new_min ? cute::ceil_div(idx_k_new_max, kBlockN) : n_block_new_min;
  59. // if (threadIdx.x == 128 && m_block == 0) { printf("bidb = %d, seqlen_k_new = %d, seqlen_k_og = %d, n_block_min = %d, n_block_max = %d, idx_k_new_min = %d, idx_k_new_max = %d, n_block_new_min = %d, n_block_new_max = %d\n", bidb, seqlen_k_new, seqlen_k_og, n_block_min, n_block_max, idx_k_new_min, idx_k_new_max, n_block_new_min, n_block_new_max);}
  60. return {n_block_new_min, n_block_new_max};
  61. }
  62. static
  63. CUTLASS_DEVICE
  64. cute::tuple<int, int> get_m_block_min_max(
  65. SeqlenInfo_t const& seqlen_info,
  66. int const n_block, int const bidb,
  67. int const window_size_left, int const window_size_right, int const sink_token_length) {
  68. int const seqlen_q = seqlen_info.seqlen_q;
  69. int const seqlen_k = seqlen_info.seqlen_k;
  70. int m_block_max = cute::ceil_div(seqlen_q, kBlockM);
  71. if constexpr (Is_local) {
  72. if (n_block >= cute::ceil_div(sink_token_length, kBlockN)) {
  73. m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + seqlen_q - seqlen_k + window_size_left, kBlockM));
  74. }
  75. }
  76. int m_block_min = 0;
  77. if constexpr (Is_causal || Is_local) {
  78. m_block_min = std::max(m_block_min, (n_block * kBlockN + seqlen_q - seqlen_k - window_size_right) / kBlockM);
  79. }
  80. return {m_block_min, m_block_max};
  81. }
  82. };
  83. } // namespace flash