123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521 |
- #pragma once
- #include <stddef.h>
- #include <torch/all.h>
- #include <ATen/cuda/CUDAContext.h>
- // clang-format will break include orders
- // clang-format off
- #include "cute/tensor.hpp"
- #include "cute/atom/mma_atom.hpp"
- #include "cutlass/numeric_types.h"
- #include "cutlass/cutlass.h"
- #include "cutlass/gemm_coord.h"
- #include "cutlass/arch/mma_sm75.h"
- #include "cutlass/arch/arch.h"
- #include "cutlass/arch/mma.h"
- #include "cutlass/gemm/device/gemm.h"
- #include "cutlass/gemm/device/gemm_universal_adapter.h"
- #include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
- #include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
- #include "broadcast_load_epilogue_c2x.hpp"
- #include "common.hpp"
- // clang-format on
- using namespace cute;
- /*
- Epilogue functions can be defined to post-process the output before it is
- written to GPU memory.
- Epilogues must contain a public type named EVTCompute of type Sm80EVT,
- as well as a static prepare_args function that constructs an
- EVTCompute::Arguments struct.
- */
- namespace aphrodite {
- // Wrappers for the GEMM kernel that is used to guard against compilation on
- // architectures that will never use the kernel. The purpose of this is to
- // reduce the size of the compiled binary.
- // __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
- // into code that will be executed on the device where it is defined.
- template <typename Kernel>
- struct enable_sm75_to_sm80 : Kernel {
- template <typename... Args>
- CUTLASS_DEVICE static void invoke(Args&&... args) {
- #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800
- Kernel::invoke(std::forward<Args>(args)...);
- #endif
- }
- };
- template <typename Kernel>
- struct enable_sm80_to_sm89 : Kernel {
- template <typename... Args>
- CUTLASS_DEVICE static void invoke(Args&&... args) {
- #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890
- Kernel::invoke(std::forward<Args>(args)...);
- #endif
- }
- };
- template <typename Kernel>
- struct enable_sm89_to_sm90 : Kernel {
- template <typename... Args>
- CUTLASS_DEVICE static void invoke(Args&&... args) {
- #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900
- Kernel::invoke(std::forward<Args>(args)...);
- #endif
- }
- };
- /*
- * This class provides the common load descriptors for the
- * ScaledEpilogue[...] classes
- */
- template <typename ElementD, typename OutputTileThreadMap>
- struct ScaledEpilogueBase {
- protected:
- using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
- template <typename T>
- using ColOrScalarLoad =
- cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
- OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
- template <typename T>
- using RowOrScalarLoad =
- cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
- OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
- template <typename T>
- using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast<
- OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
- template <typename T>
- using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast<
- OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
- template <typename T>
- using RowOrZeroLoad =
- cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast<
- OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
- // This utility function constructs the arguments for the load descriptors
- // from a tensor. It can handle both row and column, as well as row/column or
- // scalar cases.
- template <typename Descriptor, typename T>
- static auto args_from_tensor(torch::Tensor const& tensor) {
- using Arguments = typename Descriptor::Arguments;
- auto* data_ptr = static_cast<T*>(tensor.data_ptr());
- if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
- std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
- return Arguments{data_ptr, tensor.numel() != 1};
- } else {
- // it would technically work but no use case as data_ptr is never nullptr
- static_assert(!std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
- return Arguments{data_ptr};
- }
- }
- // This overload handles the case where there might not be a tensor, in which
- // case a nullptr is passed and a constant (0) is used.
- template <typename Descriptor, typename T>
- static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
- static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
- using Arguments = typename Descriptor::Arguments;
- auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
- return Arguments{data_ptr};
- }
- };
- /*
- This epilogue function defines a quantized GEMM operation similar to
- torch._scaled_mm.
- A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
- per-row. B can be quantized per-tensor or per-column.
- Any combination of per-tensor and per-row or column is supported.
- A and B must have symmetric quantization (zero point == 0).
- So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
- scales are applied elementwise with numpy-style broadcasting.
- ScaleA and ScaleB define the epilogue functions that apply the scales for
- the A and B operands respectively. These scales may be either per-tensor or
- per row or column.
- */
- template <typename ElementD, typename OutputTileThreadMap>
- struct ScaledEpilogue
- : private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
- private:
- using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
- using Accum = typename SUPER::Accum;
- using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
- using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
- using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
- cutlass::multiplies, float, float,
- cutlass::FloatRoundStyle::round_to_nearest>;
- using EVTCompute0 =
- cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
- using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
- cutlass::multiplies, ElementD, float,
- cutlass::FloatRoundStyle::round_to_nearest>;
- public:
- using EVTCompute =
- cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
- using ArgumentType = typename EVTCompute::Arguments;
- static ArgumentType prepare_args(torch::Tensor const& a_scales,
- torch::Tensor const& b_scales) {
- auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
- auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
- typename EVTCompute0::Arguments evt0_args{b_args};
- return ArgumentType{a_args, evt0_args};
- }
- };
- /*
- * This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
- * This bias can also be used in the per-tensor azp case, where the activation
- * zero point (azp) is used to compute an azp correction term,
- * which is folded into the bias.
- *
- * The bias tensor must be per-output channel.
- * ScaleA and ScaleB can be per-tensor or per-token/per-channel.
- */
- template <typename ElementD, typename OutputTileThreadMap>
- struct ScaledEpilogueBias
- : protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
- protected:
- using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
- using Accum = typename SUPER::Accum;
- using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
- using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
- using Bias = typename SUPER::template RowLoad<ElementD>;
- using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
- cutlass::multiplies, float, float,
- cutlass::FloatRoundStyle::round_to_nearest>;
- using EVTCompute0 =
- cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
- using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
- cutlass::multiply_add, ElementD, float,
- cutlass::FloatRoundStyle::round_to_nearest>;
- public:
- using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
- EVTCompute0, Bias>;
- using ArgumentType = typename EVTCompute::Arguments;
- static ArgumentType prepare_args(torch::Tensor const& a_scales,
- torch::Tensor const& b_scales,
- torch::Tensor const& bias) {
- auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
- auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
- auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
- typename EVTCompute0::Arguments evt0_args{b_args};
- return ArgumentType{a_args, evt0_args, bias_args};
- }
- };
- /*
- * This epilogue directly supports per-tensor azp in int32 form.
- * As opposed to the per-token epilogue below, this epilogue only has an azp_adj
- * term, which should already be multiplied with the scalar azp.
- * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
- *
- * This epilogue also supports bias, which remains per-channel.
- */
- template <typename ElementD, typename OutputTileThreadMap>
- struct ScaledEpilogueBiasAzp
- : protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
- private:
- using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
- using Accum = typename SUPER::Accum;
- using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
- using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
- using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
- // This is the full AZP term, azp * J @ B, shape (1,n)
- using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
- // Compute float(accum - azp_adj), both operands are int32_t
- using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
- cutlass::minus, float, int32_t,
- cutlass::FloatRoundStyle::round_to_nearest>;
- using EVTComputeAzp =
- cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Accum, AzpWithAdj>;
- using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
- cutlass::multiplies, float, float,
- cutlass::FloatRoundStyle::round_to_nearest>;
- using EVTComputeScaleB =
- cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
- EVTComputeAzp>;
- using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
- cutlass::multiply_add, ElementD, float,
- cutlass::FloatRoundStyle::round_to_nearest>;
- public:
- using EVTCompute =
- cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
- EVTComputeScaleB, Bias>;
- using ArgumentType = typename EVTCompute::Arguments;
- static ArgumentType prepare_args(torch::Tensor const& a_scales,
- torch::Tensor const& b_scales,
- torch::Tensor const& azp_adj,
- c10::optional<torch::Tensor> const& bias) {
- auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
- auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
- auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
- auto azp_adj_args =
- SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
- typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
- typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
- return ArgumentType{a_args, evt_scale_b_args, bias_args};
- }
- };
- /*
- * This epilogue supports per-token azp by computing and applying
- * the correction term using a rank-1 update. If the term were materialized,
- * it would require O(m*n) space, and this way it only requires O(m+n) space.
- * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
- * point for each row of A.
- * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
- *
- * This epilogue also supports bias, which remains per-channel.
- */
- template <typename ElementD, typename OutputTileThreadMap>
- struct ScaledEpilogueBiasAzpToken
- : protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
- private:
- using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
- using Accum = typename SUPER::Accum;
- using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
- using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
- using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
- // Per-token azp term, shape (m,1)
- using Azp = typename SUPER::template ColLoad<int32_t>;
- // This is the AZP adjustment term, J @ B, shape (1,n)
- using AzpAdj = typename SUPER::template RowLoad<int32_t>;
- // Compute azp * azp_adj
- using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
- cutlass::multiplies, int32_t, int32_t,
- cutlass::FloatRoundStyle::round_to_nearest>;
- using EVTComputeAzp =
- cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Azp, AzpAdj>;
- // Compute float(accum - azp*azp_adj), all operands are int32_t
- using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute<
- cutlass::minus, float, int32_t,
- cutlass::FloatRoundStyle::round_to_nearest>;
- using EVTComputeAcc =
- cutlass::epilogue::threadblock::Sm80EVT<ComputeAcc, Accum, EVTComputeAzp>;
- using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
- cutlass::multiplies, float, float,
- cutlass::FloatRoundStyle::round_to_nearest>;
- using EVTComputeScaleB =
- cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
- EVTComputeAcc>;
- using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
- cutlass::multiply_add, ElementD, float,
- cutlass::FloatRoundStyle::round_to_nearest>;
- public:
- using EVTCompute =
- cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
- EVTComputeScaleB, Bias>;
- using ArgumentType = typename EVTCompute::Arguments;
- static ArgumentType prepare_args(torch::Tensor const& a_scales,
- torch::Tensor const& b_scales,
- torch::Tensor const& azp_adj,
- torch::Tensor const& azp,
- c10::optional<torch::Tensor> const& bias) {
- auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
- auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
- auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
- auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
- auto azp_adj_args =
- SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
- typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
- typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
- typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
- return ArgumentType{a_args, evt_scale_b_args, bias_args};
- }
- };
- template <typename Arch, template <typename> typename ArchGuard,
- typename ElementAB_, typename ElementD_,
- template <typename, typename> typename Epilogue_, typename TileShape,
- typename WarpShape, typename InstructionShape, int32_t MainLoopStages,
- typename FP8MathOperator = cutlass::arch::OpMultiplyAdd>
- struct cutlass_2x_gemm {
- using ElementAB = ElementAB_;
- using ElementD = ElementD_;
- using ElementAcc =
- typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
- float>::type;
- using Operator =
- typename std::conditional<std::is_same_v<ElementAB, int8_t>,
- cutlass::arch::OpMultiplyAddSaturate,
- FP8MathOperator>::type;
- using OutputTileThreadMap =
- cutlass::epilogue::threadblock::OutputTileThreadLayout<
- TileShape, WarpShape, float, 4, 1 /* epilogue stages */
- >;
- using Epilogue = Epilogue_<ElementD, OutputTileThreadMap>;
- using EVTCompute = typename Epilogue::EVTCompute;
- using D = cutlass::epilogue::threadblock::VisitorAuxStore<
- OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest,
- Stride<int64_t, Int<1>, Int<0>>>;
- using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute>;
- // clang-format off
- using RowMajor = typename cutlass::layout::RowMajor;
- using ColumnMajor = typename cutlass::layout::ColumnMajor;
- using KernelType =
- ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
- ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16,
- ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16,
- float, cutlass::layout::RowMajor, 4,
- ElementAcc, float, cutlass::arch::OpClassTensorOp,
- Arch,
- TileShape, WarpShape, InstructionShape,
- EVTD,
- cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
- MainLoopStages, Operator,
- 1 /* epilogue stages */
- >::GemmKernel>;
- // clang-format on
- using Op = cutlass::gemm::device::GemmUniversalAdapter<KernelType>;
- };
- template <typename Gemm, typename... EpilogueArgs>
- inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
- torch::Tensor const& b,
- EpilogueArgs&&... epilogue_params) {
- using ElementAB = typename Gemm::ElementAB;
- using ElementD = typename Gemm::ElementD;
- int32_t m = a.size(0);
- int32_t n = b.size(1);
- int32_t k = a.size(1);
- cutlass::gemm::GemmCoord problem_size{m, n, k};
- int64_t lda = a.stride(0);
- int64_t ldb = b.stride(1);
- int64_t ldc = out.stride(0);
- using StrideC = Stride<int64_t, Int<1>, Int<0>>;
- StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
- auto a_ptr = static_cast<ElementAB const*>(a.data_ptr());
- auto b_ptr = static_cast<ElementAB const*>(b.data_ptr());
- auto c_ptr = static_cast<ElementD*>(out.data_ptr());
- typename Gemm::D::Arguments d_args{c_ptr, c_stride};
- using Epilogue = typename Gemm::Epilogue;
- auto evt_args =
- Epilogue::prepare_args(std::forward<EpilogueArgs>(epilogue_params)...);
- typename Gemm::EVTD::Arguments epilogue_args{
- evt_args,
- d_args,
- };
- typename Gemm::Op::Arguments args{
- cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, // universal mode
- problem_size, // problem size
- 1, // batch count
- epilogue_args,
- a_ptr,
- b_ptr,
- nullptr,
- nullptr,
- 0,
- 0,
- 0,
- 0,
- lda,
- ldb,
- ldc,
- ldc};
- // Launch the CUTLASS GEMM kernel.
- typename Gemm::Op gemm_op;
- size_t workspace_size = gemm_op.get_workspace_size(args);
- auto const workspace_options =
- torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
- auto workspace = torch::empty(workspace_size, workspace_options);
- auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
- CUTLASS_CHECK(gemm_op.can_implement(args));
- cutlass::Status status = gemm_op(args, workspace.data_ptr(), stream);
- CUTLASS_CHECK(status);
- }
- template <typename Gemm, typename FallbackGemm, typename... EpilogueArgs>
- inline void fallback_cutlass_gemm_caller(torch::Tensor& out,
- torch::Tensor const& a,
- torch::Tensor const& b,
- EpilogueArgs&&... args) {
- // In some cases, the GPU isn't able to accommodate the
- // shared memory requirements of the Gemm. In such cases, use
- // the FallbackGemm instead.
- static const int max_shared_mem_per_block_opt_in =
- get_cuda_max_shared_memory_per_block_opt_in(0);
- size_t const gemm_shared_mem_size =
- sizeof(typename Gemm::KernelType::SharedStorage);
- size_t const fallback_gemm_shared_mem_size =
- sizeof(typename FallbackGemm::KernelType::SharedStorage);
- if (gemm_shared_mem_size <= max_shared_mem_per_block_opt_in) {
- return cutlass_gemm_caller<Gemm>(out, a, b,
- std::forward<EpilogueArgs>(args)...);
- } else {
- TORCH_CHECK(fallback_gemm_shared_mem_size <=
- max_shared_mem_per_block_opt_in);
- return cutlass_gemm_caller<FallbackGemm>(
- out, a, b, std::forward<EpilogueArgs>(args)...);
- }
- }
- } // namespace aphrodite
|