flash_fwd_kernel.h 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cute/tensor.hpp"
  6. #include <cutlass/cutlass.h>
  7. #include <cutlass/arch/reg_reconfig.h>
  8. #include <cutlass/array.h>
  9. #include <cutlass/numeric_types.h>
  10. #include <cutlass/numeric_conversion.h>
  11. #include "cutlass/pipeline/pipeline.hpp"
  12. #include "flash.h"
  13. #include "utils.h"
  14. #include "softmax.h"
  15. #include "tile_scheduler.hpp"
  16. #include "mainloop_fwd_sm90_tma_gmma_ws.hpp"
  17. #include "epilogue_fwd_sm90_tma.hpp"
  18. namespace flash {
  19. using namespace cute;
  20. template <typename Ktraits, bool Is_causal, typename TileScheduler, typename Seqlen_traits>
  21. __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1)
  22. compute_attn_ws(CUTE_GRID_CONSTANT typename CollectiveMainloopFwd<Ktraits, Is_causal, Seqlen_traits>::Params const mainloop_params,
  23. CUTE_GRID_CONSTANT typename CollectiveEpilogueFwd<Ktraits, Seqlen_traits>::Params const epilogue_params,
  24. CUTE_GRID_CONSTANT typename TileScheduler::Params const scheduler_params,
  25. Seqlen_traits seqlen_traits_q, Seqlen_traits seqlen_traits_k
  26. ) {
  27. using Element = typename Ktraits::Element;
  28. using ElementAccum = typename Ktraits::ElementAccum;
  29. using SoftType = ElementAccum;
  30. using TileShape_MNK = typename Ktraits::TileShape_MNK;
  31. using ClusterShape = typename Ktraits::ClusterShape_MNK;
  32. static_assert(Ktraits::Is_WS);
  33. static constexpr bool Is_WS = Ktraits::Is_WS;
  34. static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
  35. static constexpr int NumCopyThreads = !Is_WS ? 0 : cutlass::NumThreadsPerWarpGroup;
  36. static constexpr int kBlockM = Ktraits::kBlockM;
  37. // static constexpr int kBlockN = Ktraits::kBlockN;
  38. // constexpr int kHeadDim = Ktraits::kHeadDim;
  39. using CollectiveMainloop = CollectiveMainloopFwd<Ktraits, Is_causal, Seqlen_traits>;
  40. using CollectiveEpilogue = CollectiveEpilogueFwd<Ktraits, Seqlen_traits>;
  41. using MainloopPipeline = typename Ktraits::MainloopPipeline;
  42. using PipelineParams = typename MainloopPipeline::Params;
  43. using PipelineState = typename MainloopPipeline::PipelineState;
  44. extern __shared__ char shared_memory[];
  45. auto &shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
  46. int const lane_predicate = cute::elect_one_sync();
  47. int const warp_idx = cutlass::canonical_warp_idx_sync();
  48. // Issue Tma Descriptor Prefetch from a single thread
  49. if (warp_idx == 0 && lane_predicate) {
  50. CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
  51. CollectiveEpilogue::prefetch_tma_descriptors(epilogue_params);
  52. }
  53. // Obtain warp index
  54. int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
  55. PipelineParams pipeline_params;
  56. pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
  57. int warp_group_idx = cutlass::canonical_warp_group_idx();
  58. pipeline_params.role = warp_group_idx == 0
  59. ? MainloopPipeline::ThreadCategory::Producer
  60. : MainloopPipeline::ThreadCategory::Consumer;
  61. pipeline_params.is_leader = warp_group_thread_idx == 0;
  62. pipeline_params.num_consumers = NumMmaThreads;
  63. if (warp_idx == 0 && lane_predicate) {
  64. shared_storage.barrier_Q.init(1 /*numThreads*/);
  65. shared_storage.barrier_O.init(size(ClusterShape{}) /*numThreads*/);
  66. }
  67. // We're counting on pipeline_k to call cutlass::arch::fence_barrier_init();
  68. MainloopPipeline pipeline_k(shared_storage.pipeline_k, pipeline_params, ClusterShape{});
  69. MainloopPipeline pipeline_v(shared_storage.pipeline_v, pipeline_params, ClusterShape{});
  70. CollectiveMainloop collective_mainloop;
  71. CollectiveEpilogue collective_epilogue;
  72. // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster
  73. if constexpr (size(ClusterShape{}) > 1) {
  74. cute::cluster_arrive_relaxed();
  75. cute::cluster_wait();
  76. } else {
  77. __syncthreads();
  78. }
  79. static_assert(Ktraits::kNWarps == 12 || Ktraits::kNWarps == 16);
  80. if (warp_group_idx == 0) { // Producer
  81. cutlass::arch::warpgroup_reg_dealloc<Ktraits::kNWarps == 12 ? 24 : 32>();
  82. // cutlass::arch::warpgroup_reg_dealloc<56>();
  83. int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
  84. if (warp_idx_in_warpgroup == 0) { // Load Q, K, V
  85. PipelineState smem_pipe_write_k = cutlass::make_producer_start_state<MainloopPipeline>();
  86. PipelineState smem_pipe_write_v = cutlass::make_producer_start_state<MainloopPipeline>();
  87. int work_idx = 0;
  88. TileScheduler scheduler(&shared_storage.tile_count_semaphore);
  89. for (auto work_tile_info = scheduler.get_initial_work();
  90. work_tile_info.is_valid(scheduler_params);
  91. work_tile_info = scheduler.template get_next_work</*IsProducer=*/true>(scheduler_params, work_tile_info)) {
  92. auto block_coord = work_tile_info.get_block_coord(scheduler_params);
  93. auto [m_block, bidh, bidb] = block_coord;
  94. seqlen_traits_q.init(bidb);
  95. seqlen_traits_k.init(bidb);
  96. if (m_block * kBlockM >= seqlen_traits_q.actual_seq_len) {
  97. continue;
  98. }
  99. int n_block_max = collective_mainloop.get_n_block_max(
  100. mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
  101. if (Is_causal && n_block_max <= 0) {
  102. scheduler.prefetch_next_work(scheduler_params, work_tile_info);
  103. scheduler.broadcast_next_work(work_tile_info);
  104. continue;
  105. }
  106. collective_mainloop.load(mainloop_params, pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v,
  107. shared_storage, scheduler, scheduler_params, work_tile_info, block_coord, work_idx,
  108. seqlen_traits_q, seqlen_traits_k);
  109. ++work_idx;
  110. }
  111. collective_mainloop.load_tail(pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v);
  112. }
  113. } else { // Consumer
  114. cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 12 ? 240 : 160>();
  115. // cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 12 ? 224 : 160>();
  116. TileScheduler scheduler(&shared_storage.tile_count_semaphore);
  117. // Initialize matmul objects.
  118. typename Ktraits::TiledMma1 tiled_mma1;
  119. PipelineState smem_pipe_read_k, smem_pipe_read_v;
  120. // We don't need separate variables smem_pipe_release_k and smem_pipe_release_v
  121. // (like in Cutlass's gemm) because the read and release pipeline states are always the same.
  122. collective_mainloop.mma_init();
  123. scheduler.init_consumer();
  124. int work_idx = 0;
  125. CUTLASS_PRAGMA_NO_UNROLL
  126. for (auto work_tile_info = scheduler.get_initial_work();
  127. work_tile_info.is_valid(scheduler_params);
  128. work_tile_info = scheduler.template get_next_work</*IsProducer=*/false>(scheduler_params, work_tile_info)) {
  129. // Attention output (GEMM-II) accumulator.
  130. Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{}));
  131. flash::Softmax<2 * (2 * kBlockM / NumMmaThreads)> softmax;
  132. auto block_coord = work_tile_info.get_block_coord(scheduler_params);
  133. auto [m_block, bidh, bidb] = block_coord;
  134. seqlen_traits_q.init(bidb);
  135. seqlen_traits_k.init(bidb);
  136. if (m_block * kBlockM >= seqlen_traits_q.actual_seq_len) {
  137. continue;
  138. }
  139. int n_block_max = collective_mainloop.get_n_block_max(
  140. mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
  141. if (Is_causal && n_block_max <= 0) { // We exit early and write 0 to gO and -inf to gLSE.
  142. collective_epilogue.store_zero(epilogue_params, shared_storage, threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q);
  143. continue;
  144. }
  145. collective_mainloop.mma(mainloop_params, pipeline_k, pipeline_v, smem_pipe_read_k, smem_pipe_read_v,
  146. tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads, work_idx, m_block, shared_storage,
  147. seqlen_traits_q, seqlen_traits_k);
  148. // tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads + (work_idx >> 30), work_idx, shared_storage);
  149. collective_epilogue.store(epilogue_params, tOrO, softmax.row_sum, shared_storage, tiled_mma1,
  150. threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q);
  151. ++work_idx;
  152. }
  153. collective_epilogue.store_tail();
  154. }
  155. }
  156. } // namespace flash