/****************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #pragma once #include "cute/algorithm/copy.hpp" #include "cute/atom/mma_atom.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/cutlass.h" #include "cutlass/layout/layout.h" #include "cutlass/numeric_types.h" #include "cutlass/pipeline/pipeline.hpp" using namespace cute; template struct SharedStorageQKVO { cute::array_aligned> smem_q; cute::array_aligned> smem_k; union { cute::array_aligned> smem_v; cute::array_aligned> smem_o; }; struct { cutlass::arch::ClusterTransactionBarrier barrier_Q; cutlass::arch::ClusterBarrier barrier_O; typename cutlass::PipelineTmaAsync::SharedStorage pipeline_k; typename cutlass::PipelineTmaAsync::SharedStorage pipeline_v; int tile_count_semaphore; }; }; // Use if Oaccum is too large for SharedStorageQKVO template struct SharedStorageQKVOaccum { cute::array_aligned> smem_q; union { struct { cute::array_aligned> smem_k; cute::array_aligned> smem_v; }; cute::array_aligned> smem_o; }; struct { cutlass::arch::ClusterTransactionBarrier barrier_Q; cutlass::arch::ClusterBarrier barrier_O; typename cutlass::PipelineTmaAsync::SharedStorage pipeline_k; typename cutlass::PipelineTmaAsync::SharedStorage pipeline_v; int tile_count_semaphore; }; }; // SharedStorage struct with no smem for O template struct SharedStorageQKV { cute::array_aligned> smem_q; cute::array_aligned> smem_k; cute::array_aligned> smem_v; struct { cutlass::arch::ClusterTransactionBarrier barrier_Q; typename cutlass::PipelineTmaAsync::SharedStorage pipeline_k; typename cutlass::PipelineTmaAsync::SharedStorage pipeline_v; int tile_count_semaphore; }; }; template struct SharedStorageQKVOVt { struct { cute::array_aligned> smem_q; cute::array_aligned> smem_k; cute::array_aligned> smem_v; union { cute::array_aligned> smem_v_out; cute::array_aligned> smem_o; }; }; struct { cutlass::arch::ClusterTransactionBarrier barrier_Q; cutlass::arch::ClusterBarrier barrier_O; typename cutlass::PipelineTmaAsync::SharedStorage pipeline_k; typename cutlass::PipelineTmaAsync::SharedStorage pipeline_v; typename cutlass::PipelineAsync::SharedStorage pipeline_vt; int tile_count_semaphore; float softmax_scale_qk_log2; float descale_v; bool seqlen_init_k; }; }; // Use if Oaccum is too large for SharedStorageQKVOVt template struct SharedStorageQKVOVtaccum { struct { cute::array_aligned> smem_q; cute::array_aligned> smem_k; union { struct { cute::array_aligned> smem_v; cute::array_aligned> smem_v_out; }; cute::array_aligned> smem_o; }; }; struct { cutlass::arch::ClusterTransactionBarrier barrier_Q; cutlass::arch::ClusterBarrier barrier_O; typename cutlass::PipelineTmaAsync::SharedStorage pipeline_k; typename cutlass::PipelineTmaAsync::SharedStorage pipeline_v; typename cutlass::PipelineAsync::SharedStorage pipeline_vt; int tile_count_semaphore; float softmax_scale_qk_log2; float descale_v; bool seqlen_init_k; }; }; template struct SharedStorageQKVVt { struct { cute::array_aligned> smem_q; cute::array_aligned> smem_k; cute::array_aligned> smem_v; cute::array_aligned> smem_v_out; }; struct { cutlass::arch::ClusterTransactionBarrier barrier_Q; typename cutlass::PipelineTmaAsync::SharedStorage pipeline_k; typename cutlass::PipelineTmaAsync::SharedStorage pipeline_v; typename cutlass::PipelineAsync::SharedStorage pipeline_vt; int tile_count_semaphore; float softmax_scale_qk_log2; float descale_v; bool seqlen_init_k; }; }; template struct SharedStorageQKVOVt_nounion { struct { cute::array_aligned> smem_q; cute::array_aligned> smem_k; cute::array_aligned> smem_v; cute::array_aligned> smem_v_out; cute::array_aligned> smem_o; }; struct { cutlass::arch::ClusterTransactionBarrier barrier_Q; cutlass::arch::ClusterBarrier barrier_O; typename cutlass::PipelineTmaAsync::SharedStorage pipeline_k; typename cutlass::PipelineTmaAsync::SharedStorage pipeline_v; typename cutlass::PipelineAsync::SharedStorage pipeline_vt; int tile_count_semaphore; float softmax_scale_qk_log2; float descale_v; bool seqlen_init_k; }; }; // If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true template struct Flash_fwd_kernel_traits { using Element = elem_type; using ElementAccum = float; using FinalOutputType = elem_type; using OutputType = std::conditional_t; using index_t = int64_t; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp; static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarp; static constexpr bool Is_Q_in_regs = Is_Q_in_regs_; static_assert(kNWarps_ == 8 || kNWarps_ == 12 || kNWarps_ == 16); static constexpr bool Is_WS = true; static_assert(!(Is_WS && Is_Q_in_regs), "Warp-specialization does not support Q in registers"); static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kBlockH = kBlockH_; static constexpr int kHeadDim = kHeadDim_; static_assert(kHeadDim % 32 == 0); static_assert(kBlockM % kBlockH == 0); using TileShape_MNK = Shape, Int, Int>; static constexpr int kClusterM = kClusterM_; using ClusterShape_MNK = Shape, _1, _1>; static constexpr int kStages = kStages_; static constexpr bool Is_split = Is_split_; static constexpr bool No_smem_O = Is_split; using AtomLayoutMNK = Layout, _1, _1>>; using TiledMma0 = decltype(cute::make_tiled_mma( std::conditional_t< Is_Q_in_regs, decltype(cute::GMMA::rs_op_selector()), decltype(cute::GMMA::ss_op_selector()) >{}, AtomLayoutMNK{})); using TiledMma1 = decltype(cute::make_tiled_mma( cute::GMMA::rs_op_selector(TileShape_MNK{})), GMMA::Major::K, GMMA::Major::MN>(), AtomLayoutMNK{})); using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{}))); // for gmem -> smem Q copy using FactoringLayoutQ = Layout, Int, Int>, Stride, _1, Int>>; using TileShapeQCopy = std::conditional_t<(kBlockH > 1), decltype(shape(FactoringLayoutQ{})), decltype(select<0, 2>(TileShape_MNK{}))>; using SmemLayoutQCopy = std::conditional_t<(kBlockH > 1), decltype(composition(SmemLayoutQ{}, FactoringLayoutQ{})), SmemLayoutQ>; using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{}, make_shape(get<1>(TileShape_MNK{}), get<2>(TileShape_MNK{}), Int{}))); // Note this is the transpose in terms of the view, not in terms of memory. using SmemLayoutVt = decltype(composition(SmemLayoutV{}, make_ordered_layout( make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int{}), Step<_2, _1, _3>{}))); using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{}))); // for smem -> gmem O copy using TileShapeOCopy = TileShapeQCopy; using SmemLayoutOCopy = std::conditional_t<(kBlockH > 1), decltype(composition(SmemLayoutO{}, FactoringLayoutQ{})), SmemLayoutO>; using SmemCopyAtomQ = Copy_Atom; using SharedStorage = std::conditional_t, SharedStorageQKV>; using MainloopPipeline = typename cutlass::PipelineTmaAsync; using MainloopPipelineNoTMA = typename cutlass::PipelineAsync; using PipelineState = typename cutlass::PipelineState; // using BarrierType = typename MainloopPipeline::ProducerBarrierType; }; // Traits struct for fp8 kernel with in-kernel transpose template struct Flash_fwd_kernel_traits_fp8 { using Element = elem_type; static_assert(cutlass::sizeof_bits_v == 8); using ElementAccum = float; using FinalOutputType = cutlass::bfloat16_t; using OutputType = std::conditional_t; using index_t = int64_t; static constexpr bool Is_split = Is_split_; static constexpr bool No_smem_O = false; // NOTE: not using smem for epilogue degrades perf substantially. // static constexpr bool No_smem_O = Is_split; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp; static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup; static constexpr bool Is_Q_in_regs = Is_Q_in_regs_; static_assert(kNWarps_ == 8 || kNWarps_ == 12 || kNWarps_ == 16); static constexpr bool Is_WS = true; static_assert(!Is_Q_in_regs, "Warp-specialization does not support Q in registers"); static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kBlockH = kBlockH_; static constexpr int kHeadDim = kHeadDim_; static_assert(kHeadDim % 32 == 0); static_assert(kBlockM % kBlockH == 0); using TileShape_MNK = Shape, Int, Int>; static constexpr int kClusterM = kClusterM_; using ClusterShape_MNK = Shape, _1, _1>; static constexpr int kStages = kStages_; static_assert(kStages > 1); // Use this to save enough smem when writing out in float precision. static constexpr bool VO_union_all = Is_split && (kBlockM != 64) && (kHeadDim == 256); using AtomLayoutMNK = Layout, _1, _1>>; using TiledMma0 = decltype(cute::make_tiled_mma( cute::GMMA::ss_op_selector(), AtomLayoutMNK{})); using TiledMma1 = decltype(cute::make_tiled_mma( cute::GMMA::rs_op_selector(TileShape_MNK{}))>(), AtomLayoutMNK{})); using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{}))); // for gmem -> smem Q copy using FactoringLayoutQ = Layout, Int, Int>, Stride, _1, Int>>; using TileShapeQCopy = std::conditional_t<(kBlockH > 1), decltype(shape(FactoringLayoutQ{})), decltype(select<0, 2>(TileShape_MNK{}))>; using SmemLayoutQCopy = std::conditional_t<(kBlockH > 1), decltype(composition(SmemLayoutQ{}, FactoringLayoutQ{})), SmemLayoutQ>; using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); using TransposeShapeAtomV = Shape<_64, _64>; using SmemLayoutAtomV = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom{}, TransposeShapeAtomV{})); using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{}, make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); // for fp8 in-kernel transpose -- src layout using SmemLayoutDivideV = decltype(tiled_divide(SmemLayoutV{}, TransposeShapeAtomV{})); using SmemShapeLDSM = Shape, Shape<_16, _4>>; using FactoringShapeV = decltype(make_shape(SmemShapeLDSM{}, shape<1>(SmemLayoutDivideV{}), shape<2>(SmemLayoutDivideV{}), shape<3>(SmemLayoutDivideV{}))); using SmemLayoutTransposeV = decltype(composition(SmemLayoutDivideV{}, make_layout(FactoringShapeV{}))); // For fp8, this is the memory transpose. using SmemLayoutAtomVt = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom{}, TransposeShapeAtomV{})); using SmemLayoutVt = decltype(tile_to_shape(SmemLayoutAtomVt{}, make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int{}))); // for fp8 in-kernel transpose -- dst layout using SmemLayoutVtTrans = decltype(composition(SmemLayoutVt{}, make_ordered_layout(product_each(shape(SmemLayoutV{})), Step<_2, _1, _3>{}))); using SmemLayoutDivideVt = decltype(tiled_divide(SmemLayoutVtTrans{}, TransposeShapeAtomV{})); #ifndef NO_FP8_COLUMN_PERMUTE using SmemShapeSTSM = Shape, Shape<_8, _8>>; #else using SmemShapeSTSM = Shape, Shape<_16, _4>>; #endif using FactoringShapeVt = decltype(make_shape(SmemShapeSTSM{}, shape<1>(SmemLayoutDivideVt{}), shape<2>(SmemLayoutDivideVt{}), shape<3>(SmemLayoutDivideVt{}))); using SmemLayoutTransposeVt = decltype(composition(SmemLayoutDivideVt{}, make_layout(FactoringShapeVt{}))); using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{}))); // for smem -> gmem O copy using TileShapeOCopy = TileShapeQCopy; using SmemLayoutOCopy = std::conditional_t<(kBlockH > 1), decltype(composition(SmemLayoutO{}, FactoringLayoutQ{})), SmemLayoutO>; // used for rmem -> smem O copy in fp8 kernel to undo column permutation using ThreadLayoutrO = Layout, _4, _1>, Stride<_4, _32, _1, _0>>; using ValueLayoutrO = Layout, Int>, Stride<_0, _2, Stride<_4, _1>, _8>>; using TiledCopyrO = decltype(make_tiled_copy(Copy_Atom, OutputType>{}, ThreadLayoutrO{}, ValueLayoutrO{})); using TiledCopyShaperO = Shape<_8, Int, _16, Int>; using SmemLayoutrO = decltype(composition(SmemLayoutO{}, Layout{})); using SmemCopyAtomQ = Copy_Atom; using SharedStorage = std::conditional_t, SharedStorageQKVOVtaccum>, SharedStorageQKVVt>; using MainloopPipeline = typename cutlass::PipelineTmaAsync; using MainloopPipelineNoTMA = typename cutlass::PipelineAsync; using PipelineState = typename cutlass::PipelineState; // using BarrierType = typename MainloopPipeline::ProducerBarrierType; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct SharedStorageQKVdOdKV; template struct SharedStorageQKVdOdKV { struct { cute::array_aligned> smem_q; cute::array_aligned> smem_do; union { struct { cute::array_aligned> smem_k; cute::array_aligned> smem_v; }; struct { cute::array_aligned> smem_dk; cute::array_aligned> smem_dv; }; }; cute::array_aligned> smem_p; cute::array_aligned> smem_ds; }; struct { cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage. cutlass::arch::ClusterTransactionBarrier barrier_K; cutlass::arch::ClusterTransactionBarrier barrier_V; typename cutlass::PipelineTmaAsync::SharedStorage pipeline_q; typename cutlass::PipelineTmaAsync::SharedStorage pipeline_do; }; }; template struct SharedStorageQKVdOdKV { struct { cute::array_aligned> smem_q; cute::array_aligned> smem_do; union { struct { cute::array_aligned> smem_k; cute::array_aligned> smem_v; }; struct { cute::array_aligned> smem_dk; cute::array_aligned> smem_dv; }; }; union { // Put smem_p in a union just so we can still refer to it in the struct, even if it's not used. cute::array_aligned> smem_p; cute::array_aligned> smem_ds; }; }; struct { cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage. cutlass::arch::ClusterTransactionBarrier barrier_K; cutlass::arch::ClusterTransactionBarrier barrier_V; typename cutlass::PipelineTmaAsync::SharedStorage pipeline_q; typename cutlass::PipelineTmaAsync::SharedStorage pipeline_do; }; }; template struct SharedStorageQKVdOdKVWS; template struct SharedStorageQKVdOdKVWS { struct { cute::array_aligned> smem_q; cute::array_aligned> smem_do; union { struct { cute::array_aligned> smem_k; cute::array_aligned> smem_v; }; struct { cute::array_aligned> smem_dk; cute::array_aligned> smem_dv; }; }; cute::array_aligned> smem_p; cute::array_aligned> smem_ds; cute::array_aligned> smem_dqacc; cute::array_aligned smem_lse; cute::array_aligned smem_dpsum; }; struct { cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage. cutlass::arch::ClusterTransactionBarrier barrier_K; cutlass::arch::ClusterTransactionBarrier barrier_V; typename cutlass::PipelineTmaAsync::SharedStorage pipeline_q; typename cutlass::PipelineTmaAsync::SharedStorage pipeline_do; }; }; template struct SharedStorageQKVdOdKVWS { struct { cute::array_aligned> smem_q; cute::array_aligned> smem_do; union { struct { cute::array_aligned> smem_k; cute::array_aligned> smem_v; }; struct { cute::array_aligned> smem_dk; cute::array_aligned> smem_dv; }; }; union { // Put smem_p in a union just so we can still refer to it in the struct, even if it's not used. cute::array_aligned> smem_p; cute::array_aligned> smem_ds; }; cute::array_aligned> smem_dqacc; cute::array_aligned smem_lse; cute::array_aligned smem_dpsum; }; struct { cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage. cutlass::arch::ClusterTransactionBarrier barrier_K; cutlass::arch::ClusterTransactionBarrier barrier_V; typename cutlass::PipelineTmaAsync::SharedStorage pipeline_q; typename cutlass::PipelineTmaAsync::SharedStorage pipeline_do; }; }; template struct SharedStorageQKVdOdKVSeqqPar; template struct SharedStorageQKVdOdKVSeqqPar { struct { cute::array_aligned> smem_k; cute::array_aligned> smem_v; union { struct { cute::array_aligned> smem_q; cute::array_aligned> smem_do; }; struct { cute::array_aligned> smem_dq; }; }; cute::array_aligned> smem_p; cute::array_aligned> smem_ds; }; struct { cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage. cutlass::arch::ClusterTransactionBarrier barrier_Q; cutlass::arch::ClusterTransactionBarrier barrier_dO; typename cutlass::PipelineTmaAsync::SharedStorage pipeline_k; typename cutlass::PipelineTmaAsync::SharedStorage pipeline_v; }; }; template struct SharedStorageQKVdOdKVSeqqPar { struct { cute::array_aligned> smem_k; cute::array_aligned> smem_v; union { struct { cute::array_aligned> smem_q; cute::array_aligned> smem_do; }; struct { cute::array_aligned> smem_dq; }; }; union { // Put smem_p in a union just so we can still refer to it in the struct, even if it's not used. cute::array_aligned> smem_p; cute::array_aligned> smem_ds; }; }; struct { cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage. cutlass::arch::ClusterTransactionBarrier barrier_Q; cutlass::arch::ClusterTransactionBarrier barrier_dO; typename cutlass::PipelineTmaAsync::SharedStorage pipeline_k; typename cutlass::PipelineTmaAsync::SharedStorage pipeline_v; }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Flash_bwd_kernel_traits { using Element = elem_type; using ElementAccum = float; using index_t = int64_t; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp; static constexpr int kNThreadsNonWS = 8 * cutlass::NumThreadsPerWarp; // static constexpr int kNThreadsdQ = cutlass::NumThreadsPerWarpGroup; static constexpr int kNThreadsdQ = 2 * cutlass::NumThreadsPerWarpGroup; static_assert(kNWarps_ == 8 || kNWarps_ == 12); static constexpr bool Is_WS = kNWarps_ >= 12; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static_assert(kHeadDim % 32 == 0); using TileShape_MNK = Shape, Int, Int>; static constexpr int kClusterN = kClusterN_; using ClusterShape_MNK = Shape<_1, Int, _1>; static constexpr int kStages = 2; static constexpr bool SdP_swapAB = SdP_swapAB_; static constexpr bool dKV_swapAB = dKV_swapAB_; static constexpr bool dQ_swapAB = dQ_swapAB_; static_assert(!(SdP_swapAB && dKV_swapAB)); // If SdP_swapAB, then we don't swap for dKV static constexpr bool Mma_dQ_is_RS = AtomLayoutMSdP == 2 && AtomLayoutMdQ == 2 && !SdP_swapAB && !dQ_swapAB; // If dQ_swapAB we can't use RS using TileShapeAtomSdP = std::conditional_t< !SdP_swapAB, Shape, Int, Int>, Shape, Int, Int> >; using AtomLayoutSdP = std::conditional_t< !SdP_swapAB, Layout, Int<2 / AtomLayoutMSdP>, _1>>, Layout, Int, _1>> >; using TiledMmaSdP = decltype(cute::make_tiled_mma( cute::GMMA::ss_op_selector(), AtomLayoutSdP{})); using TileShapeAtomdKV = std::conditional_t< !dKV_swapAB, Shape, Int, Int>, Shape, Int, Int> >; using AtomLayoutdKV = std::conditional_t< !dKV_swapAB, Layout, Int<2 / AtomLayoutNdKV>, _1>>, Layout, Int, _1>> >; using TiledMmadKV = decltype(cute::make_tiled_mma( std::conditional_t< !SdP_swapAB, decltype(cute::GMMA::ss_op_selector()), decltype(cute::GMMA::rs_op_selector()) >{}, AtomLayoutdKV{})); using TileShapeAtomdQ = std::conditional_t< !dQ_swapAB, Shape, Int, Int>, Shape, Int, Int> // Shape, Int, Int>, // Shape, Int, Int> >; using AtomLayoutdQ = std::conditional_t< !dQ_swapAB, Layout, Int<2 / AtomLayoutMdQ>, _1>>, Layout, Int, _1>> // Layout, Int<1>, _1>>, // Layout, Int<1>, _1>> >; static constexpr GMMA::Major MmadQMajorA = !dQ_swapAB ? GMMA::Major::K : GMMA::Major::MN; static constexpr GMMA::Major MmadQMajorB = !dQ_swapAB ? GMMA::Major::MN : GMMA::Major::K; using TiledMmadQ = decltype(cute::make_tiled_mma( std::conditional_t< !dQ_swapAB, std::conditional_t< Mma_dQ_is_RS, decltype(cute::GMMA::rs_op_selector()), decltype(cute::GMMA::ss_op_selector()) >, decltype(cute::GMMA::ss_op_selector()) >{}, AtomLayoutdQ{})); using GmemTiledCopyQdO = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); using GmemTiledCopyKV = cute::SM90_TMA_LOAD; using GmemTiledCopydKV = cute::SM90_TMA_STORE; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 static constexpr bool Has_cp_async = true; #else static constexpr bool Has_cp_async = false; #endif // For the dot_do_o preprocessing kernel using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, DefaultCopy >; static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem // to affect speed in practice. static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; static_assert(kNThreadsNonWS % kGmemThreadsPerRow == 0, "kNThreadsNonWS must be a multiple of kGmemThreadsPerRow"); using GmemLayoutAtom = Layout, Int>, Stride, _1>>; using GmemLayoutAtomdQ = Layout, Int>, Stride, _1>>; using GmemTiledCopydO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemTiledCopydQ = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomdQ{}, Layout>{})); // Val layout, 8 vals per store using GmemLayoutAtomdQaccum = std::conditional_t< kBlockKSmem == 32, Layout, _8>, // Thread layout, 8 threads per row Stride< _8, _1>>, Layout, _16>, // Thread layout, 16 threads per row Stride< _16, _1>> >; using GmemTiledCopydQaccum = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomdQaccum{}, Layout>{})); // Val layout, 4 vals per store using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); using SmemLayoutdO = SmemLayoutQ; using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, select<1, 2>(TileShape_MNK{}))); using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{}, select<1, 2>(TileShape_MNK{}))); using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); using SmemLayoutP = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_MNK{}))); using SmemLayoutAtomdS = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); using SmemLayoutdS = decltype(tile_to_shape(SmemLayoutAtomdS{}, select<0, 1>(TileShape_MNK{}))); // using SmemLayoutAtomdQacc = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); // using SmemLayoutdQacc = decltype(tile_to_shape(SmemLayoutAtomdQacc{}, select<0, 2>(TileShape_MNK{}))); // Note this is the transpose in terms of the view, not in terms of memory. using SmemLayoutQt = decltype(cute::composition(SmemLayoutQ{}, make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int{}), make_stride(Int{}, _1{}, Int{})))); using SmemLayoutdOt = decltype(cute::composition(SmemLayoutdO{}, make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int{}), make_stride(Int{}, _1{}, Int{})))); using SmemLayoutKt = decltype(cute::composition(SmemLayoutK{}, make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})), make_stride(Int{}, _1{})))); using SmemLayoutPt = decltype(cute::composition(SmemLayoutP{}, make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})), make_stride(Int{}, _1{})))); using SmemLayoutdSt = decltype(cute::composition(SmemLayoutdS{}, make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})), make_stride(Int{}, _1{})))); // using SmemLayoutdQacct = // decltype(cute::composition(SmemLayoutdQacc{}, // make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})), // make_stride(Int{}, _1{})))); using SmemLayoutdK = SmemLayoutK; using SmemLayoutdV = SmemLayoutV; using SmemLayoutdKt = SmemLayoutKt; using SmemLayoutdVt = SmemLayoutKt; static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; using SmemLayoutAtomdQ = decltype( // composition(Swizzle{}, composition(Swizzle<3, 3, 3>{}, Layout, Int<32>>, Stride, _1>>{})); using SmemLayoutdQ = decltype(tile_to_shape( SmemLayoutAtomdQ{}, make_shape(Int{}, Int{}))); using SmemLayoutdQt = decltype(cute::composition(SmemLayoutdQ{}, make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})), make_stride(Int{}, _1{})))); static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element); using SmemLayoutAtomdQaccTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); using SmemLayoutdQaccTMA = decltype(tile_to_shape(SmemLayoutAtomdQaccTMA{}, select<0, 2>(TileShape_MNK{}))); using SmemLayoutdQacc = SmemLayoutdQ; using SmemLayoutdQacct = SmemLayoutdQt; using SmemLayoutdQacc2 = decltype(tile_to_shape( SmemLayoutAtomdQ{}, make_shape(Int{}, Int{}, _2{}))); // using SmemLayoutdQacc = decltype(tile_to_shape(SmemLayoutAtomdQacc{}, select<0, 2>(TileShape_MNK{}))); // using SmemLayoutdQacct = // decltype(cute::composition(SmemLayoutdQacc{}, // make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})), // make_stride(Int{}, _1{})))); using RmemTiledCopydQacc = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomdQaccum{}, Layout>{})); // Val layout, 4 vals per store // using SmemCopyAtomQ = Copy_Atom; using SmemCopyAtomPdS = Copy_Atom< std::conditional_t, Element>; using SmemCopyAtomdKV = Copy_Atom< std::conditional_t, Element>; using SmemCopyAtomdQ = Copy_Atom< std::conditional_t, Element>; using SharedStorage = std::conditional_t< !Is_WS, SharedStorageQKVdOdKV, SharedStorageQKVdOdKVWS // SmemLayoutK, SmemLayoutV, SmemLayoutdS, SmemLayoutdQacc2, SmemLayoutdK, SmemLayoutdV> >; // using MainloopPipeline = typename cutlass::PipelineTmaAsync; // using PipelineState = typename cutlass::PipelineState; using MainloopPipeline = typename cutlass::PipelineTmaAsync; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Flash_bwd_seqqpar_kernel_traits { using Element = elem_type; using ElementAccum = float; using index_t = int64_t; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp; static_assert(kNWarps_ == 8); static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static_assert(kHeadDim % 32 == 0); using TileShape_MNK = Shape, Int, Int>; static constexpr int kClusterN = kClusterN_; using ClusterShape_MNK = Shape<_1, Int, _1>; static constexpr int kStages = 2; static constexpr bool SdP_swapAB = SdP_swapAB_; static constexpr bool dKV_swapAB = dKV_swapAB_; static constexpr bool dQ_swapAB = dQ_swapAB_; static_assert(!(SdP_swapAB && dKV_swapAB)); // If SdP_swapAB, then we don't swap for dKV static constexpr bool Mma_dQ_is_RS = AtomLayoutMSdP == 2 && AtomLayoutMdQ == 2 && !SdP_swapAB && !dQ_swapAB; // If dQ_swapAB we can't use RS using TileShapeAtomSdP = std::conditional_t< !SdP_swapAB, Shape, Int, Int>, Shape, Int, Int> >; using AtomLayoutSdP = std::conditional_t< !SdP_swapAB, Layout, Int<2 / AtomLayoutMSdP>, _1>>, Layout, Int, _1>> >; using TiledMmaSdP = decltype(cute::make_tiled_mma( cute::GMMA::ss_op_selector(), AtomLayoutSdP{})); using TileShapeAtomdKV = std::conditional_t< !dKV_swapAB, Shape, Int, Int>, Shape, Int, Int> >; using AtomLayoutdKV = std::conditional_t< !dKV_swapAB, Layout, Int<2 / AtomLayoutNdKV>, _1>>, Layout, Int, _1>> >; using TiledMmadKV = decltype(cute::make_tiled_mma( std::conditional_t< !SdP_swapAB, decltype(cute::GMMA::ss_op_selector()), decltype(cute::GMMA::rs_op_selector()) >{}, AtomLayoutdKV{})); using TileShapeAtomdQ = std::conditional_t< !dQ_swapAB, Shape, Int, Int>, Shape, Int, Int> >; using AtomLayoutdQ = std::conditional_t< !dQ_swapAB, Layout, Int<2 / AtomLayoutMdQ>, _1>>, Layout, Int, _1>> >; static constexpr GMMA::Major MmadQMajorA = !dQ_swapAB ? GMMA::Major::K : GMMA::Major::MN; static constexpr GMMA::Major MmadQMajorB = !dQ_swapAB ? GMMA::Major::MN : GMMA::Major::K; using TiledMmadQ = decltype(cute::make_tiled_mma( std::conditional_t< !dQ_swapAB, std::conditional_t< Mma_dQ_is_RS, decltype(cute::GMMA::rs_op_selector()), decltype(cute::GMMA::ss_op_selector()) >, decltype(cute::GMMA::ss_op_selector()) >{}, AtomLayoutdQ{})); using GmemTiledCopyQdO = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); using GmemTiledCopyKV = cute::SM90_TMA_LOAD; using GmemTiledCopydKV = cute::SM90_TMA_STORE; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 static constexpr bool Has_cp_async = true; #else static constexpr bool Has_cp_async = false; #endif // For the dot_do_o preprocessing kernel using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, DefaultCopy >; static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem // to affect speed in practice. static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); using GmemLayoutAtom = Layout, Int>, Stride, _1>>; using GmemTiledCopydO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemTiledCopydQ = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemLayoutAtomdQaccum = std::conditional_t< kBlockKSmem == 32, Layout, // Thread layout, 8 threads per row Stride< _8, _1>>, Layout, // Thread layout, 16 threads per row Stride< _16, _1>> >; using GmemTiledCopydQaccum = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomdQaccum{}, Layout>{})); // Val layout, 4 vals per store using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{}))); using SmemLayoutdO = SmemLayoutQ; using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{}, make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); using SmemLayoutP = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_MNK{}))); using SmemLayoutAtomdS = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); using SmemLayoutdS = decltype(tile_to_shape(SmemLayoutAtomdS{}, select<0, 1>(TileShape_MNK{}))); // Note this is the transpose in terms of the view, not in terms of memory. using SmemLayoutQt = decltype(cute::composition(SmemLayoutQ{}, make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})), make_stride(Int{}, _1{})))); using SmemLayoutdOt = decltype(cute::composition(SmemLayoutdO{}, make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})), make_stride(Int{}, _1{})))); using SmemLayoutKt = decltype(cute::composition(SmemLayoutK{}, make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int{}), make_stride(Int{}, _1{}, Int{})))); using SmemLayoutPt = decltype(cute::composition(SmemLayoutP{}, make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})), make_stride(Int{}, _1{})))); using SmemLayoutdSt = decltype(cute::composition(SmemLayoutdS{}, make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})), make_stride(Int{}, _1{})))); using SmemLayoutdK = decltype(tile_to_shape(SmemLayoutAtomK{}, select<1, 2>(TileShape_MNK{}))); using SmemLayoutdV = SmemLayoutdK; using SmemLayoutdKt = SmemLayoutKt; using SmemLayoutdVt = SmemLayoutKt; using SmemLayoutdQTMA = decltype(tile_to_shape(SmemLayoutAtomK{}, select<0, 2>(TileShape_MNK{}))); static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; using SmemLayoutAtomdQ = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); using SmemLayoutdQ = decltype(tile_to_shape( SmemLayoutAtomdQ{}, make_shape(Int{}, Int{}))); using SmemLayoutdQt = decltype(cute::composition(SmemLayoutdQ{}, make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})), make_stride(Int{}, _1{})))); static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element); using SmemLayoutAtomdKV = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); using SmemLayoutdKV = decltype(tile_to_shape( SmemLayoutAtomdKV{}, make_shape(Int{}, Int{}))); using SmemLayoutdKVt = decltype(cute::composition(SmemLayoutdKV{}, make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})), make_stride(Int{}, _1{})))); static constexpr int kSmemdKVSize = size(SmemLayoutdKV{}) * sizeof(Element) * 2; // using SmemCopyAtomQ = Copy_Atom; using SmemCopyAtomPdS = Copy_Atom< std::conditional_t, Element>; using SmemCopyAtomdKV = Copy_Atom< std::conditional_t, Element>; using SmemCopyAtomdQ = Copy_Atom< std::conditional_t, Element>; using SharedStorage = SharedStorageQKVdOdKVSeqqPar; // using MainloopPipeline = typename cutlass::PipelineTmaAsync; // using PipelineState = typename cutlass::PipelineState; using MainloopPipeline = typename cutlass::PipelineTmaAsync; }; ////////////////////////////////////////////////////////////////////////////////////////////////////