123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557 |
- // clang-format will break include orders
- // clang-format off
- #include <cudaTypedefs.h>
- #if defined CUDA_VERSION && CUDA_VERSION >= 12000
- #include <torch/all.h>
- #include <ATen/cuda/CUDAContext.h>
- #include <iostream>
- #include <sstream>
- #include <vector>
- #include "cutlass/cutlass.h"
- #include "cute/tensor.hpp"
- #include "cute/atom/mma_atom.hpp"
- #include "cutlass/numeric_types.h"
- #include "cutlass/util/device_memory.h"
- #include "cutlass/gemm/device/gemm_universal_adapter.h"
- #include "cutlass/gemm/kernel/gemm_universal.hpp"
- #include "cutlass/epilogue/collective/collective_builder.hpp"
- #include "cutlass/gemm/collective/collective_builder.hpp"
- #include "broadcast_load_epilogue_c3x.hpp"
- #include "common.hpp"
- // clang-format on
- using namespace cute;
- /*
- This file defines quantized GEMM operations using the CUTLASS 3.x API, for
- NVIDIA GPUs with sm90a (Hopper) or later.
- Epilogue functions can be defined to post-process the output before it is
- written to GPU memory.
- Epilogues must contain a public type named EVTCompute of type Sm90EVT,
- as well as a static prepare_args function that constructs an
- EVTCompute::Arguments struct.
- */
- namespace {
- // A wrapper for the GEMM kernel that is used to guard against compilation on
- // architectures that will never use the kernel. The purpose of this is to
- // reduce the size of the compiled binary.
- // __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
- // into code that will be executed on the device where it is defined.
- template <typename Kernel>
- struct enable_sm90_or_later : Kernel {
- template <typename... Args>
- CUTLASS_DEVICE void operator()(Args&&... args) {
- #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
- Kernel::operator()(std::forward<Args>(args)...);
- #endif
- }
- };
- /*
- * This class provides the common ScaleA and ScaleB descriptors for the
- * ScaledEpilogue and ScaledEpilogueBias classes.
- */
- template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
- struct ScaledEpilogueBase {
- protected:
- using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
- using ScaleA = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
- 0 /*Stages*/, typename EpilogueDescriptor::TileShape, float,
- Stride<Int<1>, Int<0>, Int<0>>>;
- using ScaleBDescriptor =
- cutlass::epilogue::collective::detail::RowBroadcastDescriptor<
- EpilogueDescriptor, float>;
- using ScaleB = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
- ScaleBDescriptor::Stages, typename EpilogueDescriptor::TileShape,
- typename ScaleBDescriptor::Element, Stride<Int<0>, Int<1>, Int<0>>>;
- };
- /*
- This epilogue function defines a quantized GEMM operation similar to
- torch.scaled_mm_.
- A and B may be both either int8 or fp8_e4m3. A can be
- quantized per-tensor or per-row. B can be quantized per-tensor or per-column.
- Any combination of per-tensor and per-row or column is supported.
- A and B must have symmetric quantization (zero point == 0).
- So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
- scales are applied elementwise with numpy-style broadcasting.
- ScaleA and ScaleB define the epilogue functions that apply the scales for
- the A and B operands respectively. These scales may be either per-tensor or
- per row or column.
- */
- template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
- struct ScaledEpilogue
- : private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
- private:
- using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
- using Accum = typename SUPER::Accum;
- using ScaleA = typename SUPER::ScaleA;
- using ScaleB = typename SUPER::ScaleB;
- using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
- cutlass::multiplies, float, float,
- cutlass::FloatRoundStyle::round_to_nearest>;
- using EVTCompute0 =
- cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
- using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
- cutlass::multiplies, ElementD, float,
- cutlass::FloatRoundStyle::round_to_nearest>;
- public:
- using EVTCompute =
- cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
- using ArgumentType = typename EVTCompute::Arguments;
- static ArgumentType prepare_args(torch::Tensor const& a_scales,
- torch::Tensor const& b_scales) {
- using ScaleA_Args = typename ScaleA::Arguments;
- using ScaleB_Args = typename ScaleB::Arguments;
- ScaleA_Args a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
- ScaleB_Args b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
- return ArgumentType{a_args, {b_args}};
- }
- };
- template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
- struct ScaledEpilogueBias
- : private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
- private:
- using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
- using Accum = typename SUPER::Accum;
- using ScaleA = typename SUPER::ScaleA;
- using ScaleB = typename SUPER::ScaleB;
- using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
- cutlass::multiplies, float, float,
- cutlass::FloatRoundStyle::round_to_nearest>;
- using EVTCompute0 =
- cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
- using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
- cutlass::multiply_add, ElementD, float,
- cutlass::FloatRoundStyle::round_to_nearest>;
- using BiasDescriptor =
- cutlass::epilogue::collective::detail::RowBroadcastDescriptor<
- EpilogueDescriptor, ElementD>;
- using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<
- BiasDescriptor::Stages, typename EpilogueDescriptor::TileShape, ElementD,
- Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<ElementD>, false>;
- public:
- using EVTCompute =
- cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
- using ArgumentType = typename EVTCompute::Arguments;
- static ArgumentType prepare_args(torch::Tensor const& a_scales,
- torch::Tensor const& b_scales,
- torch::Tensor const& bias) {
- using ScaleA_Args = typename ScaleA::Arguments;
- using ScaleB_Args = typename ScaleB::Arguments;
- using Bias_Args = typename Bias::Arguments;
- ScaleA_Args a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
- ScaleB_Args b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
- Bias_Args bias_args{static_cast<ElementD*>(bias.data_ptr())};
- return ArgumentType{a_args, {b_args}, bias_args};
- }
- };
- template <typename ElementAB_, typename ElementD_,
- template <typename, typename, typename> typename Epilogue_,
- typename TileShape, typename ClusterShape, typename KernelSchedule,
- typename EpilogueSchedule>
- struct cutlass_3x_gemm {
- using ElementAB = ElementAB_;
- using ElementD = ElementD_;
- using ElementAcc =
- typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
- float>::type;
- using EpilogueDescriptor =
- cutlass::epilogue::collective::detail::EpilogueDescriptor<
- TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
- ElementD, EpilogueSchedule>;
- using Epilogue = Epilogue_<ElementAcc, ElementD, EpilogueDescriptor>;
- using StrideD = Stride<int64_t, Int<1>, Int<0>>;
- using ElementC = void;
- using StrideC = StrideD;
- using EVTCompute = typename Epilogue::EVTCompute;
- using CollectiveEpilogue =
- typename cutlass::epilogue::collective::CollectiveBuilder<
- cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
- ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
- ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4,
- EpilogueSchedule, EVTCompute>::CollectiveOp;
- static constexpr size_t CEStorageSize =
- sizeof(typename CollectiveEpilogue::SharedStorage);
- using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
- static_cast<int>(CEStorageSize)>;
- // clang-format off
- using CollectiveMainloop =
- typename cutlass::gemm::collective::CollectiveBuilder<
- cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
- ElementAB, cutlass::layout::RowMajor, 16,
- ElementAB, cutlass::layout::ColumnMajor, 16,
- ElementAcc, TileShape, ClusterShape,
- Stages,
- KernelSchedule>::CollectiveOp;
- // clang-format on
- using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
- cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
- cutlass::gemm::PersistentScheduler>>;
- struct GemmKernel : public KernelType {};
- };
- template <typename Gemm, typename... EpilogueArgs>
- void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
- torch::Tensor const& b,
- EpilogueArgs&&... epilogue_params) {
- using ElementAB = typename Gemm::ElementAB;
- using ElementD = typename Gemm::ElementD;
- int32_t m = a.size(0);
- int32_t n = b.size(1);
- int32_t k = a.size(1);
- int64_t lda = a.stride(0);
- int64_t ldb = b.stride(1);
- int64_t ldc = out.stride(0);
- using StrideA = Stride<int64_t, Int<1>, Int<0>>;
- using StrideB = Stride<int64_t, Int<1>, Int<0>>;
- using StrideC = typename Gemm::StrideC;
- StrideA a_stride{lda, Int<1>{}, Int<0>{}};
- StrideB b_stride{ldb, Int<1>{}, Int<0>{}};
- StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
- using GemmKernel = typename Gemm::GemmKernel;
- typename GemmKernel::ProblemShape prob_shape{m, n, k, 1};
- auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
- auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
- typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
- b_stride};
- auto c_ptr = static_cast<ElementD*>(out.data_ptr());
- typename GemmKernel::EpilogueArguments epilogue_args{
- Gemm::Epilogue::prepare_args(
- std::forward<EpilogueArgs>(epilogue_params)...),
- c_ptr, c_stride, c_ptr, c_stride};
- typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
- prob_shape, mainloop_args, epilogue_args};
- // Launch the CUTLASS GEMM kernel.
- using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
- GemmOp gemm_op;
- CUTLASS_CHECK(gemm_op.can_implement(args));
- size_t workspace_size = gemm_op.get_workspace_size(args);
- cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
- auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
- cutlass::Status status = gemm_op.run(args, workspace.get(), stream);
- CUTLASS_CHECK(status);
- }
- template <typename InType, typename OutType,
- template <typename, typename, typename> typename Epilogue>
- struct sm90_fp8_config_default {
- // M in (128, inf)
- static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
- using KernelSchedule =
- cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
- using TileShape = Shape<_128, _128, _128>;
- using ClusterShape = Shape<_2, _1, _1>;
- using Cutlass3xGemm =
- cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
- KernelSchedule, EpilogueSchedule>;
- };
- template <typename InType, typename OutType,
- template <typename, typename, typename> typename Epilogue>
- struct sm90_fp8_config_M128 {
- // M in (64, 128]
- static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
- using KernelSchedule =
- cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
- using TileShape = Shape<_64, _128, _128>;
- using ClusterShape = Shape<_2, _1, _1>;
- using Cutlass3xGemm =
- cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
- KernelSchedule, EpilogueSchedule>;
- };
- template <typename InType, typename OutType,
- template <typename, typename, typename> typename Epilogue>
- struct sm90_fp8_config_M64 {
- // M in [1, 64]
- static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
- using KernelSchedule =
- cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
- using TileShape = Shape<_64, _64, _128>;
- using ClusterShape = Shape<_1, _8, _1>;
- using Cutlass3xGemm =
- cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
- KernelSchedule, EpilogueSchedule>;
- };
- template <typename InType, typename OutType,
- template <typename, typename, typename> typename Epilogue>
- struct sm90_int8_config_default {
- // For M > 128 and any N
- static_assert(std::is_same<InType, int8_t>());
- using KernelSchedule =
- typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
- using TileShape = Shape<_128, _128, _128>;
- using ClusterShape = Shape<_2, _1, _1>;
- using Cutlass3xGemm =
- cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
- KernelSchedule, EpilogueSchedule>;
- };
- template <typename InType, typename OutType,
- template <typename, typename, typename> typename Epilogue>
- struct sm90_int8_config_M128 {
- // For M in (64, 128] and any N
- static_assert(std::is_same<InType, int8_t>());
- using KernelSchedule =
- typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
- using TileShape = Shape<_64, _128, _128>;
- using ClusterShape = Shape<_2, _1, _1>;
- using Cutlass3xGemm =
- cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
- KernelSchedule, EpilogueSchedule>;
- };
- template <typename InType, typename OutType,
- template <typename, typename, typename> typename Epilogue>
- struct sm90_int8_config_M64 {
- // For M in (32, 64] and any N
- static_assert(std::is_same<InType, int8_t>());
- using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
- using TileShape = Shape<_64, _64, _256>;
- using ClusterShape = Shape<_1, _1, _1>;
- using Cutlass3xGemm =
- cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
- KernelSchedule, EpilogueSchedule>;
- };
- template <typename InType, typename OutType,
- template <typename, typename, typename> typename Epilogue>
- struct sm90_int8_config_M32_NBig {
- // For M in [1, 32] and N >= 8192
- static_assert(std::is_same<InType, int8_t>());
- using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
- using TileShape = Shape<_64, _128, _256>;
- using ClusterShape = Shape<_1, _4, _1>;
- using Cutlass3xGemm =
- cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
- KernelSchedule, EpilogueSchedule>;
- };
- template <typename InType, typename OutType,
- template <typename, typename, typename> typename Epilogue>
- struct sm90_int8_config_M32_NSmall {
- // For M in [1, 32] and N < 8192
- static_assert(std::is_same<InType, int8_t>());
- using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
- using TileShape = Shape<_64, _64, _256>;
- using ClusterShape = Shape<_1, _8, _1>;
- using Cutlass3xGemm =
- cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
- KernelSchedule, EpilogueSchedule>;
- };
- } // namespace
- template <typename InType, typename OutType,
- template <typename, typename, typename> typename Epilogue,
- typename... EpilogueArgs>
- void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
- torch::Tensor const& b,
- EpilogueArgs&&... args) {
- static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
- TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
- TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
- using Cutlass3xGemmDefault =
- typename sm90_fp8_config_default<InType, OutType,
- Epilogue>::Cutlass3xGemm;
- using Cutlass3xGemmM64 =
- typename sm90_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
- using Cutlass3xGemmM128 =
- typename sm90_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
- uint32_t const m = a.size(0);
- uint32_t const mp2 =
- std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
- if (mp2 <= 64) {
- // m in [1, 64]
- return cutlass_gemm_caller<Cutlass3xGemmM64>(
- out, a, b, std::forward<EpilogueArgs>(args)...);
- } else if (mp2 <= 128) {
- // m in (64, 128]
- return cutlass_gemm_caller<Cutlass3xGemmM128>(
- out, a, b, std::forward<EpilogueArgs>(args)...);
- } else {
- // m in (128, inf)
- return cutlass_gemm_caller<Cutlass3xGemmDefault>(
- out, a, b, std::forward<EpilogueArgs>(args)...);
- }
- }
- template <typename InType, typename OutType,
- template <typename, typename, typename> typename Epilogue,
- typename... EpilogueArgs>
- void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a,
- torch::Tensor const& b,
- EpilogueArgs&&... args) {
- static_assert(std::is_same<InType, int8_t>());
- TORCH_CHECK(a.dtype() == torch::kInt8);
- TORCH_CHECK(b.dtype() == torch::kInt8);
- using Cutlass3xGemmDefault =
- typename sm90_int8_config_default<InType, OutType,
- Epilogue>::Cutlass3xGemm;
- using Cutlass3xGemmM128 =
- typename sm90_int8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
- using Cutlass3xGemmM64 =
- typename sm90_int8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
- using Cutlass3xGemmM32NBig =
- typename sm90_int8_config_M32_NBig<InType, OutType,
- Epilogue>::Cutlass3xGemm;
- using Cutlass3xGemmM32NSmall =
- typename sm90_int8_config_M32_NSmall<InType, OutType,
- Epilogue>::Cutlass3xGemm;
- uint32_t const n = out.size(1);
- bool const is_small_n = n < 8192;
- uint32_t const m = a.size(0);
- uint32_t const mp2 =
- std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
- if (mp2 <= 32) {
- // m in [1, 32]
- if (is_small_n) {
- return cutlass_gemm_caller<Cutlass3xGemmM32NSmall>(
- out, a, b, std::forward<EpilogueArgs>(args)...);
- } else {
- return cutlass_gemm_caller<Cutlass3xGemmM32NBig>(
- out, a, b, std::forward<EpilogueArgs>(args)...);
- }
- } else if (mp2 <= 64) {
- // m in (32, 64]
- return cutlass_gemm_caller<Cutlass3xGemmM64>(
- out, a, b, std::forward<EpilogueArgs>(args)...);
- } else if (mp2 <= 128) {
- // m in (64, 128]
- return cutlass_gemm_caller<Cutlass3xGemmM128>(
- out, a, b, std::forward<EpilogueArgs>(args)...);
- } else {
- // m in (128, inf)
- return cutlass_gemm_caller<Cutlass3xGemmDefault>(
- out, a, b, std::forward<EpilogueArgs>(args)...);
- }
- }
- template <template <typename, typename, typename> typename Epilogue,
- typename... EpilogueArgs>
- void cutlass_scaled_mm_sm90_epilogue(torch::Tensor& out, torch::Tensor const& a,
- torch::Tensor const& b,
- EpilogueArgs&&... epilogue_args) {
- if (a.dtype() == torch::kInt8) {
- TORCH_CHECK(b.dtype() == torch::kInt8);
- if (out.dtype() == torch::kBFloat16) {
- return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
- Epilogue>(
- out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
- } else {
- TORCH_CHECK(out.dtype() == torch::kFloat16);
- return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
- out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
- }
- } else {
- TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
- TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
- if (out.dtype() == torch::kBFloat16) {
- return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
- cutlass::bfloat16_t, Epilogue>(
- out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
- } else {
- TORCH_CHECK(out.dtype() == torch::kFloat16);
- return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
- cutlass::half_t, Epilogue>(
- out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
- }
- }
- }
- void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
- torch::Tensor const& b,
- torch::Tensor const& a_scales,
- torch::Tensor const& b_scales,
- c10::optional<torch::Tensor> const& bias) {
- TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
- TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
- if (bias) {
- TORCH_CHECK(bias->dtype() == c.dtype(),
- "currently bias dtype must match output dtype ", c.dtype());
- return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogueBias>(
- c, a, b, a_scales, b_scales, *bias);
- } else {
- return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogue>(c, a, b, a_scales,
- b_scales);
- }
- }
- #endif
|