1
0

flash_fwd_combine_launch_template.h 4.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cute/tensor.hpp"
  6. #include "cutlass/cutlass.h"
  7. #include "cutlass/arch/arch.h" // For cutlass::arch::Sm80
  8. #include "cutlass/device_kernel.h" // For device_kernel
  9. #include "static_switch.h"
  10. #include "flash.h"
  11. #include "flash_fwd_combine_kernel.h"
  12. using namespace cute;
  13. template <int kHeadDim, int kBlockM, int kLogMaxSplits, bool IsEvenK, bool Varlen, typename Element, typename ElementPartial>
  14. void run_flash_fwd_combine(Flash_fwd_params &params, cudaStream_t stream) {
  15. using TileShape_MK = cute::Shape<Int<kBlockM>, Int<kHeadDim>>;
  16. using CombineKernel = flash::FlashAttnFwdCombine<TileShape_MK, kLogMaxSplits, 256 /*kNThreads*/, 1 /*AlignmentLSE*/,
  17. IsEvenK, Varlen, Element, ElementPartial, cutlass::arch::Sm80>;
  18. typename CombineKernel::Arguments args {
  19. static_cast<ElementPartial const*>(params.oaccum_ptr),
  20. {!Varlen ? params.seqlen_q : params.total_q, params.d, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_O_partial
  21. {params.oaccum_row_stride, _1{}, params.oaccum_split_stride, params.oaccum_head_stride, !Varlen ? params.oaccum_batch_stride : 0}, // stride_O_partial
  22. static_cast<float*>(params.softmax_lseaccum_ptr),
  23. {!Varlen ? params.seqlen_q : params.total_q, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_LSE_partial
  24. {_1{}, params.lseaccum_split_stride, params.lseaccum_head_stride, !Varlen ? params.lseaccum_batch_stride : 0}, // stride_LSE_partial
  25. static_cast<Element*>(params.o_ptr),
  26. {params.o_row_stride, _1{}, params.o_head_stride, !Varlen ? params.o_batch_stride : 0}, // stride_O
  27. static_cast<float*>(params.softmax_lse_ptr),
  28. {_1{}, !Varlen ? params.seqlen_q : params.total_q, !Varlen ? params.h * params.seqlen_q : 0}, // stride_LSE
  29. params.cu_seqlens_q, params.seqused_q
  30. };
  31. typename CombineKernel::Params kernel_params = CombineKernel::to_underlying_arguments(args);
  32. int num_blocks_m = cute::ceil_div(params.seqlen_q * params.h * (!Varlen ? params.b : 1), kBlockM);
  33. dim3 grid_m(num_blocks_m, !Varlen ? 1 : params.b);
  34. auto kernel = cutlass::device_kernel<CombineKernel>;
  35. int smem_size = CombineKernel::SharedStorageSize;
  36. if (smem_size >= 48 * 1024) {
  37. CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
  38. }
  39. kernel<<<grid_m, CombineKernel::MaxThreadsPerBlock, smem_size, stream>>>(kernel_params);
  40. CHECK_CUDA_KERNEL_LAUNCH();
  41. }
  42. template<typename T, typename Tpartial, int kHeadDim>
  43. void run_mha_fwd_combine_(Flash_fwd_params &params, cudaStream_t stream) {
  44. // We want kBlockM to be as small as possible to maximize parallelism.
  45. // E.g., if hdim is 64, we want kBlockM to be 16 so that we can use 256 threads, each reading 4 elements (floats).
  46. static_assert(kHeadDim % 32 == 0, "kHeadDim must be a multiple of 32");
  47. static constexpr int kBlockM = kHeadDim % 128 == 0 ? 8 : (kHeadDim % 64 == 0 ? 16 : 32);
  48. BOOL_SWITCH(params.seqused_q != nullptr, Varlen, [&] {
  49. if constexpr (kBlockM >= 16) { // If kBlockM == 8 then the minimum number of splits is 32.
  50. if (params.num_splits <= 16) {
  51. run_flash_fwd_combine<kHeadDim, kBlockM, 4, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream);
  52. return;
  53. }
  54. }
  55. if (params.num_splits <= 32) {
  56. run_flash_fwd_combine<kHeadDim, kBlockM, 5, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream);
  57. } else if (params.num_splits <= 64) {
  58. run_flash_fwd_combine<kHeadDim, kBlockM, 6, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream);
  59. } else if (params.num_splits <= 128) {
  60. run_flash_fwd_combine<kHeadDim, kBlockM, 7, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream);
  61. } else {
  62. run_flash_fwd_combine<kHeadDim, kBlockM, 8, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream);
  63. }
  64. });
  65. }