gguf_kernel.cu 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. #include <cuda_fp16.h>
  2. #include <cuda_runtime.h>
  3. #include <torch/all.h>
  4. #include <c10/cuda/CUDAGuard.h>
  5. #include "ggml-common.h"
  6. #include "vecdotq.cuh"
  7. #include "dequantize.cuh"
  8. #include "mmvq.cuh"
  9. #include "mmq.cuh"
  10. // Q8 gemv
  11. static __global__ void quantize_q8_1(const half* __restrict__ x,
  12. void* __restrict__ vy, const int kx,
  13. const int kx_padded) {
  14. const int ix = blockDim.x * blockIdx.x + threadIdx.x;
  15. if (ix >= kx_padded) {
  16. return;
  17. }
  18. const int iy = blockDim.y * blockIdx.y + threadIdx.y;
  19. const int i_padded = iy * kx_padded + ix;
  20. block_q8_1* y = (block_q8_1*)vy;
  21. const int ib = i_padded / QK8_1; // block index
  22. const int iqs = i_padded % QK8_1; // quant index
  23. const float xi = ix < kx ? __half2float(x[iy * kx + ix]) : 0.0f;
  24. float amax = fabsf(xi);
  25. float sum = xi;
  26. #pragma unroll
  27. for (int mask = 16; mask > 0; mask >>= 1) {
  28. amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, mask, 32));
  29. sum += __shfl_xor_sync(0xffffffff, sum, mask, 32);
  30. }
  31. const float d = amax / 127;
  32. const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
  33. y[ib].qs[iqs] = q;
  34. if (iqs > 0) {
  35. return;
  36. }
  37. y[ib].ds.x = __float2half(d);
  38. y[ib].ds.y = __float2half(sum);
  39. }
  40. static void quantize_row_q8_1_cuda(const half* x, void* vy, const int kx,
  41. const int ky, cudaStream_t stream) {
  42. const int64_t kx_padded = (kx + 512 - 1) / 512 * 512;
  43. const int block_num_x =
  44. (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
  45. const dim3 num_blocks(block_num_x, ky, 1);
  46. const dim3 block_size(CUDA_DEQUANTIZE_BLOCK_SIZE, 1, 1);
  47. quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
  48. }
  49. torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight
  50. int8_t type, int64_t m, int64_t n) {
  51. const at::cuda::OptionalCUDAGuard device_guard(device_of(W));
  52. auto options =
  53. torch::TensorOptions().dtype(torch::kFloat16).device(W.device());
  54. at::Tensor DW = torch::empty({m, n}, options);
  55. cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
  56. const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(type);
  57. to_fp16_cuda((void*)W.data_ptr(), (half*)DW.data_ptr(), m * n, stream);
  58. return DW;
  59. }
  60. torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight
  61. torch::Tensor X, // input
  62. int8_t type, int64_t row) {
  63. int col = X.sizes()[1];
  64. const int padded = (col + 512 - 1) / 512 * 512;
  65. const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
  66. auto options =
  67. torch::TensorOptions().dtype(torch::kFloat16).device(W.device());
  68. at::Tensor Y = torch::empty({1, row}, options);
  69. cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
  70. options = torch::TensorOptions().dtype(torch::kInt32).device(W.device());
  71. at::Tensor quant_X = torch::empty({1, padded / 32 * 9}, options);
  72. quantize_row_q8_1_cuda((half*)X.data_ptr(), (void*)quant_X.data_ptr(), col, 1,
  73. stream);
  74. switch (type) {
  75. case 2:
  76. mul_mat_vec_q4_0_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
  77. (half*)Y.data_ptr(), col, row, stream);
  78. break;
  79. case 3:
  80. mul_mat_vec_q4_1_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
  81. (half*)Y.data_ptr(), col, row, stream);
  82. break;
  83. case 6:
  84. mul_mat_vec_q5_0_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
  85. (half*)Y.data_ptr(), col, row, stream);
  86. break;
  87. case 7:
  88. mul_mat_vec_q5_1_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
  89. (half*)Y.data_ptr(), col, row, stream);
  90. break;
  91. case 8:
  92. mul_mat_vec_q8_0_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
  93. (half*)Y.data_ptr(), col, row, stream);
  94. break;
  95. case 10:
  96. mul_mat_vec_q2_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
  97. (half*)Y.data_ptr(), col, row, stream);
  98. break;
  99. case 11:
  100. mul_mat_vec_q3_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
  101. (half*)Y.data_ptr(), col, row, stream);
  102. break;
  103. case 12:
  104. mul_mat_vec_q4_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
  105. (half*)Y.data_ptr(), col, row, stream);
  106. break;
  107. case 13:
  108. mul_mat_vec_q5_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
  109. (half*)Y.data_ptr(), col, row, stream);
  110. break;
  111. case 14:
  112. mul_mat_vec_q6_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
  113. (half*)Y.data_ptr(), col, row, stream);
  114. break;
  115. case 16:
  116. mul_mat_vec_iq2_xxs_q8_1_cuda((void*)W.data_ptr(),
  117. (void*)quant_X.data_ptr(),
  118. (half*)Y.data_ptr(), col, row, stream);
  119. break;
  120. case 17:
  121. mul_mat_vec_iq2_xs_q8_1_cuda((void*)W.data_ptr(),
  122. (void*)quant_X.data_ptr(),
  123. (half*)Y.data_ptr(), col, row, stream);
  124. break;
  125. case 18:
  126. mul_mat_vec_iq3_xxs_q8_1_cuda((void*)W.data_ptr(),
  127. (void*)quant_X.data_ptr(),
  128. (half*)Y.data_ptr(), col, row, stream);
  129. break;
  130. case 19:
  131. mul_mat_vec_iq1_s_q8_1_cuda((void*)W.data_ptr(),
  132. (void*)quant_X.data_ptr(),
  133. (half*)Y.data_ptr(), col, row, stream);
  134. break;
  135. case 20:
  136. mul_mat_vec_iq4_nl_q8_1_cuda((void*)W.data_ptr(),
  137. (void*)quant_X.data_ptr(),
  138. (half*)Y.data_ptr(), col, row, stream);
  139. break;
  140. case 21:
  141. mul_mat_vec_iq3_s_q8_1_cuda((void*)W.data_ptr(),
  142. (void*)quant_X.data_ptr(),
  143. (half*)Y.data_ptr(), col, row, stream);
  144. break;
  145. case 22:
  146. mul_mat_vec_iq2_s_q8_1_cuda((void*)W.data_ptr(),
  147. (void*)quant_X.data_ptr(),
  148. (half*)Y.data_ptr(), col, row, stream);
  149. break;
  150. case 23:
  151. mul_mat_vec_iq4_xs_q8_1_cuda((void*)W.data_ptr(),
  152. (void*)quant_X.data_ptr(),
  153. (half*)Y.data_ptr(), col, row, stream);
  154. break;
  155. }
  156. return Y;
  157. }
  158. torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight
  159. torch::Tensor X, // input
  160. int8_t type, int64_t row) {
  161. int col = X.sizes()[1];
  162. int padded = (col + 512 - 1) / 512 * 512;
  163. int batch = X.sizes()[0];
  164. const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
  165. auto options =
  166. torch::TensorOptions().dtype(torch::kFloat16).device(W.device());
  167. at::Tensor Y = torch::empty({batch, row}, options);
  168. cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
  169. options = torch::TensorOptions().dtype(torch::kInt32).device(W.device());
  170. at::Tensor quant_X = torch::empty({batch, padded / 32 * 9}, options);
  171. quantize_row_q8_1_cuda((half*)X.data_ptr(), (void*)quant_X.data_ptr(), col,
  172. batch, stream);
  173. switch (type) {
  174. case 2:
  175. ggml_mul_mat_q4_0_q8_1_cuda(
  176. (void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
  177. col, row, batch, padded, row, stream);
  178. break;
  179. case 3:
  180. ggml_mul_mat_q4_1_q8_1_cuda(
  181. (void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
  182. col, row, batch, padded, row, stream);
  183. break;
  184. case 6:
  185. ggml_mul_mat_q5_0_q8_1_cuda(
  186. (void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
  187. col, row, batch, padded, row, stream);
  188. break;
  189. case 7:
  190. ggml_mul_mat_q5_1_q8_1_cuda(
  191. (void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
  192. col, row, batch, padded, row, stream);
  193. break;
  194. case 8:
  195. ggml_mul_mat_q8_0_q8_1_cuda(
  196. (void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
  197. col, row, batch, padded, row, stream);
  198. break;
  199. case 10:
  200. ggml_mul_mat_q2_K_q8_1_cuda(
  201. (void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
  202. col, row, batch, padded, row, stream);
  203. break;
  204. case 11:
  205. ggml_mul_mat_q3_K_q8_1_cuda(
  206. (void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
  207. col, row, batch, padded, row, stream);
  208. break;
  209. case 12:
  210. ggml_mul_mat_q4_K_q8_1_cuda(
  211. (void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
  212. col, row, batch, padded, row, stream);
  213. break;
  214. case 13:
  215. ggml_mul_mat_q5_K_q8_1_cuda(
  216. (void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
  217. col, row, batch, padded, row, stream);
  218. break;
  219. case 14:
  220. ggml_mul_mat_q6_K_q8_1_cuda(
  221. (void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
  222. col, row, batch, padded, row, stream);
  223. break;
  224. }
  225. return Y;
  226. }