// clang-format will break include orders // clang-format off #include #if defined CUDA_VERSION && CUDA_VERSION >= 12000 #include #include #include #include #include #include "cutlass/cutlass.h" #include "cute/tensor.hpp" #include "cute/atom/mma_atom.hpp" #include "cutlass/numeric_types.h" #include "cutlass/util/device_memory.h" #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/kernel/gemm_universal.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" #include "broadcast_load_epilogue_c3x.hpp" #include "common.hpp" // clang-format on using namespace cute; /* This file defines quantized GEMM operations using the CUTLASS 3.x API, for NVIDIA GPUs with sm90a (Hopper) or later. Epilogue functions can be defined to post-process the output before it is written to GPU memory. Epilogues must contain a public type named EVTCompute of type Sm90EVT, as well as a static prepare_args function that constructs an EVTCompute::Arguments struct. */ namespace { // A wrapper for the GEMM kernel that is used to guard against compilation on // architectures that will never use the kernel. The purpose of this is to // reduce the size of the compiled binary. // __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef // into code that will be executed on the device where it is defined. template struct enable_sm90_or_later : Kernel { template CUTLASS_DEVICE void operator()(Args&&... args) { #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 Kernel::operator()(std::forward(args)...); #endif } }; /* * This class provides the common ScaleA and ScaleB descriptors for the * ScaledEpilogue and ScaledEpilogueBias classes. */ template struct ScaledEpilogueBase { protected: using Accum = cutlass::epilogue::fusion::Sm90AccFetch; using ScaleA = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast< 0 /*Stages*/, typename EpilogueDescriptor::TileShape, float, Stride, Int<0>, Int<0>>>; using ScaleBDescriptor = cutlass::epilogue::collective::detail::RowBroadcastDescriptor< EpilogueDescriptor, float>; using ScaleB = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast< ScaleBDescriptor::Stages, typename EpilogueDescriptor::TileShape, typename ScaleBDescriptor::Element, Stride, Int<1>, Int<0>>>; }; /* This epilogue function defines a quantized GEMM operation similar to torch.scaled_mm_. A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or per-row. B can be quantized per-tensor or per-column. Any combination of per-tensor and per-row or column is supported. A and B must have symmetric quantization (zero point == 0). So the GEMM operation is D = (a_scales * A) (b_scales * B), where the scales are applied elementwise with numpy-style broadcasting. ScaleA and ScaleB define the epilogue functions that apply the scales for the A and B operands respectively. These scales may be either per-tensor or per row or column. */ template struct ScaledEpilogue : private ScaledEpilogueBase { private: using SUPER = ScaledEpilogueBase; using Accum = typename SUPER::Accum; using ScaleA = typename SUPER::ScaleA; using ScaleB = typename SUPER::ScaleB; using Compute0 = cutlass::epilogue::fusion::Sm90Compute< cutlass::multiplies, float, float, cutlass::FloatRoundStyle::round_to_nearest>; using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT; using Compute1 = cutlass::epilogue::fusion::Sm90Compute< cutlass::multiplies, ElementD, float, cutlass::FloatRoundStyle::round_to_nearest>; public: using EVTCompute = cutlass::epilogue::fusion::Sm90EVT; using ArgumentType = typename EVTCompute::Arguments; static ArgumentType prepare_args(torch::Tensor const& a_scales, torch::Tensor const& b_scales) { using ScaleA_Args = typename ScaleA::Arguments; using ScaleB_Args = typename ScaleB::Arguments; ScaleA_Args a_args{a_scales.data_ptr(), a_scales.numel() != 1, {}}; ScaleB_Args b_args{b_scales.data_ptr(), b_scales.numel() != 1, {}}; return ArgumentType{a_args, {b_args}}; } }; template struct ScaledEpilogueBias : private ScaledEpilogueBase { private: using SUPER = ScaledEpilogueBase; using Accum = typename SUPER::Accum; using ScaleA = typename SUPER::ScaleA; using ScaleB = typename SUPER::ScaleB; using Compute0 = cutlass::epilogue::fusion::Sm90Compute< cutlass::multiplies, float, float, cutlass::FloatRoundStyle::round_to_nearest>; using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT; using Compute1 = cutlass::epilogue::fusion::Sm90Compute< cutlass::multiply_add, ElementD, float, cutlass::FloatRoundStyle::round_to_nearest>; using BiasDescriptor = cutlass::epilogue::collective::detail::RowBroadcastDescriptor< EpilogueDescriptor, ElementD>; using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast< BiasDescriptor::Stages, typename EpilogueDescriptor::TileShape, ElementD, Stride, Int<1>, Int<0>>, 128 / sizeof_bits_v, false>; public: using EVTCompute = cutlass::epilogue::fusion::Sm90EVT; using ArgumentType = typename EVTCompute::Arguments; static ArgumentType prepare_args(torch::Tensor const& a_scales, torch::Tensor const& b_scales, torch::Tensor const& bias) { using ScaleA_Args = typename ScaleA::Arguments; using ScaleB_Args = typename ScaleB::Arguments; using Bias_Args = typename Bias::Arguments; ScaleA_Args a_args{a_scales.data_ptr(), a_scales.numel() != 1, {}}; ScaleB_Args b_args{b_scales.data_ptr(), b_scales.numel() != 1, {}}; Bias_Args bias_args{static_cast(bias.data_ptr())}; return ArgumentType{a_args, {b_args}, bias_args}; } }; template typename Epilogue_, typename TileShape, typename ClusterShape, typename KernelSchedule, typename EpilogueSchedule> struct cutlass_3x_gemm { using ElementAB = ElementAB_; using ElementD = ElementD_; using ElementAcc = typename std::conditional, int32_t, float>::type; using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor< TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD, ElementD, EpilogueSchedule>; using Epilogue = Epilogue_; using StrideD = Stride, Int<0>>; using ElementC = void; using StrideC = StrideD; using EVTCompute = typename Epilogue::EVTCompute; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4, EpilogueSchedule, EVTCompute>::CollectiveOp; static constexpr size_t CEStorageSize = sizeof(typename CollectiveEpilogue::SharedStorage); using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< static_cast(CEStorageSize)>; // clang-format off using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, ElementAB, cutlass::layout::RowMajor, 16, ElementAB, cutlass::layout::ColumnMajor, 16, ElementAcc, TileShape, ClusterShape, Stages, KernelSchedule>::CollectiveOp; // clang-format on using KernelType = enable_sm90_or_later, CollectiveMainloop, CollectiveEpilogue, cutlass::gemm::PersistentScheduler>>; struct GemmKernel : public KernelType {}; }; template void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, EpilogueArgs&&... epilogue_params) { using ElementAB = typename Gemm::ElementAB; using ElementD = typename Gemm::ElementD; int32_t m = a.size(0); int32_t n = b.size(1); int32_t k = a.size(1); int64_t lda = a.stride(0); int64_t ldb = b.stride(1); int64_t ldc = out.stride(0); using StrideA = Stride, Int<0>>; using StrideB = Stride, Int<0>>; using StrideC = typename Gemm::StrideC; StrideA a_stride{lda, Int<1>{}, Int<0>{}}; StrideB b_stride{ldb, Int<1>{}, Int<0>{}}; StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; using GemmKernel = typename Gemm::GemmKernel; typename GemmKernel::ProblemShape prob_shape{m, n, k, 1}; auto a_ptr = static_cast(a.data_ptr()); auto b_ptr = static_cast(b.data_ptr()); typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr, b_stride}; auto c_ptr = static_cast(out.data_ptr()); typename GemmKernel::EpilogueArguments epilogue_args{ Gemm::Epilogue::prepare_args( std::forward(epilogue_params)...), c_ptr, c_stride, c_ptr, c_stride}; typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, prob_shape, mainloop_args, epilogue_args}; // Launch the CUTLASS GEMM kernel. using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; GemmOp gemm_op; CUTLASS_CHECK(gemm_op.can_implement(args)); size_t workspace_size = gemm_op.get_workspace_size(args); cutlass::device_memory::allocation workspace(workspace_size); auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); cutlass::Status status = gemm_op.run(args, workspace.get(), stream); CUTLASS_CHECK(status); } template typename Epilogue> struct sm90_fp8_config_default { // M in (128, inf) static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using TileShape = Shape<_128, _128, _128>; using ClusterShape = Shape<_2, _1, _1>; using Cutlass3xGemm = cutlass_3x_gemm; }; template typename Epilogue> struct sm90_fp8_config_M128 { // M in (64, 128] static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using TileShape = Shape<_64, _128, _128>; using ClusterShape = Shape<_2, _1, _1>; using Cutlass3xGemm = cutlass_3x_gemm; }; template typename Epilogue> struct sm90_fp8_config_M64 { // M in [1, 64] static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using TileShape = Shape<_64, _64, _128>; using ClusterShape = Shape<_1, _8, _1>; using Cutlass3xGemm = cutlass_3x_gemm; }; template typename Epilogue> struct sm90_int8_config_default { // For M > 128 and any N static_assert(std::is_same()); using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using TileShape = Shape<_128, _128, _128>; using ClusterShape = Shape<_2, _1, _1>; using Cutlass3xGemm = cutlass_3x_gemm; }; template typename Epilogue> struct sm90_int8_config_M128 { // For M in (64, 128] and any N static_assert(std::is_same()); using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using TileShape = Shape<_64, _128, _128>; using ClusterShape = Shape<_2, _1, _1>; using Cutlass3xGemm = cutlass_3x_gemm; }; template typename Epilogue> struct sm90_int8_config_M64 { // For M in (32, 64] and any N static_assert(std::is_same()); using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using TileShape = Shape<_64, _64, _256>; using ClusterShape = Shape<_1, _1, _1>; using Cutlass3xGemm = cutlass_3x_gemm; }; template typename Epilogue> struct sm90_int8_config_M32_NBig { // For M in [1, 32] and N >= 8192 static_assert(std::is_same()); using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using TileShape = Shape<_64, _128, _256>; using ClusterShape = Shape<_1, _4, _1>; using Cutlass3xGemm = cutlass_3x_gemm; }; template typename Epilogue> struct sm90_int8_config_M32_NSmall { // For M in [1, 32] and N < 8192 static_assert(std::is_same()); using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using TileShape = Shape<_64, _64, _256>; using ClusterShape = Shape<_1, _8, _1>; using Cutlass3xGemm = cutlass_3x_gemm; }; } // namespace template typename Epilogue, typename... EpilogueArgs> void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, EpilogueArgs&&... args) { static_assert(std::is_same()); TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); using Cutlass3xGemmDefault = typename sm90_fp8_config_default::Cutlass3xGemm; using Cutlass3xGemmM64 = typename sm90_fp8_config_M64::Cutlass3xGemm; using Cutlass3xGemmM128 = typename sm90_fp8_config_M128::Cutlass3xGemm; uint32_t const m = a.size(0); uint32_t const mp2 = std::max(static_cast(64), next_pow_2(m)); // next power of 2 if (mp2 <= 64) { // m in [1, 64] return cutlass_gemm_caller( out, a, b, std::forward(args)...); } else if (mp2 <= 128) { // m in (64, 128] return cutlass_gemm_caller( out, a, b, std::forward(args)...); } else { // m in (128, inf) return cutlass_gemm_caller( out, a, b, std::forward(args)...); } } template typename Epilogue, typename... EpilogueArgs> void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, EpilogueArgs&&... args) { static_assert(std::is_same()); TORCH_CHECK(a.dtype() == torch::kInt8); TORCH_CHECK(b.dtype() == torch::kInt8); using Cutlass3xGemmDefault = typename sm90_int8_config_default::Cutlass3xGemm; using Cutlass3xGemmM128 = typename sm90_int8_config_M128::Cutlass3xGemm; using Cutlass3xGemmM64 = typename sm90_int8_config_M64::Cutlass3xGemm; using Cutlass3xGemmM32NBig = typename sm90_int8_config_M32_NBig::Cutlass3xGemm; using Cutlass3xGemmM32NSmall = typename sm90_int8_config_M32_NSmall::Cutlass3xGemm; uint32_t const n = out.size(1); bool const is_small_n = n < 8192; uint32_t const m = a.size(0); uint32_t const mp2 = std::max(static_cast(32), next_pow_2(m)); // next power of 2 if (mp2 <= 32) { // m in [1, 32] if (is_small_n) { return cutlass_gemm_caller( out, a, b, std::forward(args)...); } else { return cutlass_gemm_caller( out, a, b, std::forward(args)...); } } else if (mp2 <= 64) { // m in (32, 64] return cutlass_gemm_caller( out, a, b, std::forward(args)...); } else if (mp2 <= 128) { // m in (64, 128] return cutlass_gemm_caller( out, a, b, std::forward(args)...); } else { // m in (128, inf) return cutlass_gemm_caller( out, a, b, std::forward(args)...); } } template