tile_scheduler.hpp 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cutlass/fast_math.h"
  6. #include "cutlass/arch/barrier.h"
  7. #include "named_barrier.hpp"
  8. namespace flash {
  9. ///////////////////////////////////////////////////////////////////////////////
  10. struct SingleTileScheduler {
  11. public:
  12. // Host side kernel arguments
  13. struct Arguments {
  14. int const num_blocks_m, num_splits, num_head, num_batch;
  15. int* const tile_count_semaphore = nullptr;
  16. };
  17. // Device side kernel params
  18. struct Params {};
  19. static Params
  20. to_underlying_arguments(Arguments const& args) {
  21. return {};
  22. }
  23. static dim3
  24. get_grid_dim(Arguments const& args, int num_sm) {
  25. return {uint32_t(args.num_blocks_m), uint32_t(args.num_head), uint32_t(args.num_batch)};
  26. }
  27. struct WorkTileInfo {
  28. int M_idx = 0;
  29. int H_idx = 0;
  30. int B_idx = 0;
  31. bool is_valid_tile = false;
  32. CUTLASS_DEVICE
  33. bool
  34. is_valid(Params const& params) const {
  35. return is_valid_tile;
  36. }
  37. CUTLASS_DEVICE
  38. cute::tuple<int32_t, int32_t, int32_t, int32_t>
  39. get_block_coord(Params const& params) const {
  40. return {M_idx, 1, H_idx, B_idx};
  41. }
  42. };
  43. CUTLASS_DEVICE
  44. SingleTileScheduler(int* tile_count_smem_) { }
  45. CUTLASS_DEVICE
  46. WorkTileInfo
  47. get_initial_work() const {
  48. return {int(blockIdx.x), int(blockIdx.y), int(blockIdx.z), true};
  49. }
  50. CUTLASS_DEVICE
  51. void
  52. init_consumer() const {}
  53. CUTLASS_DEVICE
  54. void
  55. prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {}
  56. CUTLASS_DEVICE
  57. void
  58. broadcast_next_work(WorkTileInfo& current_work) const {}
  59. template<bool IsProducer=false>
  60. CUTLASS_DEVICE
  61. WorkTileInfo
  62. get_next_work(Params const& params, WorkTileInfo const& current_work) const {
  63. return {-1, -1, -1, false};
  64. }
  65. };
  66. ///////////////////////////////////////////////////////////////////////////////
  67. template <bool Is_split = false>
  68. class StaticPersistentTileScheduler {
  69. public:
  70. // Host side kernel arguments
  71. struct Arguments {
  72. int const num_blocks_m, num_splits, num_head, num_batch;
  73. int* const tile_count_semaphore = nullptr;
  74. };
  75. // Device side kernel params
  76. struct Params {
  77. int const total_blocks;
  78. cutlass::FastDivmod const m_block_divmod, split_divmod, head_divmod;
  79. };
  80. static Params
  81. to_underlying_arguments(Arguments const& args) {
  82. // return {args.num_blocks_m * args.num_head * args.num_batch,
  83. // cutlass::FastDivmod(args.num_blocks_m), cutlass::FastDivmod(args.num_head)};
  84. return {args.num_blocks_m * args.num_splits * args.num_head * args.num_batch,
  85. cutlass::FastDivmod(args.num_blocks_m),
  86. cutlass::FastDivmod(args.num_splits),
  87. cutlass::FastDivmod(args.num_head)};
  88. }
  89. static dim3
  90. get_grid_dim(Arguments const& args, int num_sm) {
  91. return {uint32_t(num_sm)};
  92. }
  93. struct WorkTileInfo {
  94. int tile_idx;
  95. CUTLASS_DEVICE
  96. bool
  97. is_valid(Params const& params) const {
  98. return tile_idx < params.total_blocks;
  99. }
  100. CUTLASS_DEVICE
  101. cute::tuple<int32_t, int32_t, int32_t, int32_t>
  102. get_block_coord(Params const& params) const {
  103. int m_block, split_idx, bidh, bidb;
  104. if constexpr(!Is_split) {
  105. bidb = params.head_divmod.divmod(bidh,
  106. params.m_block_divmod.divmod(m_block, tile_idx));
  107. return {m_block, 1, bidh, bidb};
  108. } else {
  109. bidb = params.head_divmod.divmod(bidh,
  110. params.split_divmod.divmod(split_idx,
  111. params.m_block_divmod.divmod(m_block, tile_idx)));
  112. return {m_block, split_idx, bidh, bidb};
  113. }
  114. }
  115. };
  116. CUTLASS_DEVICE
  117. StaticPersistentTileScheduler(int* tile_count_smem_) {};
  118. CUTLASS_DEVICE
  119. WorkTileInfo
  120. get_initial_work() const {
  121. return {int(blockIdx.x)};
  122. }
  123. CUTLASS_DEVICE
  124. void
  125. init_consumer() const {}
  126. CUTLASS_DEVICE
  127. void
  128. prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {}
  129. CUTLASS_DEVICE
  130. void
  131. broadcast_next_work(WorkTileInfo& current_work) const {}
  132. template<bool IsProducer=false>
  133. CUTLASS_DEVICE
  134. WorkTileInfo
  135. get_next_work(Params const& params, WorkTileInfo const& current_work) const {
  136. return {current_work.tile_idx + int(gridDim.x)};
  137. }
  138. };
  139. template<int NumMmaThreads = 2 * cutlass::NumThreadsPerWarpGroup,
  140. int NumProducerThreads = cutlass::NumThreadsPerWarp,
  141. bool Is_split = false>
  142. class DynamicPersistentTileScheduler {
  143. protected:
  144. int* const tile_count_smem;
  145. public:
  146. // Host side kernel arguments
  147. struct Arguments {
  148. int const num_blocks_m, num_splits, num_head, num_batch;
  149. int* const tile_count_semaphore;
  150. };
  151. // Device side kernel params
  152. struct Params {
  153. int const total_blocks;
  154. cutlass::FastDivmod const m_block_divmod, split_divmod, head_divmod;
  155. int* const tile_count_semaphore;
  156. };
  157. static Params
  158. to_underlying_arguments(Arguments const& args) {
  159. // return {args.num_blocks_m * args.num_head * args.num_batch,
  160. // cutlass::FastDivmod(args.num_blocks_m), cutlass::FastDivmod(args.num_head),
  161. // args.tile_count_semaphore};
  162. return {args.num_blocks_m * args.num_splits * args.num_head * args.num_batch,
  163. cutlass::FastDivmod(args.num_blocks_m),
  164. cutlass::FastDivmod(args.num_splits),
  165. cutlass::FastDivmod(args.num_head),
  166. args.tile_count_semaphore};
  167. }
  168. static dim3
  169. get_grid_dim(Arguments const& args, int num_sm) {
  170. return {uint32_t(num_sm)};
  171. }
  172. struct WorkTileInfo {
  173. int tile_idx;
  174. CUTLASS_DEVICE
  175. bool
  176. is_valid(Params const& params) const {
  177. return tile_idx < params.total_blocks;
  178. }
  179. CUTLASS_DEVICE
  180. cute::tuple<int32_t, int32_t, int32_t, int32_t>
  181. get_block_coord(Params const& params) const {
  182. int m_block, split_idx, bidh, bidb;
  183. if constexpr(!Is_split) {
  184. bidb = params.head_divmod.divmod(bidh,
  185. params.m_block_divmod.divmod(m_block, tile_idx));
  186. return {m_block, 1, bidh, bidb};
  187. } else {
  188. bidb = params.head_divmod.divmod(bidh,
  189. params.split_divmod.divmod(split_idx,
  190. params.m_block_divmod.divmod(m_block, tile_idx)));
  191. return {m_block, split_idx, bidh, bidb};
  192. }
  193. }
  194. };
  195. CUTLASS_DEVICE
  196. DynamicPersistentTileScheduler(int* tile_count_smem_) : tile_count_smem(tile_count_smem_) {};
  197. CUTLASS_DEVICE
  198. WorkTileInfo
  199. get_initial_work() const {
  200. return {int(blockIdx.x)};
  201. }
  202. CUTLASS_DEVICE
  203. void
  204. init_consumer() const {
  205. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
  206. }
  207. CUTLASS_DEVICE
  208. void
  209. prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {
  210. if (threadIdx.x % NumProducerThreads == 0) {
  211. current_work.tile_idx = atomicAdd(params.tile_count_semaphore, 1) + int(gridDim.x);
  212. }
  213. }
  214. CUTLASS_DEVICE
  215. void
  216. broadcast_next_work(WorkTileInfo& current_work) const {
  217. cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
  218. if (threadIdx.x % NumProducerThreads == 0) {
  219. *tile_count_smem = current_work.tile_idx;
  220. }
  221. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemFull) /*id*/);
  222. }
  223. template<bool IsProducer=false>
  224. CUTLASS_DEVICE
  225. WorkTileInfo
  226. get_next_work(Params const& params, WorkTileInfo const& current_work) const {
  227. if constexpr (IsProducer && NumProducerThreads == cutlass::NumThreadsPerWarp) {
  228. // thread 0 already has the right tile_idx, just need to broadcast to the rest of the producer threads (warp 0)
  229. return {__shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/)};
  230. } else if constexpr (IsProducer && NumProducerThreads == cutlass::NumThreadsPerWarpGroup) {
  231. // TODO: investigate optimal synchronize
  232. int tile_idx = *tile_count_smem;
  233. return {tile_idx};
  234. } else {
  235. cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemFull) /*id*/);
  236. int tile_idx = *tile_count_smem;
  237. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
  238. return {tile_idx};
  239. }
  240. }
  241. };
  242. } // namespace flash