123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751 |
- // clang-format will break include orders
- // clang-format off
- #include <cudaTypedefs.h>
- #if defined CUDA_VERSION && CUDA_VERSION >= 12000
- #include <torch/all.h>
- #include <ATen/cuda/CUDAContext.h>
- #include <iostream>
- #include <sstream>
- #include <vector>
- #include "cutlass/cutlass.h"
- #include "cute/tensor.hpp"
- #include "cute/atom/mma_atom.hpp"
- #include "cutlass/numeric_types.h"
- #include "cutlass/gemm/device/gemm_universal_adapter.h"
- #include "cutlass/gemm/kernel/gemm_universal.hpp"
- #include "cutlass/epilogue/collective/collective_builder.hpp"
- #include "cutlass/gemm/collective/collective_builder.hpp"
- #include "broadcast_load_epilogue_c3x.hpp"
- #include "common.hpp"
- // clang-format on
- using namespace cute;
- /*
- This file defines quantized GEMM operations using the CUTLASS 3.x API, for
- NVIDIA GPUs with sm90a (Hopper) or later.
- Epilogue functions can be defined to post-process the output before it is
- written to GPU memory.
- Epilogues must contain a public type named EVTCompute of type Sm90EVT,
- as well as a static prepare_args function that constructs an
- EVTCompute::Arguments struct.
- */
- namespace {
- // A wrapper for the GEMM kernel that is used to guard against compilation on
- // architectures that will never use the kernel. The purpose of this is to
- // reduce the size of the compiled binary.
- // __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
- // into code that will be executed on the device where it is defined.
- template <typename Kernel>
- struct enable_sm90_or_later : Kernel {
- template <typename... Args>
- CUTLASS_DEVICE void operator()(Args&&... args) {
- #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
- Kernel::operator()(std::forward<Args>(args)...);
- #endif
- }
- };
- /*
- * This class provides the common load descriptors for the
- * ScaledEpilogue[...] classes
- */
- template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
- struct ScaledEpilogueBase {
- protected:
- using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
- template <typename T>
- using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
- 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
- Stride<Int<1>, Int<0>, Int<0>>>;
- template <typename T>
- using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
- 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
- Stride<Int<0>, Int<1>, Int<0>>>;
- // Don't want to support nullptr by default
- template <typename T, bool EnableNullPtr = false>
- using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
- 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
- Stride<Int<1>, Int<0>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
- // Don't want to support nullptr by default
- template <typename T, bool EnableNullPtr = false>
- using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
- 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
- Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
- // This utility function constructs the arguments for the load descriptors
- // from a tensor. It can handle both row and column, as well as row/column or
- // scalar cases.
- template <typename Descriptor, typename T>
- static auto args_from_tensor(torch::Tensor const& tensor) {
- using Arguments = typename Descriptor::Arguments;
- auto* data_ptr = static_cast<T*>(tensor.data_ptr());
- if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
- std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
- return Arguments{data_ptr, tensor.numel() != 1};
- } else {
- static_assert(!std::is_same_v<Descriptor, ColLoad<T, true>> &&
- !std::is_same_v<Descriptor, RowLoad<T, true>>);
- return Arguments{data_ptr};
- }
- }
- // This overload handles the case where there might not be a tensor, in which
- // case a nullptr is passed and a constant (0) is used.
- template <typename Descriptor, typename T>
- static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
- using Arguments = typename Descriptor::Arguments;
- auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
- static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
- std::is_same_v<Descriptor, RowLoad<T, true>>);
- return Arguments{data_ptr};
- }
- };
- /*
- This epilogue function defines a quantized GEMM operation similar to
- torch.scaled_mm_.
- A and B may be both either int8 or fp8_e4m3. A can be
- quantized per-tensor or per-row. B can be quantized per-tensor or per-column.
- Any combination of per-tensor and per-row or column is supported.
- A and B must have symmetric quantization (zero point == 0).
- So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
- scales are applied elementwise with numpy-style broadcasting.
- ScaleA and ScaleB define the epilogue functions that apply the scales for
- the A and B operands respectively. These scales may be either per-tensor or
- per row or column.
- */
- template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
- struct ScaledEpilogue
- : private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
- private:
- using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
- using Accum = typename SUPER::Accum;
- using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
- using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
- using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
- cutlass::multiplies, float, float,
- cutlass::FloatRoundStyle::round_to_nearest>;
- using EVTCompute0 =
- cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
- using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
- cutlass::multiplies, ElementD, float,
- cutlass::FloatRoundStyle::round_to_nearest>;
- public:
- using EVTCompute =
- cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
- using ArgumentType = typename EVTCompute::Arguments;
- static ArgumentType prepare_args(torch::Tensor const& a_scales,
- torch::Tensor const& b_scales) {
- auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
- auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
- typename EVTCompute0::Arguments evt0_args{b_args};
- return ArgumentType{a_args, evt0_args};
- }
- };
- /*
- * This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
- * This bias can also be used in the per-tensor azp case, where the activation
- * zero point (azp) is used to compute an azp correction term,
- * which is folded into the bias.
- *
- * The bias tensor must be per-output channel.
- * ScaleA and ScaleB can be per-tensor or per-token/per-channel.
- */
- template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
- struct ScaledEpilogueBias
- : private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
- private:
- using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
- using Accum = typename SUPER::Accum;
- using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
- using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
- using Bias = typename SUPER::template RowLoad<ElementD>;
- using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
- cutlass::multiplies, float, float,
- cutlass::FloatRoundStyle::round_to_nearest>;
- using EVTCompute0 =
- cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
- using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
- cutlass::multiply_add, ElementD, float,
- cutlass::FloatRoundStyle::round_to_nearest>;
- public:
- using EVTCompute =
- cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
- using ArgumentType = typename EVTCompute::Arguments;
- static ArgumentType prepare_args(torch::Tensor const& a_scales,
- torch::Tensor const& b_scales,
- torch::Tensor const& bias) {
- auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
- auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
- auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
- typename EVTCompute0::Arguments evt0_args{b_args};
- return ArgumentType{a_args, evt0_args, bias_args};
- }
- };
- /*
- * This epilogue directly supports per-tensor azp in int32 form.
- * As opposed to the per-token epilogue below, this epilogue only has an azp_adj
- * term, which should already be multiplied with the scalar azp.
- * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
- *
- * This epilogue also supports bias, which remains per-channel.
- */
- template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
- struct ScaledEpilogueBiasAzp
- : private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
- private:
- using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
- using Accum = typename SUPER::Accum;
- using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
- using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
- using Bias = typename SUPER::template RowLoad<ElementD, true>;
- // This is the full AZP term, azp * J @ B, shape (1,n)
- using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
- // Compute float(accum - azp_adj), both operands are int32_t
- using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
- cutlass::minus, float, int32_t,
- cutlass::FloatRoundStyle::round_to_nearest>;
- using EVTComputeAzp =
- cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Accum, AzpWithAdj>;
- using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
- cutlass::multiplies, float, float,
- cutlass::FloatRoundStyle::round_to_nearest>;
- using EVTComputeScaleB =
- cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAzp>;
- using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
- cutlass::multiply_add, ElementD, float,
- cutlass::FloatRoundStyle::round_to_nearest>;
- public:
- using EVTCompute =
- cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
- EVTComputeScaleB, Bias>;
- using ArgumentType = typename EVTCompute::Arguments;
- static ArgumentType prepare_args(torch::Tensor const& a_scales,
- torch::Tensor const& b_scales,
- torch::Tensor const& azp_adj,
- c10::optional<torch::Tensor> const& bias) {
- auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
- auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
- auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
- auto azp_adj_args =
- SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
- typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
- typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
- return ArgumentType{a_args, evt_scale_b_args, bias_args};
- }
- };
- /*
- * This epilogue supports per-token azp by computing and applying
- * the correction term using a rank-1 update. If the term were materialized,
- * it would require O(m*n) space, and this way it only requires O(m+n) space.
- * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
- * point for each row of A.
- * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
- *
- * This epilogue also supports bias, which remains per-channel.
- */
- template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
- struct ScaledEpilogueBiasAzpToken
- : private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
- private:
- using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
- using Accum = typename SUPER::Accum;
- using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
- using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
- using Bias = typename SUPER::template RowLoad<ElementD, true>;
- // Per-token azp term, shape (m,1)
- using Azp = typename SUPER::template ColLoad<int32_t>;
- // This is the AZP adjustment term, J @ B, shape (1,n)
- using AzpAdj = typename SUPER::template RowLoad<int32_t>;
- // Compute azp * azp_adj
- using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
- cutlass::multiplies, int32_t, int32_t,
- cutlass::FloatRoundStyle::round_to_nearest>;
- using EVTComputeAzp =
- cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Azp, AzpAdj>;
- // Compute float(accum - azp*azp_adj), all operands are int32_t
- using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute<
- cutlass::minus, float, int32_t,
- cutlass::FloatRoundStyle::round_to_nearest>;
- using EVTComputeAcc =
- cutlass::epilogue::fusion::Sm90EVT<ComputeAcc, Accum, EVTComputeAzp>;
- using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
- cutlass::multiplies, float, float,
- cutlass::FloatRoundStyle::round_to_nearest>;
- using EVTComputeScaleB =
- cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAcc>;
- using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
- cutlass::multiply_add, ElementD, float,
- cutlass::FloatRoundStyle::round_to_nearest>;
- public:
- using EVTCompute =
- cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
- EVTComputeScaleB, Bias>;
- using ArgumentType = typename EVTCompute::Arguments;
- static ArgumentType prepare_args(torch::Tensor const& a_scales,
- torch::Tensor const& b_scales,
- torch::Tensor const& azp_adj,
- torch::Tensor const& azp,
- c10::optional<torch::Tensor> const& bias) {
- auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
- auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
- auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
- auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
- auto azp_adj_args =
- SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
- typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
- typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
- typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
- return ArgumentType{a_args, evt_scale_b_args, bias_args};
- }
- };
- template <typename ElementAB_, typename ElementD_,
- template <typename, typename, typename> typename Epilogue_,
- typename TileShape, typename ClusterShape, typename KernelSchedule,
- typename EpilogueSchedule>
- struct cutlass_3x_gemm {
- using ElementAB = ElementAB_;
- using ElementD = ElementD_;
- using ElementAcc =
- typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
- float>::type;
- using EpilogueDescriptor =
- cutlass::epilogue::collective::detail::EpilogueDescriptor<
- TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
- ElementD, EpilogueSchedule>;
- using Epilogue = Epilogue_<ElementAcc, ElementD, EpilogueDescriptor>;
- using StrideD = Stride<int64_t, Int<1>, Int<0>>;
- using ElementC = void;
- using StrideC = StrideD;
- using EVTCompute = typename Epilogue::EVTCompute;
- using CollectiveEpilogue =
- typename cutlass::epilogue::collective::CollectiveBuilder<
- cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
- ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
- ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4,
- EpilogueSchedule, EVTCompute>::CollectiveOp;
- static constexpr size_t CEStorageSize =
- sizeof(typename CollectiveEpilogue::SharedStorage);
- using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
- static_cast<int>(CEStorageSize)>;
- // clang-format off
- using CollectiveMainloop =
- typename cutlass::gemm::collective::CollectiveBuilder<
- cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
- ElementAB, cutlass::layout::RowMajor, 16,
- ElementAB, cutlass::layout::ColumnMajor, 16,
- ElementAcc, TileShape, ClusterShape,
- Stages,
- KernelSchedule>::CollectiveOp;
- // clang-format on
- using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
- cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
- cutlass::gemm::PersistentScheduler>>;
- struct GemmKernel : public KernelType {};
- };
- template <typename Gemm, typename... EpilogueArgs>
- void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
- torch::Tensor const& b,
- EpilogueArgs&&... epilogue_params) {
- using ElementAB = typename Gemm::ElementAB;
- using ElementD = typename Gemm::ElementD;
- int32_t m = a.size(0);
- int32_t n = b.size(1);
- int32_t k = a.size(1);
- int64_t lda = a.stride(0);
- int64_t ldb = b.stride(1);
- int64_t ldc = out.stride(0);
- using StrideA = Stride<int64_t, Int<1>, int64_t>;
- using StrideB = Stride<int64_t, Int<1>, int64_t>;
- using StrideC = typename Gemm::StrideC;
- StrideA a_stride{lda, Int<1>{}, 0};
- StrideB b_stride{ldb, Int<1>{}, 0};
- StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
- using GemmKernel = typename Gemm::GemmKernel;
- typename GemmKernel::ProblemShape prob_shape{m, n, k, 1};
- auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
- auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
- typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
- b_stride};
- auto c_ptr = static_cast<ElementD*>(out.data_ptr());
- typename GemmKernel::EpilogueArguments epilogue_args{
- Gemm::Epilogue::prepare_args(
- std::forward<EpilogueArgs>(epilogue_params)...),
- c_ptr, c_stride, c_ptr, c_stride};
- typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
- prob_shape, mainloop_args, epilogue_args};
- // Launch the CUTLASS GEMM kernel.
- using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
- GemmOp gemm_op;
- CUTLASS_CHECK(gemm_op.can_implement(args));
- size_t workspace_size = gemm_op.get_workspace_size(args);
- auto const workspace_options =
- torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
- auto workspace = torch::empty(workspace_size, workspace_options);
- auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
- cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
- CUTLASS_CHECK(status);
- }
- template <typename InType, typename OutType,
- template <typename, typename, typename> typename Epilogue>
- struct sm90_fp8_config_default {
- // M in (128, inf)
- static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
- using KernelSchedule =
- cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
- using TileShape = Shape<_128, _128, _128>;
- using ClusterShape = Shape<_2, _1, _1>;
- using Cutlass3xGemm =
- cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
- KernelSchedule, EpilogueSchedule>;
- };
- template <typename InType, typename OutType,
- template <typename, typename, typename> typename Epilogue>
- struct sm90_fp8_config_M128 {
- // M in (64, 128]
- static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
- using KernelSchedule =
- cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
- using TileShape = Shape<_64, _128, _128>;
- using ClusterShape = Shape<_2, _1, _1>;
- using Cutlass3xGemm =
- cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
- KernelSchedule, EpilogueSchedule>;
- };
- template <typename InType, typename OutType,
- template <typename, typename, typename> typename Epilogue>
- struct sm90_fp8_config_M64 {
- // M in [1, 64]
- static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
- using KernelSchedule =
- cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
- using TileShape = Shape<_64, _64, _128>;
- using ClusterShape = Shape<_1, _8, _1>;
- using Cutlass3xGemm =
- cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
- KernelSchedule, EpilogueSchedule>;
- };
- template <typename InType, typename OutType,
- template <typename, typename, typename> typename Epilogue>
- struct sm90_int8_config_default {
- // For M > 128 and any N
- static_assert(std::is_same<InType, int8_t>());
- using KernelSchedule =
- typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
- using TileShape = Shape<_128, _128, _128>;
- using ClusterShape = Shape<_2, _1, _1>;
- using Cutlass3xGemm =
- cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
- KernelSchedule, EpilogueSchedule>;
- };
- template <typename InType, typename OutType,
- template <typename, typename, typename> typename Epilogue>
- struct sm90_int8_config_M128 {
- // For M in (64, 128] and any N
- static_assert(std::is_same<InType, int8_t>());
- using KernelSchedule =
- typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
- using TileShape = Shape<_64, _128, _128>;
- using ClusterShape = Shape<_2, _1, _1>;
- using Cutlass3xGemm =
- cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
- KernelSchedule, EpilogueSchedule>;
- };
- template <typename InType, typename OutType,
- template <typename, typename, typename> typename Epilogue>
- struct sm90_int8_config_M64 {
- // For M in (32, 64] and any N
- static_assert(std::is_same<InType, int8_t>());
- using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
- using TileShape = Shape<_64, _64, _256>;
- using ClusterShape = Shape<_1, _1, _1>;
- using Cutlass3xGemm =
- cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
- KernelSchedule, EpilogueSchedule>;
- };
- template <typename InType, typename OutType,
- template <typename, typename, typename> typename Epilogue>
- struct sm90_int8_config_M32_NBig {
- // For M in [1, 32] and N >= 8192
- static_assert(std::is_same<InType, int8_t>());
- using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
- using TileShape = Shape<_64, _128, _256>;
- using ClusterShape = Shape<_1, _4, _1>;
- using Cutlass3xGemm =
- cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
- KernelSchedule, EpilogueSchedule>;
- };
- template <typename InType, typename OutType,
- template <typename, typename, typename> typename Epilogue>
- struct sm90_int8_config_M32_NSmall {
- // For M in [1, 32] and N < 8192
- static_assert(std::is_same<InType, int8_t>());
- using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
- using TileShape = Shape<_64, _64, _256>;
- using ClusterShape = Shape<_1, _8, _1>;
- using Cutlass3xGemm =
- cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
- KernelSchedule, EpilogueSchedule>;
- };
- } // namespace
- template <typename InType, typename OutType,
- template <typename, typename, typename> typename Epilogue,
- typename... EpilogueArgs>
- void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
- torch::Tensor const& b,
- EpilogueArgs&&... args) {
- static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
- TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
- TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
- using Cutlass3xGemmDefault =
- typename sm90_fp8_config_default<InType, OutType,
- Epilogue>::Cutlass3xGemm;
- using Cutlass3xGemmM64 =
- typename sm90_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
- using Cutlass3xGemmM128 =
- typename sm90_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
- uint32_t const m = a.size(0);
- uint32_t const mp2 =
- std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
- if (mp2 <= 64) {
- // m in [1, 64]
- return cutlass_gemm_caller<Cutlass3xGemmM64>(
- out, a, b, std::forward<EpilogueArgs>(args)...);
- } else if (mp2 <= 128) {
- // m in (64, 128]
- return cutlass_gemm_caller<Cutlass3xGemmM128>(
- out, a, b, std::forward<EpilogueArgs>(args)...);
- } else {
- // m in (128, inf)
- return cutlass_gemm_caller<Cutlass3xGemmDefault>(
- out, a, b, std::forward<EpilogueArgs>(args)...);
- }
- }
- template <typename InType, typename OutType,
- template <typename, typename, typename> typename Epilogue,
- typename... EpilogueArgs>
- void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a,
- torch::Tensor const& b,
- EpilogueArgs&&... args) {
- static_assert(std::is_same<InType, int8_t>());
- TORCH_CHECK(a.dtype() == torch::kInt8);
- TORCH_CHECK(b.dtype() == torch::kInt8);
- using Cutlass3xGemmDefault =
- typename sm90_int8_config_default<InType, OutType,
- Epilogue>::Cutlass3xGemm;
- using Cutlass3xGemmM128 =
- typename sm90_int8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
- using Cutlass3xGemmM64 =
- typename sm90_int8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
- using Cutlass3xGemmM32NBig =
- typename sm90_int8_config_M32_NBig<InType, OutType,
- Epilogue>::Cutlass3xGemm;
- using Cutlass3xGemmM32NSmall =
- typename sm90_int8_config_M32_NSmall<InType, OutType,
- Epilogue>::Cutlass3xGemm;
- uint32_t const n = out.size(1);
- bool const is_small_n = n < 8192;
- uint32_t const m = a.size(0);
- uint32_t const mp2 =
- std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
- if (mp2 <= 32) {
- // m in [1, 32]
- if (is_small_n) {
- return cutlass_gemm_caller<Cutlass3xGemmM32NSmall>(
- out, a, b, std::forward<EpilogueArgs>(args)...);
- } else {
- return cutlass_gemm_caller<Cutlass3xGemmM32NBig>(
- out, a, b, std::forward<EpilogueArgs>(args)...);
- }
- } else if (mp2 <= 64) {
- // m in (32, 64]
- return cutlass_gemm_caller<Cutlass3xGemmM64>(
- out, a, b, std::forward<EpilogueArgs>(args)...);
- } else if (mp2 <= 128) {
- // m in (64, 128]
- return cutlass_gemm_caller<Cutlass3xGemmM128>(
- out, a, b, std::forward<EpilogueArgs>(args)...);
- } else {
- // m in (128, inf)
- return cutlass_gemm_caller<Cutlass3xGemmDefault>(
- out, a, b, std::forward<EpilogueArgs>(args)...);
- }
- }
- template <template <typename, typename, typename> typename Epilogue,
- typename... EpilogueArgs>
- void cutlass_scaled_mm_sm90_epilogue(torch::Tensor& out, torch::Tensor const& a,
- torch::Tensor const& b,
- EpilogueArgs&&... epilogue_args) {
- if (a.dtype() == torch::kInt8) {
- TORCH_CHECK(b.dtype() == torch::kInt8);
- if (out.dtype() == torch::kBFloat16) {
- return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
- Epilogue>(
- out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
- } else {
- TORCH_CHECK(out.dtype() == torch::kFloat16);
- return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
- out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
- }
- } else {
- TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
- TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
- if (out.dtype() == torch::kBFloat16) {
- return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
- cutlass::bfloat16_t, Epilogue>(
- out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
- } else {
- TORCH_CHECK(out.dtype() == torch::kFloat16);
- return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
- cutlass::half_t, Epilogue>(
- out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
- }
- }
- }
- void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
- torch::Tensor const& b,
- torch::Tensor const& a_scales,
- torch::Tensor const& b_scales,
- c10::optional<torch::Tensor> const& bias) {
- TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
- TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
- if (bias) {
- TORCH_CHECK(bias->dtype() == c.dtype(),
- "currently bias dtype must match output dtype ", c.dtype());
- return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogueBias>(
- c, a, b, a_scales, b_scales, *bias);
- } else {
- return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogue>(c, a, b, a_scales,
- b_scales);
- }
- }
- void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
- torch::Tensor const& b,
- torch::Tensor const& a_scales,
- torch::Tensor const& b_scales,
- torch::Tensor const& azp_adj,
- c10::optional<torch::Tensor> const& azp,
- c10::optional<torch::Tensor> const& bias) {
- TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
- TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
- if (azp) {
- return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogueBiasAzpToken>(
- out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
- } else {
- return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogueBiasAzp>(
- out, a, b, a_scales, b_scales, azp_adj, bias);
- }
- }
- #endif
|