quant_ops.cpp 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. #include "quant_ops.h"
  2. #include <torch/extension.h>
  3. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  4. // Aphrodite quantization ops
  5. pybind11::module quant_ops =
  6. m.def_submodule("quant_ops", "Aphrodite custom quant operators");
  7. #ifndef USE_ROCM
  8. // AQLM
  9. quant_ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM");
  10. quant_ops.def("aqlm_dequant", &aqlm_dequant, "Dequantization for AQLM");
  11. // AWQ
  12. quant_ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
  13. quant_ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
  14. quant_ops.def("awq_group_gemm", &awq_group_gemm,
  15. "Grouped Quantized GEMM for AWQ");
  16. // GGUF
  17. quant_ops.def("ggml_dequantize", &ggml_dequantize, "ggml_dequantize");
  18. quant_ops.def("ggml_mul_mat_vec", &ggml_mul_mat_vec, "ggml_mul_mat_vec");
  19. quant_ops.def("ggml_mul_mat_vec_a8", &ggml_mul_mat_vec_a8,
  20. "ggml_mul_mat_vec_a8");
  21. quant_ops.def("ggml_mul_mat_a8", &ggml_mul_mat_a8, "ggml_mul_mat_a8");
  22. // Marlin
  23. quant_ops.def("marlin_gemm", &marlin_gemm,
  24. "Marlin Optimized Quantized GEMM for GPTQ");
  25. quant_ops.def("marlin_gemm", &marlin_gemm,
  26. "Marlin (Dense) Optimized Quantized GEMM for GPTQ");
  27. quant_ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm,
  28. "Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ");
  29. quant_ops.def("gptq_marlin_gemm", &gptq_marlin_gemm,
  30. "gptq_marlin Optimized Quantized GEMM for GPTQ");
  31. quant_ops.def("gptq_marlin_repack", &gptq_marlin_repack,
  32. "gptq_marlin repack from GPTQ");
  33. // SmoothQuant+
  34. quant_ops.def("autoquant_convert_s4_k_m8", &autoquant_convert_s4_k_m8,
  35. "convert kernel.");
  36. quant_ops.def("autoquant_s4_f16_gemm", &autoquant_s4_f16_gemm,
  37. "weight int4 activation float16 gemm kernel.");
  38. // QuIP#
  39. quant_ops.def("quip_decompress", &decompress_e8p_origorder,
  40. "decompress_packed_e8p");
  41. quant_ops.def("quip_gemv", &e8p_mm_origorder, "e8p_mm_origorder");
  42. // CUTLASS w8a8
  43. quant_ops.def("cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq,
  44. "CUTLASS w8a8 GEMM, supporting symmetric per-tensor or "
  45. "per-row/column quantization.");
  46. #endif
  47. // GPTQ
  48. quant_ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
  49. quant_ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
  50. quant_ops.def("group_gptq_gemm", &group_gptq_gemm,
  51. "Grouped Quantized GEMM for GPTQ");
  52. quant_ops.def("dequant_gptq", &dequant_gptq,
  53. "Dequantize gptq weight to half");
  54. // SqueezeLLM
  55. quant_ops.def("squeezellm_gemm", &squeezellm_gemm,
  56. "Quantized GEMM for SqueezeLLM");
  57. // INT8
  58. quant_ops.def("static_scaled_int8_quant", &static_scaled_int8_quant,
  59. "Compute int8 quantized tensor for given scaling factor");
  60. // ExLlamaV2
  61. quant_ops.def("exl2_make_q_matrix", &make_q_matrix, "preprocess for exl2");
  62. quant_ops.def("exl2_gemm", &exl2_gemm, "exl2 gemm");
  63. // FP8
  64. quant_ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant,
  65. "Compute FP8 quantized tensor for given scaling factor");
  66. quant_ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant,
  67. "Compute FP8 quantized tensor and scaling factor");
  68. }