1
0

machete_pytorch.cu 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. #include "machete_mm_launcher.cuh"
  2. #include "machete_prepack_launcher.cuh"
  3. #include "core/scalar_type.hpp"
  4. namespace machete {
  5. using namespace aphrodite;
  6. //
  7. // Utils (type dispatching)
  8. //
  9. template <typename Fn>
  10. static auto scalar_type_dispatch(ScalarType const& type, Fn fn) {
  11. if (type == aphrodite::kU4) {
  12. return fn(cutlass::uint4b_t{});
  13. } else if (type == aphrodite::kU8) {
  14. return fn(cutlass::uint8_t{});
  15. } else if (type == aphrodite::kU4B8) {
  16. return fn(cutlass::aphrodite_uint4b8_t{});
  17. } else if (type == aphrodite::kU8B128) {
  18. return fn(cutlass::aphrodite_uint8b128_t{});
  19. } else {
  20. TORCH_CHECK(false, "Unsupported type ", type.str());
  21. }
  22. }
  23. #define AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(...) \
  24. AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(__VA_ARGS__)
  25. #define AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(TYPE, NAME, ...) \
  26. AT_DISPATCH_SWITCH(TYPE, NAME, \
  27. AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(__VA_ARGS__))
  28. //
  29. // Interface
  30. //
  31. std::vector<std::string> supported_schedules(ScalarTypeTorchPtr const& btype) {
  32. return scalar_type_dispatch(*btype, [&](auto BType) {
  33. return GemmDispatcher<half_t, decltype(BType)>::supported_schedules();
  34. });
  35. }
  36. torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
  37. ScalarTypeTorchPtr const& btype,
  38. c10::optional<torch::Tensor> const& scales,
  39. c10::optional<torch::Tensor> const& zeros,
  40. c10::optional<int64_t> group_size,
  41. c10::optional<torch::Tensor> const& C,
  42. c10::optional<double> alpha, c10::optional<double> beta,
  43. c10::optional<std::string> schedule) {
  44. auto args = PyTorchArguments{.A = A,
  45. .B = B,
  46. .scales = scales,
  47. .zeros = zeros,
  48. .group_size = group_size,
  49. .C = C,
  50. .alpha = alpha,
  51. .beta = beta,
  52. .schedule = schedule};
  53. return scalar_type_dispatch(*btype, [&](auto BType) {
  54. return AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(
  55. A.scalar_type(), "machete_gemm", [&] {
  56. using ComputeType = equivalent_cutlass_type_t<scalar_t>;
  57. return GemmDispatcher<ComputeType, decltype(BType)>::dispatch(args);
  58. });
  59. });
  60. }
  61. torch::Tensor prepack_B(torch::Tensor const& B,
  62. ScalarTypeTorchPtr const& btype) {
  63. return scalar_type_dispatch(*btype, [&](auto BType) {
  64. return PrepackBDispatcher<half_t, decltype(BType), half_t>::dispatch(B);
  65. });
  66. }
  67. }; // namespace machete