Parcourir la source

do not build cutlass kernels if cuda version is too low

AlpinDale il y a 7 mois
Parent
commit
67084aca5b

+ 0 - 9
aphrodite/modeling/models/mixtral.py

@@ -274,15 +274,6 @@ class MixtralAttention(nn.Module):
         self.scaling = self.head_dim**-0.5
         self.scaling = self.head_dim**-0.5
         self.rope_theta = rope_theta
         self.rope_theta = rope_theta
 
 
-        if isinstance(
-                quant_config,
-                Fp8Config) and not quant_config.is_checkpoint_fp8_serialized:
-            print_warning_once(
-                "For Mixtral FP8 quantization, we currently do not quantize "
-                "the attention layers until their FP8 performance is improved."
-            )
-            quant_config = None
-
         self.qkv_proj = QKVParallelLinear(
         self.qkv_proj = QKVParallelLinear(
             hidden_size,
             hidden_size,
             self.head_dim,
             self.head_dim,

+ 8 - 3
kernels/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu

@@ -1,3 +1,8 @@
+// clang-format will break include orders
+// clang-format off
+#include <cudaTypedefs.h>
+
+#if defined CUDA_VERSION && CUDA_VERSION >= 12000
 #include <torch/extension.h>
 #include <torch/extension.h>
 
 
 #include <ATen/cuda/CUDAContext.h>
 #include <ATen/cuda/CUDAContext.h>
@@ -6,8 +11,6 @@
 #include <sstream>
 #include <sstream>
 #include <vector>
 #include <vector>
 
 
-// clang-format will break include orders
-// clang-format off
 #include "cutlass/cutlass.h"
 #include "cutlass/cutlass.h"
 
 
 #include "cute/tensor.hpp"
 #include "cute/tensor.hpp"
@@ -237,4 +240,6 @@ void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a,
           out, a, b, a_scales, b_scales);
           out, a, b, a_scales, b_scales);
     }
     }
   }
   }
-}
+}
+
+#endif

+ 10 - 1
kernels/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu

@@ -1,5 +1,6 @@
+#include <cudaTypedefs.h>
+
 #include <c10/cuda/CUDAGuard.h>
 #include <c10/cuda/CUDAGuard.h>
-#include <cuda_runtime.h>
 #include <torch/extension.h>
 #include <torch/extension.h>
 
 
 void cutlass_scaled_mm_dq_sm75(torch::Tensor& c, torch::Tensor const& a,
 void cutlass_scaled_mm_dq_sm75(torch::Tensor& c, torch::Tensor const& a,
@@ -17,10 +18,12 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor& c, torch::Tensor const& a,
                                torch::Tensor const& a_scales,
                                torch::Tensor const& a_scales,
                                torch::Tensor const& b_scales);
                                torch::Tensor const& b_scales);
 
 
+#if defined CUDA_VERSION && CUDA_VERSION >= 12000
 void cutlass_scaled_mm_dq_sm90(torch::Tensor& c, torch::Tensor const& a,
 void cutlass_scaled_mm_dq_sm90(torch::Tensor& c, torch::Tensor const& a,
                                torch::Tensor const& b,
                                torch::Tensor const& b,
                                torch::Tensor const& a_scales,
                                torch::Tensor const& a_scales,
                                torch::Tensor const& b_scales);
                                torch::Tensor const& b_scales);
+#endif
 
 
 void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a,
 void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a,
                           torch::Tensor const& b, torch::Tensor const& a_scales,
                           torch::Tensor const& b, torch::Tensor const& a_scales,
@@ -51,7 +54,13 @@ void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a,
 
 
   if (version_num >= 90) {
   if (version_num >= 90) {
     // Hopper
     // Hopper
+
+    // Guard against compilation issues for sm90 kernels
+#if defined CUDA_VERSION && CUDA_VERSION >= 12000
     cutlass_scaled_mm_dq_sm90(c, a, b, a_scales, b_scales);
     cutlass_scaled_mm_dq_sm90(c, a, b, a_scales, b_scales);
+#else
+    cutlass_scaled_mm_dq_sm80(c, a, b, a_scales, b_scales);
+#endif
   } else if (version_num == 89) {
   } else if (version_num == 89) {
     // Ada Lovelace
     // Ada Lovelace
     cutlass_scaled_mm_dq_sm89(c, a, b, a_scales, b_scales);
     cutlass_scaled_mm_dq_sm89(c, a, b, a_scales, b_scales);