tile_scheduler.hpp 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cutlass/fast_math.h"
  6. #include "cutlass/arch/barrier.h"
  7. #include "named_barrier.hpp"
  8. namespace flash {
  9. ///////////////////////////////////////////////////////////////////////////////
  10. struct SingleTileScheduler {
  11. public:
  12. // Host side kernel arguments
  13. struct Arguments {
  14. int const num_blocks_m, num_head, num_batch;
  15. int* const tile_count_semaphore = nullptr;
  16. };
  17. // Device side kernel params
  18. struct Params {};
  19. static Params
  20. to_underlying_arguments(Arguments const& args) {
  21. return {};
  22. }
  23. static dim3
  24. get_grid_dim(Arguments const& args, int num_sm) {
  25. return {uint32_t(args.num_blocks_m), uint32_t(args.num_head), uint32_t(args.num_batch)};
  26. }
  27. struct WorkTileInfo {
  28. int M_idx = 0;
  29. int H_idx = 0;
  30. int B_idx = 0;
  31. bool is_valid_tile = false;
  32. CUTLASS_DEVICE
  33. bool
  34. is_valid(Params const& params) const {
  35. return is_valid_tile;
  36. }
  37. CUTLASS_DEVICE
  38. cute::tuple<int32_t, int32_t, int32_t>
  39. get_block_coord(Params const& params) const {
  40. return {M_idx, H_idx, B_idx};
  41. }
  42. };
  43. CUTLASS_DEVICE
  44. SingleTileScheduler(int* tile_count_smem_) { }
  45. CUTLASS_DEVICE
  46. WorkTileInfo
  47. get_initial_work() const {
  48. return {int(blockIdx.x), int(blockIdx.y), int(blockIdx.z), true};
  49. }
  50. CUTLASS_DEVICE
  51. void
  52. init_consumer() const {}
  53. CUTLASS_DEVICE
  54. void
  55. prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {}
  56. CUTLASS_DEVICE
  57. void
  58. broadcast_next_work(WorkTileInfo& current_work) const {}
  59. template<bool IsProducer=false>
  60. CUTLASS_DEVICE
  61. WorkTileInfo
  62. get_next_work(Params const& params, WorkTileInfo const& current_work) const {
  63. return {-1, -1, -1, false};
  64. }
  65. };
  66. ///////////////////////////////////////////////////////////////////////////////
  67. class StaticPersistentTileScheduler {
  68. public:
  69. // Host side kernel arguments
  70. struct Arguments {
  71. int const num_blocks_m, num_head, num_batch;
  72. int* const tile_count_semaphore = nullptr;
  73. };
  74. // Device side kernel params
  75. struct Params {
  76. int total_blocks;
  77. cutlass::FastDivmod m_block_divmod, head_divmod;
  78. };
  79. static Params
  80. to_underlying_arguments(Arguments const& args) {
  81. return {args.num_blocks_m * args.num_head * args.num_batch,
  82. cutlass::FastDivmod(args.num_blocks_m), cutlass::FastDivmod(args.num_head)};
  83. }
  84. static dim3
  85. get_grid_dim(Arguments const& args, int num_sm) {
  86. return {uint32_t(num_sm)};
  87. }
  88. struct WorkTileInfo {
  89. int tile_idx;
  90. CUTLASS_DEVICE
  91. bool
  92. is_valid(Params const& params) const {
  93. return tile_idx < params.total_blocks;
  94. }
  95. CUTLASS_DEVICE
  96. cute::tuple<int32_t, int32_t, int32_t>
  97. get_block_coord(Params const& params) const {
  98. int m_block, bidh, bidb;
  99. bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, tile_idx));
  100. return {m_block, bidh, bidb};
  101. }
  102. };
  103. CUTLASS_DEVICE
  104. StaticPersistentTileScheduler(int* tile_count_smem_) {};
  105. CUTLASS_DEVICE
  106. WorkTileInfo
  107. get_initial_work() const {
  108. return {int(blockIdx.x)};
  109. }
  110. CUTLASS_DEVICE
  111. void
  112. init_consumer() const {}
  113. CUTLASS_DEVICE
  114. void
  115. prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {}
  116. CUTLASS_DEVICE
  117. void
  118. broadcast_next_work(WorkTileInfo& current_work) const {}
  119. template<bool IsProducer=false>
  120. CUTLASS_DEVICE
  121. WorkTileInfo
  122. get_next_work(Params const& params, WorkTileInfo const& current_work) const {
  123. return {current_work.tile_idx + int(gridDim.x)};
  124. }
  125. };
  126. template<int NumMmaThreads=2 * cutlass::NumThreadsPerWarpGroup>
  127. class DynamicPersistentTileScheduler {
  128. protected:
  129. int* const tile_count_smem;
  130. public:
  131. // Host side kernel arguments
  132. struct Arguments {
  133. int const num_blocks_m, num_head, num_batch;
  134. int* const tile_count_semaphore;
  135. };
  136. // Device side kernel params
  137. struct Params {
  138. int const total_blocks;
  139. cutlass::FastDivmod const m_block_divmod, head_divmod;
  140. int* const tile_count_semaphore;
  141. };
  142. static Params
  143. to_underlying_arguments(Arguments const& args) {
  144. return {args.num_blocks_m * args.num_head * args.num_batch,
  145. cutlass::FastDivmod(args.num_blocks_m), cutlass::FastDivmod(args.num_head),
  146. args.tile_count_semaphore};
  147. }
  148. static dim3
  149. get_grid_dim(Arguments const& args, int num_sm) {
  150. return {uint32_t(num_sm)};
  151. }
  152. struct WorkTileInfo {
  153. int tile_idx;
  154. CUTLASS_DEVICE
  155. bool
  156. is_valid(Params const& params) const {
  157. return tile_idx < params.total_blocks;
  158. }
  159. CUTLASS_DEVICE
  160. cute::tuple<int32_t, int32_t, int32_t>
  161. get_block_coord(Params const& params) const {
  162. int m_block, bidh, bidb;
  163. bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, tile_idx));
  164. return {m_block, bidh, bidb};
  165. }
  166. };
  167. CUTLASS_DEVICE
  168. DynamicPersistentTileScheduler(int* tile_count_smem_) : tile_count_smem(tile_count_smem_) {};
  169. CUTLASS_DEVICE
  170. WorkTileInfo
  171. get_initial_work() const {
  172. return {int(blockIdx.x)};
  173. }
  174. CUTLASS_DEVICE
  175. void
  176. init_consumer() const {
  177. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
  178. }
  179. CUTLASS_DEVICE
  180. void
  181. prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {
  182. if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) {
  183. current_work.tile_idx = atomicAdd(params.tile_count_semaphore, 1) + int(gridDim.x);
  184. }
  185. }
  186. CUTLASS_DEVICE
  187. void
  188. broadcast_next_work(WorkTileInfo& current_work) const {
  189. cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
  190. if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) {
  191. *tile_count_smem = current_work.tile_idx;
  192. }
  193. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::TileCountSmemFull) /*id*/);
  194. }
  195. template<bool IsProducer=false>
  196. CUTLASS_DEVICE
  197. WorkTileInfo
  198. get_next_work(Params const& params, WorkTileInfo const& current_work) const {
  199. if constexpr (IsProducer) {
  200. // thread 0 already has the right tile_idx, just need to broadcast to the rest of warp 0
  201. return {__shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/)};
  202. } else {
  203. cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::TileCountSmemFull) /*id*/);
  204. int tile_idx = *tile_count_smem;
  205. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
  206. return {tile_idx};
  207. }
  208. }
  209. };
  210. } // flash