flash_bwd_kernel_sm80.h 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. /******************************************************************************
  2. * Copyright (c) 2024, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cute/tensor.hpp"
  6. #include <cutlass/cutlass.h>
  7. #include <cutlass/array.h>
  8. #include <cutlass/numeric_types.h>
  9. #include <cutlass/kernel_hardware_info.h>
  10. #include "utils.h"
  11. namespace flash {
  12. using namespace cute;
  13. template <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
  14. class FlashAttnBwdSm80 {
  15. public:
  16. // Type Aliases
  17. static constexpr bool Is_causal = CollectiveMainloop_::Is_causal;
  18. static constexpr bool Is_local = CollectiveMainloop_::Is_local;
  19. static_assert(CollectiveMainloop_::Varlen == CollectiveEpilogue_::Varlen);
  20. static constexpr bool Varlen = CollectiveMainloop_::Varlen;
  21. // Mainloop derived types
  22. using CollectiveMainloop = CollectiveMainloop_;
  23. using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK;
  24. using TiledMmaSdP = typename CollectiveMainloop::TiledMmaSdP;
  25. using TiledMmadKV = typename CollectiveMainloop::TiledMmadKV;
  26. using ArchTag = typename CollectiveMainloop::ArchTag;
  27. using MainloopArguments = typename CollectiveMainloop::Arguments;
  28. using MainloopParams = typename CollectiveMainloop::Params;
  29. static constexpr bool dKV_swapAB = CollectiveMainloop::dKV_swapAB;
  30. // Epilogue derived types
  31. using CollectiveEpilogue = CollectiveEpilogue_;
  32. using EpilogueArguments = typename CollectiveEpilogue::Arguments;
  33. using EpilogueParams = typename CollectiveEpilogue::Params;
  34. static_assert(ArchTag::kMinComputeCapability >= 80);
  35. using TileScheduler = TileScheduler_;
  36. using TileSchedulerArguments = typename flash::TileSchedulerArguments;
  37. using TileSchedulerParams = typename TileScheduler::Params;
  38. static constexpr uint32_t NumThreads = CUTE_STATIC_V(size(TiledMmaSdP{}));
  39. static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaSdP{}));
  40. static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
  41. // Kernel level shared memory storage
  42. struct SharedStorage {
  43. struct TensorStorage : cute::aligned_struct<128> {
  44. union {
  45. typename CollectiveMainloop::TensorStorage mainloop;
  46. typename CollectiveEpilogue::TensorStorage epilogue;
  47. };
  48. } tensors;
  49. alignas(16) typename TileScheduler::SharedStorage smem_scheduler;
  50. };
  51. static constexpr int SharedStorageSize = sizeof(SharedStorage);
  52. // Device side arguments
  53. struct Arguments {
  54. MainloopArguments mainloop{};
  55. EpilogueArguments epilogue{};
  56. cutlass::KernelHardwareInfo hw_info{};
  57. TileSchedulerArguments scheduler{};
  58. };
  59. // Kernel entry point API
  60. struct Params {
  61. MainloopParams mainloop{};
  62. EpilogueParams epilogue{};
  63. cutlass::KernelHardwareInfo hw_info{};
  64. TileSchedulerParams scheduler{};
  65. };
  66. //
  67. // Methods
  68. //
  69. // Convert to underlying arguments. In this case, a simple copy for the aliased type.
  70. static
  71. Params
  72. to_underlying_arguments(Arguments const& args) {
  73. CUTLASS_TRACE_HOST("to_underlying_arguments():");
  74. // Get SM count if needed, otherwise use user supplied SM count
  75. int sm_count = args.hw_info.sm_count;
  76. if (sm_count <= 0) {
  77. CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
  78. " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
  79. sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
  80. }
  81. CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
  82. cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
  83. return {
  84. CollectiveMainloop::to_underlying_arguments(args.mainloop),
  85. CollectiveEpilogue::to_underlying_arguments(args.epilogue),
  86. hw_info,
  87. TileScheduler::to_underlying_arguments(args.scheduler)
  88. };
  89. }
  90. // Computes the kernel launch grid shape based on runtime parameters
  91. static dim3
  92. get_grid_shape(Params const& params) {
  93. return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count);
  94. }
  95. static dim3
  96. get_block_shape() {
  97. return dim3(MaxThreadsPerBlock, 1, 1);
  98. }
  99. CUTLASS_DEVICE
  100. void
  101. operator()(Params const& params, char* smem_buf) {
  102. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  103. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  104. SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
  105. CollectiveMainloop collective_mainloop;
  106. CollectiveEpilogue collective_epilogue;
  107. TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.smem_scheduler));
  108. // Initialize matmul objects.
  109. TiledMmadKV tiled_mma_dKV;
  110. scheduler.init_consumer();
  111. int warp_idx = cutlass::canonical_warp_idx_sync();
  112. CUTLASS_PRAGMA_NO_UNROLL
  113. for (auto work_tile_info = warp_idx == 0 ? scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler) : scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
  114. work_tile_info.is_valid(params.scheduler);
  115. work_tile_info = warp_idx == 0 ? scheduler.template get_next_work</*IsProducerWarp=*/true>(params.scheduler, work_tile_info) : scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
  116. auto block_coord_ = work_tile_info.get_block_coord(params.scheduler);
  117. auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_;
  118. cute::tuple<int32_t, int32_t, int32_t> block_coord = {n_block, bidh, bidb};
  119. // dK and dV output accumulator.
  120. Tensor tdKrdK = partition_fragment_C(tiled_mma_dKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
  121. Tensor tdVrdV = partition_fragment_C(tiled_mma_dKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
  122. bool tile_valid = collective_mainloop.mma(
  123. params.mainloop, tdKrdK, tdVrdV, threadIdx.x, block_coord,
  124. shared_storage);
  125. scheduler.prefetch_next_work(params.scheduler, work_tile_info);
  126. if (tile_valid) {
  127. collective_epilogue.store(params.epilogue, tdKrdK, tdVrdV, shared_storage, tiled_mma_dKV,
  128. threadIdx.x, block_coord);
  129. } else {
  130. collective_epilogue.store_zero(params.epilogue, threadIdx.x, block_coord);
  131. }
  132. }
  133. }
  134. };
  135. } // namespace flash