123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354 |
- #include <cstdint>
- #include <torch/extension.h>
- torch::Tensor awq_gemm(
- torch::Tensor _in_feats,
- torch::Tensor _kernel,
- torch::Tensor _scaling_factors,
- torch::Tensor _zeros,
- int split_k_iters);
- uintptr_t make_q_matrix(
- torch::Tensor q_weight,
- torch::Tensor q_perm,
- torch::Tensor q_invperm,
- torch::Tensor gptq_qzeros,
- torch::Tensor gptq_scales,
- torch::Tensor gptq_g_idx,
- torch::Tensor temp_dq
- );
- void gemm_half_q_half(
- torch::Tensor a,
- uintptr_t b,
- torch::Tensor c,
- bool force_cuda
- );
- void gptq_descact_matmul(
- torch::Tensor vec,
- torch::Tensor mat,
- torch::Tensor mul,
- torch::Tensor scales,
- torch::Tensor zeros,
- torch::Tensor g_idx);
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
- m.def(
- "awq_gemm",
- &awq_gemm,
- "Quantized GEMM for AWQ");
- m.def(
- "make_q_matrix",
- &make_q_matrix,
- "make_q_matrix");
- m.def(
- "gemm_half_q_half",
- &gemm_half_q_half,
- "gemm_half_q_half");
- m.def(
- "gptq_descact_matmul",
- &gptq_descact_matmul,
- "Quantized GEMM for GPTQ for parallelized desc_act layer");
- }
|