1
0

q_matrix.cu 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. /*
  2. * Adapted from https://github.com/turboderp/exllamav2
  3. * Copyright (c) 2024 turboderp
  4. *
  5. * Permission is hereby granted, free of charge, to any person obtaining a copy
  6. * of this software and associated documentation files (the "Software"), to deal
  7. * in the Software without restriction, including without limitation the rights
  8. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  9. * copies of the Software, and to permit persons to whom the Software is
  10. * furnished to do so, subject to the following conditions:
  11. *
  12. * The above copyright notice and this permission notice shall be included in all
  13. * copies or substantial portions of the Software.
  14. *
  15. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  16. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  17. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  18. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  19. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  20. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  21. * SOFTWARE.
  22. */
  23. #include <torch/extension.h>
  24. #include <c10/cuda/CUDAGuard.h>
  25. #include <ATen/cuda/CUDAContext.h>
  26. #include <cuda_runtime.h>
  27. #include "q_matrix.cuh"
  28. #include "matrix_view.cuh"
  29. #include "quant/qdq_2.cuh"
  30. #include "quant/qdq_3.cuh"
  31. #include "quant/qdq_4.cuh"
  32. #include "quant/qdq_5.cuh"
  33. #include "quant/qdq_6.cuh"
  34. #include "quant/qdq_8.cuh"
  35. namespace aphrodite {
  36. namespace exl2 {
  37. #define BLOCK_KN_SIZE 128
  38. #define THREADS_X 32
  39. #define THREADS_Y 32
  40. #define DIVIDE(x, size) (((x) + (size) - 1) / (size))
  41. // Shuffle quantized data on load
  42. __global__ void shuffle_kernel
  43. (
  44. uint32_t* __restrict__ b_q_weight,
  45. const int size_k,
  46. const int size_n,
  47. const int rows_8,
  48. const int rows_6,
  49. const int rows_5,
  50. const int rows_4,
  51. const int rows_3,
  52. const int rows_2
  53. )
  54. {
  55. int n = blockIdx.x * THREADS_X + threadIdx.x;
  56. if (n >= size_n) return;
  57. int k = 0;
  58. uint32_t* b_ptr = b_q_weight + n;
  59. while (k < rows_8) { shuffle_8bit_4 (b_ptr, size_n); b_ptr += 1 * size_n; k += 4; }
  60. while (k < rows_6) { shuffle_6bit_16(b_ptr, size_n); b_ptr += 3 * size_n; k += 16; }
  61. while (k < rows_5) { shuffle_5bit_32(b_ptr, size_n); b_ptr += 5 * size_n; k += 32; }
  62. while (k < rows_4) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k += 8; }
  63. while (k < rows_3) { shuffle_3bit_32(b_ptr, size_n); b_ptr += 3 * size_n; k += 32; }
  64. while (k < rows_2) { shuffle_2bit_16(b_ptr, size_n); b_ptr += 1 * size_n; k += 16; }
  65. }
  66. // QMatrix constructor
  67. QMatrix::QMatrix
  68. (
  69. const int _device,
  70. const int _height,
  71. const int _width,
  72. const int _groups,
  73. uint32_t* _q_weight,
  74. uint16_t* _q_perm,
  75. uint16_t* _q_invperm,
  76. uint32_t* _q_scale,
  77. half* _q_scale_max,
  78. uint16_t* _q_groups,
  79. uint16_t* _q_group_map
  80. ):
  81. device(_device),
  82. height(_height),
  83. width(_width),
  84. groups(_groups)
  85. {
  86. cudaSetDevice(device);
  87. failed = false;
  88. cuda_q_weight = _q_weight;
  89. cuda_q_perm = _q_perm;
  90. cuda_q_invperm = _q_invperm;
  91. cuda_q_scale = _q_scale;
  92. cuda_q_scale_max = _q_scale_max;
  93. cuda_q_groups = _q_groups;
  94. cuda_q_group_map = _q_group_map;
  95. // Create group map
  96. rows_8 = 0;
  97. rows_6 = 0;
  98. rows_5 = 0;
  99. rows_4 = 0;
  100. rows_3 = 0;
  101. rows_2 = 0;
  102. {
  103. uint16_t* cpu_q_groups = (uint16_t*)calloc(groups * 2, sizeof(uint16_t));
  104. cudaMemcpy(cpu_q_groups, cuda_q_groups, groups * 2 * sizeof(uint16_t), cudaMemcpyDeviceToHost);
  105. int row = 0;
  106. for (int i = 0; i < groups; i++)
  107. {
  108. int bits = cpu_q_groups[i * 2];
  109. int rows;
  110. if (i < groups - 1)
  111. {
  112. int qrows = cpu_q_groups[i * 2 + 3] - cpu_q_groups[i * 2 + 1];
  113. rows = qrows * 32 / bits;
  114. }
  115. else rows = height - row;
  116. if (bits == 8) rows_8 += rows;
  117. if (bits == 6) rows_6 += rows;
  118. if (bits == 5) rows_5 += rows;
  119. if (bits == 4) rows_4 += rows;
  120. if (bits == 3) rows_3 += rows;
  121. if (bits == 2) rows_2 += rows;
  122. row += rows;
  123. }
  124. free(cpu_q_groups);
  125. rows_6 += rows_8;
  126. rows_5 += rows_6;
  127. rows_4 += rows_5;
  128. rows_3 += rows_4;
  129. rows_2 += rows_3;
  130. }
  131. // Shuffle quantized data
  132. dim3 blockDim, gridDim;
  133. blockDim.x = THREADS_X;
  134. blockDim.y = 1;
  135. gridDim.x = DIVIDE(width, THREADS_X);
  136. gridDim.y = 1;
  137. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  138. shuffle_kernel<<<gridDim, blockDim, 0, stream>>>(
  139. cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2);
  140. }
  141. QMatrix::~QMatrix()
  142. {
  143. }
  144. // Reconstruct b[k,n]
  145. __global__ void reconstruct_kernel
  146. (
  147. const uint32_t* __restrict__ b_q_weight,
  148. const uint16_t* __restrict__ b_q_perm,
  149. const uint32_t* __restrict__ b_q_scale,
  150. const half* __restrict__ b_q_scale_max,
  151. const uint16_t* __restrict__ b_q_group_map,
  152. const int size_k,
  153. const int size_n,
  154. //const int groupsize,
  155. const int groups,
  156. half* __restrict__ b,
  157. const int rows_8,
  158. const int rows_6,
  159. const int rows_5,
  160. const int rows_4,
  161. const int rows_3,
  162. const int rows_2
  163. )
  164. {
  165. MatrixView_half_rw b_(b, size_k, size_n);
  166. MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n);
  167. int offset_k = BLOCK_KN_SIZE * blockIdx.y;
  168. int offset_n = BLOCK_KN_SIZE * blockIdx.x;
  169. // Preload remapping table
  170. int t = threadIdx.x;
  171. __shared__ uint16_t perm[BLOCK_KN_SIZE];
  172. if (offset_k + t < size_k)
  173. perm[t] = b_q_perm[offset_k + t];
  174. // Column
  175. int n = offset_n + t;
  176. if (n >= size_n) return;
  177. // Find initial group
  178. // int group = offset_k / groupsize;
  179. int group = b_q_group_map[offset_k * 2];
  180. int pre_rows_8 = min(rows_8, offset_k);
  181. int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;
  182. int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0;
  183. int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0;
  184. int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0;
  185. int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0;
  186. int qk = 0;
  187. qk += pre_rows_8 / 32 * 8;
  188. qk += pre_rows_6 / 32 * 6;
  189. qk += pre_rows_5 / 32 * 5;
  190. qk += pre_rows_4 / 32 * 4;
  191. qk += pre_rows_3 / 32 * 3;
  192. qk += pre_rows_2 / 32 * 2;
  193. const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
  194. half qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]);
  195. half2 qs_h2 = __halves2half2(qs_h, qs_h);
  196. int nextgroup = offset_k + b_q_group_map[offset_k * 2 + 1];
  197. int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
  198. int k = offset_k;
  199. int lk = 0;
  200. __syncthreads();
  201. while (k < rows_8 && k < end_k)
  202. {
  203. if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
  204. for (int p = 0; p < 4; p++)
  205. {
  206. half2 dq[4];
  207. uint32_t q_0 = *b_ptr; b_ptr += size_n;
  208. uint32_t q_1 = *b_ptr; b_ptr += size_n;
  209. dequant_8bit_8(q_0, q_1, dq, size_n);
  210. for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2);
  211. half* dqh = (half*) dq;
  212. for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]);
  213. }
  214. k += 32;
  215. }
  216. while (k < rows_6 && k < end_k)
  217. {
  218. if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
  219. for (int p = 0; p < 2; p++)
  220. {
  221. half2 dq[8];
  222. uint32_t q_0 = *b_ptr; b_ptr += size_n;
  223. uint32_t q_1 = *b_ptr; b_ptr += size_n;
  224. uint32_t q_2 = *b_ptr; b_ptr += size_n;
  225. dequant_6bit_16(q_0, q_1, q_2, dq, size_n);
  226. for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2);
  227. half* dqh = (half*) dq;
  228. for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
  229. }
  230. k += 32;
  231. }
  232. while (k < rows_5 && k < end_k)
  233. {
  234. if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
  235. for (int p = 0; p < 1; p++)
  236. {
  237. half2 dq[16];
  238. uint32_t q_0 = *b_ptr; b_ptr += size_n;
  239. uint32_t q_1 = *b_ptr; b_ptr += size_n;
  240. uint32_t q_2 = *b_ptr; b_ptr += size_n;
  241. uint32_t q_3 = *b_ptr; b_ptr += size_n;
  242. uint32_t q_4 = *b_ptr; b_ptr += size_n;
  243. dequant_5bit_32(q_0, q_1, q_2, q_3, q_4, dq, size_n);
  244. for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2);
  245. half* dqh = (half*) dq;
  246. for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]);
  247. }
  248. k += 32;
  249. }
  250. while (k < rows_4 && k < end_k)
  251. {
  252. if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
  253. for (int p = 0; p < 4; p++)
  254. {
  255. half2 dq[4];
  256. uint32_t q_0 = *b_ptr; b_ptr += size_n;
  257. dequant_4bit_8(q_0, dq, size_n);
  258. for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2);
  259. half* dqh = (half*) dq;
  260. for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]);
  261. }
  262. k += 32;
  263. }
  264. while (k < rows_3 && k < end_k)
  265. {
  266. if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
  267. for (int p = 0; p < 1; p++)
  268. {
  269. half2 dq[16];
  270. uint32_t q_0 = *b_ptr; b_ptr += size_n;
  271. uint32_t q_1 = *b_ptr; b_ptr += size_n;
  272. uint32_t q_2 = *b_ptr; b_ptr += size_n;
  273. dequant_3bit_32(q_0, q_1, q_2, dq, size_n);
  274. for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2);
  275. half* dqh = (half*) dq;
  276. for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]);
  277. }
  278. k += 32;
  279. }
  280. while (k < rows_2 && k < end_k)
  281. {
  282. if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
  283. for (int p = 0; p < 1; p++)
  284. {
  285. half2 dq[8];
  286. uint32_t q_0 = *b_ptr; b_ptr += size_n;
  287. dequant_2bit_16(q_0, dq, size_n);
  288. for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2);
  289. half* dqh = (half*) dq;
  290. for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
  291. }
  292. k += 16;
  293. }
  294. }
  295. void QMatrix::reconstruct(half* out)
  296. {
  297. dim3 blockDim, gridDim;
  298. blockDim.x = BLOCK_KN_SIZE;
  299. blockDim.y = 1;
  300. gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
  301. {
  302. gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
  303. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  304. reconstruct_kernel<<<gridDim, blockDim, 0, stream>>>
  305. (
  306. cuda_q_weight,
  307. cuda_q_perm,
  308. cuda_q_scale,
  309. cuda_q_scale_max,
  310. cuda_q_group_map,
  311. height,
  312. width,
  313. //groupsize,
  314. groups,
  315. out,
  316. rows_8,
  317. rows_6,
  318. rows_5,
  319. rows_4,
  320. rows_3,
  321. rows_2
  322. );
  323. }
  324. }
  325. } // namespace exl2
  326. } // namespace aphrodite