quantization.cpp 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. #include <cstdint>
  2. #include <torch/extension.h>
  3. torch::Tensor awq_gemm(
  4. torch::Tensor _in_feats,
  5. torch::Tensor _kernel,
  6. torch::Tensor _scaling_factors,
  7. torch::Tensor _zeros,
  8. int split_k_iters);
  9. void gptq_set_tuning_params(
  10. int matmul_recons_thd,
  11. bool matmul_fused_remap,
  12. bool matmul_no_half2);
  13. void gptq_prepare_buffers(
  14. torch::Device device,
  15. torch::Tensor temp_state,
  16. torch::Tensor temp_dq);
  17. uintptr_t gptq_make_q4(
  18. torch::Tensor qweight,
  19. torch::Tensor qzeros,
  20. torch::Tensor scales,
  21. torch::Tensor g_idx,
  22. int device);
  23. void gptq_q4_matmul(
  24. torch::Tensor x,
  25. uintptr_t w,
  26. torch::Tensor out);
  27. void gptq_descact_matmul(
  28. torch::Tensor vec,
  29. torch::Tensor mat,
  30. torch::Tensor mul,
  31. torch::Tensor scales,
  32. torch::Tensor zeros,
  33. torch::Tensor g_idx);
  34. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  35. m.def(
  36. "awq_gemm",
  37. &awq_gemm,
  38. "Quantized GEMM for AWQ");
  39. m.def(
  40. "gptq_set_tuning_params",
  41. &gptq_set_tuning_params,
  42. "Set tuning params for GPTQ");
  43. m.def(
  44. "gptq_prepare_buffers",
  45. &gptq_prepare_buffers,
  46. "Prepare buffers for GPTQ.");
  47. m.def(
  48. "gptq_make_q4",
  49. &gptq_make_q4,
  50. "Preprocess weight for GPTQ.");
  51. m.def(
  52. "gptq_q4_matmul",
  53. &gptq_q4_matmul,
  54. "Quantized GEMM for GPTQ.");
  55. m.def(
  56. "gptq_descact_matmul",
  57. &gptq_descact_matmul,
  58. "Quantized GEMM for GPTQ for parallelized desc_act layer.");
  59. }