tile_scheduler.hpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cutlass/fast_math.h"
  6. #include "cutlass/arch/barrier.h"
  7. #include "named_barrier.hpp"
  8. namespace flash {
  9. ///////////////////////////////////////////////////////////////////////////////
  10. struct SingleTileScheduler {
  11. public:
  12. // Host side kernel arguments
  13. struct Arguments {
  14. int const num_blocks_m, num_head, num_batch, max_num_docs_per_batch;
  15. int* const tile_count_semaphore = nullptr;
  16. };
  17. // Device side kernel params
  18. struct Params {};
  19. static Params
  20. to_underlying_arguments(Arguments const& args) {
  21. return {};
  22. }
  23. static dim3
  24. get_grid_dim(Arguments const& args, int num_sm) {
  25. return {uint32_t(args.num_blocks_m), uint32_t(args.num_head), uint32_t(args.num_batch)};
  26. }
  27. struct WorkTileInfo {
  28. int M_idx = 0;
  29. int H_idx = 0;
  30. int B_idx = 0;
  31. bool is_valid_tile = false;
  32. CUTLASS_DEVICE
  33. bool
  34. is_valid(Params const& params) const {
  35. return is_valid_tile;
  36. }
  37. CUTLASS_DEVICE
  38. cute::tuple<int32_t, int32_t, int32_t>
  39. get_block_coord(Params const& params) const {
  40. return {M_idx, H_idx, B_idx};
  41. }
  42. CUTLASS_DEVICE
  43. void
  44. move_to_next_batch() {}
  45. };
  46. CUTLASS_DEVICE
  47. SingleTileScheduler(int* tile_count_smem_) { }
  48. CUTLASS_DEVICE
  49. WorkTileInfo
  50. get_initial_work(Params const& params) const {
  51. return {int(blockIdx.x), int(blockIdx.y), int(blockIdx.z), true};
  52. }
  53. CUTLASS_DEVICE
  54. void
  55. init_consumer() const {}
  56. CUTLASS_DEVICE
  57. void
  58. prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {}
  59. CUTLASS_DEVICE
  60. void
  61. broadcast_next_work(WorkTileInfo& current_work) const {}
  62. template<bool IsProducer=false>
  63. CUTLASS_DEVICE
  64. WorkTileInfo
  65. get_next_work(Params const& params, WorkTileInfo const& current_work) const {
  66. return {-1, -1, -1, false};
  67. }
  68. };
  69. ///////////////////////////////////////////////////////////////////////////////
  70. class StaticPersistentTileScheduler {
  71. public:
  72. // Host side kernel arguments
  73. struct Arguments {
  74. int const num_blocks_m, num_head, num_batch, max_num_docs_per_batch;
  75. int* const tile_count_semaphore = nullptr;
  76. };
  77. // Device side kernel params
  78. struct Params {
  79. int total_blocks, max_num_docs_per_batch;
  80. cutlass::FastDivmod m_block_divmod, head_divmod;
  81. };
  82. static Params
  83. to_underlying_arguments(Arguments const& args) {
  84. return {args.num_blocks_m * args.num_head * args.num_batch,
  85. args.max_num_docs_per_batch,
  86. cutlass::FastDivmod(args.num_blocks_m), cutlass::FastDivmod(args.num_head)};
  87. }
  88. static dim3
  89. get_grid_dim(Arguments const& args, int num_sm) {
  90. return {uint32_t(num_sm)};
  91. }
  92. struct WorkTileInfo {
  93. int tile_idx;
  94. CUTLASS_DEVICE
  95. bool
  96. is_valid(Params const& params) const {
  97. return tile_idx < params.total_blocks;
  98. }
  99. CUTLASS_DEVICE
  100. cute::tuple<int32_t, int32_t, int32_t>
  101. get_block_coord(Params const& params) const {
  102. int m_block, bidh, bidb;
  103. bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, tile_idx));
  104. return {m_block, bidh, bidb};
  105. }
  106. CUTLASS_DEVICE
  107. void
  108. move_to_next_batch() {}
  109. };
  110. CUTLASS_DEVICE
  111. StaticPersistentTileScheduler(int* tile_count_smem_) {};
  112. CUTLASS_DEVICE
  113. WorkTileInfo
  114. get_initial_work(Params const& params) const {
  115. return {int(blockIdx.x)};
  116. }
  117. CUTLASS_DEVICE
  118. void
  119. init_consumer() const {}
  120. CUTLASS_DEVICE
  121. void
  122. prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {}
  123. CUTLASS_DEVICE
  124. void
  125. broadcast_next_work(WorkTileInfo& current_work) const {}
  126. template<bool IsProducer=false>
  127. CUTLASS_DEVICE
  128. WorkTileInfo
  129. get_next_work(Params const& params, WorkTileInfo const& current_work) const {
  130. return {current_work.tile_idx + int(gridDim.x)};
  131. }
  132. };
  133. class DocMaskingStaticPersistentTileScheduler {
  134. public:
  135. // Host side kernel arguments
  136. struct Arguments {
  137. int const num_blocks_m, num_head, num_batch, max_num_docs_per_batch;
  138. int* const tile_count_semaphore = nullptr;
  139. };
  140. // Device side kernel params
  141. struct Params {
  142. int total_blocks, block_size;
  143. cutlass::FastDivmod m_block_divmod, head_divmod, doc_divmod;
  144. };
  145. static Params
  146. to_underlying_arguments(Arguments const& args) {
  147. return {args.num_blocks_m * args.num_head * args.num_batch,
  148. args.max_num_docs_per_batch * args.num_head * args.num_blocks_m,
  149. cutlass::FastDivmod(args.num_blocks_m), cutlass::FastDivmod(args.num_head),
  150. cutlass::FastDivmod(args.max_num_docs_per_batch)};
  151. }
  152. static dim3
  153. get_grid_dim(Arguments const& args, int num_sm) {
  154. return {uint32_t(num_sm)};
  155. }
  156. struct WorkTileInfo {
  157. int tile_idx;
  158. int m_block;
  159. int bidh;
  160. int bidb;
  161. bool next_batch = false;
  162. CUTLASS_DEVICE
  163. WorkTileInfo(int tile_idx_, Params const& params) : tile_idx(tile_idx_) {
  164. bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, tile_idx));
  165. }
  166. CUTLASS_DEVICE
  167. bool
  168. is_valid(Params const& params) const {
  169. return tile_idx < params.total_blocks;
  170. }
  171. CUTLASS_DEVICE
  172. cute::tuple<int32_t, int32_t, int32_t>
  173. get_block_coord(Params const& params) const {
  174. return {m_block, bidh, bidb};
  175. }
  176. CUTLASS_DEVICE
  177. void
  178. move_to_next_batch() {
  179. next_batch = true;
  180. }
  181. };
  182. CUTLASS_DEVICE
  183. DocMaskingStaticPersistentTileScheduler(int* tile_count_smem_) {};
  184. CUTLASS_DEVICE
  185. WorkTileInfo
  186. get_initial_work(Params const& params) const {
  187. return {int(blockIdx.x), params};
  188. }
  189. CUTLASS_DEVICE
  190. void
  191. init_consumer() const {}
  192. CUTLASS_DEVICE
  193. void
  194. prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {}
  195. CUTLASS_DEVICE
  196. void
  197. broadcast_next_work(WorkTileInfo& current_work) const {}
  198. template<bool IsProducer=false>
  199. CUTLASS_DEVICE
  200. WorkTileInfo
  201. get_next_work(Params const& params, WorkTileInfo const& current_work) const {
  202. if (current_work.next_batch) {
  203. int min_next_tile_idx = (params.doc_divmod.div(current_work.bidb) + 1) * params.block_size;
  204. int next_tile_idx = current_work.tile_idx +
  205. (min_next_tile_idx - current_work.tile_idx + int(gridDim.x) - 1) /
  206. int(gridDim.x) * int(gridDim.x);
  207. return {next_tile_idx, params};
  208. } else {
  209. return {current_work.tile_idx + int(gridDim.x), params};
  210. }
  211. }
  212. };
  213. template<int NumMmaThreads=2 * cutlass::NumThreadsPerWarpGroup, int NumProducerThreads = cutlass::NumThreadsPerWarp>
  214. class DynamicPersistentTileScheduler {
  215. protected:
  216. int* const tile_count_smem;
  217. public:
  218. // Host side kernel arguments
  219. struct Arguments {
  220. int const num_blocks_m, num_head, num_batch, max_num_docs_per_batch;
  221. int* const tile_count_semaphore;
  222. };
  223. // Device side kernel params
  224. struct Params {
  225. int const total_blocks;
  226. cutlass::FastDivmod const m_block_divmod, head_divmod;
  227. int* const tile_count_semaphore;
  228. };
  229. static Params
  230. to_underlying_arguments(Arguments const& args) {
  231. return {args.num_blocks_m * args.num_head * args.num_batch,
  232. cutlass::FastDivmod(args.num_blocks_m), cutlass::FastDivmod(args.num_head),
  233. args.tile_count_semaphore};
  234. }
  235. static dim3
  236. get_grid_dim(Arguments const& args, int num_sm) {
  237. return {uint32_t(num_sm)};
  238. }
  239. struct WorkTileInfo {
  240. int tile_idx;
  241. CUTLASS_DEVICE
  242. bool
  243. is_valid(Params const& params) const {
  244. return tile_idx < params.total_blocks;
  245. }
  246. CUTLASS_DEVICE
  247. cute::tuple<int32_t, int32_t, int32_t>
  248. get_block_coord(Params const& params) const {
  249. int m_block, bidh, bidb;
  250. bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, tile_idx));
  251. return {m_block, bidh, bidb};
  252. }
  253. CUTLASS_DEVICE
  254. void
  255. move_to_next_batch() {}
  256. };
  257. CUTLASS_DEVICE
  258. DynamicPersistentTileScheduler(int* tile_count_smem_) : tile_count_smem(tile_count_smem_) {};
  259. CUTLASS_DEVICE
  260. WorkTileInfo
  261. get_initial_work(Params const& params) const {
  262. return {int(blockIdx.x)};
  263. }
  264. CUTLASS_DEVICE
  265. void
  266. init_consumer() const {
  267. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
  268. }
  269. CUTLASS_DEVICE
  270. void
  271. prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {
  272. if (threadIdx.x % NumProducerThreads == 0) {
  273. current_work.tile_idx = atomicAdd(params.tile_count_semaphore, 1) + int(gridDim.x);
  274. }
  275. }
  276. CUTLASS_DEVICE
  277. void
  278. broadcast_next_work(WorkTileInfo& current_work) const {
  279. cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
  280. if (threadIdx.x % NumProducerThreads == 0) {
  281. *tile_count_smem = current_work.tile_idx;
  282. }
  283. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemFull) /*id*/);
  284. }
  285. template<bool IsProducer=false>
  286. CUTLASS_DEVICE
  287. WorkTileInfo
  288. get_next_work(Params const& params, WorkTileInfo const& current_work) const {
  289. if constexpr (IsProducer && NumProducerThreads == cutlass::NumThreadsPerWarp) {
  290. // thread 0 already has the right tile_idx, just need to broadcast to the rest of the producer threads (warp 0)
  291. return {__shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/)};
  292. } else if constexpr (IsProducer && NumProducerThreads == cutlass::NumThreadsPerWarpGroup) {
  293. // TODO: investigate optimal synchronize
  294. int tile_idx = *tile_count_smem;
  295. return {tile_idx};
  296. } else {
  297. cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemFull) /*id*/);
  298. int tile_idx = *tile_count_smem;
  299. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
  300. return {tile_idx};
  301. }
  302. }
  303. };
  304. } // flash