quantization.cpp 1.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. #include <cstdint>
  2. #include <torch/extension.h>
  3. torch::Tensor awq_gemm(
  4. torch::Tensor _in_feats,
  5. torch::Tensor _kernel,
  6. torch::Tensor _scaling_factors,
  7. torch::Tensor _zeros,
  8. int split_k_iters);
  9. uintptr_t make_q_matrix(
  10. torch::Tensor q_weight,
  11. torch::Tensor q_perm,
  12. torch::Tensor q_invperm,
  13. torch::Tensor gptq_qzeros,
  14. torch::Tensor gptq_scales,
  15. torch::Tensor gptq_g_idx,
  16. torch::Tensor temp_dq
  17. );
  18. void gemm_half_q_half(
  19. torch::Tensor a,
  20. uintptr_t b,
  21. torch::Tensor c,
  22. bool force_cuda
  23. );
  24. void gptq_descact_matmul(
  25. torch::Tensor vec,
  26. torch::Tensor mat,
  27. torch::Tensor mul,
  28. torch::Tensor scales,
  29. torch::Tensor zeros,
  30. torch::Tensor g_idx);
  31. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  32. m.def(
  33. "awq_gemm",
  34. &awq_gemm,
  35. "Quantized GEMM for AWQ");
  36. m.def(
  37. "make_q_matrix",
  38. &make_q_matrix,
  39. "make_q_matrix");
  40. m.def(
  41. "gemm_half_q_half",
  42. &gemm_half_q_half,
  43. "gemm_half_q_half");
  44. m.def(
  45. "gptq_descact_matmul",
  46. &gptq_descact_matmul,
  47. "Quantized GEMM for GPTQ for parallelized desc_act layer");
  48. }