block_info.h 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. /******************************************************************************
  2. * Copyright (c) 2023, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. namespace flash {
  6. ////////////////////////////////////////////////////////////////////////////////////////////////////
  7. template<bool Varlen=true>
  8. struct BlockInfo {
  9. template<typename Params>
  10. __device__ BlockInfo(const Params &params, const int bidb)
  11. : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
  12. , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb])
  13. , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
  14. // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
  15. // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
  16. , seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
  17. , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
  18. {
  19. }
  20. template <typename index_t>
  21. __forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
  22. return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
  23. }
  24. template <typename index_t>
  25. __forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
  26. return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
  27. }
  28. const int sum_s_q;
  29. const int sum_s_k;
  30. const int actual_seqlen_q;
  31. // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
  32. const int seqlen_k_cache;
  33. const int actual_seqlen_k;
  34. };
  35. ////////////////////////////////////////////////////////////////////////////////////////////////////
  36. } // namespace flash