123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382 |
- #pragma once
- #include "scaled_mm_c2x.cuh"
- #include "cutlass/float8.h"
- /**
- * This file defines Gemm kernel configurations for SM89 (FP8) based on the Gemm
- * shape.
- */
- namespace aphrodite {
- template <typename InType, typename OutType,
- template <typename, typename> typename Epilogue>
- struct sm89_fp8_fallback_gemm {
- // Shared Memory required by this Gemm - 61440 bytes
- static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
- using TileShape = typename cutlass::gemm::GemmShape<64, 128, 64>;
- using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
- using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
- using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd;
- using Cutlass2xGemm =
- cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90, InType, OutType,
- Epilogue, TileShape, WarpShape, InstructionShape, 5,
- FP8MathOperator>;
- };
- struct sm89_fp8_config_default {
- // M in (256, inf)
- using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
- using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
- using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
- template <typename InType, typename OutType,
- template <typename, typename> typename Epilogue,
- typename... EpilogueArgs>
- static void dispatch(torch::Tensor& out, torch::Tensor const& a,
- torch::Tensor const& b, EpilogueArgs&&... args) {
- static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
- TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
- using FallbackGemm =
- typename sm89_fp8_fallback_gemm<InType, OutType,
- Epilogue>::Cutlass2xGemm;
- uint32_t const n = out.size(1);
- uint32_t const np2 = next_pow_2(n);
- if (np2 <= 4096) {
- using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
- return aphrodite::fallback_cutlass_gemm_caller<
- aphrodite::cutlass_2x_gemm<cutlass::arch::Sm89,
- aphrodite::enable_sm89_to_sm90, InType,
- OutType, Epilogue, TileShape, WarpShape,
- InstructionShape, 5, FP8MathOperator>,
- FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
- } else if (np2 <= 8192) {
- using TileShape = typename cutlass::gemm::GemmShape<256, 128, 64>;
- return aphrodite::fallback_cutlass_gemm_caller<
- aphrodite::cutlass_2x_gemm<cutlass::arch::Sm89,
- aphrodite::enable_sm89_to_sm90, InType,
- OutType, Epilogue, TileShape, WarpShape,
- InstructionShape, 3, FP8MathOperator>,
- FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
- } else {
- using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
- return aphrodite::fallback_cutlass_gemm_caller<
- aphrodite::cutlass_2x_gemm<cutlass::arch::Sm89,
- aphrodite::enable_sm89_to_sm90, InType,
- OutType, Epilogue, TileShape, WarpShape,
- InstructionShape, 5, FP8MathOperator>,
- FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
- }
- }
- };
- struct sm89_fp8_config_M256 {
- // M in (128, 256]
- using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
- using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
- using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
- template <typename InType, typename OutType,
- template <typename, typename> typename Epilogue,
- typename... EpilogueArgs>
- static void dispatch(torch::Tensor& out, torch::Tensor const& a,
- torch::Tensor const& b, EpilogueArgs&&... args) {
- static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
- TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
- using FallbackGemm =
- typename sm89_fp8_fallback_gemm<InType, OutType,
- Epilogue>::Cutlass2xGemm;
- uint32_t const n = out.size(1);
- uint32_t const np2 = next_pow_2(n);
- if (np2 <= 4096) {
- using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
- return aphrodite::fallback_cutlass_gemm_caller<
- aphrodite::cutlass_2x_gemm<cutlass::arch::Sm89,
- aphrodite::enable_sm89_to_sm90, InType,
- OutType, Epilogue, TileShape, WarpShape,
- InstructionShape, 3, FP8MathOperator>,
- FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
- } else {
- using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
- return aphrodite::fallback_cutlass_gemm_caller<
- aphrodite::cutlass_2x_gemm<cutlass::arch::Sm89,
- aphrodite::enable_sm89_to_sm90, InType,
- OutType, Epilogue, TileShape, WarpShape,
- InstructionShape, 5, FP8MathOperator>,
- FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
- }
- }
- };
- struct sm89_fp8_config_M128 {
- // M in (64, 128]
- using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
- using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
- using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
- template <typename InType, typename OutType,
- template <typename, typename> typename Epilogue,
- typename... EpilogueArgs>
- static void dispatch(torch::Tensor& out, torch::Tensor const& a,
- torch::Tensor const& b, EpilogueArgs&&... args) {
- static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
- TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
- using FallbackGemm =
- typename sm89_fp8_fallback_gemm<InType, OutType,
- Epilogue>::Cutlass2xGemm;
- uint32_t const n = out.size(1);
- uint32_t const np2 = next_pow_2(n);
- if (np2 <= 8192) {
- using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
- return aphrodite::fallback_cutlass_gemm_caller<
- aphrodite::cutlass_2x_gemm<cutlass::arch::Sm89,
- aphrodite::enable_sm89_to_sm90, InType,
- OutType, Epilogue, TileShape, WarpShape,
- InstructionShape, 3, FP8MathOperator>,
- FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
- } else if (np2 <= 16384) {
- using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
- return aphrodite::fallback_cutlass_gemm_caller<
- aphrodite::cutlass_2x_gemm<cutlass::arch::Sm89,
- aphrodite::enable_sm89_to_sm90, InType,
- OutType, Epilogue, TileShape, WarpShape,
- InstructionShape, 5, FP8MathOperator>,
- FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
- } else {
- using TileShape = typename cutlass::gemm::GemmShape<128, 64, 128>;
- return aphrodite::fallback_cutlass_gemm_caller<
- aphrodite::cutlass_2x_gemm<cutlass::arch::Sm89,
- aphrodite::enable_sm89_to_sm90, InType,
- OutType, Epilogue, TileShape, WarpShape,
- InstructionShape, 3, FP8MathOperator>,
- FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
- }
- }
- };
- struct sm89_fp8_config_M64 {
- // M in (32, 64]
- using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
- template <typename InType, typename OutType,
- template <typename, typename> typename Epilogue,
- typename... EpilogueArgs>
- static void dispatch(torch::Tensor& out, torch::Tensor const& a,
- torch::Tensor const& b, EpilogueArgs&&... args) {
- static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
- TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
- using FallbackGemm =
- typename sm89_fp8_fallback_gemm<InType, OutType,
- Epilogue>::Cutlass2xGemm;
- uint32_t const n = out.size(1);
- uint32_t const np2 = next_pow_2(n);
- if (np2 <= 8196) {
- using TileShape = typename cutlass::gemm::GemmShape<64, 64, 128>;
- using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
- using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd;
- return aphrodite::fallback_cutlass_gemm_caller<
- aphrodite::cutlass_2x_gemm<cutlass::arch::Sm89,
- aphrodite::enable_sm89_to_sm90, InType,
- OutType, Epilogue, TileShape, WarpShape,
- InstructionShape, 5, FP8MathOperator>,
- FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
- } else if (np2 <= 16384) {
- using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
- using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
- using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
- return aphrodite::fallback_cutlass_gemm_caller<
- aphrodite::cutlass_2x_gemm<cutlass::arch::Sm89,
- aphrodite::enable_sm89_to_sm90, InType,
- OutType, Epilogue, TileShape, WarpShape,
- InstructionShape, 3, FP8MathOperator>,
- FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
- } else {
- using TileShape = typename cutlass::gemm::GemmShape<64, 64, 128>;
- using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
- using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd;
- return aphrodite::fallback_cutlass_gemm_caller<
- aphrodite::cutlass_2x_gemm<cutlass::arch::Sm89,
- aphrodite::enable_sm89_to_sm90, InType,
- OutType, Epilogue, TileShape, WarpShape,
- InstructionShape, 5, FP8MathOperator>,
- FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
- }
- }
- };
- struct sm89_fp8_config_M32 {
- // M in (16, 32]
- using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
- using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
- template <typename InType, typename OutType,
- template <typename, typename> typename Epilogue,
- typename... EpilogueArgs>
- static void dispatch(torch::Tensor& out, torch::Tensor const& a,
- torch::Tensor const& b, EpilogueArgs&&... args) {
- static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
- TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
- using FallbackGemm =
- typename sm89_fp8_fallback_gemm<InType, OutType,
- Epilogue>::Cutlass2xGemm;
- uint32_t const n = out.size(1);
- uint32_t const np2 = next_pow_2(n);
- if (np2 <= 8192) {
- using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
- using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
- return aphrodite::fallback_cutlass_gemm_caller<
- aphrodite::cutlass_2x_gemm<cutlass::arch::Sm89,
- aphrodite::enable_sm89_to_sm90, InType,
- OutType, Epilogue, TileShape, WarpShape,
- InstructionShape, 5, FP8MathOperator>,
- FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
- } else if (np2 <= 16384) {
- using TileShape = typename cutlass::gemm::GemmShape<32, 128, 128>;
- using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
- return aphrodite::fallback_cutlass_gemm_caller<
- aphrodite::cutlass_2x_gemm<cutlass::arch::Sm89,
- aphrodite::enable_sm89_to_sm90, InType,
- OutType, Epilogue, TileShape, WarpShape,
- InstructionShape, 4, FP8MathOperator>,
- FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
- } else {
- using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
- using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
- return aphrodite::fallback_cutlass_gemm_caller<
- aphrodite::cutlass_2x_gemm<cutlass::arch::Sm89,
- aphrodite::enable_sm89_to_sm90, InType,
- OutType, Epilogue, TileShape, WarpShape,
- InstructionShape, 5, FP8MathOperator>,
- FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
- }
- }
- };
- struct sm89_fp8_config_M16 {
- // M in [1, 16]
- using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
- using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
- using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
- static const int32_t MainLoopStages = 5;
- template <typename InType, typename OutType,
- template <typename, typename> typename Epilogue,
- typename... EpilogueArgs>
- static void dispatch(torch::Tensor& out, torch::Tensor const& a,
- torch::Tensor const& b, EpilogueArgs&&... args) {
- static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
- TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
- using FallbackGemm =
- typename sm89_fp8_fallback_gemm<InType, OutType,
- Epilogue>::Cutlass2xGemm;
- uint32_t const n = out.size(1);
- uint32_t const np2 = next_pow_2(n);
- if (np2 <= 8192) {
- using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>;
- return aphrodite::fallback_cutlass_gemm_caller<
- aphrodite::cutlass_2x_gemm<
- cutlass::arch::Sm89, aphrodite::enable_sm89_to_sm90, InType,
- OutType, Epilogue, TileShape, WarpShape, InstructionShape,
- MainLoopStages, FP8MathOperator>,
- FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
- } else if (np2 <= 24576) {
- using TileShape = typename cutlass::gemm::GemmShape<16, 128, 64>;
- return aphrodite::fallback_cutlass_gemm_caller<
- aphrodite::cutlass_2x_gemm<
- cutlass::arch::Sm89, aphrodite::enable_sm89_to_sm90, InType,
- OutType, Epilogue, TileShape, WarpShape, InstructionShape,
- MainLoopStages, FP8MathOperator>,
- FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
- } else {
- using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
- return aphrodite::fallback_cutlass_gemm_caller<
- aphrodite::cutlass_2x_gemm<
- cutlass::arch::Sm89, aphrodite::enable_sm89_to_sm90, InType,
- OutType, Epilogue, TileShape, WarpShape, InstructionShape,
- MainLoopStages, FP8MathOperator>,
- FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
- }
- }
- };
- template <typename InType, typename OutType,
- template <typename, typename> typename Epilogue,
- typename... EpilogueArgs>
- inline void cutlass_gemm_sm89_fp8_dispatch(torch::Tensor& out,
- torch::Tensor const& a,
- torch::Tensor const& b,
- EpilogueArgs&&... args) {
- static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
- TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
- TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
- uint32_t const m = a.size(0);
- uint32_t const mp2 =
- std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
- if (mp2 <= 16) {
- // M in [1, 16]
- return sm89_fp8_config_M16::dispatch<InType, OutType, Epilogue>(
- out, a, b, std::forward<EpilogueArgs>(args)...);
- } else if (mp2 <= 32) {
- // M in (16, 32]
- return sm89_fp8_config_M32::dispatch<InType, OutType, Epilogue>(
- out, a, b, std::forward<EpilogueArgs>(args)...);
- } else if (mp2 <= 64) {
- // M in (32, 64]
- return sm89_fp8_config_M64::dispatch<InType, OutType, Epilogue>(
- out, a, b, std::forward<EpilogueArgs>(args)...);
- } else if (mp2 <= 128) {
- // M in (64, 128]
- return sm89_fp8_config_M128::dispatch<InType, OutType, Epilogue>(
- out, a, b, std::forward<EpilogueArgs>(args)...);
- } else if (mp2 <= 256) {
- // M in (128, 256]
- return sm89_fp8_config_M256::dispatch<InType, OutType, Epilogue>(
- out, a, b, std::forward<EpilogueArgs>(args)...);
- } else {
- // M in (256, inf)
- return sm89_fp8_config_default::dispatch<InType, OutType, Epilogue>(
- out, a, b, std::forward<EpilogueArgs>(args)...);
- }
- }
- } // namespace aphrodite
|