1
0

quant_ops.cpp 1.9 KB

1234567891011121314151617181920212223242526272829303132333435
  1. #include "quant_ops.h"
  2. #include <torch/extension.h>
  3. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  4. // Aphrodite quantization ops
  5. pybind11::module quant_ops = m.def_submodule("quant_ops", "Aphrodite custom quant operators");
  6. #ifndef USE_ROCM
  7. // AQLM
  8. quant_ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM");
  9. // AWQ
  10. quant_ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
  11. quant_ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
  12. quant_ops.def("awq_group_gemm", &awq_group_gemm, "Grouped Quantized GEMM for AWQ");
  13. // GGUF
  14. quant_ops.def("ggml_dequantize", &ggml_dequantize, "ggml_dequantize");
  15. quant_ops.def("ggml_mul_mat_vec", &ggml_mul_mat_vec, "ggml_mul_mat_vec");
  16. quant_ops.def("ggml_mul_mat_vec_a8", &ggml_mul_mat_vec_a8, "ggml_mul_mat_vec_a8");
  17. quant_ops.def("ggml_mul_mat_a8", &ggml_mul_mat_a8, "ggml_mul_mat_a8");
  18. // Marlin
  19. quant_ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ");
  20. quant_ops.def("autoquant_convert_s4_k_m8", &autoquant_convert_s4_k_m8, "convert kernel.");
  21. quant_ops.def("autoquant_s4_f16_gemm", &autoquant_s4_f16_gemm, "weight int4 activation float16 gemm kernel.");
  22. quant_ops.def("quip_decompress", &decompress_e8p_origorder, "decompress_packed_e8p");
  23. quant_ops.def("quip_gemv", &e8p_mm_origorder, "e8p_mm_origorder");
  24. #endif
  25. quant_ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
  26. quant_ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
  27. quant_ops.def("group_gptq_gemm", &group_gptq_gemm, "Grouped Quantized GEMM for GPTQ");
  28. quant_ops.def("dequant_gptq", &dequant_gptq, "Dequantize gptq weight to half");
  29. quant_ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
  30. quant_ops.def("exl2_make_q_matrix",&make_q_matrix, "preprocess for exl2");
  31. quant_ops.def("exl2_gemm", &exl2_gemm, "exl2 gemm");
  32. quant_ops.def("scaled_fp8_quant", &scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor");
  33. }