seq_len.h 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cutlass/cutlass.h>
  6. #include <cute/layout.hpp>
  7. namespace flash {
  8. static constexpr int kMaxTileSize = 128;
  9. template <bool UseVarSeqLen> class SeqLenTraits {
  10. public:
  11. // Total number of queries / keys. Unpadded.
  12. int sum_s = 0;
  13. // seq len offsets.
  14. int *cu_seq_len = nullptr;
  15. // actual seq len array.
  16. int *seq_used = nullptr;
  17. // seq len of the current batch.
  18. int actual_seq_len = -1;
  19. // max seq len per batch.
  20. int max_seq_len = -1;
  21. // Whether this is for fixed-seq-len or var-seq-len.
  22. static constexpr bool kUseVarSeqLen = UseVarSeqLen;
  23. using ShapeT = std::conditional_t<
  24. UseVarSeqLen,
  25. cute::Shape<int32_t, int32_t, int32_t>,
  26. cute::Shape<int32_t, int32_t, int32_t, int32_t>
  27. >;
  28. using StrideT = std::conditional_t<
  29. UseVarSeqLen,
  30. cute::Shape<int64_t, _1, int64_t>,
  31. cute::Shape<int64_t, _1, int64_t, int64_t>
  32. >;
  33. using LayoutT = cute::Layout<ShapeT, StrideT>;
  34. using ShapeLseT = std::conditional_t<
  35. UseVarSeqLen,
  36. cute::Shape<int32_t, int32_t>,
  37. cute::Shape<int32_t, int32_t, int32_t>
  38. >;
  39. using StrideLseT = std::conditional_t<
  40. UseVarSeqLen,
  41. cute::Shape<int64_t, _1>,
  42. cute::Shape<int64_t, int64_t, _1>
  43. >;
  44. using LayoutLseT = cute::Layout<ShapeLseT, StrideLseT>;
  45. CUTLASS_HOST SeqLenTraits() {}
  46. CUTLASS_HOST SeqLenTraits(
  47. int sum_s, int max_seq_len, int *cu_seq_len = nullptr, int *seq_used = nullptr):
  48. sum_s(sum_s), cu_seq_len(cu_seq_len), seq_used(seq_used),
  49. actual_seq_len(max_seq_len), max_seq_len(max_seq_len) {}
  50. // Returns the layout of a tensor in MKHB format in global memory.
  51. // padded: only useful for var-seq-len for dq_accum and softmax_d.
  52. CUTLASS_HOST_DEVICE auto get_gmem_layout(
  53. int m, int k, int h, int b,
  54. int64_t m_stride, int64_t h_stride, int64_t b_stride,
  55. bool padded = false) const {
  56. static_assert(!UseVarSeqLen, "Default implementation is for FixedSeqLen.");
  57. return make_layout(make_shape(m, k, h, b),
  58. make_stride(m_stride, cute::_1{}, h_stride, b_stride));
  59. }
  60. // Returns the layout of a tensor in MKHB format in global memory.
  61. // padded: only useful for var-seq-len for dq_accum and softmax_d.
  62. CUTLASS_HOST_DEVICE auto get_lse_gmem_layout(
  63. int m, int h, int b, bool padded = false) const {
  64. static_assert(!UseVarSeqLen, "Default implementation is for FixedSeqLen.");
  65. return make_layout(make_shape(b, h, m),
  66. make_stride(int64_t(h * m), int64_t(m), cute::_1()));
  67. }
  68. CUTLASS_DEVICE bool init(int bidb) {return true;}
  69. template <typename MTensor, typename Shape>
  70. CUTLASS_DEVICE auto get_local_tile_tensor(
  71. const MTensor &m_tensor, const Shape &tile_shape,
  72. int bidh, int bidb, bool padded = false) const {
  73. auto g_tensor = local_tile(
  74. m_tensor(_, _, bidh, bidb), tile_shape, make_coord(_, _0{}));
  75. return g_tensor;
  76. }
  77. template <typename MTensor, typename Shape>
  78. CUTLASS_DEVICE auto get_lse_local_tile_tensor(
  79. const MTensor &m_tensor, const Shape &tile_shape,
  80. int bidh, int bidb, bool padded = false) const {
  81. auto g_tensor = local_tile(m_tensor(bidb, bidh, _), tile_shape, make_coord(_));
  82. return g_tensor;
  83. }
  84. };
  85. using FixedSeqLenTraits = SeqLenTraits<false>;
  86. using VarSeqLenTraits = SeqLenTraits<true>;
  87. // Returns the static layout of a var-seq-len tensor in global memory based on
  88. // max_seq_len and max_batch_size.
  89. // padded: only useful for var-seq-len for dq_accum and softmax_d.
  90. // When padded is True, use B_M + kMaxTileSize * B as the total B_M.
  91. template <>
  92. CUTLASS_HOST_DEVICE auto VarSeqLenTraits::get_gmem_layout(
  93. int m, int k, int h, int b,
  94. int64_t m_stride, int64_t h_stride, int64_t b_stride,
  95. bool padded) const {
  96. return make_layout(
  97. make_shape(sum_s + (padded ? kMaxTileSize * b : 0), k, h),
  98. make_stride(m_stride, cute::_1{}, h_stride));
  99. }
  100. // padded: only useful for var-seq-len for dq_accum and softmax_d.
  101. // When padded is True, use B_M + kMaxTileSize * B as the total B_M.
  102. template <>
  103. CUTLASS_HOST_DEVICE auto VarSeqLenTraits::get_lse_gmem_layout(
  104. int m, int h, int b, bool padded) const {
  105. return make_layout(
  106. make_shape(h, sum_s + (padded ? kMaxTileSize * b : 0)),
  107. make_stride(int64_t(sum_s + (padded ? kMaxTileSize * b : 0)), cute::_1()));
  108. }
  109. template <>
  110. CUTLASS_DEVICE bool VarSeqLenTraits::init(int bidb) {
  111. actual_seq_len =
  112. seq_used ? seq_used[bidb] : (cu_seq_len[bidb + 1] - cu_seq_len[bidb]);
  113. return cu_seq_len[bidb] < max_seq_len;
  114. }
  115. template <>
  116. template <typename MTensor, typename Shape>
  117. CUTLASS_DEVICE auto VarSeqLenTraits::get_local_tile_tensor(
  118. const MTensor &m_tensor, const Shape &tile_shape,
  119. int bidh, int bidb, bool padded) const {
  120. auto g_offset = local_tile(
  121. m_tensor(_, _, bidh),
  122. cute::make_shape(1, get<1>(tile_shape)),
  123. make_coord(cu_seq_len[bidb] + (padded ? kMaxTileSize * bidb : 0), _0{}));
  124. auto g_sequence = make_tensor(
  125. g_offset.data(),
  126. make_layout(
  127. cute::make_shape(actual_seq_len, get<1>(tile_shape)),
  128. g_offset.stride()
  129. ));
  130. auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_, _0{}));
  131. return g_tensor;
  132. }
  133. template <>
  134. template <typename MTensor, typename Shape>
  135. CUTLASS_DEVICE auto VarSeqLenTraits::get_lse_local_tile_tensor(
  136. const MTensor &m_tensor, const Shape &tile_shape,
  137. int bidh, int bidb, bool padded) const {
  138. auto g_offset = local_tile(
  139. m_tensor(bidh, _), cute::make_shape(_1{}),
  140. make_coord(cu_seq_len[bidb] + (padded ? kMaxTileSize * bidb : 0)));
  141. auto g_sequence = make_tensor(
  142. g_offset.data(),
  143. make_layout(cute::make_shape(actual_seq_len), cute::make_shape(_1{})));
  144. auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_));
  145. return g_tensor;
  146. }
  147. ////////////////////////////////////////////////////////////////////////////////////////////////////
  148. } // namespace flash