fpA_intB_gemm_wrapper.h 672 B

12345678910111213141516
  1. #include <torch/extension.h>
  2. #include <vector>
  3. #define SMALL_M_FAST_PATH 4
  4. torch::Tensor w8_a16_gemm_forward_cuda(torch::Tensor &input,
  5. torch::Tensor &weight,
  6. torch::Tensor &scale);
  7. torch::Tensor w8_a16_gemm_forward_cuda_(torch::Tensor &input,
  8. torch::Tensor &weight,
  9. torch::Tensor &scale,
  10. torch::Tensor &output,
  11. const int m,
  12. const int n,
  13. const int k);