Browse Source

chore: gpu arch guard for cutlass w8a8 kernels

AlpinDale 7 months ago
parent
commit
e32f506e17

+ 70 - 35
kernels/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu

@@ -48,9 +48,44 @@ using namespace cute;
 
 namespace {
 
-template <typename Arch, typename ElementAB_, typename ElementD_,
-          typename TileShape, typename WarpShape, typename InstructionShape,
-          int32_t MainLoopStages>
+// Wrappers for the GEMM kernel that is used to guard against compilation on
+// architectures that will never use the kernel. The purpose of this is to
+// reduce the size of the compiled binary.
+// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
+// into code that will be executed on the device where it is defined.
+template <typename Kernel>
+struct enable_sm75_to_sm80 : Kernel {
+  template <typename... Args>
+  CUTLASS_DEVICE static void invoke(Args&&... args) {
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800
+    Kernel::invoke(std::forward<Args>(args)...);
+#endif
+  }
+};
+
+template <typename Kernel>
+struct enable_sm80_to_sm89 : Kernel {
+  template <typename... Args>
+  CUTLASS_DEVICE static void invoke(Args&&... args) {
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890
+    Kernel::invoke(std::forward<Args>(args)...);
+#endif
+  }
+};
+
+template <typename Kernel>
+struct enable_sm89_to_sm90 : Kernel {
+  template <typename... Args>
+  CUTLASS_DEVICE static void invoke(Args&&... args) {
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900
+    Kernel::invoke(std::forward<Args>(args)...);
+#endif
+  }
+};
+
+template <typename Arch, template <typename> typename ArchGuard,
+          typename ElementAB_, typename ElementD_, typename TileShape,
+          typename WarpShape, typename InstructionShape, int32_t MainLoopStages>
 struct cutlass_2x_gemm {
   using ElementAB = ElementAB_;
   using ElementD = ElementD_;
@@ -101,7 +136,7 @@ struct cutlass_2x_gemm {
   using RowMajor = typename cutlass::layout::RowMajor;
   using ColumnMajor = typename cutlass::layout::ColumnMajor;
   using KernelType = 
-    typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
+    ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
       ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16, 
       ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16, 
       float, cutlass::layout::RowMajor, 4,
@@ -112,7 +147,7 @@ struct cutlass_2x_gemm {
       cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
       MainLoopStages, Operator,
       1 /* epilogue stages */
-      >::GemmKernel;
+      >::GemmKernel>;
   // clang-format on
 
   using Op = cutlass::gemm::device::GemmUniversalAdapter<KernelType>;
@@ -208,16 +243,16 @@ void cutlass_scaled_mm_dq_sm75(torch::Tensor& out, torch::Tensor const& a,
   using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
 
   if (out.dtype() == torch::kBFloat16) {
-    return cutlass_scaled_mm_dq_dispatcher<
-        cutlass_2x_gemm<cutlass::arch::Sm75, int8_t, cutlass::bfloat16_t,
-                        TileShape, WarpShape, InstructionShape, 2>>(
-        out, a, b, a_scales, b_scales);
+    return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
+        cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::bfloat16_t,
+        TileShape, WarpShape, InstructionShape, 2>>(out, a, b, a_scales,
+                                                    b_scales);
   } else {
     TORCH_CHECK(out.dtype() == torch::kFloat16);
-    return cutlass_scaled_mm_dq_dispatcher<
-        cutlass_2x_gemm<cutlass::arch::Sm75, int8_t, cutlass::half_t, TileShape,
-                        WarpShape, InstructionShape, 2>>(out, a, b, a_scales,
-                                                         b_scales);
+    return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
+        cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::half_t,
+        TileShape, WarpShape, InstructionShape, 2>>(out, a, b, a_scales,
+                                                    b_scales);
   }
 }
 
@@ -235,16 +270,16 @@ void cutlass_scaled_mm_dq_sm80(torch::Tensor& out, torch::Tensor const& a,
   using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
 
   if (out.dtype() == torch::kBFloat16) {
-    return cutlass_scaled_mm_dq_dispatcher<
-        cutlass_2x_gemm<cutlass::arch::Sm80, int8_t, cutlass::bfloat16_t,
-                        TileShape, WarpShape, InstructionShape, 5>>(
-        out, a, b, a_scales, b_scales);
+    return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
+        cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t, cutlass::bfloat16_t,
+        TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
+                                                    b_scales);
   } else {
     TORCH_CHECK(out.dtype() == torch::kFloat16);
-    return cutlass_scaled_mm_dq_dispatcher<
-        cutlass_2x_gemm<cutlass::arch::Sm80, int8_t, cutlass::half_t, TileShape,
-                        WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
-                                                         b_scales);
+    return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
+        cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t, cutlass::half_t,
+        TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
+                                                    b_scales);
   }
 }
 
@@ -263,16 +298,16 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a,
     TORCH_CHECK(b.dtype() == torch::kInt8);
 
     if (out.dtype() == torch::kBFloat16) {
-      return cutlass_scaled_mm_dq_dispatcher<
-          cutlass_2x_gemm<cutlass::arch::Sm89, int8_t, cutlass::bfloat16_t,
-                          TileShape, WarpShape, InstructionShape, 5>>(
-          out, a, b, a_scales, b_scales);
+      return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
+          cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::bfloat16_t,
+          TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
+                                                      b_scales);
     } else {
       assert(out.dtype() == torch::kFloat16);
-      return cutlass_scaled_mm_dq_dispatcher<
-          cutlass_2x_gemm<cutlass::arch::Sm89, int8_t, cutlass::half_t,
-                          TileShape, WarpShape, InstructionShape, 5>>(
-          out, a, b, a_scales, b_scales);
+      return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
+          cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::half_t,
+          TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
+                                                      b_scales);
     }
   } else {
     TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
@@ -280,15 +315,15 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a,
 
     if (out.dtype() == torch::kBFloat16) {
       return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
-          cutlass::arch::Sm89, cutlass::float_e4m3_t, cutlass::bfloat16_t,
-          TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
-                                                      b_scales);
+          cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t,
+          cutlass::bfloat16_t, TileShape, WarpShape, InstructionShape, 5>>(
+          out, a, b, a_scales, b_scales);
     } else {
       TORCH_CHECK(out.dtype() == torch::kFloat16);
       return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
-          cutlass::arch::Sm89, cutlass::float_e4m3_t, cutlass::half_t,
-          TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
-                                                      b_scales);
+          cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t,
+          cutlass::half_t, TileShape, WarpShape, InstructionShape, 5>>(
+          out, a, b, a_scales, b_scales);
     }
   }
 }

+ 17 - 2
kernels/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu

@@ -56,6 +56,21 @@ uint32_t next_pow_2(uint32_t const num) {
   return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
 }
 
+// A wrapper for the GEMM kernel that is used to guard against compilation on
+// architectures that will never use the kernel. The purpose of this is to
+// reduce the size of the compiled binary.
+// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
+// into code that will be executed on the device where it is defined.
+template <typename Kernel>
+struct enable_sm90_or_later : Kernel {
+  template <typename... Args>
+  CUTLASS_DEVICE void operator()(Args&&... args) {
+  #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
+    Kernel::operator()(std::forward<Args>(args)...);
+  #endif
+  }
+};
+
 template <typename ElementAB_, typename ElementD_, typename TileShape,
           typename ClusterShape, typename KernelSchedule,
           typename EpilogueSchedule>
@@ -126,9 +141,9 @@ struct cutlass_3x_gemm {
           KernelSchedule>::CollectiveOp;
   // clang-format on
 
-  using KernelType = cutlass::gemm::kernel::GemmUniversal<
+  using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
       cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
-      cutlass::gemm::PersistentScheduler>;
+      cutlass::gemm::PersistentScheduler>>;
 
   struct GemmKernel : public KernelType {};
 };