#pragma once #include #include #include // clang-format will break include orders // clang-format off #include "cute/tensor.hpp" #include "cute/atom/mma_atom.hpp" #include "cutlass/numeric_types.h" #include "cutlass/cutlass.h" #include "cutlass/gemm_coord.h" #include "cutlass/arch/mma_sm75.h" #include "cutlass/arch/arch.h" #include "cutlass/arch/mma.h" #include "cutlass/gemm/device/gemm.h" #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/epilogue/threadblock/fusion/visitors.hpp" #include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" #include "broadcast_load_epilogue_c2x.hpp" #include "common.hpp" // clang-format on using namespace cute; /* Epilogue functions can be defined to post-process the output before it is written to GPU memory. Epilogues must contain a public type named EVTCompute of type Sm80EVT, as well as a static prepare_args function that constructs an EVTCompute::Arguments struct. */ namespace aphrodite { // Wrappers for the GEMM kernel that is used to guard against compilation on // architectures that will never use the kernel. The purpose of this is to // reduce the size of the compiled binary. // __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef // into code that will be executed on the device where it is defined. template struct enable_sm75_to_sm80 : Kernel { template CUTLASS_DEVICE static void invoke(Args&&... args) { #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800 Kernel::invoke(std::forward(args)...); #endif } }; template struct enable_sm80_to_sm89 : Kernel { template CUTLASS_DEVICE static void invoke(Args&&... args) { #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890 Kernel::invoke(std::forward(args)...); #endif } }; template struct enable_sm89_to_sm90 : Kernel { template CUTLASS_DEVICE static void invoke(Args&&... args) { #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900 Kernel::invoke(std::forward(args)...); #endif } }; /* * This class provides the common load descriptors for the * ScaledEpilogue[...] classes */ template struct ScaledEpilogueBase { protected: using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; template using ColOrScalarLoad = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast< OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>; template using RowOrScalarLoad = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast< OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; template using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast< OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>; template using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast< OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; template using RowOrZeroLoad = cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast< OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; // This utility function constructs the arguments for the load descriptors // from a tensor. It can handle both row and column, as well as row/column or // scalar cases. template static auto args_from_tensor(torch::Tensor const& tensor) { using Arguments = typename Descriptor::Arguments; auto* data_ptr = static_cast(tensor.data_ptr()); if constexpr (std::is_same_v> || std::is_same_v>) { return Arguments{data_ptr, tensor.numel() != 1}; } else { // it would technically work but no use case as data_ptr is never nullptr static_assert(!std::is_same_v>); return Arguments{data_ptr}; } } // This overload handles the case where there might not be a tensor, in which // case a nullptr is passed and a constant (0) is used. template static auto args_from_tensor(c10::optional const& tensor) { static_assert(std::is_same_v>); using Arguments = typename Descriptor::Arguments; auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; return Arguments{data_ptr}; } }; /* This epilogue function defines a quantized GEMM operation similar to torch._scaled_mm. A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or per-row. B can be quantized per-tensor or per-column. Any combination of per-tensor and per-row or column is supported. A and B must have symmetric quantization (zero point == 0). So the GEMM operation is D = (a_scales * A) (b_scales * B), where the scales are applied elementwise with numpy-style broadcasting. ScaleA and ScaleB define the epilogue functions that apply the scales for the A and B operands respectively. These scales may be either per-tensor or per row or column. */ template struct ScaledEpilogue : private ScaledEpilogueBase { private: using SUPER = ScaledEpilogueBase; using Accum = typename SUPER::Accum; using ScaleA = typename SUPER::template ColOrScalarLoad; using ScaleB = typename SUPER::template RowOrScalarLoad; using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< cutlass::multiplies, float, float, cutlass::FloatRoundStyle::round_to_nearest>; using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT; using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< cutlass::multiplies, ElementD, float, cutlass::FloatRoundStyle::round_to_nearest>; public: using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT; using ArgumentType = typename EVTCompute::Arguments; static ArgumentType prepare_args(torch::Tensor const& a_scales, torch::Tensor const& b_scales) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); typename EVTCompute0::Arguments evt0_args{b_args}; return ArgumentType{a_args, evt0_args}; } }; /* * This epilogue performs the same operation as ScaledEpilogue, but adds a bias. * This bias can also be used in the per-tensor azp case, where the activation * zero point (azp) is used to compute an azp correction term, * which is folded into the bias. * * The bias tensor must be per-output channel. * ScaleA and ScaleB can be per-tensor or per-token/per-channel. */ template struct ScaledEpilogueBias : protected ScaledEpilogueBase { protected: using SUPER = ScaledEpilogueBase; using Accum = typename SUPER::Accum; using ScaleA = typename SUPER::template ColOrScalarLoad; using ScaleB = typename SUPER::template RowOrScalarLoad; using Bias = typename SUPER::template RowLoad; using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< cutlass::multiplies, float, float, cutlass::FloatRoundStyle::round_to_nearest>; using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT; using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< cutlass::multiply_add, ElementD, float, cutlass::FloatRoundStyle::round_to_nearest>; public: using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT; using ArgumentType = typename EVTCompute::Arguments; static ArgumentType prepare_args(torch::Tensor const& a_scales, torch::Tensor const& b_scales, torch::Tensor const& bias) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); typename EVTCompute0::Arguments evt0_args{b_args}; return ArgumentType{a_args, evt0_args, bias_args}; } }; /* * This epilogue directly supports per-tensor azp in int32 form. * As opposed to the per-token epilogue below, this epilogue only has an azp_adj * term, which should already be multiplied with the scalar azp. * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B. * * This epilogue also supports bias, which remains per-channel. */ template struct ScaledEpilogueBiasAzp : protected ScaledEpilogueBase { private: using SUPER = ScaledEpilogueBase; using Accum = typename SUPER::Accum; using ScaleA = typename SUPER::template ColOrScalarLoad; using ScaleB = typename SUPER::template RowOrScalarLoad; using Bias = typename SUPER::template RowOrZeroLoad; // This is the full AZP term, azp * J @ B, shape (1,n) using AzpWithAdj = typename SUPER::template RowLoad; // Compute float(accum - azp_adj), both operands are int32_t using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute< cutlass::minus, float, int32_t, cutlass::FloatRoundStyle::round_to_nearest>; using EVTComputeAzp = cutlass::epilogue::threadblock::Sm80EVT; using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute< cutlass::multiplies, float, float, cutlass::FloatRoundStyle::round_to_nearest>; using EVTComputeScaleB = cutlass::epilogue::threadblock::Sm80EVT; using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< cutlass::multiply_add, ElementD, float, cutlass::FloatRoundStyle::round_to_nearest>; public: using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT; using ArgumentType = typename EVTCompute::Arguments; static ArgumentType prepare_args(torch::Tensor const& a_scales, torch::Tensor const& b_scales, torch::Tensor const& azp_adj, c10::optional const& bias) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); auto azp_adj_args = SUPER::template args_from_tensor(azp_adj); typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args}; typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args}; return ArgumentType{a_args, evt_scale_b_args, bias_args}; } }; /* * This epilogue supports per-token azp by computing and applying * the correction term using a rank-1 update. If the term were materialized, * it would require O(m*n) space, and this way it only requires O(m+n) space. * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero * point for each row of A. * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B. * * This epilogue also supports bias, which remains per-channel. */ template struct ScaledEpilogueBiasAzpToken : protected ScaledEpilogueBase { private: using SUPER = ScaledEpilogueBase; using Accum = typename SUPER::Accum; using ScaleA = typename SUPER::template ColOrScalarLoad; using ScaleB = typename SUPER::template RowOrScalarLoad; using Bias = typename SUPER::template RowOrZeroLoad; // Per-token azp term, shape (m,1) using Azp = typename SUPER::template ColLoad; // This is the AZP adjustment term, J @ B, shape (1,n) using AzpAdj = typename SUPER::template RowLoad; // Compute azp * azp_adj using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute< cutlass::multiplies, int32_t, int32_t, cutlass::FloatRoundStyle::round_to_nearest>; using EVTComputeAzp = cutlass::epilogue::threadblock::Sm80EVT; // Compute float(accum - azp*azp_adj), all operands are int32_t using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute< cutlass::minus, float, int32_t, cutlass::FloatRoundStyle::round_to_nearest>; using EVTComputeAcc = cutlass::epilogue::threadblock::Sm80EVT; using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute< cutlass::multiplies, float, float, cutlass::FloatRoundStyle::round_to_nearest>; using EVTComputeScaleB = cutlass::epilogue::threadblock::Sm80EVT; using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< cutlass::multiply_add, ElementD, float, cutlass::FloatRoundStyle::round_to_nearest>; public: using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT; using ArgumentType = typename EVTCompute::Arguments; static ArgumentType prepare_args(torch::Tensor const& a_scales, torch::Tensor const& b_scales, torch::Tensor const& azp_adj, torch::Tensor const& azp, c10::optional const& bias) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); auto azp_args = SUPER::template args_from_tensor(azp); auto azp_adj_args = SUPER::template args_from_tensor(azp_adj); typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args}; typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args}; typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args}; return ArgumentType{a_args, evt_scale_b_args, bias_args}; } }; template typename ArchGuard, typename ElementAB_, typename ElementD_, template typename Epilogue_, typename TileShape, typename WarpShape, typename InstructionShape, int32_t MainLoopStages, typename FP8MathOperator = cutlass::arch::OpMultiplyAdd> struct cutlass_2x_gemm { using ElementAB = ElementAB_; using ElementD = ElementD_; using ElementAcc = typename std::conditional, int32_t, float>::type; using Operator = typename std::conditional, cutlass::arch::OpMultiplyAddSaturate, FP8MathOperator>::type; using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout< TileShape, WarpShape, float, 4, 1 /* epilogue stages */ >; using Epilogue = Epilogue_; using EVTCompute = typename Epilogue::EVTCompute; using D = cutlass::epilogue::threadblock::VisitorAuxStore< OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest, Stride, Int<0>>>; using EVTD = cutlass::epilogue::threadblock::Sm80EVT; // clang-format off using RowMajor = typename cutlass::layout::RowMajor; using ColumnMajor = typename cutlass::layout::ColumnMajor; using KernelType = ArchGuard::GemmKernel>; // clang-format on using Op = cutlass::gemm::device::GemmUniversalAdapter; }; template inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, EpilogueArgs&&... epilogue_params) { using ElementAB = typename Gemm::ElementAB; using ElementD = typename Gemm::ElementD; int32_t m = a.size(0); int32_t n = b.size(1); int32_t k = a.size(1); cutlass::gemm::GemmCoord problem_size{m, n, k}; int64_t lda = a.stride(0); int64_t ldb = b.stride(1); int64_t ldc = out.stride(0); using StrideC = Stride, Int<0>>; StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; auto a_ptr = static_cast(a.data_ptr()); auto b_ptr = static_cast(b.data_ptr()); auto c_ptr = static_cast(out.data_ptr()); typename Gemm::D::Arguments d_args{c_ptr, c_stride}; using Epilogue = typename Gemm::Epilogue; auto evt_args = Epilogue::prepare_args(std::forward(epilogue_params)...); typename Gemm::EVTD::Arguments epilogue_args{ evt_args, d_args, }; typename Gemm::Op::Arguments args{ cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, // universal mode problem_size, // problem size 1, // batch count epilogue_args, a_ptr, b_ptr, nullptr, nullptr, 0, 0, 0, 0, lda, ldb, ldc, ldc}; // Launch the CUTLASS GEMM kernel. typename Gemm::Op gemm_op; size_t workspace_size = gemm_op.get_workspace_size(args); auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); auto workspace = torch::empty(workspace_size, workspace_options); auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); CUTLASS_CHECK(gemm_op.can_implement(args)); cutlass::Status status = gemm_op(args, workspace.data_ptr(), stream); CUTLASS_CHECK(status); } template inline void fallback_cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, EpilogueArgs&&... args) { // In some cases, the GPU isn't able to accommodate the // shared memory requirements of the Gemm. In such cases, use // the FallbackGemm instead. static const int max_shared_mem_per_block_opt_in = get_cuda_max_shared_memory_per_block_opt_in(0); size_t const gemm_shared_mem_size = sizeof(typename Gemm::KernelType::SharedStorage); size_t const fallback_gemm_shared_mem_size = sizeof(typename FallbackGemm::KernelType::SharedStorage); if (gemm_shared_mem_size <= max_shared_mem_per_block_opt_in) { return cutlass_gemm_caller(out, a, b, std::forward(args)...); } else { TORCH_CHECK(fallback_gemm_shared_mem_size <= max_shared_mem_per_block_opt_in); return cutlass_gemm_caller( out, a, b, std::forward(args)...); } } } // namespace aphrodite