seqlen.h 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. namespace flash {
  6. // We consolidate all the info related to sequence length here. This is so that we can do all
  7. // the gmem reads once at the beginning of each tile, rather than having to repeat these reads
  8. // to compute various things like n_block_min, n_block_max, etc.
  9. template <bool Varlen, int kBlock>
  10. struct SeqlenInfo {
  11. int const offset, offset_padded;
  12. int const seqlen;
  13. CUTLASS_DEVICE
  14. SeqlenInfo(int const bidb, int const seqlen_static, int const* const cu_seqlens, int const* const seqused)
  15. : offset(!Varlen || cu_seqlens == nullptr ? 0 : cu_seqlens[bidb])
  16. , offset_padded(!Varlen || cu_seqlens == nullptr ? 0 : (cu_seqlens[bidb] + bidb * kBlock) / kBlock * kBlock)
  17. , seqlen(!Varlen
  18. ? seqlen_static
  19. : (seqused ? seqused[bidb] : (cu_seqlens ? cu_seqlens[bidb + 1] - cu_seqlens[bidb] : seqlen_static)))
  20. {
  21. }
  22. };
  23. template <bool Varlen, int kBlockM>
  24. struct SeqlenInfoQK {
  25. int const offset_q, offset_k, offset_q_padded;
  26. int const seqlen_q, seqlen_k;
  27. CUTLASS_DEVICE
  28. SeqlenInfoQK(int const bidb, int const seqlen_q_static, int const seqlen_k_static,
  29. int const* const cu_seqlens_q, int const* const cu_seqlens_k,
  30. int const* const seqused_q, int const* const seqused_k
  31. )
  32. : offset_q(!Varlen || cu_seqlens_q == nullptr ? 0 : cu_seqlens_q[bidb])
  33. , offset_k(!Varlen || cu_seqlens_k == nullptr ? 0 : cu_seqlens_k[bidb])
  34. // If varlen, the layout for dPSum, LSE_log2, and dQaccum is that we pad each sequence in the batch
  35. // by an extra kBlockM, so that the write for each sequence doesn't touch the next sequence.
  36. // Sequence i starts at cu_seqlens[i] + i * kBlockM and ends at cu_seqlens[i + 1] + i * kBlockM
  37. // However, the start must align to multiples of kBlockM.
  38. , offset_q_padded(!Varlen || cu_seqlens_q == nullptr ? 0 : (cu_seqlens_q[bidb] + bidb * kBlockM) / kBlockM * kBlockM)
  39. , seqlen_q(!Varlen
  40. ? seqlen_q_static
  41. : (seqused_q ? seqused_q[bidb] : (cu_seqlens_q ? cu_seqlens_q[bidb + 1] - cu_seqlens_q[bidb] : seqlen_q_static)))
  42. , seqlen_k(!Varlen
  43. ? seqlen_k_static
  44. : (seqused_k ? seqused_k[bidb] : (cu_seqlens_k ? cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb] : seqlen_k_static)))
  45. {
  46. }
  47. };
  48. template <bool Varlen, bool AppendKV>
  49. struct SeqlenInfoQKNewK {
  50. static_assert(!(AppendKV && !Varlen), "AppendKV is only supported with Varlen");
  51. int const leftpad_k;
  52. int const offset_q, offset_k, offset_k_new;
  53. int const seqlen_q, seqlen_k_og, seqlen_k_new, seqlen_k;
  54. CUTLASS_DEVICE
  55. SeqlenInfoQKNewK(int const bidb, int const seqlen_q_static, int const seqlen_k_static, int const shape_K_new_0,
  56. int const* const cu_seqlens_q, int const* const cu_seqlens_k, int const* const cu_seqlens_k_new,
  57. int const* const seqused_q, int const* const seqused_k, int const* const ptr_leftpad_k
  58. )
  59. : leftpad_k(ptr_leftpad_k ? ptr_leftpad_k[bidb] : 0)
  60. , offset_q(!Varlen || cu_seqlens_q == nullptr ? 0 : cu_seqlens_q[bidb])
  61. , offset_k(!Varlen ? 0 : (cu_seqlens_k ? cu_seqlens_k[bidb] : 0) + leftpad_k)
  62. , offset_k_new(!AppendKV || cu_seqlens_k_new == nullptr ? 0 : cu_seqlens_k_new[bidb])
  63. , seqlen_q(!Varlen
  64. ? seqlen_q_static
  65. : (seqused_q ? seqused_q[bidb] : (cu_seqlens_q ? cu_seqlens_q[bidb + 1] - cu_seqlens_q[bidb] : seqlen_q_static)))
  66. , seqlen_k_og(!Varlen
  67. ? seqlen_k_static
  68. : (seqused_k ? seqused_k[bidb] : (cu_seqlens_k ? cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb] : seqlen_k_static)) - leftpad_k)
  69. , seqlen_k_new(!AppendKV
  70. ? 0
  71. : (cu_seqlens_k_new ? cu_seqlens_k_new[bidb + 1] - cu_seqlens_k_new[bidb] : shape_K_new_0))
  72. , seqlen_k(!AppendKV ? seqlen_k_og : seqlen_k_og + seqlen_k_new)
  73. {
  74. }
  75. };
  76. } // namespace flash