فهرست منبع

feat: update cutlass fp8 configs

AlpinDale 7 ماه پیش
والد
کامیت
f2c6791527
1فایلهای تغییر یافته به همراه90 افزوده شده و 14 حذف شده
  1. 90 14
      kernels/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu

+ 90 - 14
kernels/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu

@@ -51,6 +51,11 @@ using namespace cute;
 
 
 namespace {
 namespace {
 
 
+uint32_t next_pow_2(uint32_t const num) {
+  if (num <= 1) return num;
+  return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
+}
+
 template <typename ElementAB_, typename ElementD_, typename TileShape,
 template <typename ElementAB_, typename ElementD_, typename TileShape,
           typename ClusterShape, typename KernelSchedule,
           typename ClusterShape, typename KernelSchedule,
           typename EpilogueSchedule>
           typename EpilogueSchedule>
@@ -188,8 +193,89 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
   cutlass::Status status = gemm_op.run(args, workspace.get(), stream);
   cutlass::Status status = gemm_op.run(args, workspace.get(), stream);
   CUTLASS_CHECK(status);
   CUTLASS_CHECK(status);
 }
 }
+
+template <typename InType, typename OutType, int32_t M>
+struct sm90_fp8_config {
+  static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
+  using KernelSchedule =
+      cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
+  using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
+  using TileShape = Shape<_128, _128, _128>;
+  using ClusterShape = Shape<_2, _1, _1>;
+
+  using Cutlass3xGemm =
+      cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule,
+                      EpilogueSchedule>;
+};
+
+template <typename InType, typename OutType>
+struct sm90_fp8_config<InType, OutType, 128> {
+  static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
+  using KernelSchedule =
+      cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
+  using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
+  using TileShape = Shape<_64, _128, _128>;
+  using ClusterShape = Shape<_2, _1, _1>;
+
+  using Cutlass3xGemm =
+      cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule,
+                      EpilogueSchedule>;
+};
+
+template <typename InType, typename OutType>
+struct sm90_fp8_config<InType, OutType, 64> {
+  static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
+  using KernelSchedule =
+      cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
+  using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
+  using TileShape = Shape<_64, _64, _128>;
+  using ClusterShape = Shape<_1, _8, _1>;
+
+  using Cutlass3xGemm =
+      cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule,
+                      EpilogueSchedule>;
+};
+
 }  // namespace
 }  // namespace
 
 
+template <typename InType, typename OutType>
+void cutlass_scaled_mm_dq_sm90_fp8_dispatch(torch::Tensor& out,
+                                            torch::Tensor const& a,
+                                            torch::Tensor const& b,
+                                            torch::Tensor const& a_scales,
+                                            torch::Tensor const& b_scales) {
+  static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
+  TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
+  TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
+  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
+  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
+
+  using Cutlass3xGemmDefault =
+      typename sm90_fp8_config<InType, OutType, 0>::Cutlass3xGemm;
+  using Cutlass3xGemmM64 =
+      typename sm90_fp8_config<InType, OutType, 64>::Cutlass3xGemm;
+  using Cutlass3xGemmM128 =
+      typename sm90_fp8_config<InType, OutType, 128>::Cutlass3xGemm;
+
+  uint32_t const m = a.size(0);
+  uint32_t const mp2 =
+      std::max(static_cast<uint32_t>(64), next_pow_2(m));  // next power of 2
+
+  if (mp2 <= 64) {
+    // m in [1, 64]
+    return cutlass_scaled_mm_dq_dispatcher<Cutlass3xGemmM64>(
+        out, a, b, a_scales, b_scales);
+  } else if (mp2 <= 128) {
+    // m in (64, 128]
+    return cutlass_scaled_mm_dq_dispatcher<Cutlass3xGemmM128>(
+        out, a, b, a_scales, b_scales);
+  } else {
+    // m in (128, inf)
+    return cutlass_scaled_mm_dq_dispatcher<Cutlass3xGemmDefault>(
+        out, a, b, a_scales, b_scales);
+  }
+}
+
 void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a,
 void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a,
                                torch::Tensor const& b,
                                torch::Tensor const& b,
                                torch::Tensor const& a_scales,
                                torch::Tensor const& a_scales,
@@ -223,24 +309,14 @@ void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a,
     TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
     TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
     TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
     TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
 
 
-    using TileShape = Shape<_128, _128, _128>;
-    using ClusterShape = Shape<_1, _2, _1>;
-    using KernelSchedule =
-        typename cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative;
-    using EpilogueSchedule =
-        typename cutlass::epilogue::TmaWarpSpecializedCooperative;
-
     if (out.dtype() == torch::kBFloat16) {
     if (out.dtype() == torch::kBFloat16) {
-      return cutlass_scaled_mm_dq_dispatcher<
-          cutlass_3x_gemm<cutlass::float_e4m3_t, cutlass::bfloat16_t, TileShape,
-                          ClusterShape, KernelSchedule, EpilogueSchedule>>(
+      return cutlass_scaled_mm_dq_sm90_fp8_dispatch<cutlass::float_e4m3_t,
+                                                    cutlass::bfloat16_t>(
           out, a, b, a_scales, b_scales);
           out, a, b, a_scales, b_scales);
     } else {
     } else {
       TORCH_CHECK(out.dtype() == torch::kFloat16);
       TORCH_CHECK(out.dtype() == torch::kFloat16);
-
-      return cutlass_scaled_mm_dq_dispatcher<
-          cutlass_3x_gemm<cutlass::float_e4m3_t, cutlass::half_t, TileShape,
-                          ClusterShape, KernelSchedule, EpilogueSchedule>>(
+      return cutlass_scaled_mm_dq_sm90_fp8_dispatch<cutlass::float_e4m3_t,
+                                                    cutlass::half_t>(
           out, a, b, a_scales, b_scales);
           out, a, b, a_scales, b_scales);
     }
     }
   }
   }