tile_scheduler.hpp 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cutlass/fast_math.h"
  6. #include "cutlass/arch/barrier.h"
  7. #include "named_barrier.hpp"
  8. #include "utils.h"
  9. namespace flash {
  10. ///////////////////////////////////////////////////////////////////////////////
  11. // Host side kernel arguments
  12. struct TileSchedulerArguments {
  13. // num_head is num_head_q if not PackGQA, else num_head_k
  14. int const num_blocks, num_head, num_batch, num_splits;
  15. int const qhead_per_khead;
  16. int const seqlen; // Only used if Varlen and cu_seqlens == nullptr and seqused == nullptr
  17. int const seqlen_k, headdim, element_size; // Used to calculate L2 swizzling
  18. int* const tile_count_semaphore = nullptr;
  19. int* const cu_seqlens = nullptr;
  20. int* const seqused = nullptr;
  21. // int* const num_m_blocks_ptr = nullptr;
  22. int* const num_splits_dynamic_ptr = nullptr;
  23. };
  24. ///////////////////////////////////////////////////////////////////////////////
  25. template<bool Varlen=false, bool Split=false, bool PackGQA=false, int kBlock=128>
  26. class SingleTileScheduler {
  27. public:
  28. using SharedStorage = int;
  29. // Device side kernel params
  30. struct Params {
  31. int const num_blocks, num_head, num_batch, num_splits;
  32. int const qhead_per_khead;
  33. int const seqlen;
  34. cutlass::FastDivmod nsplits_divmod;
  35. int* const cu_seqlens;
  36. int* const seqused;
  37. };
  38. static Params
  39. to_underlying_arguments(TileSchedulerArguments const& args) {
  40. return {args.num_blocks, args.num_head, args.num_batch, !Split ? 1 : args.num_splits,
  41. args.qhead_per_khead, args.seqlen,
  42. cutlass::FastDivmod(!Split ? 1 : args.num_splits),
  43. !Varlen ? nullptr : args.cu_seqlens, !Varlen ? nullptr : args.seqused};
  44. }
  45. static dim3
  46. get_grid_shape(Params const& params, int num_sm) {
  47. return {uint32_t(params.num_blocks), uint32_t((!Split ? 1 : params.num_splits) * params.num_head), uint32_t(params.num_batch)};
  48. }
  49. struct WorkTileInfo {
  50. int block_idx = 0;
  51. int bidh = 0;
  52. int bidb = 0;
  53. bool is_valid_tile = false;
  54. CUTLASS_DEVICE
  55. bool
  56. is_valid(Params const& params) const {
  57. return is_valid_tile;
  58. }
  59. CUTLASS_DEVICE
  60. cute::tuple<int32_t, int32_t, int32_t, int32_t>
  61. get_block_coord(Params const& params) const {
  62. if constexpr (!Split) {
  63. return {block_idx, bidh, bidb, 0 /*split_idx*/};
  64. } else {
  65. int split_idx;
  66. int bidh_actual = params.nsplits_divmod.divmod(split_idx, bidh);
  67. return {block_idx, bidh_actual, bidb, split_idx};
  68. }
  69. }
  70. };
  71. CUTLASS_DEVICE
  72. SingleTileScheduler(SharedStorage* const smem_scheduler) { }
  73. template<bool IsProducerWarp=false>
  74. CUTLASS_DEVICE
  75. WorkTileInfo
  76. get_initial_work(Params const& params) const {
  77. WorkTileInfo work_info {int(blockIdx.x), int(blockIdx.y), int(blockIdx.z), true};
  78. if constexpr (Varlen) {
  79. int seqlen = params.seqused
  80. ? params.seqused[work_info.bidb]
  81. : (params.cu_seqlens ? params.cu_seqlens[work_info.bidb + 1] - params.cu_seqlens[work_info.bidb] : params.seqlen);
  82. if constexpr (PackGQA) { seqlen *= params.qhead_per_khead; }
  83. work_info.is_valid_tile = work_info.block_idx * kBlock < seqlen;
  84. }
  85. return work_info;
  86. }
  87. CUTLASS_DEVICE
  88. void
  89. init_consumer() const {}
  90. CUTLASS_DEVICE
  91. void
  92. prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {}
  93. template<bool IsProducerWarp=false>
  94. CUTLASS_DEVICE
  95. WorkTileInfo
  96. get_next_work(Params const& params, WorkTileInfo const& current_work) const {
  97. return {-1, -1, -1, false};
  98. }
  99. };
  100. ///////////////////////////////////////////////////////////////////////////////
  101. template<bool Split=false>
  102. class StaticPersistentTileScheduler {
  103. public:
  104. using SharedStorage = int;
  105. // Device side kernel params
  106. struct Params {
  107. int total_blocks;
  108. cutlass::FastDivmod m_block_divmod, head_divmod;
  109. cutlass::FastDivmod nsplits_divmod;
  110. };
  111. static Params
  112. to_underlying_arguments(TileSchedulerArguments const& args) {
  113. return {args.num_blocks * args.num_head * args.num_batch * (!Split ? 1 : args.num_splits),
  114. cutlass::FastDivmod(args.num_blocks), cutlass::FastDivmod(args.num_head * (!Split ? 1 : args.num_splits)),
  115. cutlass::FastDivmod(!Split ? 1 : args.num_splits)};
  116. }
  117. static dim3
  118. get_grid_shape(Params const& params, int num_sm) {
  119. return {uint32_t(num_sm)};
  120. }
  121. struct WorkTileInfo {
  122. int tile_idx;
  123. CUTLASS_DEVICE
  124. bool
  125. is_valid(Params const& params) const {
  126. return tile_idx < params.total_blocks;
  127. }
  128. CUTLASS_DEVICE
  129. cute::tuple<int32_t, int32_t, int32_t, int32_t>
  130. get_block_coord(Params const& params) const {
  131. int block, bidh, bidb;
  132. bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(block, tile_idx));
  133. int split_idx = 0;
  134. if constexpr (Split) {
  135. bidh = params.nsplits_divmod.divmod(split_idx, bidh);
  136. }
  137. return {block, bidh, bidb, split_idx};
  138. }
  139. };
  140. CUTLASS_DEVICE
  141. StaticPersistentTileScheduler(SharedStorage* const smem_scheduler) {};
  142. template<bool IsProducerWarp=false>
  143. CUTLASS_DEVICE
  144. WorkTileInfo
  145. get_initial_work(Params const& params) const {
  146. return {int(blockIdx.x)};
  147. }
  148. CUTLASS_DEVICE
  149. void
  150. init_consumer() const {}
  151. CUTLASS_DEVICE
  152. void
  153. prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {}
  154. template<bool IsProducerWarp=false>
  155. CUTLASS_DEVICE
  156. WorkTileInfo
  157. get_next_work(Params const& params, WorkTileInfo const& current_work) const {
  158. return {current_work.tile_idx + int(gridDim.x)};
  159. }
  160. };
  161. template<int NumMmaThreads=2 * cutlass::NumThreadsPerWarpGroup, int NumProducerThreads=cutlass::NumThreadsPerWarp,
  162. bool Split=false, bool PackGQA=false, bool WarpSpecialized=true>
  163. class DynamicPersistentTileScheduler {
  164. // This scheduler targets the causal (or local) case where each tile takes different
  165. // amount of time. We use longest-processing-time-first scheduling:
  166. // the longest remaining tile is assigned to the first SM that's free.
  167. // SM indicates they are free by incrementing a semaphore.
  168. // However, we have to make sure K & V still fit into L2 cache, so we perform scheduling
  169. // on "sections" of the head & batch dimension, each section consisting of e.g. 8 heads.
  170. // This is the L2 swizzling part. The size of each section is precomputed based on the
  171. // size of K & V and the L2 cache size.
  172. static_assert(WarpSpecialized || NumProducerThreads == NumMmaThreads);
  173. static constexpr int NumThreads = WarpSpecialized ? NumMmaThreads + NumProducerThreads : NumMmaThreads;
  174. public:
  175. using SharedStorage = int;
  176. protected:
  177. SharedStorage* const tile_count_smem;
  178. public:
  179. // Device side kernel params
  180. struct Params {
  181. int const total_blocks;
  182. cutlass::FastDivmod const m_block_divmod, head_divmod;
  183. cutlass::FastDivmod const l2_minor_divmod, l2_major_divmod;
  184. cutlass::FastDivmod const l2_minor_residual_divmod;
  185. int const num_hb_quotient;
  186. int* const tile_count_semaphore;
  187. };
  188. static Params
  189. to_underlying_arguments(TileSchedulerArguments const& args) {
  190. int const size_one_kv_head = args.seqlen_k * args.headdim * args.element_size * 2;
  191. int const size_l2 = 32 * 1024 * 1024; // 32 MB for K & V
  192. // Swizzle is the size of each "section". Round swizzle to a power of 2
  193. // If not PackGQA already, the size of each section can increase by qhead_per_khead
  194. // Need to be careful about the case where only one head will fit
  195. int const swizzle = (size_l2 < size_one_kv_head ? 1 : (1 << cutlass::find_log2(size_l2 / size_one_kv_head))) * (PackGQA ? 1 : args.qhead_per_khead);
  196. // If we're in the last section (called residual), we don't want to divide by
  197. // swizzle. Instead we want to divide by the remainder.
  198. int const num_hb_remainder = (args.num_head * args.num_batch) % swizzle;
  199. int const num_split_blocks = args.num_blocks * (!Split ? 1 : args.num_splits);
  200. // printf("num_split_blocks = %d, num_head = %d, num_batch = %d, swizzle = %d, PackGQA = %d, qhead_per_khead = %d, num_hb_remainder = %d\n", num_split_blocks, args.num_head, args.num_batch, swizzle, int(PackGQA), args.qhead_per_khead, num_hb_remainder);
  201. assert(args.tile_count_semaphore != nullptr);
  202. return {num_split_blocks * args.num_head * args.num_batch,
  203. cutlass::FastDivmod(args.num_blocks), cutlass::FastDivmod(args.num_head),
  204. cutlass::FastDivmod(swizzle), cutlass::FastDivmod(swizzle * num_split_blocks),
  205. // don't divide by 0
  206. cutlass::FastDivmod(num_hb_remainder > 0 ? num_hb_remainder : 1),
  207. (args.num_head * args.num_batch) / swizzle,
  208. args.tile_count_semaphore};
  209. }
  210. static dim3
  211. get_grid_shape(Params const& params, int num_sm) {
  212. return {uint32_t(num_sm)};
  213. }
  214. struct WorkTileInfo {
  215. int tile_idx;
  216. CUTLASS_DEVICE
  217. bool
  218. is_valid(Params const& params) const {
  219. return tile_idx < params.total_blocks;
  220. }
  221. CUTLASS_DEVICE
  222. cute::tuple<int32_t, int32_t, int32_t, int32_t>
  223. get_block_coord(Params const& params) const {
  224. int block, bidh, bidb;
  225. int l2_mod, bidhb, bidhb_residual;
  226. bidhb = params.l2_major_divmod.divmod(l2_mod, tile_idx);
  227. // If we're in the last section (called residual), we don't want to divide by
  228. // swizzle. Instead we want to divide by the remainder.
  229. if (bidhb < params.num_hb_quotient) {
  230. block = params.l2_minor_divmod.divmod(bidhb_residual, l2_mod);
  231. } else {
  232. block = params.l2_minor_residual_divmod.divmod(bidhb_residual, l2_mod);
  233. }
  234. bidb = params.head_divmod.divmod(bidh, bidhb * params.l2_minor_divmod.divisor + bidhb_residual);
  235. int split_idx = 0;
  236. if constexpr (Split) {
  237. split_idx = params.m_block_divmod.divmod(block, block);
  238. }
  239. // Longest-processing-time-first
  240. block = params.m_block_divmod.divisor - 1 - block;
  241. return {block, bidh, bidb, split_idx};
  242. }
  243. };
  244. CUTLASS_DEVICE
  245. DynamicPersistentTileScheduler(SharedStorage* const smem_scheduler) : tile_count_smem(smem_scheduler) {};
  246. template<bool IsProducerWarp=false>
  247. CUTLASS_DEVICE
  248. WorkTileInfo
  249. get_initial_work(Params const& params) const {
  250. return {int(blockIdx.x)};
  251. }
  252. CUTLASS_DEVICE
  253. void
  254. init_consumer() const {
  255. if (WarpSpecialized || cutlass::canonical_warp_idx_sync() > 0) {
  256. flash::named_barrier_arrive(NumThreads, static_cast<uint32_t>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
  257. }
  258. }
  259. CUTLASS_DEVICE
  260. void
  261. prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {
  262. if (threadIdx.x % NumProducerThreads == 0) {
  263. current_work.tile_idx = atomicAdd(params.tile_count_semaphore, 1) + int(gridDim.x);
  264. }
  265. }
  266. template<bool IsProducerWarp=false>
  267. CUTLASS_DEVICE
  268. WorkTileInfo
  269. get_next_work(Params const& params, WorkTileInfo const& current_work) const {
  270. if constexpr (IsProducerWarp) {
  271. // thread 0 already has the right tile_idx, just need to broadcast to the rest of warp 0
  272. int new_tile_idx = __shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/);
  273. flash::named_barrier_sync(NumThreads, static_cast<uint32_t>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
  274. if (threadIdx.x % NumProducerThreads == 0) {
  275. *tile_count_smem = current_work.tile_idx;
  276. }
  277. flash::named_barrier_arrive(NumThreads, static_cast<uint32_t>(FwdNamedBarriers::TileCountSmemFull) /*id*/);
  278. return {new_tile_idx};
  279. } else {
  280. flash::named_barrier_sync(NumThreads, static_cast<uint32_t>(FwdNamedBarriers::TileCountSmemFull) /*id*/);
  281. int tile_idx = *tile_count_smem;
  282. flash::named_barrier_arrive(NumThreads, static_cast<uint32_t>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
  283. return {tile_idx};
  284. }
  285. }
  286. };
  287. template<int kBlock, int NumMmaThreads=2 * cutlass::NumThreadsPerWarpGroup, int NumProducerThreads=cutlass::NumThreadsPerWarp, bool Split=false, bool PackGQA=false, bool WarpSpecialized=true>
  288. class VarlenDynamicPersistentTileScheduler {
  289. static_assert(WarpSpecialized || NumProducerThreads == NumMmaThreads);
  290. static constexpr int NumThreads = WarpSpecialized ? NumMmaThreads + NumProducerThreads : NumMmaThreads;
  291. public:
  292. using SharedStorage = int4;
  293. protected:
  294. SharedStorage* const work_info_smem;
  295. public:
  296. // Device side kernel params
  297. struct Params {
  298. int num_head, num_batch;
  299. int const qhead_per_khead;
  300. int const seqlen;
  301. cutlass::FastDivmod head_divmod;
  302. cutlass::FastDivmod nsplits_divmod;
  303. int* const tile_count_semaphore;
  304. int* const cu_seqlens;
  305. int* const seqused;
  306. // int* const num_m_blocks_ptr;
  307. int* const num_splits_dynamic_ptr;
  308. };
  309. static Params
  310. to_underlying_arguments(TileSchedulerArguments const& args) {
  311. // If Split, for the purpose of scheduling, we pretend that instead there are
  312. // (args.num_splits * args.num_head) number of heads.
  313. assert(args.tile_count_semaphore != nullptr);
  314. assert(!Split || args.num_splits_dynamic_ptr != nullptr);
  315. assert(num_head < (1 << 16)); // We use the top 16 bits to store num_splits & split_idx
  316. assert(!Split || args.num_splits < (1 << 8)); // We use the top 8 bits to store num_splits
  317. return {args.num_head, args.num_batch,
  318. args.qhead_per_khead, args.seqlen,
  319. cutlass::FastDivmod(args.num_head),
  320. cutlass::FastDivmod(!Split ? 1 : args.num_splits),
  321. args.tile_count_semaphore, args.cu_seqlens, args.seqused,
  322. // args.num_m_blocks_ptr, args.num_splits_dynamic_ptr};
  323. args.num_splits_dynamic_ptr};
  324. }
  325. static dim3
  326. get_grid_shape(Params const& params, int num_sm) {
  327. return {uint32_t(num_sm)};
  328. }
  329. struct WorkTileInfo {
  330. int tile_idx, block, bidh, bidb;
  331. CUTLASS_DEVICE
  332. bool
  333. is_valid(Params const& params) const {
  334. // if (blockIdx.x >= 0 && (threadIdx.x == 128 || threadIdx.x == 0)) { printf("blockIdx.x = %d, threadIdx.x = %d, checking valid, bidb = %d, params.num_batch = %d\n", blockIdx.x, threadIdx.x, bidb, params.num_batch); }
  335. return bidb < params.num_batch;
  336. }
  337. CUTLASS_DEVICE
  338. cute::tuple<int32_t, int32_t, int32_t, int32_t>
  339. get_block_coord(Params const& params) const {
  340. if constexpr (!Split) {
  341. return {block, bidh, bidb, 0 /*split_idx*/};
  342. } else {
  343. // the top 8 bits of bidh store num_splits and the next 8 bits store split_idx
  344. // reinterpret_cast to uint32_t to make sure we're not doing sign extension when we shift
  345. uint32_t bidh_packed = reinterpret_cast<uint32_t const&>(bidh);
  346. uint32_t bidh_actual_u = bidh_packed & 0x0000FFFF;
  347. int bidh_actual = reinterpret_cast<int&>(bidh_actual_u);
  348. // Use the top 16 bits of split_idx to store num_splits and the next 16 bits to store split_idx
  349. uint32_t split_idx_u = ((bidh_packed & 0x00FF0000) >> 16) + ((bidh_packed & 0xFF000000) >> 8);
  350. int split_idx = reinterpret_cast<int&>(split_idx_u);
  351. // int bidh_actual = params.nsplits_divmod.divmod(split_idx, bidh);
  352. // if (threadIdx.x == 128) {
  353. // printf("blockIdx.x = %d, bidb = %d, bidh = %d, bidh_actual = %d, split_idx = %d\n", blockIdx.x, bidb, bidh, bidh_actual, split_idx);
  354. // }
  355. return {block, bidh_actual, bidb, split_idx};
  356. }
  357. }
  358. };
  359. CUTLASS_DEVICE
  360. VarlenDynamicPersistentTileScheduler(SharedStorage* const smem_scheduler) : work_info_smem(smem_scheduler) {};
  361. CUTLASS_DEVICE
  362. WorkTileInfo
  363. tile_idx_to_work_tile(Params const& params, int next_tile_idx, WorkTileInfo const& current_work) const {
  364. int lane = threadIdx.x % cutlass::NumThreadsPerWarp;
  365. auto get_num_m_blocks = [&] (int bidb_start) {
  366. int batch_idx = lane + bidb_start;
  367. int seqlen = params.seqlen * (!PackGQA ? 1 : params.qhead_per_khead);
  368. if (seqlen > kBlock) {
  369. if (params.seqused) {
  370. seqlen = batch_idx < params.num_batch ? params.seqused[batch_idx] : 0;
  371. } else if (params.cu_seqlens) {
  372. int cur_cu_seqlen = batch_idx <= params.num_batch ? params.cu_seqlens[batch_idx] : 0;
  373. int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1);
  374. seqlen = next_cu_seqlen - cur_cu_seqlen;
  375. } else {
  376. seqlen = params.seqlen;
  377. }
  378. if constexpr (PackGQA) { seqlen *= params.qhead_per_khead; }
  379. }
  380. return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1
  381. ? cute::ceil_div(seqlen, kBlock) : 0;
  382. // ? params.num_m_blocks_ptr[batch_idx] : 0;
  383. };
  384. auto get_num_splits = [&] (int bidb_start) {
  385. int batch_idx = lane + bidb_start;
  386. return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1
  387. ? (!Split ? 1 : params.num_splits_dynamic_ptr[batch_idx])
  388. : 0;
  389. };
  390. int num_m_blocks = get_num_m_blocks(current_work.bidb); // Different for each lane
  391. int num_splits = get_num_splits(current_work.bidb);
  392. int num_split_m_blocks = !Split ? num_m_blocks : num_m_blocks * num_splits;
  393. // Cumulative number of blocks for the next 31 batches
  394. int num_m_blocks_cumulative = warp_prefix_sum(num_split_m_blocks);
  395. // Total number of blocks for the next 31 batches
  396. int m_blocks_in_group = __shfl_sync(0xffffffff, num_m_blocks_cumulative, cutlass::NumThreadsPerWarp - 1);
  397. // Only the lower 16 bits are the actual bidh
  398. int current_bidh = !Split ? current_work.bidh : (current_work.bidh & 0x0000FFFF);
  399. int group_end_tile = current_work.tile_idx - current_work.block - current_bidh * __shfl_sync(0xffffffff, num_split_m_blocks, 0 /*lane*/) + m_blocks_in_group * params.num_head; // Same for all lanes
  400. if constexpr (Split) {
  401. int current_split_idx = (current_work.bidh & 0x00FF0000) >> 16;
  402. group_end_tile -= current_split_idx * __shfl_sync(0xffffffff, num_m_blocks, 0 /*lane*/);
  403. }
  404. int bidb = current_work.bidb;
  405. // if (blockIdx.x <= 9 && threadIdx.x == 0) {
  406. // printf("Before while, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, cur tile_idx = %d, cur block = %d, cur bidh = %d, num_split_m_blocks = %d, group_end_tile = %d, m_blocks_in_group = %d\n", blockIdx.x, threadIdx.x, current_work.bidb, num_m_blocks, next_tile_idx, current_work.tile_idx, current_work.block, current_bidh, num_split_m_blocks, group_end_tile, m_blocks_in_group);
  407. // }
  408. // if (threadIdx.x == 0 && blockIdx.x == 0) { printf("tile_idx = %d, group_end_tile = %d, num_m_blocks_cumulative = %d, m_blocks_in_group = %d\n", current_work.tile_idx, group_end_tile, num_m_blocks_cumulative, m_blocks_in_group); }
  409. while (group_end_tile <= next_tile_idx) {
  410. bidb += cutlass::NumThreadsPerWarp - 1;
  411. if (bidb >= params.num_batch) {
  412. // if (blockIdx.x <= 9 && threadIdx.x == 0) {
  413. // printf("Returning early, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d\n", blockIdx.x, threadIdx.x, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group);
  414. // }
  415. return {next_tile_idx, 0, 0, params.num_batch};
  416. }
  417. num_m_blocks = get_num_m_blocks(bidb);
  418. num_splits = get_num_splits(bidb);
  419. num_split_m_blocks = !Split ? num_m_blocks : num_m_blocks * num_splits;
  420. num_m_blocks_cumulative = warp_prefix_sum(num_split_m_blocks);
  421. m_blocks_in_group = __shfl_sync(0xffffffff, num_m_blocks_cumulative, cutlass::NumThreadsPerWarp - 1);
  422. group_end_tile += m_blocks_in_group * params.num_head;
  423. // if (blockIdx.x <= 9 && threadIdx.x == 0) {
  424. // printf("Bottom of while, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d\n", blockIdx.x, threadIdx.x, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group);
  425. // }
  426. }
  427. int group_start_tile = group_end_tile - m_blocks_in_group * params.num_head;
  428. // The next problem to process is the first one that does not have ending tile position
  429. // that is greater than or equal to tile index.
  430. int batch_idx_in_group = __popc(__ballot_sync(0xffffffff, group_start_tile + num_m_blocks_cumulative * params.num_head <= next_tile_idx));
  431. // if (threadIdx.x == 31 || threadIdx.x == 0) { printf("blockIdx.x = %d, tidx %d, group_start_tile = %d, num_m_blocks_cumulative = %d, num_head = %d, next_tile_idx = %d, ballot = %x, batch_idx_in_group = %d\n", blockIdx.x, threadIdx.x, group_start_tile, num_m_blocks_cumulative, params.num_head, next_tile_idx, tmp, batch_idx_in_group); }
  432. bidb += batch_idx_in_group;
  433. num_m_blocks = __shfl_sync(0xffffffff, num_m_blocks, batch_idx_in_group);
  434. if constexpr (Split) { num_splits = __shfl_sync(0xffffffff, num_splits, batch_idx_in_group); }
  435. int mh_block = next_tile_idx - group_start_tile - (batch_idx_in_group == 0 ? 0 : __shfl_sync(0xffffffff, num_m_blocks_cumulative, batch_idx_in_group - 1)) * params.num_head;
  436. int bidh = mh_block / num_m_blocks;
  437. int block = mh_block - bidh * num_m_blocks;
  438. if constexpr (Split) {
  439. int bidh_actual = bidh / num_splits;
  440. int split_idx = bidh - bidh_actual * num_splits;
  441. // TODO: idk why this gives wrong answer nondeterministically
  442. // int bidh_actual, split_idx;
  443. // split_idx = params.head_divmod.divmod(bidh_actual, bidh);
  444. // Use the top 8 bits to store num_splits and the next 8 bits to store split_idx
  445. // reinterpret_cast to uint32_t to make sure we're not doing sign extension when we shift
  446. uint32_t bidh_packed = reinterpret_cast<uint32_t&>(bidh_actual) + (reinterpret_cast<uint32_t&>(split_idx) << 16) + (reinterpret_cast<uint32_t&>(num_splits) << 24);
  447. // if (threadIdx.x == 0) {
  448. // printf("blockIdx.x = %d, group_start_tiled = %d, bidb = %d, batch_idx_in_group = %d, mh_block = %d, num_m_blocks = %d, bidh = %d, bidh_actual = %d, split_idx = %d, num_splits = %d, bidh_packed = %d\n", blockIdx.x, group_start_tile, bidb, batch_idx_in_group, mh_block, num_m_blocks, bidh, bidh_actual, split_idx, num_splits, bidh_packed);
  449. // }
  450. bidh = reinterpret_cast<int&>(bidh_packed);
  451. }
  452. // if (blockIdx.x <= 9 && threadIdx.x == 0) {
  453. // printf("Before returning, blockIdx.x = %d, threadIdx.x = %d, group_start_tile = %d, batch_idx_in_group = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d, mh_block = %d, bidh = %d, block = %d\n", blockIdx.x, threadIdx.x, group_start_tile, batch_idx_in_group, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group, mh_block, bidh, block);
  454. // }
  455. return {next_tile_idx, block, bidh, bidb};
  456. }
  457. template<bool IsProducerWarp=false>
  458. CUTLASS_DEVICE
  459. WorkTileInfo
  460. get_initial_work(Params const& params) const {
  461. if constexpr (IsProducerWarp) {
  462. WorkTileInfo work_info = tile_idx_to_work_tile(params, int(blockIdx.x), {0, 0, 0, 0});
  463. if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) {
  464. *work_info_smem = make_int4(work_info.tile_idx, work_info.block, work_info.bidh, work_info.bidb);
  465. }
  466. flash::named_barrier_arrive(NumThreads, static_cast<uint32_t>(FwdNamedBarriers::TileCountSmemFull) /*id*/);
  467. return work_info;
  468. } else {
  469. return get_next_work<false>(params, {0, 0, 0, 0});
  470. }
  471. }
  472. CUTLASS_DEVICE
  473. void
  474. init_consumer() const {
  475. // Don't arrive at the TileCountSmemEmpty barrier here, because get_initial_work will do that
  476. }
  477. CUTLASS_DEVICE
  478. void
  479. prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {
  480. if (threadIdx.x % NumProducerThreads == 0) {
  481. current_work.tile_idx = atomicAdd(params.tile_count_semaphore, 1) + int(gridDim.x);
  482. }
  483. }
  484. template<bool IsProducerWarp=false>
  485. CUTLASS_DEVICE
  486. WorkTileInfo
  487. get_next_work(Params const& params, WorkTileInfo const& current_work) const {
  488. if constexpr (IsProducerWarp) {
  489. // thread 0 has the next tile_idx, just need to broadcast to the rest of warp 0
  490. int new_tile_idx = __shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/);
  491. WorkTileInfo work_info = {__shfl_sync(0xffffffff, current_work.tile_idx, 1 /*lane*/), current_work.block, current_work.bidh, current_work.bidb};
  492. work_info = tile_idx_to_work_tile(params, new_tile_idx, work_info);
  493. flash::named_barrier_sync(NumThreads, static_cast<uint32_t>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
  494. if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) {
  495. *work_info_smem = make_int4(work_info.tile_idx, work_info.block, work_info.bidh, work_info.bidb);
  496. }
  497. flash::named_barrier_arrive(NumThreads, static_cast<uint32_t>(FwdNamedBarriers::TileCountSmemFull) /*id*/);
  498. return work_info;
  499. } else {
  500. flash::named_barrier_sync(NumThreads, static_cast<uint32_t>(FwdNamedBarriers::TileCountSmemFull) /*id*/);
  501. int4 work_info = *work_info_smem;
  502. flash::named_barrier_arrive(NumThreads, static_cast<uint32_t>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
  503. return WorkTileInfo{work_info.x, work_info.y, work_info.z, work_info.w};
  504. }
  505. }
  506. };
  507. } // flash