/****************************************************************************** * Copyright (c) 2024, Tri Dao. ******************************************************************************/ #pragma once #include #include #include #include #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #include #endif #include #include // For cute::elect_one_sync() #include #include #include #include #define CHECK_CUDA(call) \ do { \ cudaError_t status_ = call; \ if (status_ != cudaSuccess) { \ fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ exit(1); \ } \ } while(0) #define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError()) namespace flash { using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct MaxOp { __device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; } }; template <> struct MaxOp { // This is slightly faster __device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct SumOp { __device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Allreduce { static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); template static __device__ __forceinline__ T run(T x, Operator &op) { constexpr int OFFSET = THREADS / 2; x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); return Allreduce::run(x, op); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template<> struct Allreduce<2> { template static __device__ __forceinline__ T run(T x, Operator &op) { x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); return x; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// // For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) // For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) template __forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90 static_assert(decltype(size<0, 0>(acc_layout))::value == 2); static_assert(decltype(size<0, 1>(acc_layout))::value == 2); static_assert(decltype(rank(acc_layout))::value == 3); auto l = acc_layout; return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l))); } else { // SM80 static_assert(decltype(size<0>(acc_layout))::value == 4); static_assert(decltype(rank(acc_layout))::value == 3); auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// // For SM90, convert acc_layout from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N)) template __forceinline__ __device__ auto convert_layout_acc_transposed_rowcol(Layout acc_layout) { static_assert(decltype(size<0, 0>(acc_layout))::value == 2); static_assert(decltype(size<0, 1>(acc_layout))::value == 2); static_assert(decltype(rank(acc_layout))::value == 3); auto l = acc_layout; return make_layout(make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l))); }; //////////////////////////////////////////////////////////////////////////////////////////////////// // For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8. // For SM90, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N)) template __forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) { using X = Underscore; if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90 static_assert(decltype(size<0, 0>(acc_layout))::value == 2); static_assert(decltype(size<0, 1>(acc_layout))::value == 2); static_assert(decltype(rank(acc_layout))::value == 3); static_assert(decltype(rank(get<0>(acc_layout)))::value == 3); auto l = logical_divide(get<0>(acc_layout), Shape{}); // (2, 2, (2, N / 16))) return make_layout(make_layout(get<0>(l), get<1>(l), get<2, 0>(l)), get<1>(acc_layout), make_layout(get<2, 1>(l), get<2>(acc_layout))); } else { // SM80 static_assert(decltype(size<0>(acc_layout))::value == 4); static_assert(decltype(rank(acc_layout))::value == 3); constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{}); static_assert(mma_shape_K == 8 || mma_shape_K == 16); if constexpr (mma_shape_K == 8) { return acc_layout; } else { auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); } } }; // Convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((4, 2, 2), MMA_M, (N / 32, MMA_N)) template __forceinline__ __device__ auto convert_layout_acc_Aregs_fp8(Layout acc_layout) { using X = Underscore; static_assert(decltype(size<0, 0>(acc_layout))::value == 2); static_assert(decltype(size<0, 1>(acc_layout))::value == 2); static_assert(decltype(rank(acc_layout))::value == 3); static_assert(decltype(rank(get<0>(acc_layout)))::value == 3); auto l = logical_divide(get<0>(acc_layout), Shape{}); // (2, 2, (2, N / 32))) return make_layout(make_layout(Shape<_4, _2, _2>{}), get<1>(acc_layout), make_layout(get<2, 1>(l), get<2>(acc_layout))); }; //////////////////////////////////////////////////////////////////////////////////////////////////// // Byte permute for fp8 kernel template CUTLASS_DEVICE void permute_regs_A_to_C(Fragment &accum) { auto data = accum.data(); #pragma unroll for (int n = 0; n < size(accum); n += 8) { uint32_t *data_32bit = reinterpret_cast(&data[n]); auto upper = data_32bit[0]; auto lower = data_32bit[1]; data_32bit[0] = __byte_perm(upper, lower, 0x5410); data_32bit[1] = __byte_perm(upper, lower, 0x7632); } } //////////////////////////////////////////////////////////////////////////////////////////////////// template __forceinline__ __device__ auto convert_type(Tensor const &tensor) { using From_type = typename Engine::value_type; constexpr int numel = decltype(size(tensor))::value; cutlass::NumericArrayConverter convert_op; // HACK: this requires tensor to be "contiguous" auto frag = convert_op(*reinterpret_cast *>(tensor.data())); return make_tensor(make_rmem_ptr(&frag), tensor.layout()); // Tensor out = make_tensor_like(tensor); // cute::copy(make_tensor(make_rmem_ptr(&frag), tensor.layout()), out); // return out; } //////////////////////////////////////////////////////////////////////////////////////////////////// template __forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) { constexpr bool Is_RS = !cute::is_base_of::value; // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const if constexpr (Is_RS) { warpgroup_fence_operand(const_cast(tCrA)); } warpgroup_fence_operand(tCrC); if constexpr (arrive) { warpgroup_arrive(); } if constexpr (zero_init) { tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); tiled_mma.accumulate_ = GMMA::ScaleOut::One; } } else { // cute::gemm(tiled_mma, tCrA, tCrB, tCrC); // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); tiled_mma.accumulate_ = GMMA::ScaleOut::One; } } if constexpr (commit) { warpgroup_commit_batch(); } if constexpr (wg_wait >= 0) { warpgroup_wait(); } warpgroup_fence_operand(tCrC); if constexpr (Is_RS) { warpgroup_fence_operand(const_cast(tCrA)); } } //////////////////////////////////////////////////////////////////////////////////////////////////// template __forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor const &S, Tensor &D, Tensor const &identity_MN, Tensor const &predicate_K, const int max_MN=0) { CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K // There's no case where !Clear_OOB_K && Clear_OOB_MN static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); #pragma unroll for (int m = 0; m < size<1>(S); ++m) { if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { #pragma unroll for (int k = 0; k < size<2>(S); ++k) { if (Is_even_K || predicate_K(k)) { cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); } else if (Clear_OOB_K) { cute::clear(D(_, m, k)); } } } else if (Clear_OOB_MN) { cute::clear(D(_, m, _)); } } } //////////////////////////////////////////////////////////////////////////////////////////////////// template __forceinline__ __device__ void write_tma( ElemO* O, const TMACopyO& tma_store_O, const LayoutO& layout_O, const TileShapeO& tile_shape_O, const SMemO& sO, int m_block, int bidh, int bidb, int n_split_idx, const SeqLenTraits& seqlen_traits_o, int write_warp_idx) { Tensor mO = tma_store_O.get_tma_tensor(layout_O.shape()); Tensor gO = seqlen_traits_o.get_o_local_tile_tensor( mO, tile_shape_O, bidh, bidb, n_split_idx )(_, _, m_block); // (M, K) auto block_tma_O = tma_store_O.get_slice(_0{}); Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K) Tensor tOsO = block_tma_O.partition_S(sO); // (TMA, TMA_M, TMA_K) int const lane_predicate = cute::elect_one_sync(); int const warp_idx = cutlass::canonical_warp_idx_sync(); if (warp_idx == write_warp_idx && lane_predicate) { cute::copy(tma_store_O, tOsO, tOgO); tma_store_arrive(); } // Note: no wait here. // tma_store_wait<0>(); } // Epilogue that copies RMEM -> GMEM directly for GQA enabled. // Reports as uncoalesced stores by the profiler template __forceinline__ __device__ void write_rmem_to_gmem( TensorO &tOrO, OutputType *O, const LayoutO& layout_O, TileShapeO tile_shape_O, int m_block, int h_block, int bidh, int bidh_kv, int bidb, int n_split_idx, TiledMma& tiled_mma, const SeqLenTraits& seqlen_traits_o, int thread_idx) { static_assert(is_same_v, "rmem dtype must be float"); Tensor mO = make_tensor(make_gmem_ptr(O), layout_O); Tensor gO = [&] { if constexpr(Use_gqa_layout) { return seqlen_traits_o.get_o_local_tile_tensor( mO, tile_shape_O, bidh_kv, bidb, n_split_idx )(_, _, _, m_block, h_block); // (bM/bH, bH, K) } else { return seqlen_traits_o.get_o_local_tile_tensor( mO, tile_shape_O, bidh, bidb, n_split_idx )(_, _, m_block); // (bM, bK) } }(); auto thread_mma = tiled_mma.get_thread_slice(thread_idx); auto tile_shape_mnk = cute::tile_shape(tiled_mma); Tensor cO = cute::make_identity_tensor(select<0, 1>(tile_shape_mnk)); Tensor tOcO = thread_mma.partition_C(cO); // tOcO has shape ((2, 2, V), MMA_M, MMA_N), we only take only the row indices. Tensor tOcO_row = tOcO(make_coord(_0{}, _, _0{}), _, _0{}); // reshape from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) Tensor tOrO_rowcol = make_tensor(tOrO.data(), flash::convert_layout_acc_rowcol(tOrO.layout())); const int m_bound = seqlen_traits_o.actual_seq_len - m_block * size<0>(gO); // hardcoded col_idx to circumvent reg spilling with counting tensor const int col_start_idx = !Column_permute_fp8 ? 2 * (thread_idx % 4) : 4 * (thread_idx % 4); if constexpr (Use_gqa_layout) { static constexpr int kBlockH = size<1>(gO); const int h_bound = shape<1>(layout_O) - h_block * kBlockH; #pragma unroll for(int nrow = 0; nrow < size<0>(tOrO_rowcol); ++nrow) { const int row = int(get<0>(tOcO_row(nrow))); const int h_local = row % kBlockH; const int m_local = row / kBlockH; if(h_local < h_bound && m_local < m_bound) { if constexpr(!Column_permute_fp8) { Tensor tOrO_nrow_float2 = recast(tOrO_rowcol(nrow, _)); #pragma unroll for (int ncol = 0; ncol < size<1>(tOrO_rowcol)/2; ++ncol) { *reinterpret_cast(&(gO(m_local, h_local, col_start_idx + 8 * ncol))) = tOrO_nrow_float2(ncol); } } else { Tensor tOrO_nrow = tOrO_rowcol(nrow, _); #pragma unroll for (int ncol = 0; ncol < size<1>(tOrO_rowcol); ncol += 4) { gO(m_local, h_local, col_start_idx + 4 * ncol) = tOrO_nrow(ncol); gO(m_local, h_local, col_start_idx + 4 * ncol + 2) = tOrO_nrow(ncol + 1); gO(m_local, h_local, col_start_idx + 4 * ncol + 1) = tOrO_nrow(ncol + 2); gO(m_local, h_local, col_start_idx + 4 * ncol + 3) = tOrO_nrow(ncol + 3); } } } } } else { #pragma unroll for(int nrow = 0; nrow < size<0>(tOrO_rowcol); ++nrow) { const int row = int(get<0>(tOcO_row(nrow))); if(row < m_bound) { if constexpr(!Column_permute_fp8) { Tensor tOrO_nrow_float2 = recast(tOrO_rowcol(nrow, _)); #pragma unroll for (int ncol = 0; ncol < size<1>(tOrO_rowcol)/2; ++ncol) { *reinterpret_cast(&(gO(row, col_start_idx + 8 * ncol))) = tOrO_nrow_float2(ncol); } } else { Tensor tOrO_nrow = tOrO_rowcol(nrow, _); #pragma unroll for (int ncol = 0; ncol < size<1>(tOrO_rowcol); ncol += 4) { gO(row, col_start_idx + 4 * ncol) = tOrO_nrow(ncol); gO(row, col_start_idx + 4 * ncol + 2) = tOrO_nrow(ncol + 1); gO(row, col_start_idx + 4 * ncol + 1) = tOrO_nrow(ncol + 2); gO(row, col_start_idx + 4 * ncol + 3) = tOrO_nrow(ncol + 3); } } } } } } template __forceinline__ __device__ void write_tiled( ElemO* O, const TiledCopyO& tiled_copy_O, const LayoutO& layout_O, const TileShapeO& tile_shape_O, const SMemO& sO, int m_block, int bidh, int bidb, const SeqLenTraits& seqlen_traits_o) { Tensor mO = make_tensor(make_gmem_ptr(O), layout_O); Tensor gO = seqlen_traits_o.get_local_tile_tensor( mO, tile_shape_O, bidh, bidb )(_, _, m_block); // (M, K) ThrCopy thr_copy_O = tiled_copy_O.get_slice(threadIdx.x - NumCopyThreads); Tensor tOgO = thr_copy_O.partition_D(gO); // (CPY,CPY_M,CPY_K,k) Tensor tOsO = thr_copy_O.partition_S(sO); // (CPY,CPY_M,CPY_K) // Prepare for TiledCopy. // Grouping is needed because cute::copy_if() does group_modes<1, R> for src and dst. // After grouping, the first dim is number of elements to read together. Tensor tOsOFlatten = cute::flatten(tOsO); Tensor tOsOGroup = cute::group_modes<1, rank(tOsOFlatten)>(tOsOFlatten); Tensor tOgOFlatten = cute::flatten(tOgO); Tensor tOgOGroup = cute::group_modes<1, rank(tOgOFlatten)>(tOgOFlatten); // Get thread coords to global index mapping. Tensor gOCounting = cute::make_identity_tensor(gO.shape()); Tensor tSgOCounting = thr_copy_O.partition_D(gOCounting); Tensor tSgOCountingFlatten = cute::flatten(tSgOCounting); Tensor tSgOCountingGrouped = cute::group_modes<1, rank(tSgOCountingFlatten)>(tSgOCountingFlatten); // Write out to GMEM. const int kNumMsPerTile = get<0>(tile_shape_O); int cta_m = std::min( seqlen_traits_o.actual_seq_len - m_block * kNumMsPerTile, kNumMsPerTile ); if (cta_m == kNumMsPerTile) { copy(tiled_copy_O, tOsOGroup, tOgOGroup); } else { auto predicate_fn = [&](auto coords) { auto s_coords = tSgOCountingGrouped(_0{}, coords); return elem_less(get<0>(s_coords), cta_m); }; copy_if(tiled_copy_O, predicate_fn, tOsOGroup, tOgOGroup); } } template __forceinline__ __device__ void write_O( ElemO* O, const TMACopyO& tma_copy_O, const TiledCopyO& tiled_copy_O, const LayoutO& layout_O, const TileShapeO& tile_shape_O, const SMemO& sO, int m_block, int bidh, int bidb, int n_split_idx, const SeqLenTraits& seqlen_traits_o, int write_warp_idx, TiledMma & tiledMma1, TensorO & tOrO) { if constexpr (IsRegToGmem) { static_assert(Is_split, "use write_rmem_to_gmem with split kv kernel only"); write_rmem_to_gmem(tOrO, O, layout_O, tile_shape_O, m_block, bidh, bidb, n_split_idx, tiledMma1, seqlen_traits_o, threadIdx.x - NumCopyThreads); } else if constexpr (IsTMACopy) { write_tma(O, tma_copy_O, layout_O, tile_shape_O, sO, m_block, bidh, bidb, n_split_idx, seqlen_traits_o, write_warp_idx); } else { static_assert(!Is_split, "Don't use write_tiled with split kv kernel"); write_tiled(O, tiled_copy_O, layout_O, tile_shape_O, sO, m_block, bidh, bidb, seqlen_traits_o); } } //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace flash