|
@@ -17,20 +17,15 @@ namespace flash {
|
|
using namespace cute;
|
|
using namespace cute;
|
|
|
|
|
|
// template <int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename Element_>
|
|
// template <int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename Element_>
|
|
-template <typename Ktraits>
|
|
|
|
|
|
+template <typename Ktraits, typename Seqlen_traits>
|
|
struct CollectiveEpilogueFwd {
|
|
struct CollectiveEpilogueFwd {
|
|
|
|
|
|
using Element = typename Ktraits::Element;
|
|
using Element = typename Ktraits::Element;
|
|
static constexpr int kBlockM = Ktraits::kBlockM;
|
|
static constexpr int kBlockM = Ktraits::kBlockM;
|
|
static constexpr int kBlockN = Ktraits::kBlockN;
|
|
static constexpr int kBlockN = Ktraits::kBlockN;
|
|
static constexpr int kHeadDim = Ktraits::kHeadDim;
|
|
static constexpr int kHeadDim = Ktraits::kHeadDim;
|
|
- // using Element = Element_;
|
|
|
|
- // static constexpr int kBlockM = kBlockM_;
|
|
|
|
- // static constexpr int kBlockN = kBlockN_;
|
|
|
|
- // static constexpr int kHeadDim = kHeadDim_;
|
|
|
|
using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
|
|
using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
|
|
|
|
|
|
- // static constexpr int kNWarps = kNWarps_;
|
|
|
|
static constexpr int kNWarps = Ktraits::kNWarps;
|
|
static constexpr int kNWarps = Ktraits::kNWarps;
|
|
static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
|
|
static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
|
|
static constexpr bool Is_WS = kNWarps >= 12;
|
|
static constexpr bool Is_WS = kNWarps >= 12;
|
|
@@ -38,20 +33,6 @@ struct CollectiveEpilogueFwd {
|
|
static constexpr int NumCopyThreads = !Is_WS ? 0 : cutlass::NumThreadsPerWarpGroup;
|
|
static constexpr int NumCopyThreads = !Is_WS ? 0 : cutlass::NumThreadsPerWarpGroup;
|
|
static constexpr int NumMmaThreads = kNThreads - NumCopyThreads;
|
|
static constexpr int NumMmaThreads = kNThreads - NumCopyThreads;
|
|
|
|
|
|
- using GmemTiledCopyOTMA = cute::SM90_TMA_STORE;
|
|
|
|
-
|
|
|
|
- // These are for storing the output tensor without TMA (e.g., for setting output to zero)
|
|
|
|
- static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
|
|
|
|
- static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
|
|
|
|
- static constexpr int kGmemThreadsPerRow = kHeadDim / kGmemElemsPerLoad;
|
|
|
|
- static_assert(NumMmaThreads % kGmemThreadsPerRow == 0, "NumMmaThreads must be a multiple of kGmemThreadsPerRow");
|
|
|
|
- using GmemLayoutAtom = Layout<Shape <Int<NumMmaThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
|
|
|
- Stride<Int<kGmemThreadsPerRow>, _1>>;
|
|
|
|
- using GmemTiledCopyO = decltype(
|
|
|
|
- make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
|
|
|
- GmemLayoutAtom{},
|
|
|
|
- Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per store
|
|
|
|
-
|
|
|
|
using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
|
|
using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
|
|
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
|
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
|
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
|
|
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
|
|
@@ -59,52 +40,72 @@ struct CollectiveEpilogueFwd {
|
|
using SmemCopyAtomO = Copy_Atom<cute::SM90_U32x4_STSM_N, Element>;
|
|
using SmemCopyAtomO = Copy_Atom<cute::SM90_U32x4_STSM_N, Element>;
|
|
using SharedStorage = cute::array_aligned<Element, cute::cosize_v<SmemLayoutO>>;
|
|
using SharedStorage = cute::array_aligned<Element, cute::cosize_v<SmemLayoutO>>;
|
|
|
|
|
|
- using ShapeO = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen_q, d, head, batch)
|
|
|
|
- using StrideO = cute::Stride<int64_t, _1, int64_t, int64_t>;
|
|
|
|
- using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen_q, head, batch)
|
|
|
|
-
|
|
|
|
|
|
+ using GmemTiledCopyOTMA = cute::SM90_TMA_STORE;
|
|
using TMA_O = decltype(make_tma_copy(
|
|
using TMA_O = decltype(make_tma_copy(
|
|
GmemTiledCopyOTMA{},
|
|
GmemTiledCopyOTMA{},
|
|
- make_tensor(make_gmem_ptr(static_cast<Element*>(nullptr)), repeat_like(StrideO{}, int32_t(0)), StrideO{}),
|
|
|
|
|
|
+ make_tensor(
|
|
|
|
+ make_gmem_ptr(static_cast<Element*>(nullptr)),
|
|
|
|
+ typename Seqlen_traits::ShapeT{},
|
|
|
|
+ typename Seqlen_traits::StrideT{}
|
|
|
|
+ ),
|
|
SmemLayoutO{},
|
|
SmemLayoutO{},
|
|
select<0, 2>(TileShape_MNK{}),
|
|
select<0, 2>(TileShape_MNK{}),
|
|
_1{})); // no mcast for O
|
|
_1{})); // no mcast for O
|
|
|
|
|
|
|
|
+ // These are for storing the output tensor without TMA (e.g., for setting output to zero and var-seq-len)
|
|
|
|
+ static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v<Element>);
|
|
|
|
+ static_assert(kHeadDim % kNumVecElem == 0);
|
|
|
|
+ static constexpr int kNumThreadsPerRow = kHeadDim / kNumVecElem;
|
|
|
|
+ static_assert(NumMmaThreads % kNumThreadsPerRow == 0);
|
|
|
|
+ static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow;
|
|
|
|
+ using TiledCopyOAtom = cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, Element>;
|
|
|
|
+ using TiledCopyOThrLayout = decltype(cute::make_layout(
|
|
|
|
+ cute::make_shape(Int<kNumRows>{}, Int<kNumThreadsPerRow>{}),
|
|
|
|
+ LayoutRight{}));
|
|
|
|
+ using TiledCopyOValLayout = decltype(cute::make_layout(
|
|
|
|
+ cute::make_shape(_1{}, Int<kNumVecElem>{}),
|
|
|
|
+ LayoutRight{}));
|
|
|
|
+ using TiledCopyO = decltype(make_tiled_copy(
|
|
|
|
+ TiledCopyOAtom{},
|
|
|
|
+ TiledCopyOThrLayout{}, // Thr layout
|
|
|
|
+ TiledCopyOValLayout{} // Val layout
|
|
|
|
+ ));
|
|
|
|
+
|
|
// Host side kernel arguments
|
|
// Host side kernel arguments
|
|
struct Arguments {
|
|
struct Arguments {
|
|
Element* ptr_O;
|
|
Element* ptr_O;
|
|
- ShapeO const shape_O;
|
|
|
|
- StrideO const stride_O;
|
|
|
|
|
|
+ typename Seqlen_traits::LayoutT const layout_O;
|
|
float* ptr_LSE;
|
|
float* ptr_LSE;
|
|
- StrideLSE const stride_LSE;
|
|
|
|
|
|
+ typename Seqlen_traits::LayoutLseT const layout_LSE;
|
|
};
|
|
};
|
|
|
|
|
|
// Device side kernel params
|
|
// Device side kernel params
|
|
struct Params {
|
|
struct Params {
|
|
Element* ptr_O;
|
|
Element* ptr_O;
|
|
- ShapeO const shape_O;
|
|
|
|
- StrideO const stride_O;
|
|
|
|
|
|
+ typename Seqlen_traits::LayoutT const layout_O;
|
|
float* ptr_LSE;
|
|
float* ptr_LSE;
|
|
- StrideLSE const stride_LSE;
|
|
|
|
|
|
+ typename Seqlen_traits::LayoutLseT const layout_LSE;
|
|
TMA_O tma_store_O;
|
|
TMA_O tma_store_O;
|
|
};
|
|
};
|
|
|
|
|
|
static Params
|
|
static Params
|
|
to_underlying_arguments(Arguments const& args) {
|
|
to_underlying_arguments(Arguments const& args) {
|
|
- Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.shape_O, args.stride_O);
|
|
|
|
|
|
+ Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.layout_O);
|
|
TMA_O tma_store_O = make_tma_copy(
|
|
TMA_O tma_store_O = make_tma_copy(
|
|
GmemTiledCopyOTMA{},
|
|
GmemTiledCopyOTMA{},
|
|
mO,
|
|
mO,
|
|
SmemLayoutO{},
|
|
SmemLayoutO{},
|
|
select<0, 2>(TileShape_MNK{}),
|
|
select<0, 2>(TileShape_MNK{}),
|
|
_1{}); // no mcast for O
|
|
_1{}); // no mcast for O
|
|
- return {args.ptr_O, args.shape_O, args.stride_O, args.ptr_LSE, args.stride_LSE, tma_store_O};
|
|
|
|
|
|
+ return {args.ptr_O, args.layout_O, args.ptr_LSE, args.layout_LSE, tma_store_O};
|
|
}
|
|
}
|
|
|
|
|
|
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
|
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
|
CUTLASS_DEVICE
|
|
CUTLASS_DEVICE
|
|
static void prefetch_tma_descriptors(Params const& epilogue_params) {
|
|
static void prefetch_tma_descriptors(Params const& epilogue_params) {
|
|
- cute::prefetch_tma_descriptor(epilogue_params.tma_store_O.get_tma_descriptor());
|
|
|
|
|
|
+ if constexpr (!Seqlen_traits::kUseVarSeqLen) {
|
|
|
|
+ cute::prefetch_tma_descriptor(epilogue_params.tma_store_O.get_tma_descriptor());
|
|
|
|
+ }
|
|
}
|
|
}
|
|
|
|
|
|
template <typename SharedStorage, typename FrgTensorO, typename FrgTensorLSE, typename TiledMma>
|
|
template <typename SharedStorage, typename FrgTensorO, typename FrgTensorLSE, typename TiledMma>
|
|
@@ -115,7 +116,8 @@ struct CollectiveEpilogueFwd {
|
|
SharedStorage& shared_storage,
|
|
SharedStorage& shared_storage,
|
|
TiledMma tiled_mma,
|
|
TiledMma tiled_mma,
|
|
int thread_idx,
|
|
int thread_idx,
|
|
- cute::tuple<int32_t, int32_t, int32_t> const& block_coord
|
|
|
|
|
|
+ cute::tuple<int32_t, int32_t, int32_t> const& block_coord,
|
|
|
|
+ const Seqlen_traits& seqlen_traits_q
|
|
) {
|
|
) {
|
|
|
|
|
|
auto [m_block, bidh, bidb] = block_coord;
|
|
auto [m_block, bidh, bidb] = block_coord;
|
|
@@ -134,16 +136,9 @@ struct CollectiveEpilogueFwd {
|
|
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp,
|
|
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp,
|
|
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
|
|
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
|
|
|
|
|
|
- Tensor mO = epilogue_params.tma_store_O.get_tma_tensor(epilogue_params.shape_O);
|
|
|
|
- Tensor gO = local_tile(mO(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K)
|
|
|
|
- auto block_tma_O = epilogue_params.tma_store_O.get_slice(_0{});
|
|
|
|
- Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K)
|
|
|
|
- Tensor tOsO = block_tma_O.partition_S(sO); // (TMA, TMA_M, TMA_K)
|
|
|
|
-
|
|
|
|
- auto shape_LSE = select<0, 2, 3>(epilogue_params.shape_O);
|
|
|
|
- Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), shape_LSE, epilogue_params.stride_LSE);
|
|
|
|
- Tensor gLSE = local_tile(mLSE(_, bidh, bidb), Shape<Int<kBlockM>>{}, make_coord(m_block));
|
|
|
|
-
|
|
|
|
|
|
+ Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), epilogue_params.layout_LSE);
|
|
|
|
+ Tensor gLSE = seqlen_traits_q.get_lse_local_tile_tensor(
|
|
|
|
+ mLSE, Shape<Int<kBlockM>>{}, bidh, bidb)(_, m_block);
|
|
Tensor caccO = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}));
|
|
Tensor caccO = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}));
|
|
auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
|
|
auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
|
|
Tensor taccOcO = thread_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
|
|
Tensor taccOcO = thread_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
|
|
@@ -156,19 +151,23 @@ struct CollectiveEpilogueFwd {
|
|
#pragma unroll
|
|
#pragma unroll
|
|
for (int mi = 0; mi < size(lse); ++mi) {
|
|
for (int mi = 0; mi < size(lse); ++mi) {
|
|
const int row = get<0>(taccOcO_row(mi));
|
|
const int row = get<0>(taccOcO_row(mi));
|
|
- if (row < get<0>(shape_LSE) - m_block * kBlockM) { gLSE(row) = lse(mi); }
|
|
|
|
|
|
+ if (row < seqlen_traits_q.actual_seq_len - m_block * kBlockM) { gLSE(row) = lse(mi); }
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
- if (cutlass::canonical_warp_idx_sync() == kNWarps - 1) {
|
|
|
|
- cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp,
|
|
|
|
- cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
|
|
|
|
- int const lane_predicate = cute::elect_one_sync();
|
|
|
|
- if (lane_predicate) {
|
|
|
|
- cute::copy(epilogue_params.tma_store_O, tOsO, tOgO);
|
|
|
|
- tma_store_arrive();
|
|
|
|
- }
|
|
|
|
|
|
+ int write_warp_idx = kNWarps - 1;
|
|
|
|
+ if (cutlass::canonical_warp_idx_sync() == write_warp_idx) {
|
|
|
|
+ cutlass::arch::NamedBarrier::sync(
|
|
|
|
+ NumMmaThreads + cutlass::NumThreadsPerWarp,
|
|
|
|
+ cutlass::arch::ReservedNamedBarriers::EpilogueBarrier
|
|
|
|
+ );
|
|
}
|
|
}
|
|
|
|
+ TiledCopyO gmem_tiled_copy_O;
|
|
|
|
+ flash::write_O<!Seqlen_traits::kUseVarSeqLen, NumCopyThreads>(
|
|
|
|
+ epilogue_params.ptr_O, epilogue_params.tma_store_O, gmem_tiled_copy_O,
|
|
|
|
+ epilogue_params.layout_O, select<0, 2>(TileShape_MNK{}), sO,
|
|
|
|
+ m_block, bidh, bidb, seqlen_traits_q, write_warp_idx
|
|
|
|
+ );
|
|
}
|
|
}
|
|
|
|
|
|
CUTLASS_DEVICE void
|
|
CUTLASS_DEVICE void
|
|
@@ -177,20 +176,25 @@ struct CollectiveEpilogueFwd {
|
|
}
|
|
}
|
|
|
|
|
|
// Write 0 to output and -inf to LSE
|
|
// Write 0 to output and -inf to LSE
|
|
|
|
+ template<typename SharedStorage>
|
|
CUTLASS_DEVICE void
|
|
CUTLASS_DEVICE void
|
|
store_zero(
|
|
store_zero(
|
|
- Params const& epilogue_params,
|
|
|
|
- int thread_idx,
|
|
|
|
- cute::tuple<int32_t, int32_t, int32_t> const& block_coord
|
|
|
|
- ) {
|
|
|
|
|
|
+ Params const& epilogue_params,
|
|
|
|
+ SharedStorage& shared_storage,
|
|
|
|
+ int thread_idx,
|
|
|
|
+ cute::tuple<int32_t, int32_t, int32_t> const& block_coord,
|
|
|
|
+ const Seqlen_traits& seqlen_traits_q
|
|
|
|
+ ) {
|
|
auto [m_block, bidh, bidb] = block_coord;
|
|
auto [m_block, bidh, bidb] = block_coord;
|
|
- Tensor mO = make_tensor(make_gmem_ptr(epilogue_params.ptr_O), epilogue_params.shape_O, epilogue_params.stride_O);
|
|
|
|
- Tensor gO = local_tile(mO(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K)
|
|
|
|
- auto shape_LSE = select<0, 2, 3>(epilogue_params.shape_O);
|
|
|
|
- Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), shape_LSE, epilogue_params.stride_LSE);
|
|
|
|
- Tensor gLSE = local_tile(mLSE(_, bidh, bidb), Shape<Int<kBlockM>>{}, make_coord(m_block));
|
|
|
|
-
|
|
|
|
- GmemTiledCopyO gmem_tiled_copy_O;
|
|
|
|
|
|
+ Tensor mO = make_tensor(make_gmem_ptr(epilogue_params.ptr_O), epilogue_params.layout_O);
|
|
|
|
+ Tensor gO = seqlen_traits_q.get_local_tile_tensor(
|
|
|
|
+ mO, select<0, 2>(TileShape_MNK{}), bidh, bidb
|
|
|
|
+ )(_, _, m_block); // (M, K)
|
|
|
|
+ Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), epilogue_params.layout_LSE);
|
|
|
|
+ Tensor gLSE = seqlen_traits_q.get_lse_local_tile_tensor(
|
|
|
|
+ mLSE, Shape<Int<kBlockM>>{}, bidh, bidb)(_, m_block);
|
|
|
|
+
|
|
|
|
+ TiledCopyO gmem_tiled_copy_O;
|
|
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
|
|
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
|
|
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
|
|
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
|
|
Tensor tOrO = make_fragment_like(tOgO);
|
|
Tensor tOrO = make_fragment_like(tOgO);
|
|
@@ -201,13 +205,13 @@ struct CollectiveEpilogueFwd {
|
|
Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
|
|
Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
|
|
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
|
|
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
|
|
#pragma unroll
|
|
#pragma unroll
|
|
- for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(epilogue_params.shape_O); }
|
|
|
|
|
|
+ for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(epilogue_params.layout_O.shape()); }
|
|
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
|
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
|
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
|
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
|
- gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, get<0>(epilogue_params.shape_O) - m_block * kBlockM
|
|
|
|
|
|
+ gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, get<0>(epilogue_params.layout_O.shape()) - m_block * kBlockM
|
|
);
|
|
);
|
|
static_assert(kBlockM <= NumMmaThreads);
|
|
static_assert(kBlockM <= NumMmaThreads);
|
|
- if (thread_idx < get<0>(shape_LSE) - m_block * kBlockM) { gLSE(thread_idx) = INFINITY; }
|
|
|
|
|
|
+ if (thread_idx < get<0>(epilogue_params.layout_LSE.shape()) - m_block * kBlockM) { gLSE(thread_idx) = INFINITY; }
|
|
}
|
|
}
|
|
|
|
|
|
};
|
|
};
|