Преглед изворни кода

chore: update cutlass to 3.5.1

AlpinDale пре 6 месеци
родитељ
комит
9d98f29b3a

+ 3 - 3
CMakeLists.txt

@@ -193,8 +193,8 @@ if(APHRODITE_GPU_LANG STREQUAL "CUDA")
   FetchContent_Declare(
         cutlass
         GIT_REPOSITORY https://github.com/nvidia/cutlass.git
-        # CUTLASS 3.5.0
-        GIT_TAG 7d49e6c7e2f8896c47f586706e67e1fb215529dc
+        # CUTLASS 3.5.1
+        GIT_TAG 06b21349bcf6ddf6a1686a47a137ad1446579db9 
   )
   FetchContent_MakeAvailable(cutlass)
 
@@ -238,7 +238,7 @@ define_gpu_extension_target(
   SOURCES ${APHRODITE_EXT_SRC}
   COMPILE_FLAGS ${APHRODITE_GPU_FLAGS}
   ARCHITECTURES ${APHRODITE_GPU_ARCHES}
-  INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
+  INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
   USE_SABI 3
   WITH_SOABI)
 

+ 111 - 81
kernels/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp

@@ -64,8 +64,6 @@ using namespace detail;
 
 // Row vector broadcast
 template<
-  // Row bcast reuses the mbarriers from the epilogue subtile load pipeline, so this must be at least
-  // ceil_div(StagesC, epi tiles per CTA tile) + 1 to ensure no data races
   int Stages,
   class CtaTileShapeMNK,
   class Element,
@@ -73,14 +71,12 @@ template<
   int Alignment = 128 / sizeof_bits_v<Element>
 >
 struct Sm90RowOrScalarBroadcast {
-  static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet");
-  static_assert(
-    (cute::is_same_v<StrideMNL, Stride<_0,_1, _0>>) || // row vector broadcast, e.g. per-col alpha/bias
-    (cute::is_same_v<StrideMNL, Stride<_0,_1,int>>));  // batched row vector broadcast
+  static_assert(Stages == 0, "Row broadcast doesn't support smem usage");
+  static_assert(is_static_v<decltype(take<0,2>(StrideMNL{}))>); // batch stride can be dynamic or static
+  static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{});
 
-  // Accumulator doesn't distribute row elements evenly amongst threads so we must buffer in smem
-  struct SharedStorage {
-    alignas(16) array_aligned<Element, size<1>(CtaTileShapeMNK{}) * Stages> smem_row;
+  struct SharedStorage { 
+    array_aligned<Element, size<1>(CtaTileShapeMNK{})> smem;
   };
 
   // This struct has been modified to have a bool indicating that ptr_row is a 
@@ -100,6 +96,12 @@ struct Sm90RowOrScalarBroadcast {
     return args;
   }
 
+  template <class ProblemShape>
+  static bool
+  can_implement(ProblemShape const& problem_shape, Arguments const& args) {
+    return true;
+  }
+
   template <class ProblemShape>
   static size_t
   get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
@@ -118,15 +120,15 @@ struct Sm90RowOrScalarBroadcast {
 
   CUTLASS_HOST_DEVICE
   Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
-      : params(params),
-        smem_row(const_cast<Element*>(shared_storage.smem_row.data())) { }
+      : params(params)
+      , smem(const_cast<Element*>(shared_storage.smem.data())) { }
 
   Params params;
-  Element* smem_row;
+  Element *smem = nullptr;
 
   CUTLASS_DEVICE bool
   is_producer_load_needed() const {
-    return true;
+    return false;
   }
 
   CUTLASS_DEVICE bool
@@ -139,78 +141,76 @@ struct Sm90RowOrScalarBroadcast {
     return (!params.row_broadcast && *(params.ptr_row) == Element(0));
   }
 
-  template <int EpiTiles, class GTensor, class STensor>
-  struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks {
-    CUTLASS_DEVICE
-    ProducerLoadCallbacks(GTensor&& gRow, STensor&& sRow, Params const& params)
-      : gRow(cute::forward<GTensor>(gRow)),
-        sRow(cute::forward<STensor>(sRow)),
-        params(params) {}
-
-    GTensor gRow;                                                                                 // (CTA_M,CTA_N)
-    STensor sRow;                                                                                 // (CTA_M,CTA_N,PIPE)
-    Params const& params;
-
-    CUTLASS_DEVICE void
-    begin(uint64_t* full_mbarrier_ptr, int load_iteration, bool issue_tma_load) {
-      if (!params.row_broadcast) {
-        return;
-      }
-
-      if (issue_tma_load) {
-        // Increment the expect-tx count of the first subtile's mbarrier by the row vector's byte-size
-        constexpr uint32_t copy_bytes = size<1>(CtaTileShapeMNK{}) * sizeof_bits_v<Element> / 8;
-        cutlass::arch::ClusterTransactionBarrier::expect_transaction(full_mbarrier_ptr, copy_bytes);
-        // Issue the TMA bulk copy
-        auto bulk_copy = Copy_Atom<SM90_BULK_COPY_AUTO, Element>{}.with(*full_mbarrier_ptr);
-        // Filter so we don't issue redundant copies over stride-0 modes
-        int bcast_pipe_index = (load_iteration / EpiTiles) % Stages;
-        copy(bulk_copy, filter(gRow), filter(sRow(_,_,bcast_pipe_index)));
-      }
-    }
-  };
-
   template <class... Args>
   CUTLASS_DEVICE auto
   get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
-
-    auto [M, N, K, L] = args.problem_shape_mnkl;
-    auto [m, n, k, l] = args.tile_coord_mnkl;
-    Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
-    Tensor gRow = local_tile(mRow, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l));            // (CTA_M,CTA_N)
-    Tensor sRow = make_tensor(make_smem_ptr(smem_row),                                            // (CTA_M,CTA_N,PIPE)
-                    make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages),
-                    make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{})));
-
-    constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value;
-    return ProducerLoadCallbacks<EpiTiles, decltype(gRow), decltype(sRow)>(
-      cute::move(gRow), cute::move(sRow), params);
+    return EmptyProducerLoadCallbacks{};
   }
 
-  template <int EpiTiles, class RTensor, class STensor>
+  template <class GS_GTensor, class GS_STensor, class GS_CTensor, class Tiled_G2S, class SR_STensor, class SR_RTensor, class CTensor, class ThrResidue, class ThrNum>
   struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
     CUTLASS_DEVICE
-    ConsumerStoreCallbacks(RTensor&& tCrRow, STensor&& tCsRow, Params const& params)
-      : tCrRow(cute::forward<RTensor>(tCrRow)),
-        tCsRow(cute::forward<STensor>(tCsRow)),
-        params(params) {}
-
-    RTensor tCrRow;                                                               // (CPY,CPY_M,CPY_N)
-    STensor tCsRow;                                                               // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)
+    ConsumerStoreCallbacks(
+        GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_, 
+        GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_, 
+        SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_,
+        CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, Params const& params_)
+      : tGS_gRow(tGS_gRow_)
+      , tGS_sRow(tGS_sRow_)
+      , tGS_cRow(tGS_cRow_)
+      , tiled_G2S(tiled_g2s_)
+      , tSR_sRow(tSR_sRow_)
+      , tSR_rRow(tSR_rRow_)
+      , tCcRow(tCcRow_)
+      , residue_tCcRow(residue_tCcRow_)
+      , params(params_) {}
+
+    GS_GTensor tGS_gRow;                                                         // (CPY,CPY_M,CPY_N)
+    GS_STensor tGS_sRow;                                                         // (CPY,CPY_M,CPY_N)
+    GS_CTensor tGS_cRow;                                                         // (CPY,CPY_M,CPY_N)
+    Tiled_G2S tiled_G2S;
+
+    SR_STensor tSR_sRow;                                                         // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
+    SR_RTensor tSR_rRow;                                                         // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) 
+  
+    CTensor tCcRow;                                                              // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
+    ThrResidue residue_tCcRow;                                                   // (m, n)
+    ThrNum thr_num;
     Params const& params;
 
     CUTLASS_DEVICE void
-    previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) {
+    begin() {
       if (!params.row_broadcast) {
-        fill(tCrRow, *(params.ptr_row));
+        fill(tSR_rRow, *(params.ptr_row));
         return;
       }
 
+      auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
+      Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
+      Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
+      Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
+
+      for (int i = 0; i < size(tGS_gRow_flt); ++i) {
+        if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
+          continue; // OOB of SMEM, 
+        }
+        if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) {
+          tGS_sRow_flt(i) = tGS_gRow_flt(i);
+        }
+        else {
+          tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds.
+        }
+      }
+      synchronize();
+    }
+
+    CUTLASS_DEVICE void
+    begin_loop(int epi_m, int epi_n) {
       if (epi_m == 0) { // Assumes M-major subtile loop
-        // Filter so we don't issue redundant copies over stride-0 modes
-        // (only works if 0-strides are in same location, which is by construction)
-        int bcast_pipe_index = (load_iteration / EpiTiles) % Stages;
-        copy_aligned(filter(tCsRow(_,_,_,epi_m,epi_n,bcast_pipe_index)), filter(tCrRow));
+        if (!params.row_broadcast) return; // Do not issue LDS when row is scalar 
+        Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
+        Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
+        copy(tSR_sRow_flt, tSR_rRow_flt);
       }
     }
 
@@ -221,7 +221,7 @@ struct Sm90RowOrScalarBroadcast {
 
       CUTLASS_PRAGMA_UNROLL
       for (int i = 0; i < FragmentSize; ++i) {
-        frg_row[i] = tCrRow(epi_v * FragmentSize + i);
+        frg_row[i] = tSR_rRow(epi_v * FragmentSize + i);
       }
 
       return frg_row;
@@ -234,17 +234,41 @@ struct Sm90RowOrScalarBroadcast {
   >
   CUTLASS_DEVICE auto
   get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
+    auto [M, N, K, L] = args.problem_shape_mnkl;
+    auto [m, n, k, l] = args.tile_coord_mnkl;
+    using ThreadCount = decltype(size(args.tiled_copy));
 
-    Tensor sRow = make_tensor(make_smem_ptr(smem_row),                                            // (CTA_M,CTA_N,PIPE)
-                    make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages),
-                    make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{})));
-    Tensor tCsRow = sm90_partition_for_epilogue<ReferenceSrc>(                    // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)
-                      sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
-    Tensor tCrRow = make_tensor_like(take<0,3>(tCsRow));                                           // (CPY,CPY_M,CPY_N)
-
-    constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value;
-    return ConsumerStoreCallbacks<EpiTiles, decltype(tCrRow), decltype(tCsRow)>(
-      cute::move(tCrRow), cute::move(tCsRow), params);
+    Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
+    Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n));          // (CTA_M, CTA_N)
+    Tensor sRow = make_tensor(make_smem_ptr(smem), 
+        make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{}));  // (CTA_M, CTA_N)
+    //// G2S: Gmem to Smem
+    auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
+                                     Layout< Shape<_1, ThreadCount>, 
+                                            Stride<_0,          _1>>{}, 
+                                     Layout<_1>{});   
+    auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
+    Tensor tGS_gRow = thr_g2s.partition_S(gRow);
+    Tensor tGS_sRow = thr_g2s.partition_D(sRow);
+
+    //// G2S: Coord 
+    auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
+    Tensor tGS_cRow = thr_g2s.partition_S(cRow);
+
+    //// S2R: Smem to Reg
+    Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
+    Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow));                                           // (CPY,CPY_M,CPY_N)
+
+    return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
+      tGS_gRow, 
+      tGS_sRow, 
+      tGS_cRow, tiled_g2s, 
+      tSR_sRow, 
+      tSR_rRow, 
+      args.tCcD, 
+      args.residue_cD,
+      ThreadCount{}, 
+      params);
   }
 };
 
@@ -285,6 +309,12 @@ struct Sm90ColOrScalarBroadcast {
     return args;
   }
 
+  template <class ProblemShape>
+  static bool
+  can_implement(ProblemShape const& problem_shape, Arguments const& args) {
+    return true;
+  }
+
   template <class ProblemShape>
   static size_t
   get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {

+ 7 - 4
kernels/quantization/cutlass_w8a8/scaled_mm_c2x.cuh

@@ -10,8 +10,6 @@
 #include "cute/atom/mma_atom.hpp"
 #include "cutlass/numeric_types.h"
 
-#include "cutlass/util/device_memory.h"
-
 #include "cutlass/cutlass.h"
 #include "cutlass/gemm_coord.h"
 #include "cutlass/arch/mma_sm75.h"
@@ -93,12 +91,15 @@ struct ScaledEpilogueBase {
 /*
  This epilogue function defines a quantized GEMM operation similar to
  torch._scaled_mm.
+
  A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
  per-row. B can be quantized per-tensor or per-column.
  Any combination of per-tensor and per-row or column is supported.
  A and B must have symmetric quantization (zero point == 0).
+
  So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
  scales are applied elementwise with numpy-style broadcasting.
+
  ScaleA and ScaleB define the epilogue functions that apply the scales for
  the A and B operands respectively. These scales may be either per-tensor or
  per row or column.
@@ -298,12 +299,14 @@ inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
   // Launch the CUTLASS GEMM kernel.
   typename Gemm::Op gemm_op;
   size_t workspace_size = gemm_op.get_workspace_size(args);
-  cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
+  auto const workspace_options =
+      torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
+  auto workspace = torch::empty(workspace_size, workspace_options);
 
   auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
 
   CUTLASS_CHECK(gemm_op.can_implement(args));
-  cutlass::Status status = gemm_op(args, workspace.get(), stream);
+  cutlass::Status status = gemm_op(args, workspace.data_ptr(), stream);
   CUTLASS_CHECK(status);
 }
 

+ 11 - 19
kernels/quantization/cutlass_w8a8/scaled_mm_c3x.cu

@@ -18,8 +18,6 @@
 #include "cute/atom/mma_atom.hpp"
 #include "cutlass/numeric_types.h"
 
-#include "cutlass/util/device_memory.h"
-
 #include "cutlass/gemm/device/gemm_universal_adapter.h"
 #include "cutlass/gemm/kernel/gemm_universal.hpp"
 #include "cutlass/epilogue/collective/collective_builder.hpp"
@@ -72,13 +70,9 @@ struct ScaledEpilogueBase {
       0 /*Stages*/, typename EpilogueDescriptor::TileShape, float,
       Stride<Int<1>, Int<0>, Int<0>>>;
 
-  using ScaleBDescriptor =
-      cutlass::epilogue::collective::detail::RowBroadcastDescriptor<
-          EpilogueDescriptor, float>;
-
   using ScaleB = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
-      ScaleBDescriptor::Stages, typename EpilogueDescriptor::TileShape,
-      typename ScaleBDescriptor::Element, Stride<Int<0>, Int<1>, Int<0>>>;
+      0 /*Stages*/, typename EpilogueDescriptor::TileShape, float,
+      Stride<Int<0>, Int<1>, Int<0>>>;
 };
 
 /*
@@ -154,12 +148,8 @@ struct ScaledEpilogueBias
       cutlass::multiply_add, ElementD, float,
       cutlass::FloatRoundStyle::round_to_nearest>;
 
-  using BiasDescriptor =
-      cutlass::epilogue::collective::detail::RowBroadcastDescriptor<
-          EpilogueDescriptor, ElementD>;
-
   using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<
-      BiasDescriptor::Stages, typename EpilogueDescriptor::TileShape, ElementD,
+      0 /*Stages*/, typename EpilogueDescriptor::TileShape, ElementD,
       Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<ElementD>, false>;
 
  public:
@@ -251,12 +241,12 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
   int64_t ldb = b.stride(1);
   int64_t ldc = out.stride(0);
 
-  using StrideA = Stride<int64_t, Int<1>, Int<0>>;
-  using StrideB = Stride<int64_t, Int<1>, Int<0>>;
+  using StrideA = Stride<int64_t, Int<1>, int64_t>;
+  using StrideB = Stride<int64_t, Int<1>, int64_t>;
   using StrideC = typename Gemm::StrideC;
 
-  StrideA a_stride{lda, Int<1>{}, Int<0>{}};
-  StrideB b_stride{ldb, Int<1>{}, Int<0>{}};
+  StrideA a_stride{lda, Int<1>{}, 0};
+  StrideB b_stride{ldb, Int<1>{}, 0};
   StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
 
   using GemmKernel = typename Gemm::GemmKernel;
@@ -282,11 +272,13 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
   CUTLASS_CHECK(gemm_op.can_implement(args));
 
   size_t workspace_size = gemm_op.get_workspace_size(args);
-  cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
+  auto const workspace_options =
+      torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
+  auto workspace = torch::empty(workspace_size, workspace_options);
 
   auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
 
-  cutlass::Status status = gemm_op.run(args, workspace.get(), stream);
+  cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
   CUTLASS_CHECK(status);
 }