|
@@ -465,29 +465,30 @@ void cutlass_scaled_mm_sm90(torch::Tensor& out, torch::Tensor const& a,
|
|
|
ScaledEpilogue>(
|
|
|
out, a, b, a_scales, b_scales);
|
|
|
}
|
|
|
- } else {
|
|
|
- TORCH_CHECK(out.dtype() == torch::kFloat16);
|
|
|
-
|
|
|
- return cutlass_gemm_caller<
|
|
|
- cutlass_3x_gemm<int8_t, cutlass::half_t, ScaledEpilogue, TileShape,
|
|
|
- ClusterShape, KernelSchedule, EpilogueSchedule>>(
|
|
|
- out, a, b, a_scales, b_scales);
|
|
|
- }
|
|
|
} else {
|
|
|
- TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
|
|
- TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
|
|
+ TORCH_CHECK(out.dtype() == torch::kFloat16);
|
|
|
|
|
|
- if (out.dtype() == torch::kBFloat16) {
|
|
|
- return cutlass_gemm_sm90_fp8_dispatch<
|
|
|
- cutlass::float_e4m3_t, cutlass::bfloat16_t, ScaledEpilogue>(
|
|
|
- out, a, b, a_scales, b_scales);
|
|
|
- } else {
|
|
|
- TORCH_CHECK(out.dtype() == torch::kFloat16);
|
|
|
- return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
|
|
- cutlass::half_t, ScaledEpilogue>(
|
|
|
- out, a, b, a_scales, b_scales);
|
|
|
- }
|
|
|
+ return cutlass_gemm_caller<
|
|
|
+ cutlass_3x_gemm<int8_t, cutlass::half_t, ScaledEpilogue, TileShape,
|
|
|
+ ClusterShape, KernelSchedule, EpilogueSchedule>>(
|
|
|
+ out, a, b, a_scales, b_scales);
|
|
|
}
|
|
|
}
|
|
|
+else {
|
|
|
+ TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
|
|
+ TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
|
|
+
|
|
|
+ if (out.dtype() == torch::kBFloat16) {
|
|
|
+ return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
|
|
+ cutlass::bfloat16_t, ScaledEpilogue>(
|
|
|
+ out, a, b, a_scales, b_scales);
|
|
|
+ } else {
|
|
|
+ TORCH_CHECK(out.dtype() == torch::kFloat16);
|
|
|
+ return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
|
|
+ cutlass::half_t, ScaledEpilogue>(
|
|
|
+ out, a, b, a_scales, b_scales);
|
|
|
+ }
|
|
|
+}
|
|
|
+}
|
|
|
|
|
|
#endif
|