123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203 |
- #include <stddef.h>
- #include <torch/all.h>
- #include "cutlass/cutlass.h"
- #include "scaled_mm_c2x.cuh"
- #include "scaled_mm_c2x_sm75_dispatch.cuh"
- #include "scaled_mm_c2x_sm80_dispatch.cuh"
- #include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
- #include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
- /*
- This file defines quantized GEMM operations using the CUTLASS 2.x API, for
- NVIDIA GPUs with SM versions prior to sm90 (Hopper).
- */
- template <template <typename, typename> typename Epilogue,
- typename... EpilogueArgs>
- void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a,
- torch::Tensor const& b,
- EpilogueArgs&&... epilogue_args) {
- TORCH_CHECK(a.dtype() == torch::kInt8);
- TORCH_CHECK(b.dtype() == torch::kInt8);
- if (out.dtype() == torch::kBFloat16) {
- return aphrodite::cutlass_gemm_sm75_dispatch<int8_t, cutlass::bfloat16_t,
- Epilogue>(
- out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
- } else {
- TORCH_CHECK(out.dtype() == torch::kFloat16);
- return aphrodite::cutlass_gemm_sm75_dispatch<int8_t, cutlass::half_t,
- Epilogue>(
- out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
- }
- }
- void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
- torch::Tensor const& b,
- torch::Tensor const& a_scales,
- torch::Tensor const& b_scales,
- c10::optional<torch::Tensor> const& bias) {
- TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
- TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
- if (bias) {
- TORCH_CHECK(bias->dtype() == out.dtype(),
- "currently bias dtype must match output dtype ", out.dtype());
- return cutlass_scaled_mm_sm75_epilogue<aphrodite::ScaledEpilogueBias>(
- out, a, b, a_scales, b_scales, *bias);
- } else {
- return cutlass_scaled_mm_sm75_epilogue<aphrodite::ScaledEpilogue>(
- out, a, b, a_scales, b_scales);
- }
- }
- void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a,
- torch::Tensor const& b,
- torch::Tensor const& a_scales,
- torch::Tensor const& b_scales,
- torch::Tensor const& azp_adj,
- c10::optional<torch::Tensor> const& azp,
- c10::optional<torch::Tensor> const& bias) {
- TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
- TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
- if (azp) {
- return cutlass_scaled_mm_sm75_epilogue<
- aphrodite::ScaledEpilogueBiasAzpToken>(out, a, b, a_scales, b_scales,
- azp_adj, *azp, bias);
- } else {
- return cutlass_scaled_mm_sm75_epilogue<aphrodite::ScaledEpilogueBiasAzp>(
- out, a, b, a_scales, b_scales, azp_adj, bias);
- }
- }
- template <template <typename, typename> typename Epilogue,
- typename... EpilogueArgs>
- void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a,
- torch::Tensor const& b,
- EpilogueArgs&&... epilogue_args) {
- TORCH_CHECK(a.dtype() == torch::kInt8);
- TORCH_CHECK(b.dtype() == torch::kInt8);
- if (out.dtype() == torch::kBFloat16) {
- return aphrodite::cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t,
- Epilogue>(
- out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
- } else {
- TORCH_CHECK(out.dtype() == torch::kFloat16);
- return aphrodite::cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t,
- Epilogue>(
- out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
- }
- }
- void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
- torch::Tensor const& b,
- torch::Tensor const& a_scales,
- torch::Tensor const& b_scales,
- c10::optional<torch::Tensor> const& bias) {
- TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
- TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
- if (bias) {
- TORCH_CHECK(bias->dtype() == out.dtype(),
- "currently bias dtype must match output dtype ", out.dtype());
- return cutlass_scaled_mm_sm80_epilogue<aphrodite::ScaledEpilogueBias>(
- out, a, b, a_scales, b_scales, *bias);
- } else {
- return cutlass_scaled_mm_sm80_epilogue<aphrodite::ScaledEpilogue>(
- out, a, b, a_scales, b_scales);
- }
- }
- void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a,
- torch::Tensor const& b,
- torch::Tensor const& a_scales,
- torch::Tensor const& b_scales,
- torch::Tensor const& azp_adj,
- c10::optional<torch::Tensor> const& azp,
- c10::optional<torch::Tensor> const& bias) {
- TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
- TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
- if (azp) {
- return cutlass_scaled_mm_sm80_epilogue<
- aphrodite::ScaledEpilogueBiasAzpToken>(out, a, b, a_scales, b_scales,
- azp_adj, *azp, bias);
- } else {
- return cutlass_scaled_mm_sm80_epilogue<aphrodite::ScaledEpilogueBiasAzp>(
- out, a, b, a_scales, b_scales, azp_adj, bias);
- }
- }
- template <template <typename, typename> typename Epilogue,
- typename... EpilogueArgs>
- void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
- torch::Tensor const& b,
- EpilogueArgs&&... epilogue_args) {
- if (a.dtype() == torch::kInt8) {
- TORCH_CHECK(b.dtype() == torch::kInt8);
- if (out.dtype() == torch::kBFloat16) {
- return aphrodite::cutlass_gemm_sm89_int8_dispatch<
- int8_t, cutlass::bfloat16_t, Epilogue>(
- out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
- } else {
- assert(out.dtype() == torch::kFloat16);
- return aphrodite::cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t,
- Epilogue>(
- out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
- }
- } else {
- TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
- TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
- if (out.dtype() == torch::kBFloat16) {
- return aphrodite::cutlass_gemm_sm89_fp8_dispatch<
- cutlass::float_e4m3_t, cutlass::bfloat16_t, Epilogue>(
- out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
- } else {
- TORCH_CHECK(out.dtype() == torch::kFloat16);
- return aphrodite::cutlass_gemm_sm89_fp8_dispatch<
- cutlass::float_e4m3_t, cutlass::half_t, Epilogue>(
- out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
- }
- }
- }
- void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
- torch::Tensor const& b,
- torch::Tensor const& a_scales,
- torch::Tensor const& b_scales,
- c10::optional<torch::Tensor> const& bias) {
- TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
- TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
- if (bias) {
- TORCH_CHECK(bias->dtype() == out.dtype(),
- "currently bias dtype must match output dtype ", out.dtype());
- return cutlass_scaled_mm_sm89_epilogue<aphrodite::ScaledEpilogueBias>(
- out, a, b, a_scales, b_scales, *bias);
- } else {
- return cutlass_scaled_mm_sm89_epilogue<aphrodite::ScaledEpilogue>(
- out, a, b, a_scales, b_scales);
- }
- }
- void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a,
- torch::Tensor const& b,
- torch::Tensor const& a_scales,
- torch::Tensor const& b_scales,
- torch::Tensor const& azp_adj,
- c10::optional<torch::Tensor> const& azp,
- c10::optional<torch::Tensor> const& bias) {
- TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
- TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
- if (azp) {
- return cutlass_scaled_mm_sm89_epilogue<
- aphrodite::ScaledEpilogueBiasAzpToken>(out, a, b, a_scales, b_scales,
- azp_adj, *azp, bias);
- } else {
- return cutlass_scaled_mm_sm89_epilogue<aphrodite::ScaledEpilogueBiasAzp>(
- out, a, b, a_scales, b_scales, azp_adj, bias);
- }
- }
|