autogptq_cuda_256.cpp 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. #include <torch/all.h>
  2. #include <torch/python.h>
  3. #include <c10/cuda/CUDAGuard.h>
  4. void vecquant2matmul_cuda(
  5. torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
  6. torch::Tensor scales, torch::Tensor zeros,
  7. torch::Tensor g_idx
  8. );
  9. void vecquant2matmul(
  10. torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
  11. torch::Tensor scales, torch::Tensor zeros,
  12. torch::Tensor g_idx
  13. ) {
  14. const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
  15. vecquant2matmul_cuda(vec, mat, mul, scales, zeros, g_idx);
  16. }
  17. void vecquant3matmul_cuda(
  18. torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
  19. torch::Tensor scales, torch::Tensor zeros,
  20. torch::Tensor g_idx
  21. );
  22. void vecquant3matmul(
  23. torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
  24. torch::Tensor scales, torch::Tensor zeros,
  25. torch::Tensor g_idx
  26. ) {
  27. const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
  28. vecquant3matmul_cuda(vec, mat, mul, scales, zeros, g_idx);
  29. }
  30. void vecquant4matmul_cuda(
  31. torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
  32. torch::Tensor scales, torch::Tensor zeros,
  33. torch::Tensor g_idx
  34. );
  35. void vecquant4matmul(
  36. torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
  37. torch::Tensor scales, torch::Tensor zeros,
  38. torch::Tensor g_idx
  39. ) {
  40. const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
  41. vecquant4matmul_cuda(vec, mat, mul, scales, zeros, g_idx);
  42. }
  43. void vecquant8matmul_cuda(
  44. torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
  45. torch::Tensor scales, torch::Tensor zeros,
  46. torch::Tensor g_idx
  47. );
  48. void vecquant8matmul(
  49. torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
  50. torch::Tensor scales, torch::Tensor zeros,
  51. torch::Tensor g_idx
  52. ) {
  53. const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
  54. vecquant8matmul_cuda(vec, mat, mul, scales, zeros, g_idx);
  55. }
  56. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  57. m.def("vecquant2matmul", &vecquant2matmul, "Vector 2-bit Quantized Matrix Multiplication (CUDA) (desc_act)");
  58. m.def("vecquant3matmul", &vecquant3matmul, "Vector 3-bit Quantized Matrix Multiplication (CUDA) (desc_act)");
  59. m.def("vecquant4matmul", &vecquant4matmul, "Vector 4-bit Quantized Matrix Multiplication (CUDA) (desc_act)");
  60. m.def("vecquant8matmul", &vecquant8matmul, "Vector 8-bit Quantized Matrix Multiplication (CUDA) (desc_act)");
  61. }