/****************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #pragma once #include "cute/tensor.hpp" #include "cutlass/device_kernel.h" // For device_kernel #include "cutlass/kernel_launch.h" // For kernel_launch #include "cutlass/cluster_launch.hpp" // For ClusterLauncher #include "static_switch.h" #include "flash.h" #include "flash_bwd_preprocess_kernel.h" #include "flash_bwd_postprocess_kernel.h" #include "tile_scheduler.hpp" #include "mainloop_bwd_sm90_tma_gmma_ws.hpp" #include "mainloop_bwd_sm80.hpp" #include "epilogue_bwd.hpp" #include "flash_bwd_kernel_sm90.h" #include "flash_bwd_kernel_sm80.h" using namespace cute; template void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { static_assert(!(Is_causal && Is_local), "Is_causal and Is_local cannot be true at the same time."); using ElementAccum = float; using ArchTag = std::conditional_t= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>; int const total_q_padded_rounded = cute::round_up(params.total_q + params.b * kBlockM, kBlockM); int const total_k_padded_rounded = cute::round_up(params.total_k + params.b * kBlockN, kBlockN); bool const is_varlen_q = params.cu_seqlens_q; bool const is_varlen_k = params.cu_seqlens_k; int seqlen_q = !is_varlen_q ? params.seqlen_q : params.total_q; int seqlen_k = !is_varlen_k ? params.seqlen_k : params.total_k; int seqlen_q_rounded = !is_varlen_q ? params.seqlen_q_rounded : total_q_padded_rounded; int seqlen_k_rounded = !is_varlen_k ? params.seqlen_k_rounded : total_k_padded_rounded; int batch_q = !is_varlen_q ? params.b : 1; int batch_k = !is_varlen_k ? params.b : 1; using TileShape_MK = cute::Shape, Int>; using PreprocessKernel = flash::FlashAttnBwdPreprocess; typename PreprocessKernel::Arguments preprocess_args { static_cast(params.o_ptr), {seqlen_q, params.d, params.h, batch_q}, // shape_O {params.o_row_stride, _1{}, params.o_head_stride, !is_varlen_q ? params.o_batch_stride : 0}, // stride_O static_cast(params.do_ptr), {params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0}, // stride_dO static_cast(params.dsoftmax_sum), {seqlen_q_rounded, params.h, batch_q}, // shape_dPsum {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_dPsum static_cast(params.softmax_lse_ptr), {_1{}, seqlen_q, !is_varlen_q ? params.h * params.seqlen_q : 0}, // stride_LSE static_cast(params.softmax_lse_log2_ptr), {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_LSE_log2 static_cast(params.dq_accum_ptr), {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * seqlen_q_rounded * params.h : 0}, // stride_dQaccum params.b, params.dq_semaphore, params.cu_seqlens_q, params.seqused_q }; typename PreprocessKernel::Params preprocess_params = PreprocessKernel::to_underlying_arguments(preprocess_args); int num_m_block = cute::ceil_div(params.seqlen_q, kBlockM); dim3 grid_m(num_m_block, params.h, params.b); cutlass::kernel_launch(grid_m, PreprocessKernel::MaxThreadsPerBlock, PreprocessKernel::SharedStorageSize, stream, preprocess_params, false /*launch_with_pdl*/); CHECK_CUDA_KERNEL_LAUNCH(); using TileShape_MNK = cute::Shape, Int, Int>; using ClusterShape = cute::Shape<_1, Int<1>, _1>; // Currently doesn't not support cluster // Stages_dS_or_QSm80 is Stages_dS if Sm90 and Stages if Sm80 static constexpr int Stages = Arch >= 90 ? 2 : Stages_dS_or_QSm80; static constexpr int Stages_dS = Arch >= 90 ? Stages_dS_or_QSm80 : 1; using CollectiveMainloop = std::conditional_t< Arch >= 90, flash::CollectiveMainloopBwdSm90, flash::CollectiveMainloopBwdSm80 >; using CollectiveEpilogue = std::conditional_t< !GQA, flash::CollectiveEpilogueBwd= 90 ? 1 : cutlass::NumWarpsPerWarpGroup) / AtomLayoutNdKV>, flash::CollectiveEpilogueBwdGQA >; using Scheduler = flash::SingleTileScheduler; using AttnKernel = std::conditional_t< Arch >= 90, flash::enable_sm90_or_later>, flash::enable_sm80_to_sm89> >; typename CollectiveMainloop::Arguments mainloop_args { static_cast(params.q_ptr), {seqlen_q, params.d, params.h, batch_q}, // shape_Q {params.q_row_stride, _1{}, params.q_head_stride, !is_varlen_q ? params.q_batch_stride : 0}, // stride_Q static_cast(params.k_ptr), {seqlen_k, params.d, params.h_k, batch_k}, // shape_K {params.k_row_stride, _1{}, params.k_head_stride, !is_varlen_k ? params.k_batch_stride : 0}, // stride_K static_cast(params.v_ptr), {params.v_row_stride, _1{}, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0}, // stride_V static_cast(params.do_ptr), {params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0}, // stride_dO static_cast(params.dq_accum_ptr), {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQaccum static_cast(params.softmax_lse_log2_ptr), {seqlen_q_rounded, params.h, batch_q}, // shape_LSE {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_LSE_log2 static_cast(params.dsoftmax_sum), {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_dPsum params.scale_softmax, params.window_size_left, params.window_size_right, params.sink_token_length, params.softcap, params.b, params.dq_semaphore, params.cu_seqlens_q, params.cu_seqlens_k, params.seqused_q, params.seqused_k }; // The case work with GQA is ugly but idk how to fix it. typename CollectiveEpilogue::Arguments epilogue_args { static_cast(!GQA ? params.dk_ptr : params.dk_accum_ptr), [&] { if constexpr (!GQA) { return typename CollectiveEpilogue::ShapedKV {seqlen_k, params.d, params.h, batch_k}; // shape_dK } else { return typename CollectiveEpilogue::ShapedKV {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}; // shape_dKaccum } }(), [&] { if constexpr (!GQA) { return typename CollectiveEpilogue::StridedKV {params.dk_row_stride, _1{}, params.dk_head_stride, !is_varlen_k ? params.dk_batch_stride : 0}; // stride_dK } else { return typename CollectiveEpilogue::StridedKV {_1{}, params.d_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.d_rounded * params.seqlen_k_rounded : 0}; // stride_dKaccum } }(), static_cast(!GQA ? params.dv_ptr : params.dv_accum_ptr), [&] { if constexpr (!GQA) { return typename CollectiveEpilogue::StridedKV {params.dv_row_stride, _1{}, params.dv_head_stride, !is_varlen_k ? params.dv_batch_stride : 0}; // stride_dV } else { return typename CollectiveEpilogue::StridedKV {_1{}, params.d_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.d_rounded * params.seqlen_k_rounded : 0}; // stride_dVaccum } }(), params.h, params.dk_semaphore, params.dv_semaphore, params.cu_seqlens_k, params.seqused_k, }; int num_blocks_n = cutlass::ceil_div(params.seqlen_k, get<1>(TileShape_MNK{})); num_blocks_n = cutlass::round_up(num_blocks_n, size<1>(ClusterShape{})); typename flash::TileSchedulerArguments scheduler_args { num_blocks_n, params.h, params.b, 1 /*num_splits*/, params.h / params.h_k, params.seqlen_k, params.seqlen_q, params.d, sizeof(Element), params.tile_count_semaphore, params.cu_seqlens_k, params.seqused_k }; int device; cudaGetDevice(&device); typename AttnKernel::Params kernel_params = AttnKernel::to_underlying_arguments({ mainloop_args, epilogue_args, {device, params.num_sm}, scheduler_args }); dim3 grid_dims = AttnKernel::get_grid_shape(kernel_params); dim3 block_dims = AttnKernel::get_block_shape(); int smem_size = AttnKernel::SharedStorageSize; // int smem_size_q = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_q)); // int smem_size_do = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_do)); // int smem_size_ds = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_ds)); // int smem_size_dqacc = [&] { // if constexpr (Arch >= 90) { // return sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_dqacc)); // } else { // return 0; // } // }(); // int smem_size_k = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k)); // int smem_size_v = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v)); // int smem_size_lse = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_lse)); // int smem_size_dpsum = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_dpsum)); // printf("smem_size = %d, q = %d, k = %d, v = %d, do = %d, ds = %d, dqacc = %d, lse = %d, dpsum = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v, smem_size_do, smem_size_ds, smem_size_dqacc, smem_size_lse, smem_size_dpsum); if constexpr (size(ClusterShape{}) > 1) { void const* kernel = (void const*) cutlass::device_kernel; if (smem_size >= 48 * 1024) { CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); } dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{})); cutlass::ClusterLauncher::launch( grid_dims, cluster_dims, block_dims, smem_size, stream, kernel, kernel_params, false /*launch_with_pdl*/); } else { if (smem_size >= 48 * 1024) { CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); } cutlass::kernel_launch(grid_dims, block_dims, smem_size, stream, kernel_params, false /*launch_with_pdl*/); } CHECK_CUDA_KERNEL_LAUNCH(); using PostprocessKernel = flash::FlashAttnBwdPostprocessConvertdQ; typename PostprocessKernel::Arguments postprocess_args { static_cast(params.dq_accum_ptr), {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQaccum static_cast(params.dq_ptr), {seqlen_q, params.d, params.h, batch_q}, // shape_dQ {params.dq_row_stride, _1{}, params.dq_head_stride, params.dq_batch_stride}, // stride_dQ params.scale_softmax, params.cu_seqlens_q, params.seqused_q }; typename PostprocessKernel::Params postprocess_params = PostprocessKernel::to_underlying_arguments(postprocess_args); int num_m_block_postprocess = cute::ceil_div(params.seqlen_q, get<0>(TileShape_MK{})); dim3 grid_m_postprocess(num_m_block_postprocess, params.h, params.b); int smem_size_postprocess = PostprocessKernel::SharedStorageSize; if (smem_size_postprocess >= 48 * 1024) { CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_postprocess)); } cutlass::kernel_launch(grid_m_postprocess, PostprocessKernel::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_params, false /*launch_with_pdl*/); CHECK_CUDA_KERNEL_LAUNCH(); if constexpr (GQA) { using TileShape_NK = cute::Shape, Int>; using PostprocessKerneldKV = flash::FlashAttnBwdPostprocessConvertdQ; typename PostprocessKerneldKV::Arguments postprocess_dK_args { static_cast(params.dk_accum_ptr), {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}, // shape_dKaccum {_1{}, seqlen_k_rounded * params.d_rounded, !is_varlen_k ? params.d_rounded * params.seqlen_k_rounded * params.h_k : 0}, // stride_dKaccum static_cast(params.dk_ptr), {seqlen_k, params.d, params.h_k, batch_k}, // shape_dK {params.dk_row_stride, _1{}, params.dk_head_stride, params.dk_batch_stride}, // stride_dK 1.f, params.cu_seqlens_k, params.seqused_k }; typename PostprocessKerneldKV::Params postprocess_dK_params = PostprocessKerneldKV::to_underlying_arguments(postprocess_dK_args); typename PostprocessKerneldKV::Arguments postprocess_dV_args { static_cast(params.dv_accum_ptr), {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}, // shape_dVaccum {_1{}, seqlen_k_rounded * params.d_rounded, !is_varlen_k ? params.d_rounded * params.seqlen_k_rounded * params.h_k : 0}, // stride_dVaccum static_cast(params.dv_ptr), {seqlen_k, params.d, params.h_k, batch_k}, // shape_dV {params.dv_row_stride, _1{}, params.dv_head_stride, params.dv_batch_stride}, // stride_dV 1.f, params.cu_seqlens_k, params.seqused_k }; typename PostprocessKerneldKV::Params postprocess_dV_params = PostprocessKerneldKV::to_underlying_arguments(postprocess_dV_args); int num_n_block_postprocess = cute::ceil_div(params.seqlen_k, get<0>(TileShape_NK{})); dim3 grid_n_postprocess(num_n_block_postprocess, params.h_k, params.b); int smem_size_postprocess = PostprocessKerneldKV::SharedStorageSize; if (smem_size_postprocess >= 48 * 1024) { CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_postprocess)); } cutlass::kernel_launch(grid_n_postprocess, PostprocessKerneldKV::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_dK_params, false /*launch_with_pdl*/); CHECK_CUDA_KERNEL_LAUNCH(); cutlass::kernel_launch(grid_n_postprocess, PostprocessKerneldKV::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_dV_params, false /*launch_with_pdl*/); CHECK_CUDA_KERNEL_LAUNCH(); } } template void run_mha_bwd_dispatch(Flash_bwd_params ¶ms, cudaStream_t stream) { VARLEN_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] { BOOL_SWITCH(params.h != params.h_k, GQA, [&] { // BOOL_SWITCH(params.deterministic, Deterministic, [&] { // run_flash_bwd(params, stream); run_flash_bwd(params, stream); // }); }); }); } template void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) { CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] { SOFTCAP_SWITCH(params.softcap > 0.f, Has_softcap, [&] { if constexpr (Arch >= 90) { if constexpr (Is_causal && Has_softcap) { // register spill with 128 x 128 run_mha_bwd_dispatch(params, stream); } else { // With ShuffleStats we no longer have register spilling when Has_softcap and using 128 x 128 block. run_mha_bwd_dispatch(params, stream); } } else { run_mha_bwd_dispatch(params, stream); // Sm86 // run_mha_bwd_dispatch(params, stream); // run_mha_bwd_dispatch(params, stream); // run_mha_bwd_dispatch(params, stream); // run_mha_bwd_dispatch(params, stream); } }); }); } template void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream) { CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] { SOFTCAP_SWITCH(params.softcap > 0.f, Has_softcap, [&] { if constexpr (Arch >= 90) { run_mha_bwd_dispatch(params, stream); } else { run_mha_bwd_dispatch(params, stream); // Sm86 // run_mha_bwd_dispatch(params, stream); } }); }); } template void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) { CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] { SOFTCAP_SWITCH(params.softcap > 0.f, Has_softcap, [&] { if constexpr (Arch >= 90) { if constexpr (Is_causal || Is_local || Has_softcap) { run_mha_bwd_dispatch(params, stream); } else { run_mha_bwd_dispatch(params, stream); } } else { run_mha_bwd_dispatch(params, stream); // Sm86 // run_mha_bwd_dispatch(params, stream); } }); }); } template void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream) { CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] { SOFTCAP_SWITCH(params.softcap > 0.f, Has_softcap, [&] { if constexpr (Arch >= 90) { run_mha_bwd_dispatch(params, stream); } else { run_mha_bwd_dispatch(params, stream); // Sm86 // run_mha_bwd_dispatch(params, stream); } }); }); } template void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream) { CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] { SOFTCAP_SWITCH(params.softcap > 0.f, Has_softcap, [&] { if constexpr (Arch >= 90) { run_mha_bwd_dispatch(params, stream); } else { run_mha_bwd_dispatch(params, stream); // Sm86 // run_mha_bwd_dispatch(params, stream); // run_mha_bwd_dispatch(params, stream); } }); }); }