seq_len.h 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <array>
  6. #include <algorithm>
  7. #include <cutlass/cutlass.h>
  8. #include <cute/layout.hpp>
  9. namespace flash {
  10. static constexpr int kMaxTileSize = 128;
  11. template <bool UseVarSeqLen_, bool UsePagedKV_, bool UseGQAPacking_> class SeqLenTraits {
  12. public:
  13. static_assert((!UsePagedKV_) || (UseVarSeqLen_ && UsePagedKV_), "PagedKV is only supported for VarSeqLen.");
  14. static_assert(!(UseVarSeqLen_ && UseGQAPacking_),
  15. "Variable sequence length with GQA parallelization not implemented yet.");
  16. // Total number of queries / keys. Unpadded.
  17. int sum_s = 0;
  18. // seq len offsets.
  19. int *cu_seq_len = nullptr;
  20. // actual seq len array.
  21. int *seq_used = nullptr;
  22. // seq len of the current batch.
  23. int actual_seq_len = -1;
  24. // Whether this is for fixed-seq-len or var-seq-len.
  25. static constexpr bool UseVarSeqLen = UseVarSeqLen_;
  26. static constexpr bool UseGQAPacking = UseGQAPacking_;
  27. static constexpr bool UsePagedKV = UsePagedKV_;
  28. using ShapeT = std::conditional_t<
  29. UseVarSeqLen,
  30. std::conditional_t<
  31. !UsePagedKV,
  32. cute::Shape<int32_t, int32_t, int32_t>,
  33. cute::Shape<int32_t, int32_t, int32_t, int32_t>>,
  34. std::conditional_t<
  35. UseGQAPacking,
  36. cute::Shape<int32_t, int32_t, int32_t, int32_t, int32_t>,
  37. cute::Shape<int32_t, int32_t, int32_t, int32_t>
  38. >
  39. >;
  40. using VirtualShapeT = std::conditional_t<
  41. UsePagedKV,
  42. cute::Shape<int32_t, int32_t, int32_t, int32_t>,
  43. ShapeT
  44. >;
  45. using StrideT = std::conditional_t<
  46. UseVarSeqLen,
  47. std::conditional_t<
  48. !UsePagedKV,
  49. cute::Shape<int64_t, _1, int64_t>,
  50. cute::Shape<int64_t, _1, int64_t, int64_t>>,
  51. std::conditional_t<
  52. UseGQAPacking,
  53. cute::Shape<int64_t, int64_t, _1, int64_t, int64_t>,
  54. cute::Shape<int64_t, _1, int64_t, int64_t>
  55. >
  56. >;
  57. using LayoutT = cute::Layout<ShapeT, StrideT>;
  58. using ShapeLseT = std::conditional_t<
  59. UseVarSeqLen,
  60. cute::Shape<int32_t, int32_t>,
  61. cute::Shape<int32_t, int32_t, int32_t>
  62. >;
  63. using StrideLseT = std::conditional_t<
  64. UseVarSeqLen,
  65. cute::Shape<int64_t, _1>,
  66. cute::Shape<int64_t, int64_t, _1>
  67. >;
  68. using LayoutLseT = cute::Layout<ShapeLseT, StrideLseT>;
  69. // Not used for varseqlen
  70. using ShapeOAccumT = std::conditional_t<
  71. UseGQAPacking,
  72. cute::Shape<int32_t, int32_t, int32_t, int32_t, int32_t, int32_t>,
  73. cute::Shape<int32_t, int32_t, int32_t, int32_t, int32_t>
  74. >;
  75. using StrideOAccumT = std::conditional_t<
  76. UseGQAPacking,
  77. cute::Shape<int64_t, int64_t, _1, int64_t, int64_t, int64_t>,
  78. cute::Shape<int64_t, _1, int64_t, int64_t, int64_t>
  79. >;
  80. using LayoutOAccumT = cute::Layout<ShapeOAccumT, StrideOAccumT>;
  81. using ShapeLseAccumT = cute::Shape<int32_t, int32_t, int32_t, int32_t>;
  82. using StrideLseAccumT = cute::Shape<int64_t, int64_t, int64_t, _1>;
  83. using LayoutLseAccumT = cute::Layout<ShapeLseAccumT, StrideLseAccumT>;
  84. CUTLASS_HOST SeqLenTraits() {}
  85. CUTLASS_HOST SeqLenTraits(
  86. int sum_s, int max_seq_len, int *cu_seq_len = nullptr, int *seq_used = nullptr):
  87. sum_s(sum_s), cu_seq_len(cu_seq_len), seq_used(seq_used), actual_seq_len(max_seq_len) {}
  88. CUTLASS_DEVICE void init(int bidb) {
  89. // TODO: add leftpad, seqlen_new for kv cache support
  90. if (seq_used) {
  91. actual_seq_len = seq_used[bidb];
  92. }
  93. }
  94. CUTLASS_DEVICE void init_no_guard(int bidb) {
  95. actual_seq_len = seq_used[bidb];
  96. }
  97. // Returns the layout of a tensor in MKHB format in global memory.
  98. // padded: only useful for var-seq-len for dq_accum and softmax_d.
  99. CUTLASS_HOST_DEVICE auto get_gmem_layout(
  100. int m, int k, int h, int b,
  101. int64_t m_stride, int64_t h_stride, int64_t b_stride,
  102. int page_block_size, int num_blocks,
  103. bool padded = false) const {
  104. static_assert(!UseVarSeqLen, "Specialize default implementation for VarSeqLen.");
  105. // static_assert(!UseGQAPacking, "Specialize default implementation for UseGQAPacking.");
  106. return make_layout(make_shape(m, k, h, b),
  107. make_stride(m_stride, cute::_1{}, h_stride, b_stride));
  108. }
  109. // Returns the layout of a tensor in MKHB format in virtual memory space
  110. // that is mapped to the global memory via the block table when paged attention is used
  111. CUTLASS_HOST_DEVICE VirtualShapeT get_virtual_shape(
  112. int m, int k, int h_k, int b, int h_h_k_ratio, bool padded) const {
  113. return make_shape(m, k, h_k, b);
  114. }
  115. // Returns the layout of a tensor in MKHB format in global memory.
  116. // padded: only useful for var-seq-len for dq_accum and softmax_d.
  117. // Overload that separates h into h_k and h/h_k.
  118. CUTLASS_HOST_DEVICE auto get_gmem_layout(
  119. int m, int k, int h_k, int b, int h_h_k_ratio,
  120. int64_t m_stride, int64_t h_stride, int64_t b_stride,
  121. bool padded = false) const {
  122. static_assert(!UseVarSeqLen, "Specialize default implementation for VarSeqLen.");
  123. static_assert(!UseGQAPacking, "Specialize default implementation for UseGQAPacking.");
  124. return make_layout(make_shape(m, k, h_k * h_h_k_ratio, b),
  125. make_stride(m_stride, cute::_1{}, h_stride, b_stride));
  126. }
  127. // Returns the layout of a tensor in MKHBT format in global memory,
  128. // where T is number of splits.
  129. CUTLASS_HOST_DEVICE auto get_oaccum_gmem_layout(
  130. int m, int k, int h, int b, int num_splits,
  131. int64_t m_stride, int64_t h_stride, int64_t b_stride, int64_t split_stride,
  132. bool padded = false) const {
  133. return make_layout(make_shape(m, k, h, b, num_splits),
  134. make_stride(m_stride, cute::_1{}, h_stride, b_stride, split_stride));
  135. }
  136. // Returns the layout of a tensor in MKHBT format in global memory,
  137. // where T is number of splits.
  138. // Overload that separates h into h_k and h/h_k.
  139. CUTLASS_HOST_DEVICE auto get_oaccum_gmem_layout(
  140. int m, int k, int h_k, int b, int h_h_k_ratio, int num_splits,
  141. int64_t m_stride, int64_t h_stride, int64_t b_stride, int64_t split_stride,
  142. bool padded = false) const {
  143. return make_layout(make_shape(m, k, h_k * h_h_k_ratio, b, num_splits),
  144. make_stride(m_stride, cute::_1{}, h_stride, b_stride, split_stride));
  145. }
  146. // Returns the layout of lse tensor in BHM format in global memory.
  147. // padded: only useful for var-seq-len for dq_accum and softmax_d.
  148. CUTLASS_HOST_DEVICE auto get_lse_gmem_layout(
  149. int m, int h, int b, bool padded = false) const {
  150. static_assert(!UseVarSeqLen, "Specialize default implementation for VarSeqLen.");
  151. return make_layout(make_shape(b, h, m),
  152. make_stride(int64_t(h * m), int64_t(m), cute::_1()));
  153. }
  154. // Returns the layout of lse tensor in TBHM format in global memory,
  155. // where T is number of splits.
  156. CUTLASS_HOST_DEVICE auto get_lseaccum_gmem_layout(
  157. int m, int h, int b, int num_splits, bool padded = false) const {
  158. return make_layout(make_shape(num_splits, b, h, m),
  159. make_stride(int64_t(b * h * m), int64_t(h * m), int64_t(m), cute::_1()));
  160. }
  161. template <typename MTensor, typename Shape>
  162. CUTLASS_DEVICE auto get_local_tile_tensor(
  163. const MTensor &m_tensor, const Shape &tile_shape,
  164. int bidh, int bidb, bool padded = false) const {
  165. auto g_tensor = local_tile(
  166. m_tensor(_, _, bidh, bidb), tile_shape, make_coord(_, _0{}));
  167. return g_tensor;
  168. }
  169. template <bool Is_split, typename MTensor, typename Shape>
  170. CUTLASS_DEVICE auto get_lse_local_tile_tensor(
  171. const MTensor &m_tensor, const Shape &tile_shape,
  172. int bidh, int bidb, int n_split_idx, bool padded = false) const {
  173. // m_tensor has shape (B, H, M) or (splits, B, H, M)
  174. // Expect tile shape (bM)
  175. // Returns g_tensor of shape = (bM, ceil_div(M,bM))
  176. if constexpr(!Is_split) {
  177. auto g_tensor = local_tile(m_tensor(bidb, bidh, _), tile_shape, make_coord(_));
  178. return g_tensor;
  179. } else {
  180. auto g_tensor = local_tile(m_tensor(n_split_idx, bidb, bidh, _), tile_shape, make_coord(_));
  181. return g_tensor;
  182. }
  183. }
  184. template <bool Is_split, typename MTensor, typename Shape>
  185. CUTLASS_DEVICE auto get_o_local_tile_tensor(
  186. const MTensor &m_tensor, const Shape &tile_shape,
  187. int bidh, int bidb, int split_idx, bool padded = false) const {
  188. // static_assert(!UseVarSeqLen, "Don't use get_o_local_tile_tensor with VarSeqLen.");
  189. // m_tensor has shape (M, K, H, B) or (M, K, H, B, splits)
  190. // Expect tile shape (bM, K)
  191. // Returns g_tensor of shape = (bM, K, ceil_div(M,bM))
  192. if constexpr(!Is_split) {
  193. auto g_tensor = local_tile(
  194. m_tensor(_, _, bidh, bidb), tile_shape, make_coord(_, _0{}));
  195. return g_tensor;
  196. } else {
  197. auto g_tensor = local_tile(
  198. m_tensor(_, _, bidh, bidb, split_idx), tile_shape, make_coord(_, _0{}));
  199. return g_tensor;
  200. }
  201. }
  202. };
  203. using FixedSeqLenTraits = SeqLenTraits<false, false, false>;
  204. using VarSeqLenTraits = SeqLenTraits<true, false, false>;
  205. using PagedSeqLenTraits = SeqLenTraits<true, true, false>;
  206. using FixedGQASeqLenTraits = SeqLenTraits<false, false, true>;
  207. template <>
  208. CUTLASS_DEVICE void VarSeqLenTraits::init(int bidb) {
  209. actual_seq_len =
  210. seq_used ? seq_used[bidb] : (cu_seq_len[bidb + 1] - cu_seq_len[bidb]);
  211. }
  212. template <>
  213. CUTLASS_DEVICE void FixedGQASeqLenTraits::init(int bidb) {
  214. // no op
  215. }
  216. // Returns the static layout of a var-seq-len tensor in global memory based on
  217. // max_seq_len and max_batch_size.
  218. // padded: only useful for var-seq-len for dq_accum and softmax_d.
  219. // When padded is True, use B_M + kMaxTileSize * B as the total B_M.
  220. template <>
  221. CUTLASS_HOST_DEVICE auto VarSeqLenTraits::get_gmem_layout(
  222. int m, int k, int h, int b,
  223. int64_t m_stride, int64_t h_stride, int64_t b_stride,
  224. int page_block_size, int num_blocks,
  225. bool padded) const {
  226. return make_layout(
  227. make_shape(sum_s + (padded ? kMaxTileSize * b : 0), k, h),
  228. make_stride(m_stride, cute::_1{}, h_stride));
  229. }
  230. template <>
  231. CUTLASS_HOST_DEVICE auto VarSeqLenTraits::get_gmem_layout(
  232. int m, int k, int h_k, int b, int h_h_k_ratio,
  233. int64_t m_stride, int64_t h_stride, int64_t b_stride,
  234. bool padded) const {
  235. return make_layout(
  236. make_shape(sum_s + (padded ? kMaxTileSize * b : 0), k, h_k * h_h_k_ratio),
  237. make_stride(m_stride, cute::_1{}, h_stride));
  238. }
  239. template <>
  240. CUTLASS_HOST_DEVICE VarSeqLenTraits::VirtualShapeT VarSeqLenTraits::get_virtual_shape(
  241. int m, int k, int h, int b, int h_h_k_ratio,
  242. bool padded) const {
  243. return make_shape(sum_s + (padded ? kMaxTileSize * b : 0), k, h);
  244. }
  245. // padded: only useful for var-seq-len for dq_accum and softmax_d.
  246. // When padded is True, use B_M + kMaxTileSize * B as the total B_M.
  247. //template <>
  248. template <>
  249. CUTLASS_HOST_DEVICE auto VarSeqLenTraits::get_lse_gmem_layout(
  250. int m, int h, int b, bool padded) const {
  251. return make_layout(
  252. make_shape(h, sum_s + (padded ? kMaxTileSize * b : 0)),
  253. make_stride(int64_t(sum_s + (padded ? kMaxTileSize * b : 0)), cute::_1()));
  254. }
  255. template <>
  256. template <typename MTensor, typename Shape>
  257. CUTLASS_DEVICE auto VarSeqLenTraits::get_local_tile_tensor(
  258. const MTensor &m_tensor, const Shape &tile_shape,
  259. int bidh, int bidb, bool padded) const {
  260. auto g_offset = local_tile(
  261. m_tensor(_, _, bidh),
  262. cute::make_shape(1, get<1>(tile_shape)),
  263. make_coord(cu_seq_len[bidb] + (padded ? kMaxTileSize * bidb : 0), _0{}));
  264. auto g_sequence = make_tensor(
  265. g_offset.data(),
  266. make_layout(
  267. cute::make_shape(actual_seq_len, get<1>(tile_shape)),
  268. g_offset.stride()
  269. ));
  270. auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_, _0{}));
  271. return g_tensor;
  272. }
  273. // TODO: restructure to not duplicate code
  274. template <>
  275. template <bool Is_split, typename MTensor, typename Shape>
  276. CUTLASS_DEVICE auto VarSeqLenTraits::get_o_local_tile_tensor(
  277. const MTensor &m_tensor, const Shape &tile_shape,
  278. int bidh, int bidb, int n_split_idx, bool padded) const {
  279. static_assert(!Is_split, "Don't currently support split kv kernel with VarSeqLenTraits");
  280. auto g_offset = local_tile(
  281. m_tensor(_, _, bidh),
  282. cute::make_shape(1, get<1>(tile_shape)),
  283. make_coord(cu_seq_len[bidb] + (padded ? kMaxTileSize * bidb : 0), _0{}));
  284. auto g_sequence = make_tensor(
  285. g_offset.data(),
  286. make_layout(
  287. cute::make_shape(actual_seq_len, get<1>(tile_shape)),
  288. g_offset.stride()
  289. ));
  290. auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_, _0{}));
  291. return g_tensor;
  292. }
  293. template <>
  294. template <bool Is_split, typename MTensor, typename Shape>
  295. CUTLASS_DEVICE auto VarSeqLenTraits::get_lse_local_tile_tensor(
  296. const MTensor &m_tensor, const Shape &tile_shape,
  297. int bidh, int bidb, int n_split_idx, bool padded) const {
  298. static_assert(!Is_split, "Don't currently support split kv kernel with VarSeqLenTraits");
  299. auto g_offset = local_tile(
  300. m_tensor(bidh, _), cute::make_shape(_1{}),
  301. make_coord(cu_seq_len[bidb] + (padded ? kMaxTileSize * bidb : 0)));
  302. auto g_sequence = make_tensor(
  303. g_offset.data(),
  304. make_layout(cute::make_shape(actual_seq_len), cute::make_shape(_1{})));
  305. auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_));
  306. return g_tensor;
  307. }
  308. // Returns layout of QO tensor in (M,H/HK,K,HK,B) format in global memory.
  309. template <>
  310. CUTLASS_HOST_DEVICE auto FixedGQASeqLenTraits::get_gmem_layout(
  311. int m, int k, int h_k, int b, int h_h_k_ratio,
  312. int64_t m_stride, int64_t h_stride, int64_t b_stride, bool padded) const {
  313. return make_layout(make_shape(m, h_h_k_ratio, k, h_k, b),
  314. make_stride(m_stride, h_stride, cute::_1{},
  315. h_stride * h_h_k_ratio, b_stride));
  316. }
  317. template <>
  318. CUTLASS_HOST_DEVICE FixedGQASeqLenTraits::VirtualShapeT FixedGQASeqLenTraits::get_virtual_shape(
  319. int m, int k, int h_k, int b, int h_h_k_ratio,
  320. bool padded) const {
  321. return make_shape(m, h_h_k_ratio, k, h_k, b);
  322. }
  323. // Returns layout of Oaccum tensor in (M,H/HK,K,HK,B,T) format in global memory.
  324. template <>
  325. CUTLASS_HOST_DEVICE auto FixedGQASeqLenTraits::get_oaccum_gmem_layout(
  326. int m, int k, int h_k, int b, int h_h_k_ratio, int num_splits,
  327. int64_t m_stride, int64_t h_stride, int64_t b_stride, int64_t split_stride,
  328. bool padded) const {
  329. return make_layout(make_shape(m, h_h_k_ratio, k, h_k, b, num_splits),
  330. make_stride(m_stride, h_stride, cute::_1{},
  331. h_stride * h_h_k_ratio, b_stride,
  332. split_stride));
  333. }
  334. template <>
  335. template <typename MTensor, typename Shape>
  336. CUTLASS_DEVICE auto FixedGQASeqLenTraits::get_local_tile_tensor(
  337. const MTensor &m_tensor, const Shape &tile_shape,
  338. int bidh_kv, int bidb, bool padded) const {
  339. // m_tensor has shape (M, H/H_K, K, H_K, B)
  340. // Expect tile_shape (bM/bH, bH, K)
  341. // Returns g_tensor of shape (bM/bH, bH, K, ceil_div(M,bM/bH), ceil_div(H/H_K,bH))
  342. auto g_tensor = local_tile(
  343. m_tensor(_, _, _, bidh_kv, bidb), tile_shape, make_coord(_, _, _0{}));
  344. return g_tensor;
  345. }
  346. template <>
  347. template <bool Is_split, typename MTensor, typename Shape>
  348. CUTLASS_DEVICE auto FixedGQASeqLenTraits::get_o_local_tile_tensor(
  349. const MTensor &m_tensor, const Shape &tile_shape,
  350. int bidh_kv, int bidb, int split_idx, bool padded) const {
  351. // m_tensor has shape (M, H/H_K, K, H_K, B) or (M, H/H_K, K, H_K, B, splits)
  352. // Expect tile_shape (bM/bH, bH, K)
  353. // Returns g_tensor of shape (bM/bH, bH, K, ceil_div(M,bM/bH), ceil_div(H/H_K,bH))
  354. if constexpr(!Is_split) {
  355. auto g_tensor = local_tile(
  356. m_tensor(_, _, _, bidh_kv, bidb), tile_shape, make_coord(_, _, _0{}));
  357. return g_tensor;
  358. } else {
  359. auto g_tensor = local_tile(
  360. m_tensor(_, _, _, bidh_kv, bidb, split_idx), tile_shape, make_coord(_, _, _0{}));
  361. return g_tensor;
  362. }
  363. }
  364. /////////////// PagedSeqLenTraits /////////////////
  365. // Returns the layout of a tensor in MKHB format in global memory.
  366. // padded: only useful for var-seq-len for dq_accum and softmax_d.
  367. template<>
  368. CUTLASS_HOST_DEVICE auto PagedSeqLenTraits::get_gmem_layout(
  369. int m, int k, int h, int b,
  370. int64_t m_stride, int64_t h_stride, int64_t b_stride,
  371. int page_block_size, int num_blocks,
  372. bool padded) const {
  373. return static_cast<PagedSeqLenTraits::LayoutT>(make_layout(make_shape((int)page_block_size, k, h, (int)num_blocks),
  374. make_stride(m_stride, cute::_1{}, h_stride, b_stride)));
  375. }
  376. template <>
  377. CUTLASS_DEVICE void PagedSeqLenTraits::init(int bidb) {
  378. actual_seq_len =
  379. seq_used ? seq_used[bidb] : (cu_seq_len[bidb + 1] - cu_seq_len[bidb]);
  380. }
  381. template <>
  382. template <typename MTensor, typename Shape>
  383. CUTLASS_DEVICE auto PagedSeqLenTraits::get_local_tile_tensor(
  384. const MTensor &m_tensor, const Shape &tile_shape,
  385. int bidh, int bidb, bool padded) const {
  386. auto g_slice = m_tensor(_, _, bidh, bidb); // = m_tensor[:,:, head_idx, batch_idx]
  387. auto g_seq_slice = make_tensor( // m_tensor[:actual_seq_len,:, head_idx, batch_idx]
  388. g_slice.data(),
  389. make_layout(cute::make_shape(actual_seq_len, get<1>(g_slice.layout().shape())), g_slice.layout().stride()));
  390. // slice up into tiles
  391. auto g_tensor = local_tile(
  392. g_seq_slice, tile_shape, make_coord(_, _0{}));
  393. return g_tensor;
  394. }
  395. ////////////////////////////////////////////////////////////////////////////////////////////////////
  396. } // namespace flash