|
@@ -1,5 +1,6 @@
|
|
|
|
+#include <cudaTypedefs.h>
|
|
|
|
+
|
|
#include <c10/cuda/CUDAGuard.h>
|
|
#include <c10/cuda/CUDAGuard.h>
|
|
-#include <cuda_runtime.h>
|
|
|
|
#include <torch/extension.h>
|
|
#include <torch/extension.h>
|
|
|
|
|
|
void cutlass_scaled_mm_dq_sm75(torch::Tensor& c, torch::Tensor const& a,
|
|
void cutlass_scaled_mm_dq_sm75(torch::Tensor& c, torch::Tensor const& a,
|
|
@@ -17,10 +18,12 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor& c, torch::Tensor const& a,
|
|
torch::Tensor const& a_scales,
|
|
torch::Tensor const& a_scales,
|
|
torch::Tensor const& b_scales);
|
|
torch::Tensor const& b_scales);
|
|
|
|
|
|
|
|
+#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
|
void cutlass_scaled_mm_dq_sm90(torch::Tensor& c, torch::Tensor const& a,
|
|
void cutlass_scaled_mm_dq_sm90(torch::Tensor& c, torch::Tensor const& a,
|
|
torch::Tensor const& b,
|
|
torch::Tensor const& b,
|
|
torch::Tensor const& a_scales,
|
|
torch::Tensor const& a_scales,
|
|
torch::Tensor const& b_scales);
|
|
torch::Tensor const& b_scales);
|
|
|
|
+#endif
|
|
|
|
|
|
void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a,
|
|
void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a,
|
|
torch::Tensor const& b, torch::Tensor const& a_scales,
|
|
torch::Tensor const& b, torch::Tensor const& a_scales,
|
|
@@ -51,7 +54,13 @@ void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a,
|
|
|
|
|
|
if (version_num >= 90) {
|
|
if (version_num >= 90) {
|
|
// Hopper
|
|
// Hopper
|
|
|
|
+
|
|
|
|
+ // Guard against compilation issues for sm90 kernels
|
|
|
|
+#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
|
cutlass_scaled_mm_dq_sm90(c, a, b, a_scales, b_scales);
|
|
cutlass_scaled_mm_dq_sm90(c, a, b, a_scales, b_scales);
|
|
|
|
+#else
|
|
|
|
+ cutlass_scaled_mm_dq_sm80(c, a, b, a_scales, b_scales);
|
|
|
|
+#endif
|
|
} else if (version_num == 89) {
|
|
} else if (version_num == 89) {
|
|
// Ada Lovelace
|
|
// Ada Lovelace
|
|
cutlass_scaled_mm_dq_sm89(c, a, b, a_scales, b_scales);
|
|
cutlass_scaled_mm_dq_sm89(c, a, b, a_scales, b_scales);
|