|
@@ -51,6 +51,11 @@ using namespace cute;
|
|
|
|
|
|
namespace {
|
|
namespace {
|
|
|
|
|
|
|
|
+uint32_t next_pow_2(uint32_t const num) {
|
|
|
|
+ if (num <= 1) return num;
|
|
|
|
+ return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
|
|
|
|
+}
|
|
|
|
+
|
|
template <typename ElementAB_, typename ElementD_, typename TileShape,
|
|
template <typename ElementAB_, typename ElementD_, typename TileShape,
|
|
typename ClusterShape, typename KernelSchedule,
|
|
typename ClusterShape, typename KernelSchedule,
|
|
typename EpilogueSchedule>
|
|
typename EpilogueSchedule>
|
|
@@ -188,8 +193,89 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
|
|
cutlass::Status status = gemm_op.run(args, workspace.get(), stream);
|
|
cutlass::Status status = gemm_op.run(args, workspace.get(), stream);
|
|
CUTLASS_CHECK(status);
|
|
CUTLASS_CHECK(status);
|
|
}
|
|
}
|
|
|
|
+
|
|
|
|
+template <typename InType, typename OutType, int32_t M>
|
|
|
|
+struct sm90_fp8_config {
|
|
|
|
+ static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
|
|
|
+ using KernelSchedule =
|
|
|
|
+ cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
|
|
|
+ using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
|
|
|
+ using TileShape = Shape<_128, _128, _128>;
|
|
|
|
+ using ClusterShape = Shape<_2, _1, _1>;
|
|
|
|
+
|
|
|
|
+ using Cutlass3xGemm =
|
|
|
|
+ cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule,
|
|
|
|
+ EpilogueSchedule>;
|
|
|
|
+};
|
|
|
|
+
|
|
|
|
+template <typename InType, typename OutType>
|
|
|
|
+struct sm90_fp8_config<InType, OutType, 128> {
|
|
|
|
+ static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
|
|
|
+ using KernelSchedule =
|
|
|
|
+ cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
|
|
|
+ using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
|
|
|
+ using TileShape = Shape<_64, _128, _128>;
|
|
|
|
+ using ClusterShape = Shape<_2, _1, _1>;
|
|
|
|
+
|
|
|
|
+ using Cutlass3xGemm =
|
|
|
|
+ cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule,
|
|
|
|
+ EpilogueSchedule>;
|
|
|
|
+};
|
|
|
|
+
|
|
|
|
+template <typename InType, typename OutType>
|
|
|
|
+struct sm90_fp8_config<InType, OutType, 64> {
|
|
|
|
+ static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
|
|
|
+ using KernelSchedule =
|
|
|
|
+ cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
|
|
|
+ using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
|
|
|
+ using TileShape = Shape<_64, _64, _128>;
|
|
|
|
+ using ClusterShape = Shape<_1, _8, _1>;
|
|
|
|
+
|
|
|
|
+ using Cutlass3xGemm =
|
|
|
|
+ cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule,
|
|
|
|
+ EpilogueSchedule>;
|
|
|
|
+};
|
|
|
|
+
|
|
} // namespace
|
|
} // namespace
|
|
|
|
|
|
|
|
+template <typename InType, typename OutType>
|
|
|
|
+void cutlass_scaled_mm_dq_sm90_fp8_dispatch(torch::Tensor& out,
|
|
|
|
+ torch::Tensor const& a,
|
|
|
|
+ torch::Tensor const& b,
|
|
|
|
+ torch::Tensor const& a_scales,
|
|
|
|
+ torch::Tensor const& b_scales) {
|
|
|
|
+ static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
|
|
|
+ TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
|
|
|
+ TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
|
|
|
+ TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
|
|
|
+ TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
|
|
|
+
|
|
|
|
+ using Cutlass3xGemmDefault =
|
|
|
|
+ typename sm90_fp8_config<InType, OutType, 0>::Cutlass3xGemm;
|
|
|
|
+ using Cutlass3xGemmM64 =
|
|
|
|
+ typename sm90_fp8_config<InType, OutType, 64>::Cutlass3xGemm;
|
|
|
|
+ using Cutlass3xGemmM128 =
|
|
|
|
+ typename sm90_fp8_config<InType, OutType, 128>::Cutlass3xGemm;
|
|
|
|
+
|
|
|
|
+ uint32_t const m = a.size(0);
|
|
|
|
+ uint32_t const mp2 =
|
|
|
|
+ std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
|
|
|
|
+
|
|
|
|
+ if (mp2 <= 64) {
|
|
|
|
+ // m in [1, 64]
|
|
|
|
+ return cutlass_scaled_mm_dq_dispatcher<Cutlass3xGemmM64>(
|
|
|
|
+ out, a, b, a_scales, b_scales);
|
|
|
|
+ } else if (mp2 <= 128) {
|
|
|
|
+ // m in (64, 128]
|
|
|
|
+ return cutlass_scaled_mm_dq_dispatcher<Cutlass3xGemmM128>(
|
|
|
|
+ out, a, b, a_scales, b_scales);
|
|
|
|
+ } else {
|
|
|
|
+ // m in (128, inf)
|
|
|
|
+ return cutlass_scaled_mm_dq_dispatcher<Cutlass3xGemmDefault>(
|
|
|
|
+ out, a, b, a_scales, b_scales);
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a,
|
|
void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a,
|
|
torch::Tensor const& b,
|
|
torch::Tensor const& b,
|
|
torch::Tensor const& a_scales,
|
|
torch::Tensor const& a_scales,
|
|
@@ -223,24 +309,14 @@ void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a,
|
|
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
|
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
|
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
|
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
|
|
|
|
|
- using TileShape = Shape<_128, _128, _128>;
|
|
|
|
- using ClusterShape = Shape<_1, _2, _1>;
|
|
|
|
- using KernelSchedule =
|
|
|
|
- typename cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative;
|
|
|
|
- using EpilogueSchedule =
|
|
|
|
- typename cutlass::epilogue::TmaWarpSpecializedCooperative;
|
|
|
|
-
|
|
|
|
if (out.dtype() == torch::kBFloat16) {
|
|
if (out.dtype() == torch::kBFloat16) {
|
|
- return cutlass_scaled_mm_dq_dispatcher<
|
|
|
|
- cutlass_3x_gemm<cutlass::float_e4m3_t, cutlass::bfloat16_t, TileShape,
|
|
|
|
- ClusterShape, KernelSchedule, EpilogueSchedule>>(
|
|
|
|
|
|
+ return cutlass_scaled_mm_dq_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
|
|
|
+ cutlass::bfloat16_t>(
|
|
out, a, b, a_scales, b_scales);
|
|
out, a, b, a_scales, b_scales);
|
|
} else {
|
|
} else {
|
|
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
|
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
|
-
|
|
|
|
- return cutlass_scaled_mm_dq_dispatcher<
|
|
|
|
- cutlass_3x_gemm<cutlass::float_e4m3_t, cutlass::half_t, TileShape,
|
|
|
|
- ClusterShape, KernelSchedule, EpilogueSchedule>>(
|
|
|
|
|
|
+ return cutlass_scaled_mm_dq_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
|
|
|
+ cutlass::half_t>(
|
|
out, a, b, a_scales, b_scales);
|
|
out, a, b, a_scales, b_scales);
|
|
}
|
|
}
|
|
}
|
|
}
|