block_info.h 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. /******************************************************************************
  2. * Copyright (c) 2023, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. namespace flash {
  6. ////////////////////////////////////////////////////////////////////////////////////////////////////
  7. template <bool Varlen = true>
  8. struct BlockInfo {
  9. template <typename Params>
  10. __device__ BlockInfo(const Params& params, const int bidb)
  11. : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr
  12. ? -1
  13. : params.cu_seqlens_q[bidb]),
  14. sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ||
  15. !params.is_seqlens_k_cumulative
  16. ? -1
  17. : params.cu_seqlens_k[bidb]),
  18. actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr
  19. ? params.seqlen_q
  20. : params.cu_seqlens_q[bidb + 1] - sum_s_q)
  21. // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] -
  22. // cu_seqlens_k[bidb]. Otherwise it's cu_seqlens_k[bidb], i.e., we use
  23. // cu_seqlens_k to store the sequence lengths of K.
  24. ,
  25. seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr
  26. ? params.seqlen_k
  27. : (params.is_seqlens_k_cumulative
  28. ? params.cu_seqlens_k[bidb + 1] - sum_s_k
  29. : params.cu_seqlens_k[bidb])),
  30. actual_seqlen_k(params.seqused_k
  31. ? params.seqused_k[bidb]
  32. : seqlen_k_cache + (params.knew_ptr == nullptr
  33. ? 0
  34. : params.seqlen_knew)) {}
  35. template <typename index_t>
  36. __forceinline__ __device__ index_t q_offset(const index_t batch_stride,
  37. const index_t row_stride,
  38. const int bidb) const {
  39. return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
  40. }
  41. template <typename index_t>
  42. __forceinline__ __device__ index_t k_offset(const index_t batch_stride,
  43. const index_t row_stride,
  44. const int bidb) const {
  45. return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
  46. }
  47. const int sum_s_q;
  48. const int sum_s_k;
  49. const int actual_seqlen_q;
  50. // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise
  51. // actual_seqlen_k is set to 0.
  52. const int seqlen_k_cache;
  53. const int actual_seqlen_k;
  54. };
  55. ////////////////////////////////////////////////////////////////////////////////////////////////////
  56. } // namespace flash