|
@@ -64,8 +64,6 @@ using namespace detail;
|
|
|
|
|
|
// Row vector broadcast
|
|
|
template<
|
|
|
- // Row bcast reuses the mbarriers from the epilogue subtile load pipeline, so this must be at least
|
|
|
- // ceil_div(StagesC, epi tiles per CTA tile) + 1 to ensure no data races
|
|
|
int Stages,
|
|
|
class CtaTileShapeMNK,
|
|
|
class Element,
|
|
@@ -73,14 +71,12 @@ template<
|
|
|
int Alignment = 128 / sizeof_bits_v<Element>
|
|
|
>
|
|
|
struct Sm90RowOrScalarBroadcast {
|
|
|
- static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet");
|
|
|
- static_assert(
|
|
|
- (cute::is_same_v<StrideMNL, Stride<_0,_1, _0>>) || // row vector broadcast, e.g. per-col alpha/bias
|
|
|
- (cute::is_same_v<StrideMNL, Stride<_0,_1,int>>)); // batched row vector broadcast
|
|
|
+ static_assert(Stages == 0, "Row broadcast doesn't support smem usage");
|
|
|
+ static_assert(is_static_v<decltype(take<0,2>(StrideMNL{}))>); // batch stride can be dynamic or static
|
|
|
+ static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{});
|
|
|
|
|
|
- // Accumulator doesn't distribute row elements evenly amongst threads so we must buffer in smem
|
|
|
- struct SharedStorage {
|
|
|
- alignas(16) array_aligned<Element, size<1>(CtaTileShapeMNK{}) * Stages> smem_row;
|
|
|
+ struct SharedStorage {
|
|
|
+ array_aligned<Element, size<1>(CtaTileShapeMNK{})> smem;
|
|
|
};
|
|
|
|
|
|
// This struct has been modified to have a bool indicating that ptr_row is a
|
|
@@ -100,6 +96,12 @@ struct Sm90RowOrScalarBroadcast {
|
|
|
return args;
|
|
|
}
|
|
|
|
|
|
+ template <class ProblemShape>
|
|
|
+ static bool
|
|
|
+ can_implement(ProblemShape const& problem_shape, Arguments const& args) {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+
|
|
|
template <class ProblemShape>
|
|
|
static size_t
|
|
|
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
|
@@ -118,15 +120,15 @@ struct Sm90RowOrScalarBroadcast {
|
|
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
|
Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
|
|
|
- : params(params),
|
|
|
- smem_row(const_cast<Element*>(shared_storage.smem_row.data())) { }
|
|
|
+ : params(params)
|
|
|
+ , smem(const_cast<Element*>(shared_storage.smem.data())) { }
|
|
|
|
|
|
Params params;
|
|
|
- Element* smem_row;
|
|
|
+ Element *smem = nullptr;
|
|
|
|
|
|
CUTLASS_DEVICE bool
|
|
|
is_producer_load_needed() const {
|
|
|
- return true;
|
|
|
+ return false;
|
|
|
}
|
|
|
|
|
|
CUTLASS_DEVICE bool
|
|
@@ -139,78 +141,76 @@ struct Sm90RowOrScalarBroadcast {
|
|
|
return (!params.row_broadcast && *(params.ptr_row) == Element(0));
|
|
|
}
|
|
|
|
|
|
- template <int EpiTiles, class GTensor, class STensor>
|
|
|
- struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks {
|
|
|
- CUTLASS_DEVICE
|
|
|
- ProducerLoadCallbacks(GTensor&& gRow, STensor&& sRow, Params const& params)
|
|
|
- : gRow(cute::forward<GTensor>(gRow)),
|
|
|
- sRow(cute::forward<STensor>(sRow)),
|
|
|
- params(params) {}
|
|
|
-
|
|
|
- GTensor gRow; // (CTA_M,CTA_N)
|
|
|
- STensor sRow; // (CTA_M,CTA_N,PIPE)
|
|
|
- Params const& params;
|
|
|
-
|
|
|
- CUTLASS_DEVICE void
|
|
|
- begin(uint64_t* full_mbarrier_ptr, int load_iteration, bool issue_tma_load) {
|
|
|
- if (!params.row_broadcast) {
|
|
|
- return;
|
|
|
- }
|
|
|
-
|
|
|
- if (issue_tma_load) {
|
|
|
- // Increment the expect-tx count of the first subtile's mbarrier by the row vector's byte-size
|
|
|
- constexpr uint32_t copy_bytes = size<1>(CtaTileShapeMNK{}) * sizeof_bits_v<Element> / 8;
|
|
|
- cutlass::arch::ClusterTransactionBarrier::expect_transaction(full_mbarrier_ptr, copy_bytes);
|
|
|
- // Issue the TMA bulk copy
|
|
|
- auto bulk_copy = Copy_Atom<SM90_BULK_COPY_AUTO, Element>{}.with(*full_mbarrier_ptr);
|
|
|
- // Filter so we don't issue redundant copies over stride-0 modes
|
|
|
- int bcast_pipe_index = (load_iteration / EpiTiles) % Stages;
|
|
|
- copy(bulk_copy, filter(gRow), filter(sRow(_,_,bcast_pipe_index)));
|
|
|
- }
|
|
|
- }
|
|
|
- };
|
|
|
-
|
|
|
template <class... Args>
|
|
|
CUTLASS_DEVICE auto
|
|
|
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
|
|
|
-
|
|
|
- auto [M, N, K, L] = args.problem_shape_mnkl;
|
|
|
- auto [m, n, k, l] = args.tile_coord_mnkl;
|
|
|
- Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
|
|
|
- Tensor gRow = local_tile(mRow, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N)
|
|
|
- Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE)
|
|
|
- make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages),
|
|
|
- make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{})));
|
|
|
-
|
|
|
- constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value;
|
|
|
- return ProducerLoadCallbacks<EpiTiles, decltype(gRow), decltype(sRow)>(
|
|
|
- cute::move(gRow), cute::move(sRow), params);
|
|
|
+ return EmptyProducerLoadCallbacks{};
|
|
|
}
|
|
|
|
|
|
- template <int EpiTiles, class RTensor, class STensor>
|
|
|
+ template <class GS_GTensor, class GS_STensor, class GS_CTensor, class Tiled_G2S, class SR_STensor, class SR_RTensor, class CTensor, class ThrResidue, class ThrNum>
|
|
|
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
|
|
CUTLASS_DEVICE
|
|
|
- ConsumerStoreCallbacks(RTensor&& tCrRow, STensor&& tCsRow, Params const& params)
|
|
|
- : tCrRow(cute::forward<RTensor>(tCrRow)),
|
|
|
- tCsRow(cute::forward<STensor>(tCsRow)),
|
|
|
- params(params) {}
|
|
|
-
|
|
|
- RTensor tCrRow; // (CPY,CPY_M,CPY_N)
|
|
|
- STensor tCsRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)
|
|
|
+ ConsumerStoreCallbacks(
|
|
|
+ GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_,
|
|
|
+ GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_,
|
|
|
+ SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_,
|
|
|
+ CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, Params const& params_)
|
|
|
+ : tGS_gRow(tGS_gRow_)
|
|
|
+ , tGS_sRow(tGS_sRow_)
|
|
|
+ , tGS_cRow(tGS_cRow_)
|
|
|
+ , tiled_G2S(tiled_g2s_)
|
|
|
+ , tSR_sRow(tSR_sRow_)
|
|
|
+ , tSR_rRow(tSR_rRow_)
|
|
|
+ , tCcRow(tCcRow_)
|
|
|
+ , residue_tCcRow(residue_tCcRow_)
|
|
|
+ , params(params_) {}
|
|
|
+
|
|
|
+ GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N)
|
|
|
+ GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N)
|
|
|
+ GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N)
|
|
|
+ Tiled_G2S tiled_G2S;
|
|
|
+
|
|
|
+ SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
|
|
+ SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
|
|
+
|
|
|
+ CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
|
|
+ ThrResidue residue_tCcRow; // (m, n)
|
|
|
+ ThrNum thr_num;
|
|
|
Params const& params;
|
|
|
|
|
|
CUTLASS_DEVICE void
|
|
|
- previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) {
|
|
|
+ begin() {
|
|
|
if (!params.row_broadcast) {
|
|
|
- fill(tCrRow, *(params.ptr_row));
|
|
|
+ fill(tSR_rRow, *(params.ptr_row));
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
+ auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
|
|
|
+ Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
|
|
|
+ Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
|
|
|
+ Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
|
|
|
+
|
|
|
+ for (int i = 0; i < size(tGS_gRow_flt); ++i) {
|
|
|
+ if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
|
|
|
+ continue; // OOB of SMEM,
|
|
|
+ }
|
|
|
+ if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) {
|
|
|
+ tGS_sRow_flt(i) = tGS_gRow_flt(i);
|
|
|
+ }
|
|
|
+ else {
|
|
|
+ tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds.
|
|
|
+ }
|
|
|
+ }
|
|
|
+ synchronize();
|
|
|
+ }
|
|
|
+
|
|
|
+ CUTLASS_DEVICE void
|
|
|
+ begin_loop(int epi_m, int epi_n) {
|
|
|
if (epi_m == 0) { // Assumes M-major subtile loop
|
|
|
- // Filter so we don't issue redundant copies over stride-0 modes
|
|
|
- // (only works if 0-strides are in same location, which is by construction)
|
|
|
- int bcast_pipe_index = (load_iteration / EpiTiles) % Stages;
|
|
|
- copy_aligned(filter(tCsRow(_,_,_,epi_m,epi_n,bcast_pipe_index)), filter(tCrRow));
|
|
|
+ if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
|
|
|
+ Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
|
|
|
+ Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
|
|
|
+ copy(tSR_sRow_flt, tSR_rRow_flt);
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -221,7 +221,7 @@ struct Sm90RowOrScalarBroadcast {
|
|
|
|
|
|
CUTLASS_PRAGMA_UNROLL
|
|
|
for (int i = 0; i < FragmentSize; ++i) {
|
|
|
- frg_row[i] = tCrRow(epi_v * FragmentSize + i);
|
|
|
+ frg_row[i] = tSR_rRow(epi_v * FragmentSize + i);
|
|
|
}
|
|
|
|
|
|
return frg_row;
|
|
@@ -234,17 +234,41 @@ struct Sm90RowOrScalarBroadcast {
|
|
|
>
|
|
|
CUTLASS_DEVICE auto
|
|
|
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
|
|
+ auto [M, N, K, L] = args.problem_shape_mnkl;
|
|
|
+ auto [m, n, k, l] = args.tile_coord_mnkl;
|
|
|
+ using ThreadCount = decltype(size(args.tiled_copy));
|
|
|
|
|
|
- Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE)
|
|
|
- make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages),
|
|
|
- make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{})));
|
|
|
- Tensor tCsRow = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)
|
|
|
- sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
|
|
|
- Tensor tCrRow = make_tensor_like(take<0,3>(tCsRow)); // (CPY,CPY_M,CPY_N)
|
|
|
-
|
|
|
- constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value;
|
|
|
- return ConsumerStoreCallbacks<EpiTiles, decltype(tCrRow), decltype(tCsRow)>(
|
|
|
- cute::move(tCrRow), cute::move(tCsRow), params);
|
|
|
+ Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
|
|
|
+ Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
|
|
|
+ Tensor sRow = make_tensor(make_smem_ptr(smem),
|
|
|
+ make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
|
|
|
+ //// G2S: Gmem to Smem
|
|
|
+ auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
|
|
+ Layout< Shape<_1, ThreadCount>,
|
|
|
+ Stride<_0, _1>>{},
|
|
|
+ Layout<_1>{});
|
|
|
+ auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
|
|
|
+ Tensor tGS_gRow = thr_g2s.partition_S(gRow);
|
|
|
+ Tensor tGS_sRow = thr_g2s.partition_D(sRow);
|
|
|
+
|
|
|
+ //// G2S: Coord
|
|
|
+ auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
|
|
|
+ Tensor tGS_cRow = thr_g2s.partition_S(cRow);
|
|
|
+
|
|
|
+ //// S2R: Smem to Reg
|
|
|
+ Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
|
|
|
+ Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
|
|
|
+
|
|
|
+ return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
|
|
|
+ tGS_gRow,
|
|
|
+ tGS_sRow,
|
|
|
+ tGS_cRow, tiled_g2s,
|
|
|
+ tSR_sRow,
|
|
|
+ tSR_rRow,
|
|
|
+ args.tCcD,
|
|
|
+ args.residue_cD,
|
|
|
+ ThreadCount{},
|
|
|
+ params);
|
|
|
}
|
|
|
};
|
|
|
|
|
@@ -285,6 +309,12 @@ struct Sm90ColOrScalarBroadcast {
|
|
|
return args;
|
|
|
}
|
|
|
|
|
|
+ template <class ProblemShape>
|
|
|
+ static bool
|
|
|
+ can_implement(ProblemShape const& problem_shape, Arguments const& args) {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+
|
|
|
template <class ProblemShape>
|
|
|
static size_t
|
|
|
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|