1
0

flash_fwd_kernel_sm80.h 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. /******************************************************************************
  2. * Copyright (c) 2024, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cute/tensor.hpp"
  6. #include <cutlass/cutlass.h>
  7. #include <cutlass/array.h>
  8. #include <cutlass/numeric_types.h>
  9. #include <cutlass/kernel_hardware_info.h>
  10. #include "seqlen.h"
  11. #include "utils.h"
  12. #include "softmax.h"
  13. namespace flash {
  14. using namespace cute;
  15. template <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
  16. class FlashAttnFwdSm80 {
  17. public:
  18. // Type Aliases
  19. using CollectiveMainloop = CollectiveMainloop_;
  20. using CollectiveEpilogue = CollectiveEpilogue_;
  21. static constexpr bool Is_causal = CollectiveMainloop::Is_causal;
  22. static constexpr bool Is_local = CollectiveMainloop::Is_local;
  23. static_assert(CollectiveMainloop::Varlen == CollectiveEpilogue::Varlen);
  24. static constexpr bool Has_softcap = CollectiveMainloop::Has_softcap;
  25. static constexpr bool Varlen = CollectiveMainloop::Varlen;
  26. static constexpr bool PagedKV = CollectiveMainloop::PagedKV;
  27. static constexpr bool Split = CollectiveMainloop::Split;
  28. static constexpr bool Is_FP8 = CollectiveMainloop::Is_FP8;
  29. static constexpr bool Transpose_V = CollectiveMainloop::Transpose_V;
  30. static constexpr bool AppendKV = CollectiveMainloop::AppendKV;
  31. static constexpr bool PackGQA = CollectiveMainloop::PackGQA;
  32. static constexpr int NumProducerThreads = CollectiveMainloop::NumProducerThreads;
  33. using SeqlenInfo_t = typename CollectiveMainloop::SeqlenInfo_t;
  34. // Mainloop derived types
  35. using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK;
  36. using TiledMma = typename CollectiveMainloop::TiledMma;
  37. using ArchTag = typename CollectiveMainloop::ArchTag;
  38. using MainloopArguments = typename CollectiveMainloop::Arguments;
  39. using MainloopParams = typename CollectiveMainloop::Params;
  40. // Epilogue derived types
  41. using EpilogueArguments = typename CollectiveEpilogue::Arguments;
  42. using EpilogueParams = typename CollectiveEpilogue::Params;
  43. static_assert(ArchTag::kMinComputeCapability >= 80);
  44. using TileScheduler = TileScheduler_;
  45. using TileSchedulerArguments = typename flash::TileSchedulerArguments;
  46. using TileSchedulerParams = typename TileScheduler::Params;
  47. static constexpr uint32_t NumThreads = CUTE_STATIC_V(size(TiledMma{}));
  48. static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{}));
  49. static constexpr uint32_t MinBlocksPerMultiprocessor = NumThreads == 128 ? 2 : 1;
  50. // Kernel level shared memory storage
  51. // We overlap the shared memory for the mainloop and epilogue. However, we only want smem_o to overlap with smem_v + smem_k and not smem_q
  52. // and nothing else, so we'll pad in case sizeof(smem_o) > sizeof(smem_v) + sizeof(smem_k).
  53. static constexpr int mainloop_smem_padding_ = int(sizeof(typename CollectiveEpilogue::TensorStorage))
  54. - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v)))
  55. - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k)));
  56. static constexpr int mainloop_smem_padding = mainloop_smem_padding_ < 0 ? 0 : mainloop_smem_padding_;
  57. struct SharedStorage {
  58. struct TensorStorage : cute::aligned_struct<128> {
  59. union {
  60. struct {
  61. cute::array<uint32_t, mainloop_smem_padding / sizeof(uint32_t)> padding_;
  62. typename CollectiveMainloop::TensorStorage mainloop;
  63. };
  64. // We want smem_o to line up with the start of smem_v
  65. typename CollectiveEpilogue::TensorStorage epilogue;
  66. };
  67. } tensors;
  68. alignas(16) typename TileScheduler::SharedStorage smem_scheduler;
  69. };
  70. static constexpr int SharedStorageSize = sizeof(SharedStorage);
  71. // Device side arguments
  72. struct Arguments {
  73. MainloopArguments mainloop{};
  74. EpilogueArguments epilogue{};
  75. cutlass::KernelHardwareInfo hw_info{};
  76. TileSchedulerArguments scheduler{};
  77. };
  78. // Kernel entry point API
  79. struct Params {
  80. MainloopParams mainloop{};
  81. EpilogueParams epilogue{};
  82. cutlass::KernelHardwareInfo hw_info{};
  83. TileSchedulerParams scheduler{};
  84. };
  85. //
  86. // Methods
  87. //
  88. // Convert to underlying arguments. In this case, a simple copy for the aliased type.
  89. static
  90. Params
  91. to_underlying_arguments(Arguments const& args) {
  92. CUTLASS_TRACE_HOST("to_underlying_arguments():");
  93. // Get SM count if needed, otherwise use user supplied SM count
  94. int sm_count = args.hw_info.sm_count;
  95. if (sm_count <= 0) {
  96. CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
  97. " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
  98. sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
  99. }
  100. CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
  101. cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
  102. return {
  103. CollectiveMainloop::to_underlying_arguments(args.mainloop),
  104. CollectiveEpilogue::to_underlying_arguments(args.epilogue),
  105. hw_info,
  106. TileScheduler::to_underlying_arguments(args.scheduler)
  107. };
  108. }
  109. // Computes the kernel launch grid shape based on runtime parameters
  110. static dim3
  111. get_grid_shape(Params const& params) {
  112. return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count * MinBlocksPerMultiprocessor);
  113. }
  114. static dim3
  115. get_block_shape() {
  116. return dim3(MaxThreadsPerBlock, 1, 1);
  117. }
  118. CUTLASS_DEVICE
  119. void
  120. operator()(Params const& params, char* smem_buf) {
  121. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  122. SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
  123. CollectiveMainloop collective_mainloop;
  124. CollectiveEpilogue collective_epilogue;
  125. TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.smem_scheduler));
  126. // Initialize matmul objects.
  127. TiledMma tiled_mma;
  128. scheduler.init_consumer();
  129. int warp_idx = cutlass::canonical_warp_idx_sync();
  130. CUTLASS_PRAGMA_NO_UNROLL
  131. for (auto work_tile_info = warp_idx == 0 ? scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler) : scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
  132. work_tile_info.is_valid(params.scheduler);
  133. work_tile_info = warp_idx == 0 ? scheduler.template get_next_work</*IsProducerWarp=*/true>(params.scheduler, work_tile_info) : scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
  134. // Attention output (GEMM-II) accumulator.
  135. Tensor tOrO = partition_fragment_C(tiled_mma, select<0, 2>(TileShape_MNK{}));
  136. float softmax_scale_log2 = params.mainloop.softmax_scale_log2;
  137. // If there's tanh softcap, the scaling will be done before tanh.
  138. auto block_coord = work_tile_info.get_block_coord(params.scheduler);
  139. int const bidb = get<2>(block_coord);
  140. if constexpr (Is_FP8 && !Has_softcap) {
  141. int const bidh = get<1>(block_coord);
  142. int const bidh_kv = !PackGQA ? params.mainloop.qhead_per_khead_divmod.divide(bidh) : bidh;
  143. float const q_descale = params.mainloop.ptr_q_descale == nullptr ? 1.0f : params.mainloop.ptr_q_descale[bidb * get<0>(params.mainloop.stride_q_descale) + bidh_kv * get<1>(params.mainloop.stride_q_descale)];
  144. float const k_descale = params.mainloop.ptr_k_descale == nullptr ? 1.0f : params.mainloop.ptr_k_descale[bidb * get<0>(params.mainloop.stride_k_descale) + bidh_kv * get<1>(params.mainloop.stride_k_descale)];
  145. softmax_scale_log2 *= q_descale * k_descale;
  146. }
  147. flash::Softmax<2 * (2 * kBlockM / NumThreads), /*Max_offset=*/!Is_FP8 ? 0 : 8> softmax(softmax_scale_log2);
  148. SeqlenInfo_t seqlen_info{
  149. bidb,
  150. get<0>(params.mainloop.shape_Q),
  151. !PagedKV ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable),
  152. get<0>(params.mainloop.shape_K_new),
  153. params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new,
  154. params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k,
  155. };
  156. if constexpr (AppendKV) {
  157. bool tile_new_valid = collective_mainloop.store_kv_new(
  158. params.mainloop, threadIdx.x, shared_storage, seqlen_info, block_coord);
  159. if (tile_new_valid) { __syncthreads(); }
  160. }
  161. bool tile_valid = collective_mainloop.mma(
  162. params.mainloop, tOrO, softmax, threadIdx.x, seqlen_info, block_coord,
  163. shared_storage);
  164. scheduler.prefetch_next_work(params.scheduler, work_tile_info);
  165. if (tile_valid) {
  166. // if (threadIdx.x == 128) { printf("Before epilogue, bid.x = %d, bid.y = %d, bid.z = %d, m_block = %d, bidb = %d, split_idx = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, m_block, bidb, split_idx); }
  167. collective_epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma,
  168. threadIdx.x, block_coord);
  169. } else {
  170. // Write 0 to gO and -inf to gLSE.
  171. // If Split, we don't have to write 0 to O if the mha_combine kernel is used, since it will
  172. // not use the value of O if LSE is -inf.
  173. collective_epilogue.template store_zero<!Split /*Clear_O*/>(params.epilogue, threadIdx.x, block_coord);
  174. }
  175. }
  176. }
  177. };
  178. } // namespace flash