|
@@ -48,9 +48,44 @@ using namespace cute;
|
|
|
|
|
|
namespace {
|
|
|
|
|
|
-template <typename Arch, typename ElementAB_, typename ElementD_,
|
|
|
- typename TileShape, typename WarpShape, typename InstructionShape,
|
|
|
- int32_t MainLoopStages>
|
|
|
+// Wrappers for the GEMM kernel that is used to guard against compilation on
|
|
|
+// architectures that will never use the kernel. The purpose of this is to
|
|
|
+// reduce the size of the compiled binary.
|
|
|
+// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
|
|
|
+// into code that will be executed on the device where it is defined.
|
|
|
+template <typename Kernel>
|
|
|
+struct enable_sm75_to_sm80 : Kernel {
|
|
|
+ template <typename... Args>
|
|
|
+ CUTLASS_DEVICE static void invoke(Args&&... args) {
|
|
|
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800
|
|
|
+ Kernel::invoke(std::forward<Args>(args)...);
|
|
|
+#endif
|
|
|
+ }
|
|
|
+};
|
|
|
+
|
|
|
+template <typename Kernel>
|
|
|
+struct enable_sm80_to_sm89 : Kernel {
|
|
|
+ template <typename... Args>
|
|
|
+ CUTLASS_DEVICE static void invoke(Args&&... args) {
|
|
|
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890
|
|
|
+ Kernel::invoke(std::forward<Args>(args)...);
|
|
|
+#endif
|
|
|
+ }
|
|
|
+};
|
|
|
+
|
|
|
+template <typename Kernel>
|
|
|
+struct enable_sm89_to_sm90 : Kernel {
|
|
|
+ template <typename... Args>
|
|
|
+ CUTLASS_DEVICE static void invoke(Args&&... args) {
|
|
|
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900
|
|
|
+ Kernel::invoke(std::forward<Args>(args)...);
|
|
|
+#endif
|
|
|
+ }
|
|
|
+};
|
|
|
+
|
|
|
+template <typename Arch, template <typename> typename ArchGuard,
|
|
|
+ typename ElementAB_, typename ElementD_, typename TileShape,
|
|
|
+ typename WarpShape, typename InstructionShape, int32_t MainLoopStages>
|
|
|
struct cutlass_2x_gemm {
|
|
|
using ElementAB = ElementAB_;
|
|
|
using ElementD = ElementD_;
|
|
@@ -101,7 +136,7 @@ struct cutlass_2x_gemm {
|
|
|
using RowMajor = typename cutlass::layout::RowMajor;
|
|
|
using ColumnMajor = typename cutlass::layout::ColumnMajor;
|
|
|
using KernelType =
|
|
|
- typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
|
|
|
+ ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
|
|
|
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16,
|
|
|
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16,
|
|
|
float, cutlass::layout::RowMajor, 4,
|
|
@@ -112,7 +147,7 @@ struct cutlass_2x_gemm {
|
|
|
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
|
|
|
MainLoopStages, Operator,
|
|
|
1 /* epilogue stages */
|
|
|
- >::GemmKernel;
|
|
|
+ >::GemmKernel>;
|
|
|
// clang-format on
|
|
|
|
|
|
using Op = cutlass::gemm::device::GemmUniversalAdapter<KernelType>;
|
|
@@ -208,16 +243,16 @@ void cutlass_scaled_mm_dq_sm75(torch::Tensor& out, torch::Tensor const& a,
|
|
|
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
|
|
|
|
|
|
if (out.dtype() == torch::kBFloat16) {
|
|
|
- return cutlass_scaled_mm_dq_dispatcher<
|
|
|
- cutlass_2x_gemm<cutlass::arch::Sm75, int8_t, cutlass::bfloat16_t,
|
|
|
- TileShape, WarpShape, InstructionShape, 2>>(
|
|
|
- out, a, b, a_scales, b_scales);
|
|
|
+ return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
|
|
|
+ cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::bfloat16_t,
|
|
|
+ TileShape, WarpShape, InstructionShape, 2>>(out, a, b, a_scales,
|
|
|
+ b_scales);
|
|
|
} else {
|
|
|
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
|
|
- return cutlass_scaled_mm_dq_dispatcher<
|
|
|
- cutlass_2x_gemm<cutlass::arch::Sm75, int8_t, cutlass::half_t, TileShape,
|
|
|
- WarpShape, InstructionShape, 2>>(out, a, b, a_scales,
|
|
|
- b_scales);
|
|
|
+ return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
|
|
|
+ cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::half_t,
|
|
|
+ TileShape, WarpShape, InstructionShape, 2>>(out, a, b, a_scales,
|
|
|
+ b_scales);
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -235,16 +270,16 @@ void cutlass_scaled_mm_dq_sm80(torch::Tensor& out, torch::Tensor const& a,
|
|
|
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
|
|
|
|
|
if (out.dtype() == torch::kBFloat16) {
|
|
|
- return cutlass_scaled_mm_dq_dispatcher<
|
|
|
- cutlass_2x_gemm<cutlass::arch::Sm80, int8_t, cutlass::bfloat16_t,
|
|
|
- TileShape, WarpShape, InstructionShape, 5>>(
|
|
|
- out, a, b, a_scales, b_scales);
|
|
|
+ return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
|
|
|
+ cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t, cutlass::bfloat16_t,
|
|
|
+ TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
|
|
|
+ b_scales);
|
|
|
} else {
|
|
|
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
|
|
- return cutlass_scaled_mm_dq_dispatcher<
|
|
|
- cutlass_2x_gemm<cutlass::arch::Sm80, int8_t, cutlass::half_t, TileShape,
|
|
|
- WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
|
|
|
- b_scales);
|
|
|
+ return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
|
|
|
+ cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t, cutlass::half_t,
|
|
|
+ TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
|
|
|
+ b_scales);
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -263,16 +298,16 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a,
|
|
|
TORCH_CHECK(b.dtype() == torch::kInt8);
|
|
|
|
|
|
if (out.dtype() == torch::kBFloat16) {
|
|
|
- return cutlass_scaled_mm_dq_dispatcher<
|
|
|
- cutlass_2x_gemm<cutlass::arch::Sm89, int8_t, cutlass::bfloat16_t,
|
|
|
- TileShape, WarpShape, InstructionShape, 5>>(
|
|
|
- out, a, b, a_scales, b_scales);
|
|
|
+ return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
|
|
|
+ cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::bfloat16_t,
|
|
|
+ TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
|
|
|
+ b_scales);
|
|
|
} else {
|
|
|
assert(out.dtype() == torch::kFloat16);
|
|
|
- return cutlass_scaled_mm_dq_dispatcher<
|
|
|
- cutlass_2x_gemm<cutlass::arch::Sm89, int8_t, cutlass::half_t,
|
|
|
- TileShape, WarpShape, InstructionShape, 5>>(
|
|
|
- out, a, b, a_scales, b_scales);
|
|
|
+ return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
|
|
|
+ cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::half_t,
|
|
|
+ TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
|
|
|
+ b_scales);
|
|
|
}
|
|
|
} else {
|
|
|
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
|
@@ -280,15 +315,15 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a,
|
|
|
|
|
|
if (out.dtype() == torch::kBFloat16) {
|
|
|
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
|
|
|
- cutlass::arch::Sm89, cutlass::float_e4m3_t, cutlass::bfloat16_t,
|
|
|
- TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
|
|
|
- b_scales);
|
|
|
+ cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t,
|
|
|
+ cutlass::bfloat16_t, TileShape, WarpShape, InstructionShape, 5>>(
|
|
|
+ out, a, b, a_scales, b_scales);
|
|
|
} else {
|
|
|
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
|
|
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
|
|
|
- cutlass::arch::Sm89, cutlass::float_e4m3_t, cutlass::half_t,
|
|
|
- TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
|
|
|
- b_scales);
|
|
|
+ cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t,
|
|
|
+ cutlass::half_t, TileShape, WarpShape, InstructionShape, 5>>(
|
|
|
+ out, a, b, a_scales, b_scales);
|
|
|
}
|
|
|
}
|
|
|
}
|