alt_matmul.cpp 547 B

12345678910111213141516171819202122232425
  1. #include <torch/all.h>
  2. #include <torch/python.h>
  3. #include <c10/cuda/CUDAGuard.h>
  4. void vecquant4matmul_cuda(
  5. torch::Tensor vec,
  6. torch::Tensor mat,
  7. torch::Tensor mul,
  8. torch::Tensor scales,
  9. torch::Tensor zeros,
  10. torch::Tensor g_idx
  11. );
  12. void gptq_descact_matmul(
  13. torch::Tensor vec,
  14. torch::Tensor mat,
  15. torch::Tensor mul,
  16. torch::Tensor scales,
  17. torch::Tensor zeros,
  18. torch::Tensor g_idx
  19. )
  20. {
  21. const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
  22. vecquant4matmul_cuda(vec, mat, mul, scales, zeros, g_idx);
  23. }