machete_pytorch.cu 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. #include "machete_mm_launcher.cuh"
  2. #include "machete_prepack_launcher.cuh"
  3. #include "core/scalar_type.hpp"
  4. namespace machete {
  5. using namespace aphrodite;
  6. //
  7. // Utils (type dispatching)
  8. //
  9. template <typename Fn>
  10. static auto scalar_type_dispatch(ScalarType const& type, Fn fn) {
  11. if (type == aphrodite::kU4) {
  12. return fn(cutlass::uint4b_t{});
  13. } else if (type == aphrodite::kU8) {
  14. return fn(cutlass::uint8_t{});
  15. } else if (type == aphrodite::kU4B8) {
  16. return fn(cutlass::aphrodite_uint4b8_t{});
  17. } else if (type == aphrodite::kU8B128) {
  18. return fn(cutlass::aphrodite_uint8b128_t{});
  19. } else {
  20. TORCH_CHECK(false, "Unsupported type ", type.str());
  21. }
  22. }
  23. #define AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(...) \
  24. AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(__VA_ARGS__)
  25. #define AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(TYPE, NAME, ...) \
  26. AT_DISPATCH_SWITCH(TYPE, NAME, \
  27. AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(__VA_ARGS__))
  28. //
  29. // Interface
  30. //
  31. std::vector<std::string> supported_schedules(ScalarTypeTorchPtr const& btype) {
  32. #if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
  33. return scalar_type_dispatch(*btype, [&](auto BType) {
  34. return GemmDispatcher<half_t, decltype(BType)>::supported_schedules();
  35. });
  36. #else
  37. TORCH_CHECK(false, "Machete requires CUDA 12.0 or later");
  38. #endif
  39. }
  40. torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
  41. ScalarTypeTorchPtr const& btype,
  42. c10::optional<torch::Tensor> const& scales,
  43. c10::optional<torch::Tensor> const& zeros,
  44. c10::optional<int64_t> group_size,
  45. c10::optional<torch::Tensor> const& C,
  46. c10::optional<double> alpha, c10::optional<double> beta,
  47. c10::optional<std::string> schedule) {
  48. #if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
  49. auto args = PyTorchArguments{.A = A,
  50. .B = B,
  51. .scales = scales,
  52. .zeros = zeros,
  53. .group_size = group_size,
  54. .C = C,
  55. .alpha = alpha,
  56. .beta = beta,
  57. .schedule = schedule};
  58. return scalar_type_dispatch(*btype, [&](auto BType) {
  59. return AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(
  60. A.scalar_type(), "machete_gemm", [&] {
  61. using ComputeType = equivalent_cutlass_type_t<scalar_t>;
  62. return GemmDispatcher<ComputeType, decltype(BType)>::dispatch(args);
  63. });
  64. });
  65. #else
  66. TORCH_CHECK(false, "Machete requires CUDA 12.0 or later");
  67. #endif
  68. }
  69. torch::Tensor prepack_B(torch::Tensor const& B,
  70. ScalarTypeTorchPtr const& btype) {
  71. #if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
  72. return scalar_type_dispatch(*btype, [&](auto BType) {
  73. return PrepackBDispatcher<half_t, decltype(BType), half_t>::dispatch(B);
  74. });
  75. #else
  76. TORCH_CHECK(false, "Machete requires CUDA 12.0 or later");
  77. #endif
  78. }
  79. }; // namespace machete