seq_len.h 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <array>
  6. #include <algorithm>
  7. #include <cutlass/cutlass.h>
  8. #include <cute/layout.hpp>
  9. namespace flash {
  10. static constexpr int kMaxTileSize = 128;
  11. template <bool UseVarSeqLen_, bool UseGQAPacking_> class SeqLenTraits {
  12. public:
  13. static_assert(!(UseVarSeqLen_ && UseGQAPacking_),
  14. "Variable sequence length with GQA parallelization not implemented yet.");
  15. // Total number of queries / keys. Unpadded.
  16. int sum_s = 0;
  17. // seq len offsets.
  18. int *cu_seq_len = nullptr;
  19. // actual seq len array.
  20. int *seq_used = nullptr;
  21. // seq len of the current batch.
  22. int actual_seq_len = -1;
  23. // Whether this is for fixed-seq-len or var-seq-len.
  24. static constexpr bool UseVarSeqLen = UseVarSeqLen_;
  25. static constexpr bool UseGQAPacking = UseGQAPacking_;
  26. using ShapeT = std::conditional_t<
  27. UseVarSeqLen,
  28. cute::Shape<int32_t, int32_t, int32_t>,
  29. std::conditional_t<
  30. UseGQAPacking,
  31. cute::Shape<int32_t, int32_t, int32_t, int32_t, int32_t>,
  32. cute::Shape<int32_t, int32_t, int32_t, int32_t>
  33. >
  34. >;
  35. using StrideT = std::conditional_t<
  36. UseVarSeqLen,
  37. cute::Shape<int64_t, _1, int64_t>,
  38. std::conditional_t<
  39. UseGQAPacking,
  40. cute::Shape<int64_t, int64_t, _1, int64_t, int64_t>,
  41. cute::Shape<int64_t, _1, int64_t, int64_t>
  42. >
  43. >;
  44. using LayoutT = cute::Layout<ShapeT, StrideT>;
  45. using ShapeLseT = std::conditional_t<
  46. UseVarSeqLen,
  47. cute::Shape<int32_t, int32_t>,
  48. cute::Shape<int32_t, int32_t, int32_t>
  49. >;
  50. using StrideLseT = std::conditional_t<
  51. UseVarSeqLen,
  52. cute::Shape<int64_t, _1>,
  53. cute::Shape<int64_t, int64_t, _1>
  54. >;
  55. using LayoutLseT = cute::Layout<ShapeLseT, StrideLseT>;
  56. // Not used for varseqlen
  57. using ShapeOAccumT = std::conditional_t<
  58. UseGQAPacking,
  59. cute::Shape<int32_t, int32_t, int32_t, int32_t, int32_t, int32_t>,
  60. cute::Shape<int32_t, int32_t, int32_t, int32_t, int32_t>
  61. >;
  62. using StrideOAccumT = std::conditional_t<
  63. UseGQAPacking,
  64. cute::Shape<int64_t, int64_t, _1, int64_t, int64_t, int64_t>,
  65. cute::Shape<int64_t, _1, int64_t, int64_t, int64_t>
  66. >;
  67. using LayoutOAccumT = cute::Layout<ShapeOAccumT, StrideOAccumT>;
  68. using ShapeLseAccumT = cute::Shape<int32_t, int32_t, int32_t, int32_t>;
  69. using StrideLseAccumT = cute::Shape<int64_t, int64_t, int64_t, _1>;
  70. using LayoutLseAccumT = cute::Layout<ShapeLseAccumT, StrideLseAccumT>;
  71. CUTLASS_HOST SeqLenTraits() {}
  72. CUTLASS_HOST SeqLenTraits(
  73. int sum_s, int max_seq_len, int *cu_seq_len = nullptr, int *seq_used = nullptr):
  74. sum_s(sum_s), cu_seq_len(cu_seq_len), seq_used(seq_used), actual_seq_len(max_seq_len) {}
  75. CUTLASS_DEVICE void init(int bidb) {
  76. // TODO: add leftpad, seqlen_new for kv cache support
  77. if (seq_used) {
  78. actual_seq_len = seq_used[bidb];
  79. }
  80. }
  81. CUTLASS_DEVICE void init_no_guard(int bidb) {
  82. actual_seq_len = seq_used[bidb];
  83. }
  84. // Returns the layout of a tensor in MKHB format in global memory.
  85. // padded: only useful for var-seq-len for dq_accum and softmax_d.
  86. CUTLASS_HOST_DEVICE auto get_gmem_layout(
  87. int m, int k, int h, int b,
  88. int64_t m_stride, int64_t h_stride, int64_t b_stride,
  89. bool padded = false) const {
  90. static_assert(!UseVarSeqLen, "Specialize default implementation for VarSeqLen.");
  91. // static_assert(!UseGQAPacking, "Specialize default implementation for UseGQAPacking.");
  92. return make_layout(make_shape(m, k, h, b),
  93. make_stride(m_stride, cute::_1{}, h_stride, b_stride));
  94. }
  95. // Returns the layout of a tensor in MKHB format in global memory.
  96. // padded: only useful for var-seq-len for dq_accum and softmax_d.
  97. // Overload that separates h into h_k and h/h_k.
  98. CUTLASS_HOST_DEVICE auto get_gmem_layout(
  99. int m, int k, int h_k, int b, int h_h_k_ratio,
  100. int64_t m_stride, int64_t h_stride, int64_t b_stride,
  101. bool padded = false) const {
  102. static_assert(!UseVarSeqLen, "Specialize default implementation for VarSeqLen.");
  103. static_assert(!UseGQAPacking, "Specialize default implementation for UseGQAPacking.");
  104. return make_layout(make_shape(m, k, h_k * h_h_k_ratio, b),
  105. make_stride(m_stride, cute::_1{}, h_stride, b_stride));
  106. }
  107. // Returns the layout of a tensor in MKHBT format in global memory,
  108. // where T is number of splits.
  109. CUTLASS_HOST_DEVICE auto get_oaccum_gmem_layout(
  110. int m, int k, int h, int b, int num_splits,
  111. int64_t m_stride, int64_t h_stride, int64_t b_stride, int64_t split_stride,
  112. bool padded = false) const {
  113. return make_layout(make_shape(m, k, h, b, num_splits),
  114. make_stride(m_stride, cute::_1{}, h_stride, b_stride, split_stride));
  115. }
  116. // Returns the layout of a tensor in MKHBT format in global memory,
  117. // where T is number of splits.
  118. // Overload that separates h into h_k and h/h_k.
  119. CUTLASS_HOST_DEVICE auto get_oaccum_gmem_layout(
  120. int m, int k, int h_k, int b, int h_h_k_ratio, int num_splits,
  121. int64_t m_stride, int64_t h_stride, int64_t b_stride, int64_t split_stride,
  122. bool padded = false) const {
  123. return make_layout(make_shape(m, k, h_k * h_h_k_ratio, b, num_splits),
  124. make_stride(m_stride, cute::_1{}, h_stride, b_stride, split_stride));
  125. }
  126. // Returns the layout of lse tensor in BHM format in global memory.
  127. // padded: only useful for var-seq-len for dq_accum and softmax_d.
  128. CUTLASS_HOST_DEVICE auto get_lse_gmem_layout(
  129. int m, int h, int b, bool padded = false) const {
  130. static_assert(!UseVarSeqLen, "Specialize default implementation for VarSeqLen.");
  131. return make_layout(make_shape(b, h, m),
  132. make_stride(int64_t(h * m), int64_t(m), cute::_1()));
  133. }
  134. // Returns the layout of lse tensor in TBHM format in global memory,
  135. // where T is number of splits.
  136. CUTLASS_HOST_DEVICE auto get_lseaccum_gmem_layout(
  137. int m, int h, int b, int num_splits, bool padded = false) const {
  138. return make_layout(make_shape(num_splits, b, h, m),
  139. make_stride(int64_t(b * h * m), int64_t(h * m), int64_t(m), cute::_1()));
  140. }
  141. template <typename MTensor, typename Shape>
  142. CUTLASS_DEVICE auto get_local_tile_tensor(
  143. const MTensor &m_tensor, const Shape &tile_shape,
  144. int bidh, int bidb, bool padded = false) const {
  145. auto g_tensor = local_tile(
  146. m_tensor(_, _, bidh, bidb), tile_shape, make_coord(_, _0{}));
  147. return g_tensor;
  148. }
  149. template <bool Is_split, typename MTensor, typename Shape>
  150. CUTLASS_DEVICE auto get_lse_local_tile_tensor(
  151. const MTensor &m_tensor, const Shape &tile_shape,
  152. int bidh, int bidb, int n_split_idx, bool padded = false) const {
  153. // m_tensor has shape (B, H, M) or (splits, B, H, M)
  154. // Expect tile shape (bM)
  155. // Returns g_tensor of shape = (bM, ceil_div(M,bM))
  156. if constexpr(!Is_split) {
  157. auto g_tensor = local_tile(m_tensor(bidb, bidh, _), tile_shape, make_coord(_));
  158. return g_tensor;
  159. } else {
  160. auto g_tensor = local_tile(m_tensor(n_split_idx, bidb, bidh, _), tile_shape, make_coord(_));
  161. return g_tensor;
  162. }
  163. }
  164. template <bool Is_split, typename MTensor, typename Shape>
  165. CUTLASS_DEVICE auto get_o_local_tile_tensor(
  166. const MTensor &m_tensor, const Shape &tile_shape,
  167. int bidh, int bidb, int split_idx, bool padded = false) const {
  168. // static_assert(!UseVarSeqLen, "Don't use get_o_local_tile_tensor with VarSeqLen.");
  169. // m_tensor has shape (M, K, H, B) or (M, K, H, B, splits)
  170. // Expect tile shape (bM, K)
  171. // Returns g_tensor of shape = (bM, K, ceil_div(M,bM))
  172. if constexpr(!Is_split) {
  173. auto g_tensor = local_tile(
  174. m_tensor(_, _, bidh, bidb), tile_shape, make_coord(_, _0{}));
  175. return g_tensor;
  176. } else {
  177. auto g_tensor = local_tile(
  178. m_tensor(_, _, bidh, bidb, split_idx), tile_shape, make_coord(_, _0{}));
  179. return g_tensor;
  180. }
  181. }
  182. };
  183. using FixedSeqLenTraits = SeqLenTraits<false, false>;
  184. using VarSeqLenTraits = SeqLenTraits<true, false>;
  185. using FixedGQASeqLenTraits = SeqLenTraits<false, true>;
  186. template <>
  187. CUTLASS_DEVICE void VarSeqLenTraits::init(int bidb) {
  188. actual_seq_len =
  189. seq_used ? seq_used[bidb] : (cu_seq_len[bidb + 1] - cu_seq_len[bidb]);
  190. }
  191. template <>
  192. CUTLASS_DEVICE void VarSeqLenTraits::init_no_guard(int bidb) {
  193. actual_seq_len =
  194. seq_used ? seq_used[bidb] : (cu_seq_len[bidb + 1] - cu_seq_len[bidb]);
  195. }
  196. template <>
  197. CUTLASS_DEVICE void FixedGQASeqLenTraits::init(int bidb) {
  198. // no op
  199. }
  200. // Returns the static layout of a var-seq-len tensor in global memory based on
  201. // max_seq_len and max_batch_size.
  202. // padded: only useful for var-seq-len for dq_accum and softmax_d.
  203. // When padded is True, use B_M + kMaxTileSize * B as the total B_M.
  204. template <>
  205. CUTLASS_HOST_DEVICE auto VarSeqLenTraits::get_gmem_layout(
  206. int m, int k, int h, int b,
  207. int64_t m_stride, int64_t h_stride, int64_t b_stride,
  208. bool padded) const {
  209. return make_layout(
  210. make_shape(sum_s + (padded ? kMaxTileSize * b : 0), k, h),
  211. make_stride(m_stride, cute::_1{}, h_stride));
  212. }
  213. template <>
  214. CUTLASS_HOST_DEVICE auto VarSeqLenTraits::get_gmem_layout(
  215. int m, int k, int h_k, int b, int h_h_k_ratio,
  216. int64_t m_stride, int64_t h_stride, int64_t b_stride,
  217. bool padded) const {
  218. return make_layout(
  219. make_shape(sum_s + (padded ? kMaxTileSize * b : 0), k, h_k * h_h_k_ratio),
  220. make_stride(m_stride, cute::_1{}, h_stride));
  221. }
  222. // padded: only useful for var-seq-len for dq_accum and softmax_d.
  223. // When padded is True, use B_M + kMaxTileSize * B as the total B_M.
  224. template <>
  225. CUTLASS_HOST_DEVICE auto VarSeqLenTraits::get_lse_gmem_layout(
  226. int m, int h, int b, bool padded) const {
  227. return make_layout(
  228. make_shape(h, sum_s + (padded ? kMaxTileSize * b : 0)),
  229. make_stride(int64_t(sum_s + (padded ? kMaxTileSize * b : 0)), cute::_1()));
  230. }
  231. template <>
  232. template <typename MTensor, typename Shape>
  233. CUTLASS_DEVICE auto VarSeqLenTraits::get_local_tile_tensor(
  234. const MTensor &m_tensor, const Shape &tile_shape,
  235. int bidh, int bidb, bool padded) const {
  236. auto g_offset = local_tile(
  237. m_tensor(_, _, bidh),
  238. cute::make_shape(1, get<1>(tile_shape)),
  239. make_coord(cu_seq_len[bidb] + (padded ? kMaxTileSize * bidb : 0), _0{}));
  240. auto g_sequence = make_tensor(
  241. g_offset.data(),
  242. make_layout(
  243. cute::make_shape(actual_seq_len, get<1>(tile_shape)),
  244. g_offset.stride()
  245. ));
  246. auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_, _0{}));
  247. return g_tensor;
  248. }
  249. // TODO: restructure to not duplicate code
  250. template <>
  251. template <bool Is_split, typename MTensor, typename Shape>
  252. CUTLASS_DEVICE auto VarSeqLenTraits::get_o_local_tile_tensor(
  253. const MTensor &m_tensor, const Shape &tile_shape,
  254. int bidh, int bidb, int n_split_idx, bool padded) const {
  255. static_assert(!Is_split, "Don't currently support split kv kernel with VarSeqLenTraits");
  256. auto g_offset = local_tile(
  257. m_tensor(_, _, bidh),
  258. cute::make_shape(1, get<1>(tile_shape)),
  259. make_coord(cu_seq_len[bidb] + (padded ? kMaxTileSize * bidb : 0), _0{}));
  260. auto g_sequence = make_tensor(
  261. g_offset.data(),
  262. make_layout(
  263. cute::make_shape(actual_seq_len, get<1>(tile_shape)),
  264. g_offset.stride()
  265. ));
  266. auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_, _0{}));
  267. return g_tensor;
  268. }
  269. template <>
  270. template <bool Is_split, typename MTensor, typename Shape>
  271. CUTLASS_DEVICE auto VarSeqLenTraits::get_lse_local_tile_tensor(
  272. const MTensor &m_tensor, const Shape &tile_shape,
  273. int bidh, int bidb, int n_split_idx, bool padded) const {
  274. static_assert(!Is_split, "Don't currently support split kv kernel with VarSeqLenTraits");
  275. auto g_offset = local_tile(
  276. m_tensor(bidh, _), cute::make_shape(_1{}),
  277. make_coord(cu_seq_len[bidb] + (padded ? kMaxTileSize * bidb : 0)));
  278. auto g_sequence = make_tensor(
  279. g_offset.data(),
  280. make_layout(cute::make_shape(actual_seq_len), cute::make_shape(_1{})));
  281. auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_));
  282. return g_tensor;
  283. }
  284. // Returns layout of QO tensor in (M,H/HK,K,HK,B) format in global memory.
  285. template <>
  286. CUTLASS_HOST_DEVICE auto FixedGQASeqLenTraits::get_gmem_layout(
  287. int m, int k, int h_k, int b, int h_h_k_ratio,
  288. int64_t m_stride, int64_t h_stride, int64_t b_stride, bool padded) const {
  289. return make_layout(make_shape(m, h_h_k_ratio, k, h_k, b),
  290. make_stride(m_stride, h_stride, cute::_1{},
  291. h_stride * h_h_k_ratio, b_stride));
  292. }
  293. // Returns layout of Oaccum tensor in (M,H/HK,K,HK,B,T) format in global memory.
  294. template <>
  295. CUTLASS_HOST_DEVICE auto FixedGQASeqLenTraits::get_oaccum_gmem_layout(
  296. int m, int k, int h_k, int b, int h_h_k_ratio, int num_splits,
  297. int64_t m_stride, int64_t h_stride, int64_t b_stride, int64_t split_stride,
  298. bool padded) const {
  299. return make_layout(make_shape(m, h_h_k_ratio, k, h_k, b, num_splits),
  300. make_stride(m_stride, h_stride, cute::_1{},
  301. h_stride * h_h_k_ratio, b_stride,
  302. split_stride));
  303. }
  304. template <>
  305. template <typename MTensor, typename Shape>
  306. CUTLASS_DEVICE auto FixedGQASeqLenTraits::get_local_tile_tensor(
  307. const MTensor &m_tensor, const Shape &tile_shape,
  308. int bidh_kv, int bidb, bool padded) const {
  309. // m_tensor has shape (M, H/H_K, K, H_K, B)
  310. // Expect tile_shape (bM/bH, bH, K)
  311. // Returns g_tensor of shape (bM/bH, bH, K, ceil_div(M,bM/bH), ceil_div(H/H_K,bH))
  312. auto g_tensor = local_tile(
  313. m_tensor(_, _, _, bidh_kv, bidb), tile_shape, make_coord(_, _, _0{}));
  314. return g_tensor;
  315. }
  316. template <>
  317. template <bool Is_split, typename MTensor, typename Shape>
  318. CUTLASS_DEVICE auto FixedGQASeqLenTraits::get_o_local_tile_tensor(
  319. const MTensor &m_tensor, const Shape &tile_shape,
  320. int bidh_kv, int bidb, int split_idx, bool padded) const {
  321. // m_tensor has shape (M, H/H_K, K, H_K, B) or (M, H/H_K, K, H_K, B, splits)
  322. // Expect tile_shape (bM/bH, bH, K)
  323. // Returns g_tensor of shape (bM/bH, bH, K, ceil_div(M,bM/bH), ceil_div(H/H_K,bH))
  324. if constexpr(!Is_split) {
  325. auto g_tensor = local_tile(
  326. m_tensor(_, _, _, bidh_kv, bidb), tile_shape, make_coord(_, _, _0{}));
  327. return g_tensor;
  328. } else {
  329. auto g_tensor = local_tile(
  330. m_tensor(_, _, _, bidh_kv, bidb, split_idx), tile_shape, make_coord(_, _, _0{}));
  331. return g_tensor;
  332. }
  333. }
  334. ////////////////////////////////////////////////////////////////////////////////////////////////////
  335. } // namespace flash