exllama_ext.cpp 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. // Adapted from turboderp exllama: https://github.com/turboderp/exllama
  2. #include <torch/extension.h>
  3. #include <c10/cuda/CUDAGuard.h>
  4. #include <ATen/cuda/CUDAContext.h>
  5. #include <cuda_runtime.h>
  6. #include <cuda_fp16.h>
  7. #include <cstdint>
  8. #include <cstdio>
  9. #include "util.cuh"
  10. #include "tuning.h"
  11. #include "cuda_buffers.cuh"
  12. #include "cuda_func/q4_matrix.cuh"
  13. #include "cuda_func/q4_matmul.cuh"
  14. #include "cuda_func/column_remap.cuh"
  15. // Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a
  16. // minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of
  17. // exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console.
  18. void check_cuda(cudaError_t ret)
  19. {
  20. switch (ret)
  21. {
  22. case cudaSuccess:
  23. break;
  24. case cudaUnspecified:
  25. printf(" **** Unspecified error\n");
  26. TORCH_CHECK(false, "CUDA error");
  27. break;
  28. default:
  29. printf(" **** CUDA error\n"); \
  30. printf(" **** %s\n", cudaGetErrorString(ret)); \
  31. TORCH_CHECK(false, "CUDA error"); \
  32. break;
  33. }
  34. }
  35. // Some decluttering macros
  36. #define STRINGIFY_(__x) #__x
  37. #define STRINGIFY(__x) STRINGIFY_(__x)
  38. #define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
  39. #define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
  40. #define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
  41. #define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
  42. #define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod))
  43. #define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small")
  44. #define TORCH_CHECK_DEVICE_INDEX(__index) \
  45. do { \
  46. TORCH_CHECK(__index >= 0, "no device index"); \
  47. TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \
  48. } while(0)
  49. #define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \
  50. do { \
  51. TORCH_CHECK_DTYPE(__w, kInt); \
  52. TORCH_CHECK_DTYPE(__w_scales, kHalf); \
  53. TORCH_CHECK_DTYPE(__w_zeros, kInt); \
  54. TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \
  55. TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \
  56. TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \
  57. TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \
  58. } while(0)
  59. int get_groupsize(torch::Tensor w, torch::Tensor w_zeros)
  60. {
  61. int groupsize = w.size(0) * 8 / w_zeros.size(0);
  62. TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]")
  63. return groupsize;
  64. }
  65. // Tuning parameters
  66. ExLlamaTuning tuningParams;
  67. void gptq_set_tuning_params
  68. (
  69. int matmul_recons_thd,
  70. bool matmul_fused_remap,
  71. bool matmul_no_half2
  72. )
  73. {
  74. tuningParams.matmul_recons_thd = matmul_recons_thd;
  75. tuningParams.matmul_fused_remap = matmul_fused_remap;
  76. tuningParams.matmul_no_half2 = matmul_no_half2;
  77. }
  78. // Release all unmanaged objects allocated by the extension
  79. void gptq_cleanup()
  80. {
  81. cleanup_buffers_cuda();
  82. g_q4_free_matrices();
  83. }
  84. // Prepare buffers for forward pass
  85. void gptq_prepare_buffers
  86. (
  87. torch::Device device,
  88. torch::Tensor temp_state,
  89. torch::Tensor temp_dq
  90. )
  91. {
  92. int device_index = device.index();
  93. TORCH_CHECK_DEVICE_INDEX(device_index);
  94. const at::cuda::OptionalCUDAGuard device_guard(device);
  95. prepare_buffers_cuda
  96. (
  97. device_index,
  98. // buffer size used for sanity checks
  99. temp_state.numel(),
  100. (half*) temp_state.data_ptr(),
  101. (half*) temp_dq.data_ptr()
  102. );
  103. }
  104. // Create Q4Matrix, return handle
  105. uintptr_t gptq_make_q4
  106. (
  107. torch::Tensor qweight,
  108. torch::Tensor qzeros,
  109. torch::Tensor scales,
  110. torch::Tensor g_idx,
  111. int device
  112. )
  113. {
  114. TORCH_CHECK_DTYPE(qweight, kInt);
  115. TORCH_CHECK_DTYPE(qzeros, kInt);
  116. TORCH_CHECK_DTYPE(scales, kHalf);
  117. TORCH_CHECK_DTYPE_OPT(g_idx, kInt);
  118. TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8);
  119. TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1);
  120. TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1);
  121. int width = qweight.size(1);
  122. int height = qweight.size(0) * 8;
  123. int groups = qzeros.size(0);
  124. Q4Matrix* m = new Q4Matrix
  125. (
  126. height,
  127. width,
  128. groups,
  129. (uint32_t*) qweight.data_ptr(),
  130. (uint32_t*) qzeros.data_ptr(),
  131. (half*) scales.data_ptr(),
  132. g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(),
  133. device
  134. );
  135. g_q4_keep_matrix(m);
  136. return reinterpret_cast<uintptr_t> (m);
  137. }
  138. // Matmul half @ quant -> half
  139. void gptq_q4_matmul
  140. (
  141. torch::Tensor x,
  142. uintptr_t w,
  143. torch::Tensor out
  144. )
  145. {
  146. Q4Matrix* wm = reinterpret_cast<Q4Matrix*> (w);
  147. TORCH_CHECK_DTYPE(x, kHalf);
  148. TORCH_CHECK_DTYPE(out, kHalf);
  149. TORCH_CHECK_SHAPES(x, 0, out, 0, 1);
  150. TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes")
  151. const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
  152. int x_height = x.size(0);
  153. if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd)
  154. {
  155. q4_matmul_cuda
  156. (
  157. &tuningParams,
  158. (half*) x.data_ptr(),
  159. x_height,
  160. wm,
  161. (half*) out.data_ptr()
  162. );
  163. }
  164. else
  165. {
  166. q4_matmul_recons_cuda
  167. (
  168. &tuningParams,
  169. (half*) x.data_ptr(),
  170. x_height,
  171. wm,
  172. (half*) out.data_ptr(),
  173. at::cuda::getCurrentCUDABlasHandle()
  174. );
  175. }
  176. }
  177. // Remap columns in half tensor
  178. void gptq_column_remap
  179. (
  180. torch::Tensor x,
  181. torch::Tensor x_new,
  182. torch::Tensor x_map
  183. )
  184. {
  185. TORCH_CHECK_DTYPE(x, kHalf);
  186. TORCH_CHECK_DTYPE(x_new, kHalf);
  187. TORCH_CHECK_DTYPE(x_map, kInt);
  188. TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1);
  189. int height = x.size(0);
  190. int width = x.size(1);
  191. TORCH_CHECK_BUFFER_SIZE(x_new, height * width);
  192. const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
  193. column_remap_cuda
  194. (
  195. (half*) x.data_ptr(),
  196. (half*) x_new.data_ptr(),
  197. height,
  198. width,
  199. (uint32_t*) x_map.data_ptr()
  200. );
  201. }