q4_matmul.cuh 861 B

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. // Adapted from turboderp exllama: https://github.com/turboderp/exllama
  2. #ifndef _q4_matmul_cuh
  3. #define _q4_matmul_cuh
  4. #include <cuda_runtime.h>
  5. #include <cuda_fp16.h>
  6. #include <cstdint>
  7. #include <cstdio>
  8. #include <ATen/cuda/CUDAContext.h>
  9. #include "q4_matrix.cuh"
  10. #include "../tuning.h"
  11. // Workaround for hipify_python using rocblas instead of hipblas.
  12. #if defined(USE_ROCM)
  13. #include <hipblas/hipblas.h>
  14. #define rocblas_handle hipblasHandle_t
  15. #endif
  16. void q4_matmul_cuda
  17. (
  18. ExLlamaTuning* tuningParams,
  19. const half* x,
  20. const int x_height,
  21. const Q4Matrix* w,
  22. half* out,
  23. bool no_zero = false,
  24. cudaStream_t alt_stream = NULL
  25. );
  26. void q4_matmul_recons_cuda
  27. (
  28. ExLlamaTuning* tuningParams,
  29. const half* x,
  30. const int x_height,
  31. Q4Matrix* w,
  32. half* out,
  33. const cublasHandle_t handle,
  34. bool no_zero = false
  35. );
  36. #endif