1
0

seq_len.h 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cutlass/cutlass.h>
  6. #include <cute/layout.hpp>
  7. namespace flash {
  8. static constexpr int kMaxTileSize = 128;
  9. template <bool UseVarSeqLen> class SeqLenTraits {
  10. public:
  11. // Total number of queries / keys. Unpadded.
  12. int sum_s = 0;
  13. // seq len offsets.
  14. int *cu_seq_len = nullptr;
  15. // actual seq len array.
  16. int *seq_used = nullptr;
  17. // seq len of the current batch.
  18. int actual_seq_len = -1;
  19. // Whether this is for fixed-seq-len or var-seq-len.
  20. static constexpr bool kUseVarSeqLen = UseVarSeqLen;
  21. using ShapeT = std::conditional_t<
  22. UseVarSeqLen,
  23. cute::Shape<int32_t, int32_t, int32_t>,
  24. cute::Shape<int32_t, int32_t, int32_t, int32_t>
  25. >;
  26. using StrideT = std::conditional_t<
  27. UseVarSeqLen,
  28. cute::Shape<int64_t, _1, int64_t>,
  29. cute::Shape<int64_t, _1, int64_t, int64_t>
  30. >;
  31. using LayoutT = cute::Layout<ShapeT, StrideT>;
  32. using ShapeLseT = std::conditional_t<
  33. UseVarSeqLen,
  34. cute::Shape<int32_t, int32_t>,
  35. cute::Shape<int32_t, int32_t, int32_t>
  36. >;
  37. using StrideLseT = std::conditional_t<
  38. UseVarSeqLen,
  39. cute::Shape<int64_t, _1>,
  40. cute::Shape<int64_t, int64_t, _1>
  41. >;
  42. using LayoutLseT = cute::Layout<ShapeLseT, StrideLseT>;
  43. CUTLASS_HOST SeqLenTraits() {}
  44. CUTLASS_HOST SeqLenTraits(
  45. int sum_s, int max_seq_len, int *cu_seq_len = nullptr, int *seq_used = nullptr):
  46. sum_s(sum_s), cu_seq_len(cu_seq_len), seq_used(seq_used), actual_seq_len(max_seq_len) {}
  47. // Returns the layout of a tensor in MKHB format in global memory.
  48. // padded: only useful for var-seq-len for dq_accum and softmax_d.
  49. CUTLASS_HOST_DEVICE auto get_gmem_layout(
  50. int m, int k, int h, int b,
  51. int64_t m_stride, int64_t h_stride, int64_t b_stride,
  52. bool padded = false) const {
  53. static_assert(!UseVarSeqLen, "Default implementation is for FixedSeqLen.");
  54. return make_layout(make_shape(m, k, h, b),
  55. make_stride(m_stride, cute::_1{}, h_stride, b_stride));
  56. }
  57. // Returns the layout of a tensor in MKHB format in global memory.
  58. // padded: only useful for var-seq-len for dq_accum and softmax_d.
  59. CUTLASS_HOST_DEVICE auto get_lse_gmem_layout(
  60. int m, int h, int b, bool padded = false) const {
  61. static_assert(!UseVarSeqLen, "Default implementation is for FixedSeqLen.");
  62. return make_layout(make_shape(b, h, m),
  63. make_stride(int64_t(h * m), int64_t(m), cute::_1()));
  64. }
  65. CUTLASS_DEVICE void init(int bidb) {}
  66. template <typename MTensor, typename Shape>
  67. CUTLASS_DEVICE auto get_local_tile_tensor(
  68. const MTensor &m_tensor, const Shape &tile_shape,
  69. int bidh, int bidb, bool padded = false) const {
  70. auto g_tensor = local_tile(
  71. m_tensor(_, _, bidh, bidb), tile_shape, make_coord(_, _0{}));
  72. return g_tensor;
  73. }
  74. template <typename MTensor, typename Shape>
  75. CUTLASS_DEVICE auto get_lse_local_tile_tensor(
  76. const MTensor &m_tensor, const Shape &tile_shape,
  77. int bidh, int bidb, bool padded = false) const {
  78. auto g_tensor = local_tile(m_tensor(bidb, bidh, _), tile_shape, make_coord(_));
  79. return g_tensor;
  80. }
  81. };
  82. using FixedSeqLenTraits = SeqLenTraits<false>;
  83. using VarSeqLenTraits = SeqLenTraits<true>;
  84. // Returns the static layout of a var-seq-len tensor in global memory based on
  85. // max_seq_len and max_batch_size.
  86. // padded: only useful for var-seq-len for dq_accum and softmax_d.
  87. // When padded is True, use B_M + kMaxTileSize * B as the total B_M.
  88. template <>
  89. CUTLASS_HOST_DEVICE auto VarSeqLenTraits::get_gmem_layout(
  90. int m, int k, int h, int b,
  91. int64_t m_stride, int64_t h_stride, int64_t b_stride,
  92. bool padded) const {
  93. return make_layout(
  94. make_shape(sum_s + (padded ? kMaxTileSize * b : 0), k, h),
  95. make_stride(m_stride, cute::_1{}, h_stride));
  96. }
  97. // padded: only useful for var-seq-len for dq_accum and softmax_d.
  98. // When padded is True, use B_M + kMaxTileSize * B as the total B_M.
  99. template <>
  100. CUTLASS_HOST_DEVICE auto VarSeqLenTraits::get_lse_gmem_layout(
  101. int m, int h, int b, bool padded) const {
  102. return make_layout(
  103. make_shape(h, sum_s + (padded ? kMaxTileSize * b : 0)),
  104. make_stride(int64_t(sum_s + (padded ? kMaxTileSize * b : 0)), cute::_1()));
  105. }
  106. template <>
  107. CUTLASS_DEVICE void VarSeqLenTraits::init(int bidb) {
  108. actual_seq_len =
  109. seq_used ? seq_used[bidb] : (cu_seq_len[bidb + 1] - cu_seq_len[bidb]);
  110. }
  111. template <>
  112. template <typename MTensor, typename Shape>
  113. CUTLASS_DEVICE auto VarSeqLenTraits::get_local_tile_tensor(
  114. const MTensor &m_tensor, const Shape &tile_shape,
  115. int bidh, int bidb, bool padded) const {
  116. auto g_offset = local_tile(
  117. m_tensor(_, _, bidh),
  118. cute::make_shape(1, get<1>(tile_shape)),
  119. make_coord(cu_seq_len[bidb] + (padded ? kMaxTileSize * bidb : 0), _0{}));
  120. auto g_sequence = make_tensor(
  121. g_offset.data(),
  122. make_layout(
  123. cute::make_shape(actual_seq_len, get<1>(tile_shape)),
  124. g_offset.stride()
  125. ));
  126. auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_, _0{}));
  127. return g_tensor;
  128. }
  129. template <>
  130. template <typename MTensor, typename Shape>
  131. CUTLASS_DEVICE auto VarSeqLenTraits::get_lse_local_tile_tensor(
  132. const MTensor &m_tensor, const Shape &tile_shape,
  133. int bidh, int bidb, bool padded) const {
  134. auto g_offset = local_tile(
  135. m_tensor(bidh, _), cute::make_shape(_1{}),
  136. make_coord(cu_seq_len[bidb] + (padded ? kMaxTileSize * bidb : 0)));
  137. auto g_sequence = make_tensor(
  138. g_offset.data(),
  139. make_layout(cute::make_shape(actual_seq_len), cute::make_shape(_1{})));
  140. auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_));
  141. return g_tensor;
  142. }
  143. ////////////////////////////////////////////////////////////////////////////////////////////////////
  144. } // namespace flash