machete_mm_launcher.cuh 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. #pragma once
  2. #include <torch/all.h>
  3. #include <Python.h>
  4. #include "machete_mm_kernel.cuh"
  5. #include "cutlass_extensions/torch_utils.hpp"
  6. namespace machete {
  7. struct PyTorchArguments {
  8. torch::Tensor const& A;
  9. torch::Tensor const& B;
  10. c10::optional<torch::Tensor> const& scales;
  11. c10::optional<torch::Tensor> const& zeros;
  12. c10::optional<int64_t> group_size;
  13. c10::optional<torch::Tensor> const& C;
  14. c10::optional<double> alpha;
  15. c10::optional<double> beta;
  16. c10::optional<std::string> schedule;
  17. };
  18. template <typename MacheteKernel>
  19. torch::Tensor run_impl(PyTorchArguments args) {
  20. const at::cuda::OptionalCUDAGuard device_guard(device_of(args.A));
  21. auto device = args.A.device();
  22. auto stream = at::cuda::getCurrentCUDAStream(device.index());
  23. using EleA = typename MacheteKernel::ElementA;
  24. using EleB = typename MacheteKernel::ElementB;
  25. using EleC = typename MacheteKernel::ElementC;
  26. using EleD = typename MacheteKernel::ElementD;
  27. using EleScale = typename MacheteKernel::ElementS;
  28. using EleZero = typename MacheteKernel::ElementZ;
  29. using StrideA = typename MacheteKernel::StrideA;
  30. using StrideC = typename MacheteKernel::StrideC;
  31. using StrideD = typename MacheteKernel::StrideD;
  32. using StrideS = typename MacheteKernel::StrideS;
  33. using StrideZ = typename MacheteKernel::StrideZ;
  34. int M = args.A.size(0);
  35. int N = args.B.size(1);
  36. int K = args.A.size(1);
  37. // Allocate output
  38. torch::Tensor D =
  39. torch::empty({M, N}, torch::TensorOptions()
  40. .dtype(equivalent_scalar_type_v<EleD>)
  41. .device(device));
  42. auto const &A = args.A, &B = args.B;
  43. auto const &C = args.C, &scales = args.scales, &zeros = args.zeros;
  44. auto layout_A = make_cute_layout<StrideA>(A, "A");
  45. auto layout_D = make_cute_layout<StrideD>(D, "D");
  46. auto layout_C = maybe_make_cute_layout<StrideC>(C, "C");
  47. auto layout_S = maybe_make_cute_layout<StrideS>(scales, "scales");
  48. auto layout_Z = maybe_make_cute_layout<StrideZ>(zeros, "zeros");
  49. auto A_ptr = static_cast<EleA const*>(A.const_data_ptr());
  50. auto B_ptr = static_cast<EleB const*>(B.const_data_ptr());
  51. auto D_ptr = static_cast<EleD*>(D.mutable_data_ptr());
  52. auto C_ptr = static_cast<EleC const*>(C ? C->const_data_ptr() : nullptr);
  53. auto S_ptr =
  54. static_cast<EleScale const*>(scales ? scales->const_data_ptr() : nullptr);
  55. auto Z_ptr =
  56. static_cast<EleZero const*>(zeros ? zeros->const_data_ptr() : nullptr);
  57. auto arguments = MacheteKernel::create_arguments(
  58. stream, A_ptr, layout_A, B_ptr, D_ptr, layout_D, C_ptr, layout_C, S_ptr,
  59. layout_S, Z_ptr, layout_Z, args.alpha.value_or(1), args.beta.value_or(0),
  60. args.group_size);
  61. TORCH_CHECK(MacheteKernel::can_implement(arguments),
  62. "Machete kernel cannot be run with these arguments");
  63. size_t workspace_size = MacheteKernel::get_workspace_size(arguments);
  64. torch::Tensor workspace = torch::empty(
  65. workspace_size, torch::TensorOptions().dtype(torch::kU8).device(device));
  66. MacheteKernel::run(arguments, workspace.mutable_data_ptr(), stream);
  67. return D;
  68. };
  69. template <typename ElementA, typename ElementB, typename ElementD = ElementA,
  70. typename AccumulatorT = float, typename ScaleT = ElementA,
  71. typename ZeroT = ElementA>
  72. struct GemmDispatcher {
  73. static torch::Tensor dispatch(PyTorchArguments args);
  74. static std::vector<std::string> supported_schedules();
  75. };
  76. }; // namespace machete