q_gemm_kernel.cuh 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581
  1. /*
  2. * Adapted from https://github.com/turboderp/exllamav2
  3. * Copyright (c) 2024 turboderp
  4. *
  5. * Permission is hereby granted, free of charge, to any person obtaining a copy
  6. * of this software and associated documentation files (the "Software"), to deal
  7. * in the Software without restriction, including without limitation the rights
  8. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  9. * copies of the Software, and to permit persons to whom the Software is
  10. * furnished to do so, subject to the following conditions:
  11. *
  12. * The above copyright notice and this permission notice shall be included in all
  13. * copies or substantial portions of the Software.
  14. *
  15. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  16. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  17. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  18. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  19. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  20. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  21. * SOFTWARE.
  22. */
  23. #include "compat.cuh"
  24. namespace aphrodite {
  25. namespace exl2 {
  26. #define MAX_Q_GEMM_WEIGHTS 4
  27. #define EXL2_BLOCK_KN_SIZE 64
  28. #define EXL2_BLOCK_M_SIZE_MAX 8
  29. #define EXL2_MAX_GROUPS_IN_BLOCK (EXL2_BLOCK_KN_SIZE / 32)
  30. __forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h)
  31. {
  32. half2 result = {};
  33. const half2* a2_ptr = (const half2*)a_ptr;
  34. #pragma unroll
  35. for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
  36. return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
  37. }
  38. __forceinline__ __device__ half2 dot22_16(half2(&dq)[8], const half* a_ptr, const half2 g_result, const half qs_h)
  39. {
  40. half2 result = {};
  41. const half2* a2_ptr = (const half2*)a_ptr;
  42. #pragma unroll
  43. for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
  44. return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
  45. }
  46. __forceinline__ __device__ half2 dot22_32(half2(&dq)[16], const half* a_ptr, const half2 g_result, const half qs_h)
  47. {
  48. half2 result = {};
  49. const half2* a2_ptr = (const half2*)a_ptr;
  50. #pragma unroll
  51. for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
  52. return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
  53. }
  54. __forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr, const float g_result, const float qs_f)
  55. {
  56. half2 result = {};
  57. const half2* a2_ptr = (const half2*)a_ptr;
  58. #pragma unroll
  59. for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
  60. float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
  61. return fma(result_f, qs_f, g_result);
  62. }
  63. __forceinline__ __device__ float dot22_16_f(half2(&dq)[8], const half* a_ptr, const float g_result, const float qs_f)
  64. {
  65. half2 result = {};
  66. const half2* a2_ptr = (const half2*)a_ptr;
  67. #pragma unroll
  68. for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
  69. float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
  70. return fma(result_f, qs_f, g_result);
  71. }
  72. __forceinline__ __device__ float dot22_32_f(half2(&dq)[16], const half* a_ptr, const float g_result, const float qs_f)
  73. {
  74. half2 result = {};
  75. const half2* a2_ptr = (const half2*)a_ptr;
  76. #pragma unroll
  77. for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
  78. float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
  79. return fma(result_f, qs_f, g_result);
  80. }
  81. __forceinline__ __device__ half dot22_8_h(half2(&dq)[4], const half* a_ptr, const half g_result, const half qs_h)
  82. {
  83. // Use FP32 accumulator to avoid potential overflow since unscaled weights are in the range -128..127
  84. float result = {};
  85. #pragma unroll
  86. for (int i = 0; i < 4; i++)
  87. {
  88. half2 w01 = dq[i];
  89. float w0 = __low2float(w01);
  90. float w1 = __high2float(w01);
  91. float x0 = __half2float(*a_ptr++);
  92. float x1 = __half2float(*a_ptr++);
  93. result = fma(w0, x0, result);
  94. result = fma(w1, x1, result);
  95. }
  96. float qs = __half2float(qs_h);
  97. result *= qs;
  98. half result_h = __float2half_rn(result);
  99. return __hadd(result_h, g_result);
  100. }
  101. __forceinline__ __device__ half dot22_16_h(half2(&dq)[8], const half* a_ptr, const half g_result, const half qs_h)
  102. {
  103. half2 result = {};
  104. const half2* a2_ptr = (const half2*)a_ptr;
  105. #pragma unroll
  106. for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
  107. half result_h = __hadd(__low2half(result), __high2half(result));
  108. return __hfma(result_h, qs_h, g_result);
  109. }
  110. __forceinline__ __device__ half dot22_32_h(half2(&dq)[16], const half* a_ptr, const half g_result, const half qs_h)
  111. {
  112. half2 result = {};
  113. const half2* a2_ptr = (const half2*)a_ptr;
  114. #pragma unroll
  115. for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
  116. half result_h = __hadd(__low2half(result), __high2half(result));
  117. return __hfma(result_h, qs_h, g_result);
  118. }
  119. typedef void (*fp_gemm_half_q_half_kernel)
  120. (
  121. const half*,
  122. const uint32_t*,
  123. const uint32_t*,
  124. const half*,
  125. half*,
  126. const int,
  127. const int,
  128. const int,
  129. const int,
  130. const int,
  131. const uint16_t*,
  132. const uint16_t*,
  133. const int,
  134. const int,
  135. const int,
  136. const int,
  137. const int,
  138. const int,
  139. const bool
  140. );
  141. template <int m_count>
  142. __global__ void gemm_half_q_half_kernel
  143. (
  144. const half* __restrict__ a,
  145. const uint32_t* __restrict__ b_q_weight,
  146. const uint32_t* __restrict__ b_q_scale,
  147. const half* __restrict__ b_q_scale_max,
  148. half* __restrict__ c,
  149. const int size_m,
  150. const int size_n,
  151. const int size_k,
  152. const int height,
  153. const int groups,
  154. const uint16_t* __restrict__ b_q_group_map,
  155. const uint16_t* __restrict__ b_q_perm,
  156. const int rows_8,
  157. const int rows_6,
  158. const int rows_5,
  159. const int rows_4,
  160. const int rows_3,
  161. const int rows_2,
  162. const bool clear
  163. )
  164. {
  165. MatrixView_half a_(a, size_m, size_k);
  166. MatrixView_half_rw c_(c, size_m, size_n);
  167. MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n);
  168. int t = threadIdx.x;
  169. // Block
  170. int offset_n = blockIdx.x * EXL2_BLOCK_KN_SIZE * 4;
  171. int offset_m = blockIdx.y * m_count;
  172. int offset_k = blockIdx.z * EXL2_BLOCK_KN_SIZE;
  173. int end_n = min(offset_n + EXL2_BLOCK_KN_SIZE * 4, size_n);
  174. int end_m = min(offset_m + m_count, size_m);
  175. int end_k = min(offset_k + EXL2_BLOCK_KN_SIZE, height);
  176. int n = offset_n + t * 4;
  177. // Read weights
  178. half_uint16 weights[MAX_Q_GEMM_WEIGHTS];
  179. // Preload block_a
  180. __shared__ half block_a[m_count][EXL2_BLOCK_KN_SIZE];
  181. if (offset_k + t < end_k)
  182. {
  183. for (int m = 0; m < m_count; ++m)
  184. {
  185. const half* a_ptr = a_.item_ptr(offset_m + m, 0);
  186. half* block_a_ptr = block_a[m];
  187. half a0 = a_ptr[b_q_perm[offset_k + t]];
  188. // half a0 = a_ptr[offset_k + t];
  189. block_a_ptr[t] = a0;
  190. }
  191. }
  192. // Clear
  193. if (n >= size_n) return;
  194. if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0)
  195. {
  196. for (int m = 0; m < m_count; m++)
  197. *((uint64_t*) c_.item_ptr(offset_m + m, n)) = 0;
  198. }
  199. __syncthreads();
  200. // Find initial group
  201. //int group = offset_k / groupsize;
  202. int group = b_q_group_map[offset_k * 2];
  203. // if (offset_m == 0 && t == 0)
  204. // DBGI2(offset_k, group);
  205. // Preload scales
  206. half scales[EXL2_MAX_GROUPS_IN_BLOCK][4];
  207. //int groups_in_block = DIVIDE((end_k - offset_k), groupsize);
  208. int temp_k = offset_k;
  209. for (int g = 0; temp_k < end_k; g++)
  210. {
  211. int qscales[4];
  212. b_q_scale_.item4(qscales, group + g, n);
  213. qscales[0]++;
  214. qscales[1]++;
  215. qscales[2]++;
  216. qscales[3]++;
  217. half maxscale = b_q_scale_max[group + g];
  218. scales[g][0] = __hmul(__int2half_rn(qscales[0] * qscales[0]), maxscale);
  219. scales[g][1] = __hmul(__int2half_rn(qscales[1] * qscales[1]), maxscale);
  220. scales[g][2] = __hmul(__int2half_rn(qscales[2] * qscales[2]), maxscale);
  221. scales[g][3] = __hmul(__int2half_rn(qscales[3] * qscales[3]), maxscale);
  222. temp_k += b_q_group_map[temp_k * 2 + 1];
  223. }
  224. // a, b offset
  225. int pre_rows_8 = min(rows_8, offset_k);
  226. int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;
  227. int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0;
  228. int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0;
  229. int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0;
  230. int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0;
  231. int qk = 0;
  232. qk += pre_rows_8 / 32 * 8;
  233. qk += pre_rows_6 / 32 * 6;
  234. qk += pre_rows_5 / 32 * 5;
  235. qk += pre_rows_4 / 32 * 4;
  236. qk += pre_rows_3 / 32 * 3;
  237. qk += pre_rows_2 / 32 * 2;
  238. const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
  239. const half* a_ptr = &block_a[0][0];
  240. int a_stride = EXL2_BLOCK_KN_SIZE;
  241. // Initial group
  242. int scales_idx = 0;
  243. half qs_h0 = scales[scales_idx][0];
  244. half qs_h1 = scales[scales_idx][1];
  245. half qs_h2 = scales[scales_idx][2];
  246. half qs_h3 = scales[scales_idx][3];
  247. int nextgroup = offset_k + b_q_group_map[offset_k * 2 + 1];
  248. // Column result
  249. half block_c[m_count][4] = {};
  250. // Dequantize groups
  251. int k = offset_k;
  252. while (k < rows_8 && k < end_k)
  253. {
  254. if (k == nextgroup)
  255. {
  256. group++;
  257. scales_idx++;
  258. qs_h0 = scales[scales_idx][0];
  259. qs_h1 = scales[scales_idx][1];
  260. qs_h2 = scales[scales_idx][2];
  261. qs_h3 = scales[scales_idx][3];
  262. nextgroup += b_q_group_map[k * 2 + 1];
  263. }
  264. #pragma unroll
  265. for (int j = 0; j < 4; j++)
  266. {
  267. int4 load_int4[2];
  268. load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
  269. load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
  270. half2 dq[4][4];
  271. dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n);
  272. dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n);
  273. dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n);
  274. dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n);
  275. for (int m = 0; m < m_count; m++)
  276. {
  277. block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
  278. block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
  279. block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
  280. block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
  281. }
  282. a_ptr += 8;
  283. }
  284. k += 32;
  285. }
  286. while (k < rows_6 && k < end_k)
  287. {
  288. if (k == nextgroup)
  289. {
  290. group++;
  291. scales_idx++;
  292. qs_h0 = scales[scales_idx][0];
  293. qs_h1 = scales[scales_idx][1];
  294. qs_h2 = scales[scales_idx][2];
  295. qs_h3 = scales[scales_idx][3];
  296. nextgroup += b_q_group_map[k * 2 + 1];
  297. }
  298. #pragma unroll
  299. for (int j = 0; j < 2; j++)
  300. {
  301. int4 load_int4[3];
  302. load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
  303. load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
  304. load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
  305. half2 dq[4][8];
  306. dequant_6bit_16(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n);
  307. dequant_6bit_16(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n);
  308. dequant_6bit_16(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n);
  309. dequant_6bit_16(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n);
  310. for (int m = 0; m < m_count; m++)
  311. {
  312. block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
  313. block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
  314. block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
  315. block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
  316. }
  317. a_ptr += 16;
  318. }
  319. k += 32;
  320. }
  321. while (k < rows_5 && k < end_k)
  322. {
  323. if (k == nextgroup)
  324. {
  325. group++;
  326. scales_idx++;
  327. qs_h0 = scales[scales_idx][0];
  328. qs_h1 = scales[scales_idx][1];
  329. qs_h2 = scales[scales_idx][2];
  330. qs_h3 = scales[scales_idx][3];
  331. nextgroup += b_q_group_map[k * 2 + 1];
  332. }
  333. #pragma unroll
  334. for (int j = 0; j < 1; j++)
  335. {
  336. int4 load_int4[5];
  337. load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
  338. load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
  339. load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
  340. load_int4[3] = *((int4*) b_ptr); b_ptr += size_n;
  341. load_int4[4] = *((int4*) b_ptr); b_ptr += size_n;
  342. half2 dq[4][16];
  343. dequant_5bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, load_int4[3].x, load_int4[4].x, dq[0], size_n);
  344. dequant_5bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, load_int4[3].y, load_int4[4].y, dq[1], size_n);
  345. dequant_5bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, load_int4[3].z, load_int4[4].z, dq[2], size_n);
  346. dequant_5bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, load_int4[3].w, load_int4[4].w, dq[3], size_n);
  347. for (int m = 0; m < m_count; m++)
  348. {
  349. block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
  350. block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
  351. block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
  352. block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
  353. }
  354. a_ptr += 32;
  355. }
  356. k += 32;
  357. }
  358. while (k < rows_4 && k < end_k)
  359. {
  360. if (k == nextgroup)
  361. {
  362. group++;
  363. scales_idx++;
  364. qs_h0 = scales[scales_idx][0];
  365. qs_h1 = scales[scales_idx][1];
  366. qs_h2 = scales[scales_idx][2];
  367. qs_h3 = scales[scales_idx][3];
  368. nextgroup += b_q_group_map[k * 2 + 1];
  369. }
  370. #pragma unroll
  371. for (int j = 0; j < 4; j++)
  372. {
  373. int4 load_int4[1];
  374. load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
  375. half2 dq[4][4];
  376. dequant_4bit_8(load_int4[0].x, dq[0], size_n);
  377. dequant_4bit_8(load_int4[0].y, dq[1], size_n);
  378. dequant_4bit_8(load_int4[0].z, dq[2], size_n);
  379. dequant_4bit_8(load_int4[0].w, dq[3], size_n);
  380. for (int m = 0; m < m_count; m++)
  381. {
  382. block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
  383. block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
  384. block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
  385. block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
  386. }
  387. a_ptr += 8;
  388. }
  389. k += 32;
  390. }
  391. while (k < rows_3 && k < end_k)
  392. {
  393. if (k == nextgroup)
  394. {
  395. group++;
  396. scales_idx++;
  397. qs_h0 = scales[scales_idx][0];
  398. qs_h1 = scales[scales_idx][1];
  399. qs_h2 = scales[scales_idx][2];
  400. qs_h3 = scales[scales_idx][3];
  401. nextgroup += b_q_group_map[k * 2 + 1];
  402. }
  403. #pragma unroll
  404. for (int j = 0; j < 1; j++)
  405. {
  406. int4 load_int4[3];
  407. load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
  408. load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
  409. load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
  410. half2 dq[4][16];
  411. dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n);
  412. dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n);
  413. dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n);
  414. dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n);
  415. for (int m = 0; m < m_count; m++)
  416. {
  417. block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
  418. block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
  419. block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
  420. block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
  421. }
  422. a_ptr += 32;
  423. }
  424. k += 32;
  425. }
  426. while (k < rows_2 && k < end_k)
  427. {
  428. if (k == nextgroup)
  429. {
  430. group++;
  431. scales_idx++;
  432. qs_h0 = scales[scales_idx][0];
  433. qs_h1 = scales[scales_idx][1];
  434. qs_h2 = scales[scales_idx][2];
  435. qs_h3 = scales[scales_idx][3];
  436. nextgroup += b_q_group_map[k * 2 + 1];
  437. }
  438. #pragma unroll
  439. for (int j = 0; j < 1; j++)
  440. {
  441. int4 load_int4[1];
  442. load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
  443. half2 dq[4][8];
  444. dequant_2bit_16(load_int4[0].x, dq[0], size_n);
  445. dequant_2bit_16(load_int4[0].y, dq[1], size_n);
  446. dequant_2bit_16(load_int4[0].z, dq[2], size_n);
  447. dequant_2bit_16(load_int4[0].w, dq[3], size_n);
  448. for (int m = 0; m < m_count; m++)
  449. {
  450. block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
  451. block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
  452. block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
  453. block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
  454. }
  455. a_ptr += 16;
  456. }
  457. k += 16;
  458. }
  459. // Accumulate column sums in c
  460. for (int m = 0; m < m_count; m++)
  461. {
  462. half2* out = (half2*)c_.item_ptr(offset_m + m, n);
  463. half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]);
  464. half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]);
  465. atomicAdd(out , result01);
  466. atomicAdd(out + 1, result23);
  467. // *out = result01;
  468. // *(out + 1) = result23;
  469. }
  470. }
  471. struct map_m_count_exl2 {
  472. static constexpr fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(const int m_count)
  473. {
  474. #if EXL2_BLOCK_M_SIZE_MAX >= 1
  475. if (m_count == 1) return gemm_half_q_half_kernel<1>;
  476. #endif
  477. #if EXL2_BLOCK_M_SIZE_MAX >= 2
  478. if (m_count == 2) return gemm_half_q_half_kernel<2>;
  479. #endif
  480. #if EXL2_BLOCK_M_SIZE_MAX >= 3
  481. if (m_count == 3) return gemm_half_q_half_kernel<3>;
  482. #endif
  483. #if EXL2_BLOCK_M_SIZE_MAX >= 4
  484. if (m_count == 4) return gemm_half_q_half_kernel<4>;
  485. #endif
  486. #if EXL2_BLOCK_M_SIZE_MAX >= 5
  487. if (m_count == 5) return gemm_half_q_half_kernel<5>;
  488. #endif
  489. #if EXL2_BLOCK_M_SIZE_MAX >= 6
  490. if (m_count == 6) return gemm_half_q_half_kernel<6>;
  491. #endif
  492. #if EXL2_BLOCK_M_SIZE_MAX >= 7
  493. if (m_count == 7) return gemm_half_q_half_kernel<7>;
  494. #endif
  495. #if EXL2_BLOCK_M_SIZE_MAX >= 8
  496. if (m_count == 8) return gemm_half_q_half_kernel<8>;
  497. #endif
  498. return NULL;
  499. }
  500. };
  501. fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(const int m_count)
  502. {
  503. return map_m_count_exl2::pick_gemm_half_q_half_kernel(m_count);
  504. }
  505. } // namespace exl2
  506. } // namespace aphrodite