tile_scheduler.hpp 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cutlass/fast_math.h"
  6. #include "cutlass/arch/barrier.h"
  7. #include "named_barrier.hpp"
  8. namespace flash {
  9. ///////////////////////////////////////////////////////////////////////////////
  10. // Host side kernel arguments
  11. struct TileSchedulerArguments {
  12. // num_head is num_head_q if not PackGQA, else num_head_k
  13. int const num_blocks, num_head, num_batch, num_splits;
  14. int const qhead_per_khead;
  15. int const seqlen; // Only used if Varlen and cu_seqlens == nullptr and seqused == nullptr
  16. int const seqlen_k, headdim, element_size; // Used to calculate L2 swizzling
  17. int* const tile_count_semaphore = nullptr;
  18. int* const cu_seqlens = nullptr;
  19. int* const seqused = nullptr;
  20. };
  21. ///////////////////////////////////////////////////////////////////////////////
  22. template<bool Varlen=false, bool Split=false, bool PackGQA=false, int kBlock=128>
  23. class SingleTileScheduler {
  24. public:
  25. using SharedStorage = int;
  26. // Device side kernel params
  27. struct Params {
  28. int const num_blocks, num_head, num_batch, num_splits;
  29. int const qhead_per_khead;
  30. int const seqlen;
  31. cutlass::FastDivmod nsplits_divmod;
  32. int* const cu_seqlens;
  33. int* const seqused;
  34. };
  35. static Params
  36. to_underlying_arguments(TileSchedulerArguments const& args) {
  37. return {args.num_blocks, args.num_head, args.num_batch, !Split ? 1 : args.num_splits,
  38. args.qhead_per_khead, args.seqlen,
  39. cutlass::FastDivmod(!Split ? 1 : args.num_splits),
  40. !Varlen ? nullptr : args.cu_seqlens, !Varlen ? nullptr : args.seqused};
  41. }
  42. static dim3
  43. get_grid_shape(Params const& params, int num_sm) {
  44. return {uint32_t(params.num_blocks), uint32_t((!Split ? 1 : params.num_splits) * params.num_head), uint32_t(params.num_batch)};
  45. }
  46. struct WorkTileInfo {
  47. int block_idx = 0;
  48. int bidh = 0;
  49. int bidb = 0;
  50. bool is_valid_tile = false;
  51. CUTLASS_DEVICE
  52. bool
  53. is_valid(Params const& params) const {
  54. return is_valid_tile;
  55. }
  56. CUTLASS_DEVICE
  57. cute::tuple<int32_t, int32_t, int32_t, int32_t>
  58. get_block_coord(Params const& params) const {
  59. if constexpr (!Split) {
  60. return {block_idx, bidh, bidb, 0 /*split_idx*/};
  61. } else {
  62. int split_idx;
  63. int bidh_actual = params.nsplits_divmod.divmod(split_idx, bidh);
  64. return {block_idx, bidh_actual, bidb, split_idx};
  65. }
  66. }
  67. };
  68. CUTLASS_DEVICE
  69. SingleTileScheduler(SharedStorage* const smem_scheduler) { }
  70. template<bool IsProducerWarp=false>
  71. CUTLASS_DEVICE
  72. WorkTileInfo
  73. get_initial_work(Params const& params) const {
  74. WorkTileInfo work_info {int(blockIdx.x), int(blockIdx.y), int(blockIdx.z), true};
  75. if constexpr (Varlen) {
  76. int seqlen = params.seqused
  77. ? params.seqused[work_info.bidb]
  78. : (params.cu_seqlens ? params.cu_seqlens[work_info.bidb + 1] - params.cu_seqlens[work_info.bidb] : params.seqlen);
  79. if constexpr (PackGQA) { seqlen *= params.qhead_per_khead; }
  80. work_info.is_valid_tile = work_info.block_idx * kBlock < seqlen;
  81. }
  82. return work_info;
  83. }
  84. CUTLASS_DEVICE
  85. void
  86. init_consumer() const {}
  87. CUTLASS_DEVICE
  88. void
  89. prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {}
  90. template<bool IsProducerWarp=false>
  91. CUTLASS_DEVICE
  92. WorkTileInfo
  93. get_next_work(Params const& params, WorkTileInfo const& current_work) const {
  94. return {-1, -1, -1, false};
  95. }
  96. };
  97. ///////////////////////////////////////////////////////////////////////////////
  98. template<bool Split=false>
  99. class StaticPersistentTileScheduler {
  100. public:
  101. using SharedStorage = int;
  102. // Device side kernel params
  103. struct Params {
  104. int total_blocks;
  105. cutlass::FastDivmod m_block_divmod, head_divmod;
  106. cutlass::FastDivmod nsplits_divmod;
  107. };
  108. static Params
  109. to_underlying_arguments(TileSchedulerArguments const& args) {
  110. return {args.num_blocks * args.num_head * args.num_batch * (!Split ? 1 : args.num_splits),
  111. cutlass::FastDivmod(args.num_blocks), cutlass::FastDivmod(args.num_head * (!Split ? 1 : args.num_splits)),
  112. cutlass::FastDivmod(!Split ? 1 : args.num_splits)};
  113. }
  114. static dim3
  115. get_grid_shape(Params const& params, int num_sm) {
  116. return {uint32_t(num_sm)};
  117. }
  118. struct WorkTileInfo {
  119. int tile_idx;
  120. CUTLASS_DEVICE
  121. bool
  122. is_valid(Params const& params) const {
  123. return tile_idx < params.total_blocks;
  124. }
  125. CUTLASS_DEVICE
  126. cute::tuple<int32_t, int32_t, int32_t, int32_t>
  127. get_block_coord(Params const& params) const {
  128. int block, bidh, bidb;
  129. bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(block, tile_idx));
  130. int split_idx = 0;
  131. if constexpr (Split) {
  132. bidh = params.nsplits_divmod.divmod(split_idx, bidh);
  133. }
  134. return {block, bidh, bidb, split_idx};
  135. }
  136. };
  137. CUTLASS_DEVICE
  138. StaticPersistentTileScheduler(SharedStorage* const smem_scheduler) {};
  139. template<bool IsProducerWarp=false>
  140. CUTLASS_DEVICE
  141. WorkTileInfo
  142. get_initial_work(Params const& params) const {
  143. return {int(blockIdx.x)};
  144. }
  145. CUTLASS_DEVICE
  146. void
  147. init_consumer() const {}
  148. CUTLASS_DEVICE
  149. void
  150. prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {}
  151. template<bool IsProducerWarp=false>
  152. CUTLASS_DEVICE
  153. WorkTileInfo
  154. get_next_work(Params const& params, WorkTileInfo const& current_work) const {
  155. return {current_work.tile_idx + int(gridDim.x)};
  156. }
  157. };
  158. template<int NumMmaThreads=2 * cutlass::NumThreadsPerWarpGroup, int NumProducerThreads=cutlass::NumThreadsPerWarp,
  159. bool Split=false, bool PackGQA=false, bool WarpSpecialized=true>
  160. class DynamicPersistentTileScheduler {
  161. // This scheduler targets the causal (or local) case where each tile takes different
  162. // amount of time. We use longest-processing-time-first scheduling:
  163. // the longest remaining tile is assigned to the first SM that's free.
  164. // SM indicates they are free by incrementing a semaphore.
  165. // However, we have to make sure K & V still fit into L2 cache, so we perform scheduling
  166. // on "sections" of the head & batch dimension, each section consisting of e.g. 8 heads.
  167. // This is the L2 swizzling part. The size of each section is precomputed based on the
  168. // size of K & V and the L2 cache size.
  169. static_assert(WarpSpecialized || NumProducerThreads == NumMmaThreads);
  170. static constexpr int NumThreads = WarpSpecialized ? NumMmaThreads + NumProducerThreads : NumMmaThreads;
  171. public:
  172. using SharedStorage = int;
  173. protected:
  174. SharedStorage* const tile_count_smem;
  175. public:
  176. // Device side kernel params
  177. struct Params {
  178. int const total_blocks;
  179. cutlass::FastDivmod const m_block_divmod, head_divmod;
  180. cutlass::FastDivmod const l2_minor_divmod, l2_major_divmod;
  181. cutlass::FastDivmod const l2_minor_residual_divmod;
  182. int const num_hb_quotient;
  183. int* const tile_count_semaphore;
  184. };
  185. static Params
  186. to_underlying_arguments(TileSchedulerArguments const& args) {
  187. int const size_one_kv_head = args.seqlen_k * args.headdim * args.element_size * 2;
  188. int const size_l2 = 32 * 1024 * 1024; // 32 MB for K & V
  189. // Swizzle is the size of each "section". Round swizzle to a power of 2
  190. // If not PackGQA already, the size of each section can increase by qhead_per_khead
  191. int const swizzle = (1 << cutlass::find_log2(size_l2 / size_one_kv_head)) * (PackGQA ? 1 : args.qhead_per_khead);
  192. // If we're in the last section (called residual), we don't want to divide by
  193. // swizzle. Instead we want to divide by the remainder.
  194. int const num_hb_remainder = (args.num_head * args.num_batch) % swizzle;
  195. int const num_split_blocks = args.num_blocks * (!Split ? 1 : args.num_splits);
  196. // printf("num_split_blocks = %d, num_head = %d, num_batch = %d, swizzle = %d, PackGQA = %d, qhead_per_khead = %d, num_hb_remainder = %d\n", num_split_blocks, args.num_head, args.num_batch, swizzle, int(PackGQA), args.qhead_per_khead, num_hb_remainder);
  197. assert(args.tile_count_semaphore != nullptr);
  198. return {num_split_blocks * args.num_head * args.num_batch,
  199. cutlass::FastDivmod(args.num_blocks), cutlass::FastDivmod(args.num_head),
  200. cutlass::FastDivmod(swizzle), cutlass::FastDivmod(swizzle * num_split_blocks),
  201. // don't divide by 0
  202. cutlass::FastDivmod(num_hb_remainder > 0 ? num_hb_remainder : 1),
  203. (args.num_head * args.num_batch) / swizzle,
  204. args.tile_count_semaphore};
  205. }
  206. static dim3
  207. get_grid_shape(Params const& params, int num_sm) {
  208. return {uint32_t(num_sm)};
  209. }
  210. struct WorkTileInfo {
  211. int tile_idx;
  212. CUTLASS_DEVICE
  213. bool
  214. is_valid(Params const& params) const {
  215. return tile_idx < params.total_blocks;
  216. }
  217. CUTLASS_DEVICE
  218. cute::tuple<int32_t, int32_t, int32_t, int32_t>
  219. get_block_coord(Params const& params) const {
  220. int block, bidh, bidb;
  221. int l2_mod, bidhb, bidhb_residual;
  222. bidhb = params.l2_major_divmod.divmod(l2_mod, tile_idx);
  223. // If we're in the last section (called residual), we don't want to divide by
  224. // swizzle. Instead we want to divide by the remainder.
  225. if (bidhb < params.num_hb_quotient) {
  226. block = params.l2_minor_divmod.divmod(bidhb_residual, l2_mod);
  227. } else {
  228. block = params.l2_minor_residual_divmod.divmod(bidhb_residual, l2_mod);
  229. }
  230. bidb = params.head_divmod.divmod(bidh, bidhb * params.l2_minor_divmod.divisor + bidhb_residual);
  231. int split_idx = 0;
  232. if constexpr (Split) {
  233. split_idx = params.m_block_divmod.divmod(block, block);
  234. }
  235. // Longest-processing-time-first
  236. block = params.m_block_divmod.divisor - 1 - block;
  237. return {block, bidh, bidb, split_idx};
  238. }
  239. };
  240. CUTLASS_DEVICE
  241. DynamicPersistentTileScheduler(SharedStorage* const smem_scheduler) : tile_count_smem(smem_scheduler) {};
  242. template<bool IsProducerWarp=false>
  243. CUTLASS_DEVICE
  244. WorkTileInfo
  245. get_initial_work(Params const& params) const {
  246. return {int(blockIdx.x)};
  247. }
  248. CUTLASS_DEVICE
  249. void
  250. init_consumer() const {
  251. if (WarpSpecialized || cutlass::canonical_warp_idx_sync() > 0) {
  252. flash::named_barrier_arrive(NumThreads, static_cast<uint32_t>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
  253. }
  254. }
  255. CUTLASS_DEVICE
  256. void
  257. prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {
  258. if (threadIdx.x % NumProducerThreads == 0) {
  259. current_work.tile_idx = atomicAdd(params.tile_count_semaphore, 1) + int(gridDim.x);
  260. }
  261. }
  262. template<bool IsProducerWarp=false>
  263. CUTLASS_DEVICE
  264. WorkTileInfo
  265. get_next_work(Params const& params, WorkTileInfo const& current_work) const {
  266. if constexpr (IsProducerWarp) {
  267. // thread 0 already has the right tile_idx, just need to broadcast to the rest of warp 0
  268. int new_tile_idx = __shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/);
  269. flash::named_barrier_sync(NumThreads, static_cast<uint32_t>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
  270. if (threadIdx.x % NumProducerThreads == 0) {
  271. *tile_count_smem = current_work.tile_idx;
  272. }
  273. flash::named_barrier_arrive(NumThreads, static_cast<uint32_t>(FwdNamedBarriers::TileCountSmemFull) /*id*/);
  274. return {new_tile_idx};
  275. } else {
  276. flash::named_barrier_sync(NumThreads, static_cast<uint32_t>(FwdNamedBarriers::TileCountSmemFull) /*id*/);
  277. int tile_idx = *tile_count_smem;
  278. flash::named_barrier_arrive(NumThreads, static_cast<uint32_t>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
  279. return {tile_idx};
  280. }
  281. }
  282. };
  283. template<int kBlock, int NumMmaThreads=2 * cutlass::NumThreadsPerWarpGroup, int NumProducerThreads=cutlass::NumThreadsPerWarp, bool Split=false, bool PackGQA=false, bool WarpSpecialized=true>
  284. class VarlenDynamicPersistentTileScheduler {
  285. static_assert(WarpSpecialized || NumProducerThreads == NumMmaThreads);
  286. static constexpr int NumThreads = WarpSpecialized ? NumMmaThreads + NumProducerThreads : NumMmaThreads;
  287. public:
  288. using SharedStorage = int4;
  289. protected:
  290. SharedStorage* const work_info_smem;
  291. public:
  292. // Device side kernel params
  293. struct Params {
  294. int num_head, num_batch;
  295. int const qhead_per_khead;
  296. int const seqlen;
  297. cutlass::FastDivmod nsplits_divmod;
  298. int* const tile_count_semaphore;
  299. int* const cu_seqlens;
  300. int* const seqused;
  301. };
  302. static Params
  303. to_underlying_arguments(TileSchedulerArguments const& args) {
  304. // If Split, for the purpose of scheduling, we pretend that instead there are
  305. // (args.num_splits * args.num_head) number of heads.
  306. assert(args.tile_count_semaphore != nullptr);
  307. return {args.num_head * (!Split ? 1 : args.num_splits), args.num_batch,
  308. args.qhead_per_khead, args.seqlen,
  309. cutlass::FastDivmod(!Split ? 1 : args.num_splits),
  310. args.tile_count_semaphore, args.cu_seqlens, args.seqused};
  311. }
  312. static dim3
  313. get_grid_shape(Params const& params, int num_sm) {
  314. return {uint32_t(num_sm)};
  315. }
  316. struct WorkTileInfo {
  317. int tile_idx, block, bidh, bidb;
  318. CUTLASS_DEVICE
  319. bool
  320. is_valid(Params const& params) const {
  321. // if (blockIdx.x >= 0 && (threadIdx.x == 128 || threadIdx.x == 0)) { printf("blockIdx.x = %d, threadIdx.x = %d, checking valid, bidb = %d, params.num_batch = %d\n", blockIdx.x, threadIdx.x, bidb, params.num_batch); }
  322. return bidb < params.num_batch;
  323. }
  324. CUTLASS_DEVICE
  325. cute::tuple<int32_t, int32_t, int32_t, int32_t>
  326. get_block_coord(Params const& params) const {
  327. if constexpr (!Split) {
  328. return {block, bidh, bidb, 0 /*split_idx*/};
  329. } else {
  330. int split_idx;
  331. int bidh_actual = params.nsplits_divmod.divmod(split_idx, bidh);
  332. return {block, bidh_actual, bidb, split_idx};
  333. }
  334. }
  335. };
  336. CUTLASS_DEVICE
  337. VarlenDynamicPersistentTileScheduler(SharedStorage* const smem_scheduler) : work_info_smem(smem_scheduler) {};
  338. CUTLASS_DEVICE
  339. WorkTileInfo
  340. tile_idx_to_work_tile(Params const& params, int next_tile_idx, WorkTileInfo const& current_work) const {
  341. auto prefix_sum = [](int val) {
  342. int lane = threadIdx.x % cutlass::NumThreadsPerWarp;
  343. CUTLASS_PRAGMA_UNROLL
  344. for (int i = 1; i < cutlass::NumThreadsPerWarp; i <<= 1) {
  345. int32_t partial_sum = __shfl_up_sync(0xffffffff, val, i);
  346. if (lane >= i) { val += partial_sum; }
  347. }
  348. return val;
  349. };
  350. auto get_num_m_blocks = [&](int bidb_start) {
  351. int lane = threadIdx.x % cutlass::NumThreadsPerWarp;
  352. int seqlen;
  353. if (params.seqused) {
  354. seqlen = lane + bidb_start < params.num_batch ? params.seqused[lane + bidb_start] : 0;
  355. } else if (params.cu_seqlens) {
  356. int cur_cu_seqlen = lane + bidb_start <= params.num_batch ? params.cu_seqlens[lane + bidb_start] : 0;
  357. int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1);
  358. seqlen = next_cu_seqlen - cur_cu_seqlen;
  359. } else {
  360. seqlen = params.seqlen;
  361. }
  362. if constexpr (PackGQA) { seqlen *= params.qhead_per_khead; }
  363. return lane + bidb_start < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1
  364. ? cute::ceil_div(seqlen, kBlock) : 0;
  365. };
  366. int num_m_blocks = get_num_m_blocks(current_work.bidb); // Different for each lane
  367. // Cumulative number of blocks for the next 31 batches
  368. int num_m_blocks_cumulative = prefix_sum(num_m_blocks);
  369. // Total number of blocks for the next 31 batches
  370. int m_blocks_in_group = __shfl_sync(0xffffffff, num_m_blocks_cumulative, cutlass::NumThreadsPerWarp - 1);
  371. int group_end_tile = current_work.tile_idx - current_work.block - current_work.bidh * __shfl_sync(0xffffffff, num_m_blocks, 0 /*lane*/) + m_blocks_in_group * params.num_head; // Same for all lanes
  372. int bidb = current_work.bidb;
  373. // if (blockIdx.x <= 9 && threadIdx.x == 0) {
  374. // printf("Before while, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d\n", blockIdx.x, threadIdx.x, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group);
  375. // }
  376. while (group_end_tile <= next_tile_idx) {
  377. bidb += cutlass::NumThreadsPerWarp - 1;
  378. if (bidb >= params.num_batch) {
  379. // if (blockIdx.x <= 9 && threadIdx.x == 0) {
  380. // printf("Returning early, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d\n", blockIdx.x, threadIdx.x, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group);
  381. // }
  382. return {next_tile_idx, 0, 0, params.num_batch};
  383. }
  384. num_m_blocks = get_num_m_blocks(bidb);
  385. num_m_blocks_cumulative = prefix_sum(num_m_blocks);
  386. m_blocks_in_group = __shfl_sync(0xffffffff, num_m_blocks_cumulative, cutlass::NumThreadsPerWarp - 1);
  387. group_end_tile += m_blocks_in_group * params.num_head;
  388. // if (blockIdx.x <= 9 && threadIdx.x == 0) {
  389. // printf("Bottom of while, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d\n", blockIdx.x, threadIdx.x, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group);
  390. // }
  391. }
  392. int group_start_tile = group_end_tile - m_blocks_in_group * params.num_head;
  393. // The next problem to process is the first one that does not have ending tile position
  394. // that is greater than or equal to tile index.
  395. int batch_idx_in_group = __popc(__ballot_sync(0xffffffff, group_start_tile + num_m_blocks_cumulative * params.num_head <= next_tile_idx));
  396. bidb += batch_idx_in_group;
  397. num_m_blocks = __shfl_sync(0xffffffff, num_m_blocks, batch_idx_in_group);
  398. int mh_block = next_tile_idx - group_start_tile - (batch_idx_in_group == 0 ? 0 : __shfl_sync(0xffffffff, num_m_blocks_cumulative, batch_idx_in_group - 1)) * params.num_head;
  399. int bidh = mh_block / num_m_blocks;
  400. int block = mh_block - bidh * num_m_blocks;
  401. // if (blockIdx.x <= 9 && threadIdx.x == 0) {
  402. // printf("blockIdx.x = %d, threadIdx.x = %d, batch_idx_in_group = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d, mh_block = %d, bidh = %d, block = %d\n", blockIdx.x, threadIdx.x, batch_idx_in_group, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group, mh_block, bidh, block);
  403. // }
  404. return {next_tile_idx, block, bidh, bidb};
  405. }
  406. template<bool IsProducerWarp=false>
  407. CUTLASS_DEVICE
  408. WorkTileInfo
  409. get_initial_work(Params const& params) const {
  410. if constexpr (IsProducerWarp) {
  411. WorkTileInfo work_info = tile_idx_to_work_tile(params, int(blockIdx.x), {0, 0, 0, 0});
  412. if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) {
  413. *work_info_smem = make_int4(work_info.tile_idx, work_info.block, work_info.bidh, work_info.bidb);
  414. }
  415. flash::named_barrier_arrive(NumThreads, static_cast<uint32_t>(FwdNamedBarriers::TileCountSmemFull) /*id*/);
  416. return work_info;
  417. } else {
  418. return get_next_work<false>(params, {0, 0, 0, 0});
  419. }
  420. }
  421. CUTLASS_DEVICE
  422. void
  423. init_consumer() const {
  424. // Don't arrive at the TileCountSmemEmpty barrier here, because get_initial_work will do that
  425. }
  426. CUTLASS_DEVICE
  427. void
  428. prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {
  429. if (threadIdx.x % NumProducerThreads == 0) {
  430. current_work.tile_idx = atomicAdd(params.tile_count_semaphore, 1) + int(gridDim.x);
  431. }
  432. }
  433. template<bool IsProducerWarp=false>
  434. CUTLASS_DEVICE
  435. WorkTileInfo
  436. get_next_work(Params const& params, WorkTileInfo const& current_work) const {
  437. if constexpr (IsProducerWarp) {
  438. // thread 0 has the next tile_idx, just need to broadcast to the rest of warp 0
  439. int new_tile_idx = __shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/);
  440. WorkTileInfo work_info = {__shfl_sync(0xffffffff, current_work.tile_idx, 1 /*lane*/), current_work.block, current_work.bidh, current_work.bidb};
  441. work_info = tile_idx_to_work_tile(params, new_tile_idx, work_info);
  442. flash::named_barrier_sync(NumThreads, static_cast<uint32_t>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
  443. if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) {
  444. *work_info_smem = make_int4(work_info.tile_idx, work_info.block, work_info.bidh, work_info.bidb);
  445. }
  446. flash::named_barrier_arrive(NumThreads, static_cast<uint32_t>(FwdNamedBarriers::TileCountSmemFull) /*id*/);
  447. return work_info;
  448. } else {
  449. flash::named_barrier_sync(NumThreads, static_cast<uint32_t>(FwdNamedBarriers::TileCountSmemFull) /*id*/);
  450. int4 work_info = *work_info_smem;
  451. flash::named_barrier_arrive(NumThreads, static_cast<uint32_t>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
  452. return WorkTileInfo{work_info.x, work_info.y, work_info.z, work_info.w};
  453. }
  454. }
  455. };
  456. } // flash