1
0

q_matrix.cu 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. /*
  2. * Adapted from https://github.com/turboderp/exllamav2
  3. * Copyright (c) 2024 turboderp
  4. *
  5. * Permission is hereby granted, free of charge, to any person obtaining a copy
  6. * of this software and associated documentation files (the "Software"), to deal
  7. * in the Software without restriction, including without limitation the rights
  8. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  9. * copies of the Software, and to permit persons to whom the Software is
  10. * furnished to do so, subject to the following conditions:
  11. *
  12. * The above copyright notice and this permission notice shall be included in
  13. * all copies or substantial portions of the Software.
  14. *
  15. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  16. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  17. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  18. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  19. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  20. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  21. * SOFTWARE.
  22. */
  23. #include <torch/all.h>
  24. #include <c10/cuda/CUDAGuard.h>
  25. #include <ATen/cuda/CUDAContext.h>
  26. #include <cuda_runtime.h>
  27. #include "q_matrix.cuh"
  28. #include "matrix_view.cuh"
  29. #include "quant/qdq_2.cuh"
  30. #include "quant/qdq_3.cuh"
  31. #include "quant/qdq_4.cuh"
  32. #include "quant/qdq_5.cuh"
  33. #include "quant/qdq_6.cuh"
  34. #include "quant/qdq_8.cuh"
  35. namespace aphrodite {
  36. namespace exl2 {
  37. #define BLOCK_KN_SIZE 128
  38. #define THREADS_X 32
  39. #define THREADS_Y 32
  40. #define DIVIDE(x, size) (((x) + (size) - 1) / (size))
  41. // Shuffle quantized data on load
  42. __global__ void shuffle_kernel(uint32_t* __restrict__ b_q_weight,
  43. const int size_k, const int size_n,
  44. const int rows_8, const int rows_6,
  45. const int rows_5, const int rows_4,
  46. const int rows_3, const int rows_2) {
  47. int n = blockIdx.x * THREADS_X + threadIdx.x;
  48. if (n >= size_n) return;
  49. int k = 0;
  50. uint32_t* b_ptr = b_q_weight + n;
  51. while (k < rows_8) {
  52. shuffle_8bit_4(b_ptr, size_n);
  53. b_ptr += 1 * size_n;
  54. k += 4;
  55. }
  56. while (k < rows_6) {
  57. shuffle_6bit_16(b_ptr, size_n);
  58. b_ptr += 3 * size_n;
  59. k += 16;
  60. }
  61. while (k < rows_5) {
  62. shuffle_5bit_32(b_ptr, size_n);
  63. b_ptr += 5 * size_n;
  64. k += 32;
  65. }
  66. while (k < rows_4) {
  67. shuffle_4bit_8(b_ptr, size_n);
  68. b_ptr += 1 * size_n;
  69. k += 8;
  70. }
  71. while (k < rows_3) {
  72. shuffle_3bit_32(b_ptr, size_n);
  73. b_ptr += 3 * size_n;
  74. k += 32;
  75. }
  76. while (k < rows_2) {
  77. shuffle_2bit_16(b_ptr, size_n);
  78. b_ptr += 1 * size_n;
  79. k += 16;
  80. }
  81. }
  82. // QMatrix constructor
  83. QMatrix::QMatrix(const int _device, const int _height, const int _width,
  84. const int _groups,
  85. uint32_t* _q_weight, uint16_t* _q_perm, uint16_t* _q_invperm,
  86. uint32_t* _q_scale, half* _q_scale_max, uint16_t* _q_groups,
  87. uint16_t* _q_group_map)
  88. : device(_device), height(_height), width(_width), groups(_groups) {
  89. cudaSetDevice(device);
  90. failed = false;
  91. cuda_q_weight = _q_weight;
  92. cuda_q_perm = _q_perm;
  93. cuda_q_invperm = _q_invperm;
  94. cuda_q_scale = _q_scale;
  95. cuda_q_scale_max = _q_scale_max;
  96. cuda_q_groups = _q_groups;
  97. cuda_q_group_map = _q_group_map;
  98. // Create group map
  99. rows_8 = 0;
  100. rows_6 = 0;
  101. rows_5 = 0;
  102. rows_4 = 0;
  103. rows_3 = 0;
  104. rows_2 = 0;
  105. {
  106. uint16_t* cpu_q_groups = (uint16_t*)calloc(groups * 2, sizeof(uint16_t));
  107. cudaMemcpy(cpu_q_groups, cuda_q_groups, groups * 2 * sizeof(uint16_t),
  108. cudaMemcpyDeviceToHost);
  109. int row = 0;
  110. for (int i = 0; i < groups; i++) {
  111. int bits = cpu_q_groups[i * 2];
  112. int rows;
  113. if (i < groups - 1) {
  114. int qrows = cpu_q_groups[i * 2 + 3] - cpu_q_groups[i * 2 + 1];
  115. rows = qrows * 32 / bits;
  116. } else
  117. rows = height - row;
  118. if (bits == 8) rows_8 += rows;
  119. if (bits == 6) rows_6 += rows;
  120. if (bits == 5) rows_5 += rows;
  121. if (bits == 4) rows_4 += rows;
  122. if (bits == 3) rows_3 += rows;
  123. if (bits == 2) rows_2 += rows;
  124. row += rows;
  125. }
  126. free(cpu_q_groups);
  127. rows_6 += rows_8;
  128. rows_5 += rows_6;
  129. rows_4 += rows_5;
  130. rows_3 += rows_4;
  131. rows_2 += rows_3;
  132. }
  133. // Shuffle quantized data
  134. dim3 blockDim, gridDim;
  135. blockDim.x = THREADS_X;
  136. blockDim.y = 1;
  137. gridDim.x = DIVIDE(width, THREADS_X);
  138. gridDim.y = 1;
  139. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  140. shuffle_kernel<<<gridDim, blockDim, 0, stream>>>(cuda_q_weight, height, width,
  141. rows_8, rows_6, rows_5,
  142. rows_4, rows_3, rows_2);
  143. }
  144. QMatrix::~QMatrix() {}
  145. // Reconstruct b[k,n]
  146. __global__ void reconstruct_kernel(const uint32_t* __restrict__ b_q_weight,
  147. const uint16_t* __restrict__ b_q_perm,
  148. const uint32_t* __restrict__ b_q_scale,
  149. const half* __restrict__ b_q_scale_max,
  150. const uint16_t* __restrict__ b_q_group_map,
  151. const int size_k, const int size_n,
  152. // const int groupsize,
  153. const int groups, half* __restrict__ b,
  154. const int rows_8, const int rows_6,
  155. const int rows_5, const int rows_4,
  156. const int rows_3, const int rows_2) {
  157. MatrixView_half_rw b_(b, size_k, size_n);
  158. MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n);
  159. int offset_k = BLOCK_KN_SIZE * blockIdx.y;
  160. int offset_n = BLOCK_KN_SIZE * blockIdx.x;
  161. // Preload remapping table
  162. int t = threadIdx.x;
  163. __shared__ uint16_t perm[BLOCK_KN_SIZE];
  164. if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t];
  165. // Column
  166. int n = offset_n + t;
  167. if (n >= size_n) return;
  168. // Find initial group
  169. // int group = offset_k / groupsize;
  170. int group = b_q_group_map[offset_k * 2];
  171. int pre_rows_8 = min(rows_8, offset_k);
  172. int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;
  173. int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0;
  174. int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0;
  175. int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0;
  176. int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0;
  177. int qk = 0;
  178. qk += pre_rows_8 / 32 * 8;
  179. qk += pre_rows_6 / 32 * 6;
  180. qk += pre_rows_5 / 32 * 5;
  181. qk += pre_rows_4 / 32 * 4;
  182. qk += pre_rows_3 / 32 * 3;
  183. qk += pre_rows_2 / 32 * 2;
  184. const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
  185. half qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]);
  186. half2 qs_h2 = __halves2half2(qs_h, qs_h);
  187. int nextgroup = offset_k + b_q_group_map[offset_k * 2 + 1];
  188. int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
  189. int k = offset_k;
  190. int lk = 0;
  191. __syncthreads();
  192. while (k < rows_8 && k < end_k) {
  193. if (k == nextgroup) {
  194. group++;
  195. qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]);
  196. nextgroup += b_q_group_map[k * 2 + 1];
  197. qs_h2 = __halves2half2(qs_h, qs_h);
  198. }
  199. for (int p = 0; p < 4; p++) {
  200. half2 dq[4];
  201. uint32_t q_0 = *b_ptr;
  202. b_ptr += size_n;
  203. uint32_t q_1 = *b_ptr;
  204. b_ptr += size_n;
  205. dequant_8bit_8(q_0, q_1, dq, size_n);
  206. for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2);
  207. half* dqh = (half*)dq;
  208. for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]);
  209. }
  210. k += 32;
  211. }
  212. while (k < rows_6 && k < end_k) {
  213. if (k == nextgroup) {
  214. group++;
  215. qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]);
  216. nextgroup += b_q_group_map[k * 2 + 1];
  217. qs_h2 = __halves2half2(qs_h, qs_h);
  218. }
  219. for (int p = 0; p < 2; p++) {
  220. half2 dq[8];
  221. uint32_t q_0 = *b_ptr;
  222. b_ptr += size_n;
  223. uint32_t q_1 = *b_ptr;
  224. b_ptr += size_n;
  225. uint32_t q_2 = *b_ptr;
  226. b_ptr += size_n;
  227. dequant_6bit_16(q_0, q_1, q_2, dq, size_n);
  228. for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2);
  229. half* dqh = (half*)dq;
  230. for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
  231. }
  232. k += 32;
  233. }
  234. while (k < rows_5 && k < end_k) {
  235. if (k == nextgroup) {
  236. group++;
  237. qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]);
  238. nextgroup += b_q_group_map[k * 2 + 1];
  239. qs_h2 = __halves2half2(qs_h, qs_h);
  240. }
  241. for (int p = 0; p < 1; p++) {
  242. half2 dq[16];
  243. uint32_t q_0 = *b_ptr;
  244. b_ptr += size_n;
  245. uint32_t q_1 = *b_ptr;
  246. b_ptr += size_n;
  247. uint32_t q_2 = *b_ptr;
  248. b_ptr += size_n;
  249. uint32_t q_3 = *b_ptr;
  250. b_ptr += size_n;
  251. uint32_t q_4 = *b_ptr;
  252. b_ptr += size_n;
  253. dequant_5bit_32(q_0, q_1, q_2, q_3, q_4, dq, size_n);
  254. for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2);
  255. half* dqh = (half*)dq;
  256. for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]);
  257. }
  258. k += 32;
  259. }
  260. while (k < rows_4 && k < end_k) {
  261. if (k == nextgroup) {
  262. group++;
  263. qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]);
  264. nextgroup += b_q_group_map[k * 2 + 1];
  265. qs_h2 = __halves2half2(qs_h, qs_h);
  266. }
  267. for (int p = 0; p < 4; p++) {
  268. half2 dq[4];
  269. uint32_t q_0 = *b_ptr;
  270. b_ptr += size_n;
  271. dequant_4bit_8(q_0, dq, size_n);
  272. for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2);
  273. half* dqh = (half*)dq;
  274. for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]);
  275. }
  276. k += 32;
  277. }
  278. while (k < rows_3 && k < end_k) {
  279. if (k == nextgroup) {
  280. group++;
  281. qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]);
  282. nextgroup += b_q_group_map[k * 2 + 1];
  283. qs_h2 = __halves2half2(qs_h, qs_h);
  284. }
  285. for (int p = 0; p < 1; p++) {
  286. half2 dq[16];
  287. uint32_t q_0 = *b_ptr;
  288. b_ptr += size_n;
  289. uint32_t q_1 = *b_ptr;
  290. b_ptr += size_n;
  291. uint32_t q_2 = *b_ptr;
  292. b_ptr += size_n;
  293. dequant_3bit_32(q_0, q_1, q_2, dq, size_n);
  294. for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2);
  295. half* dqh = (half*)dq;
  296. for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]);
  297. }
  298. k += 32;
  299. }
  300. while (k < rows_2 && k < end_k) {
  301. if (k == nextgroup) {
  302. group++;
  303. qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]);
  304. nextgroup += b_q_group_map[k * 2 + 1];
  305. qs_h2 = __halves2half2(qs_h, qs_h);
  306. }
  307. for (int p = 0; p < 1; p++) {
  308. half2 dq[8];
  309. uint32_t q_0 = *b_ptr;
  310. b_ptr += size_n;
  311. dequant_2bit_16(q_0, dq, size_n);
  312. for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2);
  313. half* dqh = (half*)dq;
  314. for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
  315. }
  316. k += 16;
  317. }
  318. }
  319. void QMatrix::reconstruct(half* out) {
  320. dim3 blockDim, gridDim;
  321. blockDim.x = BLOCK_KN_SIZE;
  322. blockDim.y = 1;
  323. gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
  324. {
  325. gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
  326. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  327. reconstruct_kernel<<<gridDim, blockDim, 0, stream>>>(
  328. cuda_q_weight, cuda_q_perm, cuda_q_scale, cuda_q_scale_max,
  329. cuda_q_group_map, height, width,
  330. // groupsize,
  331. groups, out, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2);
  332. }
  333. }
  334. } // namespace exl2
  335. } // namespace aphrodite