1
0

tile_scheduler.hpp 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cutlass/fast_math.h"
  6. #include "cutlass/arch/barrier.h"
  7. #include "named_barrier.hpp"
  8. namespace flash {
  9. ///////////////////////////////////////////////////////////////////////////////
  10. // Host side kernel arguments
  11. struct TileSchedulerArguments {
  12. // num_head is num_head_q if not PackGQA, else num_head_k
  13. int const num_blocks, num_head, num_batch, num_splits;
  14. int const qhead_per_khead;
  15. int const seqlen; // Only used if Varlen and cu_seqlens == nullptr and seqused == nullptr
  16. int const seqlen_k, headdim, element_size; // Used to calculate L2 swizzling
  17. int* const tile_count_semaphore = nullptr;
  18. int* const cu_seqlens = nullptr;
  19. int* const seqused = nullptr;
  20. };
  21. ///////////////////////////////////////////////////////////////////////////////
  22. template<bool Varlen=false, bool Split=false, bool PackGQA=false, int kBlock=128>
  23. class SingleTileScheduler {
  24. public:
  25. using SharedStorage = int;
  26. // Device side kernel params
  27. struct Params {
  28. int const num_blocks, num_head, num_batch, num_splits;
  29. int const qhead_per_khead;
  30. int const seqlen;
  31. cutlass::FastDivmod nsplits_divmod;
  32. int* const cu_seqlens;
  33. int* const seqused;
  34. };
  35. static Params
  36. to_underlying_arguments(TileSchedulerArguments const& args) {
  37. return {args.num_blocks, args.num_head, args.num_batch, !Split ? 1 : args.num_splits,
  38. args.qhead_per_khead, args.seqlen,
  39. cutlass::FastDivmod(!Split ? 1 : args.num_splits),
  40. !Varlen ? nullptr : args.cu_seqlens, !Varlen ? nullptr : args.seqused};
  41. }
  42. static dim3
  43. get_grid_shape(Params const& params, int num_sm) {
  44. return {uint32_t(params.num_blocks), uint32_t((!Split ? 1 : params.num_splits) * params.num_head), uint32_t(params.num_batch)};
  45. }
  46. struct WorkTileInfo {
  47. int block_idx = 0;
  48. int bidh = 0;
  49. int bidb = 0;
  50. bool is_valid_tile = false;
  51. CUTLASS_DEVICE
  52. bool
  53. is_valid(Params const& params) const {
  54. return is_valid_tile;
  55. }
  56. CUTLASS_DEVICE
  57. cute::tuple<int32_t, int32_t, int32_t, int32_t>
  58. get_block_coord(Params const& params) const {
  59. if constexpr (!Split) {
  60. return {block_idx, bidh, bidb, 0 /*split_idx*/};
  61. } else {
  62. int split_idx;
  63. int bidh_actual = params.nsplits_divmod.divmod(split_idx, bidh);
  64. return {block_idx, bidh_actual, bidb, split_idx};
  65. }
  66. }
  67. };
  68. CUTLASS_DEVICE
  69. SingleTileScheduler(SharedStorage* const smem_scheduler) { }
  70. template<bool IsProducerWarp=false>
  71. CUTLASS_DEVICE
  72. WorkTileInfo
  73. get_initial_work(Params const& params) const {
  74. WorkTileInfo work_info {int(blockIdx.x), int(blockIdx.y), int(blockIdx.z), true};
  75. if constexpr (Varlen) {
  76. int seqlen = params.seqused
  77. ? params.seqused[work_info.bidb]
  78. : (params.cu_seqlens ? params.cu_seqlens[work_info.bidb + 1] - params.cu_seqlens[work_info.bidb] : params.seqlen);
  79. if constexpr (PackGQA) { seqlen *= params.qhead_per_khead; }
  80. work_info.is_valid_tile = work_info.block_idx * kBlock < seqlen;
  81. }
  82. return work_info;
  83. }
  84. CUTLASS_DEVICE
  85. void
  86. init_consumer() const {}
  87. CUTLASS_DEVICE
  88. void
  89. prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {}
  90. template<bool IsProducerWarp=false>
  91. CUTLASS_DEVICE
  92. WorkTileInfo
  93. get_next_work(Params const& params, WorkTileInfo const& current_work) const {
  94. return {-1, -1, -1, false};
  95. }
  96. };
  97. ///////////////////////////////////////////////////////////////////////////////
  98. template<bool Split=false>
  99. class StaticPersistentTileScheduler {
  100. public:
  101. using SharedStorage = int;
  102. // Device side kernel params
  103. struct Params {
  104. int total_blocks;
  105. cutlass::FastDivmod m_block_divmod, head_divmod;
  106. cutlass::FastDivmod nsplits_divmod;
  107. };
  108. static Params
  109. to_underlying_arguments(TileSchedulerArguments const& args) {
  110. return {args.num_blocks * args.num_head * args.num_batch * (!Split ? 1 : args.num_splits),
  111. cutlass::FastDivmod(args.num_blocks), cutlass::FastDivmod(args.num_head * (!Split ? 1 : args.num_splits)),
  112. cutlass::FastDivmod(!Split ? 1 : args.num_splits)};
  113. }
  114. static dim3
  115. get_grid_shape(Params const& params, int num_sm) {
  116. return {uint32_t(num_sm)};
  117. }
  118. struct WorkTileInfo {
  119. int tile_idx;
  120. CUTLASS_DEVICE
  121. bool
  122. is_valid(Params const& params) const {
  123. return tile_idx < params.total_blocks;
  124. }
  125. CUTLASS_DEVICE
  126. cute::tuple<int32_t, int32_t, int32_t, int32_t>
  127. get_block_coord(Params const& params) const {
  128. int block, bidh, bidb;
  129. bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(block, tile_idx));
  130. int split_idx = 0;
  131. if constexpr (Split) {
  132. bidh = params.nsplits_divmod.divmod(split_idx, bidh);
  133. }
  134. return {block, bidh, bidb, split_idx};
  135. }
  136. };
  137. CUTLASS_DEVICE
  138. StaticPersistentTileScheduler(SharedStorage* const smem_scheduler) {};
  139. template<bool IsProducerWarp=false>
  140. CUTLASS_DEVICE
  141. WorkTileInfo
  142. get_initial_work(Params const& params) const {
  143. return {int(blockIdx.x)};
  144. }
  145. CUTLASS_DEVICE
  146. void
  147. init_consumer() const {}
  148. CUTLASS_DEVICE
  149. void
  150. prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {}
  151. template<bool IsProducerWarp=false>
  152. CUTLASS_DEVICE
  153. WorkTileInfo
  154. get_next_work(Params const& params, WorkTileInfo const& current_work) const {
  155. return {current_work.tile_idx + int(gridDim.x)};
  156. }
  157. };
  158. template<int NumMmaThreads=2 * cutlass::NumThreadsPerWarpGroup, int NumProducerThreads=cutlass::NumThreadsPerWarp,
  159. bool Split=false, bool PackGQA=false>
  160. class DynamicPersistentTileScheduler {
  161. // This scheduler targets the causal (or local) case where each tile takes different
  162. // amount of time. We using longest-processing-time-first scheduling:
  163. // the longest remaining tile is assigned to the first SM that's free.
  164. // SM indicates they are free by incrementing a semaphore.
  165. // However, we have to make sure K & V still fit into L2 cache, so we perform scheduling
  166. // on "sections" of the head & batch dimension, each section consisting of e.g. 8 heads.
  167. // This is the L2 swizzling part. The size of each section is precomputed based on the
  168. // size of K & V and the L2 cache size.
  169. public:
  170. using SharedStorage = int;
  171. protected:
  172. SharedStorage* const tile_count_smem;
  173. public:
  174. // Device side kernel params
  175. struct Params {
  176. int const total_blocks;
  177. cutlass::FastDivmod const m_block_divmod, head_divmod;
  178. cutlass::FastDivmod const l2_minor_divmod, l2_major_divmod;
  179. cutlass::FastDivmod const l2_minor_residual_divmod;
  180. int const num_hb_quotient;
  181. int* const tile_count_semaphore;
  182. };
  183. static Params
  184. to_underlying_arguments(TileSchedulerArguments const& args) {
  185. int const size_one_kv_head = args.seqlen_k * args.headdim * args.element_size * 2;
  186. int const size_l2 = 32 * 1024 * 1024; // 32 MB for K & V
  187. // Swizzle is the size of each "section". Round swizzle to a power of 2
  188. // If not PackGQA already, the size of each section can increase by qhead_per_khead
  189. int const swizzle = (1 << cutlass::find_log2(size_l2 / size_one_kv_head)) * (PackGQA ? 1 : args.qhead_per_khead);
  190. // If we're in the last section (called residual), we don't want to divide by
  191. // swizzle. Instead we want to divide by the remainder.
  192. int const num_hb_remainder = (args.num_head * args.num_batch) % swizzle;
  193. int const num_split_blocks = args.num_blocks * (!Split ? 1 : args.num_splits);
  194. // printf("num_split_blocks = %d, num_head = %d, num_batch = %d, swizzle = %d, PackGQA = %d, qhead_per_khead = %d, num_hb_remainder = %d\n", num_split_blocks, args.num_head, args.num_batch, swizzle, int(PackGQA), args.qhead_per_khead, num_hb_remainder);
  195. return {num_split_blocks * args.num_head * args.num_batch,
  196. cutlass::FastDivmod(args.num_blocks), cutlass::FastDivmod(args.num_head),
  197. cutlass::FastDivmod(swizzle), cutlass::FastDivmod(swizzle * num_split_blocks),
  198. // don't divide by 0
  199. cutlass::FastDivmod(num_hb_remainder > 0 ? num_hb_remainder : 1),
  200. (args.num_head * args.num_batch) / swizzle,
  201. args.tile_count_semaphore};
  202. }
  203. static dim3
  204. get_grid_shape(Params const& params, int num_sm) {
  205. return {uint32_t(num_sm)};
  206. }
  207. struct WorkTileInfo {
  208. int tile_idx;
  209. CUTLASS_DEVICE
  210. bool
  211. is_valid(Params const& params) const {
  212. return tile_idx < params.total_blocks;
  213. }
  214. CUTLASS_DEVICE
  215. cute::tuple<int32_t, int32_t, int32_t, int32_t>
  216. get_block_coord(Params const& params) const {
  217. int block, bidh, bidb;
  218. int l2_mod, bidhb, bidhb_residual;
  219. bidhb = params.l2_major_divmod.divmod(l2_mod, tile_idx);
  220. // If we're in the last section (called residual), we don't want to divide by
  221. // swizzle. Instead we want to divide by the remainder.
  222. if (bidhb < params.num_hb_quotient) {
  223. block = params.l2_minor_divmod.divmod(bidhb_residual, l2_mod);
  224. } else {
  225. block = params.l2_minor_residual_divmod.divmod(bidhb_residual, l2_mod);
  226. }
  227. bidb = params.head_divmod.divmod(bidh, bidhb * params.l2_minor_divmod.divisor + bidhb_residual);
  228. int split_idx = 0;
  229. if constexpr (Split) {
  230. split_idx = params.m_block_divmod.divmod(block, block);
  231. }
  232. // Longest-processing-time-first
  233. block = params.m_block_divmod.divisor - 1 - block;
  234. return {block, bidh, bidb, split_idx};
  235. }
  236. };
  237. CUTLASS_DEVICE
  238. DynamicPersistentTileScheduler(SharedStorage* const smem_scheduler) : tile_count_smem(smem_scheduler) {};
  239. template<bool IsProducerWarp=false>
  240. CUTLASS_DEVICE
  241. WorkTileInfo
  242. get_initial_work(Params const& params) const {
  243. return {int(blockIdx.x)};
  244. }
  245. CUTLASS_DEVICE
  246. void
  247. init_consumer() const {
  248. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
  249. }
  250. CUTLASS_DEVICE
  251. void
  252. prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {
  253. if (threadIdx.x % NumProducerThreads == 0) {
  254. current_work.tile_idx = atomicAdd(params.tile_count_semaphore, 1) + int(gridDim.x);
  255. }
  256. }
  257. template<bool IsProducerWarp=false>
  258. CUTLASS_DEVICE
  259. WorkTileInfo
  260. get_next_work(Params const& params, WorkTileInfo const& current_work) const {
  261. if constexpr (IsProducerWarp) {
  262. // thread 0 already has the right tile_idx, just need to broadcast to the rest of warp 0
  263. int new_tile_idx = __shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/);
  264. cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
  265. if (threadIdx.x % NumProducerThreads == 0) {
  266. *tile_count_smem = current_work.tile_idx;
  267. }
  268. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemFull) /*id*/);
  269. return {new_tile_idx};
  270. } else {
  271. cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemFull) /*id*/);
  272. int tile_idx = *tile_count_smem;
  273. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
  274. return {tile_idx};
  275. }
  276. }
  277. };
  278. template<int kBlock, int NumMmaThreads=2 * cutlass::NumThreadsPerWarpGroup, int NumProducerThreads=cutlass::NumThreadsPerWarp, bool Split=false, bool PackGQA=false>
  279. class VarlenDynamicPersistentTileScheduler {
  280. public:
  281. using SharedStorage = int4;
  282. protected:
  283. SharedStorage* const work_info_smem;
  284. public:
  285. // Device side kernel params
  286. struct Params {
  287. int num_head, num_batch;
  288. int const qhead_per_khead;
  289. int const seqlen;
  290. cutlass::FastDivmod nsplits_divmod;
  291. int* const tile_count_semaphore;
  292. int* const cu_seqlens;
  293. int* const seqused;
  294. };
  295. static Params
  296. to_underlying_arguments(TileSchedulerArguments const& args) {
  297. // If Split, for the purpose of scheduling, we pretend that instead there are
  298. // (args.num_splits * args.num_head) number of heads.
  299. return {args.num_head * (!Split ? 1 : args.num_splits), args.num_batch,
  300. args.qhead_per_khead, args.seqlen,
  301. cutlass::FastDivmod(!Split ? 1 : args.num_splits),
  302. args.tile_count_semaphore, args.cu_seqlens, args.seqused};
  303. }
  304. static dim3
  305. get_grid_shape(Params const& params, int num_sm) {
  306. return {uint32_t(num_sm)};
  307. }
  308. struct WorkTileInfo {
  309. int tile_idx, block, bidh, bidb;
  310. CUTLASS_DEVICE
  311. bool
  312. is_valid(Params const& params) const {
  313. // if (blockIdx.x >= 0 && (threadIdx.x == 128 || threadIdx.x == 0)) { printf("blockIdx.x = %d, threadIdx.x = %d, checking valid, bidb = %d, params.num_batch = %d\n", blockIdx.x, threadIdx.x, bidb, params.num_batch); }
  314. return bidb < params.num_batch;
  315. }
  316. CUTLASS_DEVICE
  317. cute::tuple<int32_t, int32_t, int32_t, int32_t>
  318. get_block_coord(Params const& params) const {
  319. if constexpr (!Split) {
  320. return {block, bidh, bidb, 0 /*split_idx*/};
  321. } else {
  322. int split_idx;
  323. int bidh_actual = params.nsplits_divmod.divmod(split_idx, bidh);
  324. return {block, bidh_actual, bidb, split_idx};
  325. }
  326. }
  327. };
  328. CUTLASS_DEVICE
  329. VarlenDynamicPersistentTileScheduler(SharedStorage* const smem_scheduler) : work_info_smem(smem_scheduler) {};
  330. CUTLASS_DEVICE
  331. WorkTileInfo
  332. tile_idx_to_work_tile(Params const& params, int next_tile_idx, WorkTileInfo const& current_work) const {
  333. auto prefix_sum = [](int val) {
  334. int lane = threadIdx.x % cutlass::NumThreadsPerWarp;
  335. CUTLASS_PRAGMA_UNROLL
  336. for (int i = 1; i < cutlass::NumThreadsPerWarp; i <<= 1) {
  337. int32_t partial_sum = __shfl_up_sync(0xffffffff, val, i);
  338. if (lane >= i) { val += partial_sum; }
  339. }
  340. return val;
  341. };
  342. auto get_num_m_blocks = [&](int bidb_start) {
  343. int lane = threadIdx.x % cutlass::NumThreadsPerWarp;
  344. int seqlen;
  345. if (params.seqused) {
  346. seqlen = lane + bidb_start < params.num_batch ? params.seqused[lane + bidb_start] : 0;
  347. } else if (params.cu_seqlens) {
  348. int cur_cu_seqlen = lane + bidb_start <= params.num_batch ? params.cu_seqlens[lane + bidb_start] : 0;
  349. int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1);
  350. seqlen = next_cu_seqlen - cur_cu_seqlen;
  351. } else {
  352. seqlen = params.seqlen;
  353. }
  354. if constexpr (PackGQA) { seqlen *= params.qhead_per_khead; }
  355. return lane + bidb_start < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1
  356. ? cute::ceil_div(seqlen, kBlock) : 0;
  357. };
  358. int num_m_blocks = get_num_m_blocks(current_work.bidb); // Different for each lane
  359. // Cumulative number of blocks for the next 31 batches
  360. int num_m_blocks_cumulative = prefix_sum(num_m_blocks);
  361. // Total number of blocks for the next 31 batches
  362. int m_blocks_in_group = __shfl_sync(0xffffffff, num_m_blocks_cumulative, cutlass::NumThreadsPerWarp - 1);
  363. int group_end_tile = current_work.tile_idx - current_work.block - current_work.bidh * __shfl_sync(0xffffffff, num_m_blocks, 0 /*lane*/) + m_blocks_in_group * params.num_head; // Same for all lanes
  364. int bidb = current_work.bidb;
  365. // if (blockIdx.x <= 9 && threadIdx.x == 0) {
  366. // printf("Before while, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d\n", blockIdx.x, threadIdx.x, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group);
  367. // }
  368. while (group_end_tile <= next_tile_idx) {
  369. bidb += cutlass::NumThreadsPerWarp - 1;
  370. if (bidb >= params.num_batch) {
  371. // if (blockIdx.x <= 9 && threadIdx.x == 0) {
  372. // printf("Returning early, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d\n", blockIdx.x, threadIdx.x, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group);
  373. // }
  374. return {next_tile_idx, 0, 0, params.num_batch};
  375. }
  376. num_m_blocks = get_num_m_blocks(bidb);
  377. num_m_blocks_cumulative = prefix_sum(num_m_blocks);
  378. m_blocks_in_group = __shfl_sync(0xffffffff, num_m_blocks_cumulative, cutlass::NumThreadsPerWarp - 1);
  379. group_end_tile += m_blocks_in_group * params.num_head;
  380. // if (blockIdx.x <= 9 && threadIdx.x == 0) {
  381. // printf("Bottom of while, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d\n", blockIdx.x, threadIdx.x, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group);
  382. // }
  383. }
  384. int group_start_tile = group_end_tile - m_blocks_in_group * params.num_head;
  385. // The next problem to process is the first one that does not have ending tile position
  386. // that is greater than or equal to tile index.
  387. int batch_idx_in_group = __popc(__ballot_sync(0xffffffff, group_start_tile + num_m_blocks_cumulative * params.num_head <= next_tile_idx));
  388. bidb += batch_idx_in_group;
  389. num_m_blocks = __shfl_sync(0xffffffff, num_m_blocks, batch_idx_in_group);
  390. int mh_block = next_tile_idx - group_start_tile - (batch_idx_in_group == 0 ? 0 : __shfl_sync(0xffffffff, num_m_blocks_cumulative, batch_idx_in_group - 1)) * params.num_head;
  391. int bidh = mh_block / num_m_blocks;
  392. int block = mh_block - bidh * num_m_blocks;
  393. // if (blockIdx.x <= 9 && threadIdx.x == 0) {
  394. // printf("blockIdx.x = %d, threadIdx.x = %d, batch_idx_in_group = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d, mh_block = %d, bidh = %d, block = %d\n", blockIdx.x, threadIdx.x, batch_idx_in_group, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group, mh_block, bidh, block);
  395. // }
  396. return {next_tile_idx, block, bidh, bidb};
  397. }
  398. template<bool IsProducerWarp=false>
  399. CUTLASS_DEVICE
  400. WorkTileInfo
  401. get_initial_work(Params const& params) const {
  402. if constexpr (IsProducerWarp) {
  403. WorkTileInfo work_info = tile_idx_to_work_tile(params, int(blockIdx.x), {0, 0, 0, 0});
  404. if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) {
  405. *work_info_smem = make_int4(work_info.tile_idx, work_info.block, work_info.bidh, work_info.bidb);
  406. }
  407. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemFull) /*id*/);
  408. return work_info;
  409. } else {
  410. return get_next_work<false>(params, {0, 0, 0, 0});
  411. }
  412. }
  413. CUTLASS_DEVICE
  414. void
  415. init_consumer() const {
  416. // Don't arrive at the TileCountSmemEmpty barrier here, because get_initial_work will do that
  417. }
  418. CUTLASS_DEVICE
  419. void
  420. prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {
  421. if (threadIdx.x % NumProducerThreads == 0) {
  422. current_work.tile_idx = atomicAdd(params.tile_count_semaphore, 1) + int(gridDim.x);
  423. }
  424. }
  425. template<bool IsProducerWarp=false>
  426. CUTLASS_DEVICE
  427. WorkTileInfo
  428. get_next_work(Params const& params, WorkTileInfo const& current_work) const {
  429. if constexpr (IsProducerWarp) {
  430. // thread 0 has the next tile_idx, just need to broadcast to the rest of warp 0
  431. int new_tile_idx = __shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/);
  432. WorkTileInfo work_info = {__shfl_sync(0xffffffff, current_work.tile_idx, 1 /*lane*/), current_work.block, current_work.bidh, current_work.bidb};
  433. work_info = tile_idx_to_work_tile(params, new_tile_idx, work_info);
  434. cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
  435. if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) {
  436. *work_info_smem = make_int4(work_info.tile_idx, work_info.block, work_info.bidh, work_info.bidb);
  437. }
  438. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemFull) /*id*/);
  439. return work_info;
  440. } else {
  441. cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemFull) /*id*/);
  442. int4 work_info = *work_info_smem;
  443. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
  444. return WorkTileInfo{work_info.x, work_info.y, work_info.z, work_info.w};
  445. }
  446. }
  447. };
  448. } // flash