mmq.cuh 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600
  1. // copied from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmq.cu
  2. template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
  3. allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
  4. static __device__ __forceinline__ void mul_mat_q(
  5. const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
  6. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
  7. const block_q_t * x = (const block_q_t *) vx;
  8. const block_q8_1 * y = (const block_q8_1 *) vy;
  9. const int blocks_per_row_x = ncols_x / qk;
  10. const int blocks_per_col_y = nrows_y / QK8_1;
  11. const int blocks_per_warp = WARP_SIZE / qi;
  12. const int & ncols_dst = ncols_y;
  13. const int row_dst_0 = blockIdx.x*mmq_y;
  14. const int & row_x_0 = row_dst_0;
  15. const int col_dst_0 = blockIdx.y*mmq_x;
  16. const int & col_y_0 = col_dst_0;
  17. int * tile_x_ql = nullptr;
  18. half2 * tile_x_dm = nullptr;
  19. int * tile_x_qh = nullptr;
  20. int * tile_x_sc = nullptr;
  21. allocate_tiles(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc);
  22. __shared__ int tile_y_qs[mmq_x * WARP_SIZE];
  23. __shared__ half2 tile_y_ds[mmq_x * WARP_SIZE/QI8_1];
  24. float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {{0.0f}};
  25. for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) {
  26. load_tiles(x + row_x_0*blocks_per_row_x + ib0, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc,
  27. threadIdx.y, nrows_x-row_x_0-1, threadIdx.x, blocks_per_row_x);
  28. #pragma unroll
  29. for (int ir = 0; ir < qr; ++ir) {
  30. const int kqs = ir*WARP_SIZE + threadIdx.x;
  31. const int kbxd = kqs / QI8_1;
  32. #pragma unroll
  33. for (int i = 0; i < mmq_x; i += nwarps) {
  34. const int col_y_eff = min(col_y_0 + threadIdx.y + i, ncols_y-1); // to prevent out-of-bounds memory accesses
  35. const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kbxd];
  36. const int index_y = (threadIdx.y + i) * WARP_SIZE + kqs % WARP_SIZE;
  37. tile_y_qs[index_y] = get_int_from_int8_aligned(by0->qs, threadIdx.x % QI8_1);
  38. }
  39. #pragma unroll
  40. for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) {
  41. const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE/QI8_1)) % mmq_x;
  42. const int kby = threadIdx.x % (WARP_SIZE/QI8_1);
  43. const int col_y_eff = min(col_y_0 + ids, ncols_y-1);
  44. // if the sum is not needed it's faster to transform the scale to f32 ahead of time
  45. const half2 * dsi_src = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + ir*(WARP_SIZE/QI8_1) + kby].ds;
  46. half2 * dsi_dst = &tile_y_ds[ids * (WARP_SIZE/QI8_1) + kby];
  47. if (need_sum) {
  48. *dsi_dst = *dsi_src;
  49. } else {
  50. float * dfi_dst = (float *) dsi_dst;
  51. *dfi_dst = __low2float(*dsi_src);
  52. }
  53. }
  54. __syncthreads();
  55. // #pragma unroll // unrolling this loop causes too much register pressure
  56. for (int k = ir*WARP_SIZE/qr; k < (ir+1)*WARP_SIZE/qr; k += vdr) {
  57. #pragma unroll
  58. for (int j = 0; j < mmq_x; j += nwarps) {
  59. #pragma unroll
  60. for (int i = 0; i < mmq_y; i += WARP_SIZE) {
  61. sum[i/WARP_SIZE][j/nwarps] += vec_dot(
  62. tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds,
  63. threadIdx.x + i, threadIdx.y + j, k);
  64. }
  65. }
  66. }
  67. __syncthreads();
  68. }
  69. }
  70. #pragma unroll
  71. for (int j = 0; j < mmq_x; j += nwarps) {
  72. const int col_dst = col_dst_0 + j + threadIdx.y;
  73. if (col_dst >= ncols_dst) {
  74. return;
  75. }
  76. #pragma unroll
  77. for (int i = 0; i < mmq_y; i += WARP_SIZE) {
  78. const int row_dst = row_dst_0 + threadIdx.x + i;
  79. if (row_dst >= nrows_dst) {
  80. continue;
  81. }
  82. dst[col_dst*nrows_dst + row_dst] = __float2half(sum[i/WARP_SIZE][j/nwarps]);
  83. }
  84. }
  85. }
  86. #if defined(USE_ROCM)
  87. #define MMQ_X_Q4_0 64
  88. #define MMQ_Y_Q4_0 128
  89. #define NWARPS_Q4_0 8
  90. #else
  91. #define MMQ_X_Q4_0 4
  92. #define MMQ_Y_Q4_0 32
  93. #define NWARPS_Q4_0 4
  94. #endif
  95. template <bool need_check> static __global__ void
  96. #if defined(USE_ROCM)
  97. __launch_bounds__(WARP_SIZE*NWARPS_Q4_0, 2)
  98. #endif
  99. mul_mat_q4_0(
  100. const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
  101. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
  102. const int mmq_x = MMQ_X_Q4_0;
  103. const int mmq_y = MMQ_Y_Q4_0;
  104. const int nwarps = NWARPS_Q4_0;
  105. mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, allocate_tiles_q4_0<mmq_y>,
  106. load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
  107. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  108. }
  109. static void ggml_mul_mat_q4_0_q8_1_cuda(
  110. const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
  111. const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
  112. int mmq_x = MMQ_X_Q4_0;
  113. int mmq_y = MMQ_Y_Q4_0;
  114. int nwarps = NWARPS_Q4_0;
  115. const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
  116. const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
  117. const dim3 block_nums(block_num_x, block_num_y, 1);
  118. const dim3 block_dims(WARP_SIZE, nwarps, 1);
  119. if (nrows_x % mmq_y == 0) {
  120. const bool need_check = false;
  121. mul_mat_q4_0<need_check><<<block_nums, block_dims, 0, stream>>>
  122. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  123. } else {
  124. const bool need_check = true;
  125. mul_mat_q4_0<need_check><<<block_nums, block_dims, 0, stream>>>
  126. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  127. }
  128. }
  129. #if defined(USE_ROCM)
  130. #define MMQ_X_Q4_1 64
  131. #define MMQ_Y_Q4_1 128
  132. #define NWARPS_Q4_1 8
  133. #else
  134. #define MMQ_X_Q4_1 4
  135. #define MMQ_Y_Q4_1 32
  136. #define NWARPS_Q4_1 4
  137. #endif
  138. template <bool need_check> static __global__ void
  139. #if defined(USE_ROCM)
  140. __launch_bounds__(WARP_SIZE*NWARPS_Q4_1, 2)
  141. #endif
  142. mul_mat_q4_1(
  143. const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
  144. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
  145. const int mmq_x = MMQ_X_Q4_1;
  146. const int mmq_y = MMQ_Y_Q4_1;
  147. const int nwarps = NWARPS_Q4_1;
  148. mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, allocate_tiles_q4_1<mmq_y>,
  149. load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
  150. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  151. }
  152. static void ggml_mul_mat_q4_1_q8_1_cuda(
  153. const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
  154. const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
  155. int mmq_x = MMQ_X_Q4_1;
  156. int mmq_y = MMQ_Y_Q4_1;
  157. int nwarps = NWARPS_Q4_1;
  158. const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
  159. const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
  160. const dim3 block_nums(block_num_x, block_num_y, 1);
  161. const dim3 block_dims(WARP_SIZE, nwarps, 1);
  162. if (nrows_x % mmq_y == 0) {
  163. const bool need_check = false;
  164. mul_mat_q4_1<need_check><<<block_nums, block_dims, 0, stream>>>
  165. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  166. } else {
  167. const bool need_check = true;
  168. mul_mat_q4_1<need_check><<<block_nums, block_dims, 0, stream>>>
  169. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  170. }
  171. }
  172. #if defined(USE_ROCM)
  173. #define MMQ_X_Q5_0 64
  174. #define MMQ_Y_Q5_0 128
  175. #define NWARPS_Q5_0 8
  176. #else
  177. #define MMQ_X_Q5_0 4
  178. #define MMQ_Y_Q5_0 32
  179. #define NWARPS_Q5_0 4
  180. #endif
  181. template <bool need_check> static __global__ void
  182. #if defined(USE_ROCM)
  183. __launch_bounds__(WARP_SIZE*NWARPS_Q5_0, 2)
  184. #endif
  185. mul_mat_q5_0(
  186. const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
  187. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
  188. const int mmq_x = MMQ_X_Q5_0;
  189. const int mmq_y = MMQ_Y_Q5_0;
  190. const int nwarps = NWARPS_Q5_0;
  191. mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, allocate_tiles_q5_0<mmq_y>,
  192. load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
  193. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  194. }
  195. static void ggml_mul_mat_q5_0_q8_1_cuda(
  196. const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
  197. const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
  198. const int mmq_x = MMQ_X_Q5_0;
  199. const int mmq_y = MMQ_Y_Q5_0;
  200. const int nwarps = NWARPS_Q5_0;
  201. const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
  202. const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
  203. const dim3 block_nums(block_num_x, block_num_y, 1);
  204. const dim3 block_dims(WARP_SIZE, nwarps, 1);
  205. if (nrows_x % mmq_y == 0) {
  206. const bool need_check = false;
  207. mul_mat_q5_0<need_check><<<block_nums, block_dims, 0, stream>>>
  208. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  209. } else {
  210. const bool need_check = true;
  211. mul_mat_q5_0<need_check><<<block_nums, block_dims, 0, stream>>>
  212. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  213. }
  214. }
  215. #if defined(USE_ROCM)
  216. #define MMQ_X_Q5_1 64
  217. #define MMQ_Y_Q5_1 128
  218. #define NWARPS_Q5_1 8
  219. #else
  220. #define MMQ_X_Q5_1 4
  221. #define MMQ_Y_Q5_1 32
  222. #define NWARPS_Q5_1 4
  223. #endif
  224. template <bool need_check> static __global__ void
  225. #if defined(USE_ROCM)
  226. __launch_bounds__(WARP_SIZE*NWARPS_Q5_1, 2)
  227. #endif
  228. mul_mat_q5_1(
  229. const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
  230. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
  231. const int mmq_x = MMQ_X_Q5_1;
  232. const int mmq_y = MMQ_Y_Q5_1;
  233. const int nwarps = NWARPS_Q5_1;
  234. mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, allocate_tiles_q5_1<mmq_y>,
  235. load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
  236. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  237. }
  238. static void ggml_mul_mat_q5_1_q8_1_cuda(
  239. const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
  240. const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
  241. const int mmq_x = MMQ_X_Q5_1;
  242. const int mmq_y = MMQ_Y_Q5_1;
  243. const int nwarps = NWARPS_Q5_1;
  244. const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
  245. const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
  246. const dim3 block_nums(block_num_x, block_num_y, 1);
  247. const dim3 block_dims(WARP_SIZE, nwarps, 1);
  248. if (nrows_x % mmq_y == 0) {
  249. const bool need_check = false;
  250. mul_mat_q5_1<need_check><<<block_nums, block_dims, 0, stream>>>
  251. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  252. } else {
  253. const bool need_check = true;
  254. mul_mat_q5_1<need_check><<<block_nums, block_dims, 0, stream>>>
  255. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  256. }
  257. }
  258. #if defined(USE_ROCM)
  259. #define MMQ_X_Q8_0 64
  260. #define MMQ_Y_Q8_0 128
  261. #define NWARPS_Q8_0 8
  262. #else
  263. #define MMQ_X_Q8_0 4
  264. #define MMQ_Y_Q8_0 32
  265. #define NWARPS_Q8_0 4
  266. #endif
  267. template <bool need_check> static __global__ void
  268. #if defined(USE_ROCM)
  269. __launch_bounds__(WARP_SIZE*NWARPS_Q8_0, 2)
  270. #endif
  271. mul_mat_q8_0(
  272. const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
  273. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
  274. const int mmq_x = MMQ_X_Q8_0;
  275. const int mmq_y = MMQ_Y_Q8_0;
  276. const int nwarps = NWARPS_Q8_0;
  277. mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, allocate_tiles_q8_0<mmq_y>,
  278. load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
  279. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  280. }
  281. static void ggml_mul_mat_q8_0_q8_1_cuda(
  282. const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
  283. const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
  284. const int mmq_x = MMQ_X_Q8_0;
  285. const int mmq_y = MMQ_Y_Q8_0;
  286. const int nwarps = NWARPS_Q8_0;
  287. const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
  288. const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
  289. const dim3 block_nums(block_num_x, block_num_y, 1);
  290. const dim3 block_dims(WARP_SIZE, nwarps, 1);
  291. if (nrows_x % mmq_y == 0) {
  292. const bool need_check = false;
  293. mul_mat_q8_0<need_check><<<block_nums, block_dims, 0, stream>>>
  294. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  295. } else {
  296. const bool need_check = true;
  297. mul_mat_q8_0<need_check><<<block_nums, block_dims, 0, stream>>>
  298. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  299. }
  300. }
  301. #if defined(USE_ROCM)
  302. #define MMQ_X_Q2_K 64
  303. #define MMQ_Y_Q2_K 128
  304. #define NWARPS_Q2_K 8
  305. #else
  306. #define MMQ_X_Q2_K 4
  307. #define MMQ_Y_Q2_K 32
  308. #define NWARPS_Q2_K 4
  309. #endif
  310. template <bool need_check> static __global__ void
  311. #if defined(USE_ROCM)
  312. __launch_bounds__(WARP_SIZE*NWARPS_Q2_K, 2)
  313. #endif
  314. mul_mat_q2_K(
  315. const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
  316. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
  317. const int mmq_x = MMQ_X_Q2_K;
  318. const int mmq_y = MMQ_Y_Q2_K;
  319. const int nwarps = NWARPS_Q2_K;
  320. mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, allocate_tiles_q2_K<mmq_y>,
  321. load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
  322. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  323. }
  324. static void ggml_mul_mat_q2_K_q8_1_cuda(
  325. const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
  326. const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
  327. const int mmq_x = MMQ_X_Q2_K;
  328. const int mmq_y = MMQ_Y_Q2_K;
  329. const int nwarps = NWARPS_Q2_K;
  330. const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
  331. const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
  332. const dim3 block_nums(block_num_x, block_num_y, 1);
  333. const dim3 block_dims(WARP_SIZE, nwarps, 1);
  334. if (nrows_x % mmq_y == 0) {
  335. const bool need_check = false;
  336. mul_mat_q2_K<need_check><<<block_nums, block_dims, 0, stream>>>
  337. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  338. } else {
  339. const bool need_check = true;
  340. mul_mat_q2_K<need_check><<<block_nums, block_dims, 0, stream>>>
  341. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  342. }
  343. }
  344. #if defined(USE_ROCM)
  345. #define MMQ_X_Q3_K 64
  346. #define MMQ_Y_Q3_K 128
  347. #define NWARPS_Q3_K 8
  348. #else
  349. #define MMQ_X_Q3_K 4
  350. #define MMQ_Y_Q3_K 32
  351. #define NWARPS_Q3_K 4
  352. #endif
  353. template <bool need_check> static __global__ void
  354. #if defined(USE_ROCM)
  355. __launch_bounds__(WARP_SIZE*NWARPS_Q3_K, 2)
  356. #endif
  357. mul_mat_q3_K(
  358. const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
  359. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
  360. const int mmq_x = MMQ_X_Q3_K;
  361. const int mmq_y = MMQ_Y_Q3_K;
  362. const int nwarps = NWARPS_Q3_K;
  363. mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, allocate_tiles_q3_K<mmq_y>,
  364. load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
  365. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  366. }
  367. static void ggml_mul_mat_q3_K_q8_1_cuda(
  368. const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
  369. const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
  370. const int mmq_x = MMQ_X_Q3_K;
  371. const int mmq_y = MMQ_Y_Q3_K;
  372. const int nwarps = NWARPS_Q3_K;
  373. const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
  374. const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
  375. const dim3 block_nums(block_num_x, block_num_y, 1);
  376. const dim3 block_dims(WARP_SIZE, nwarps, 1);
  377. if (nrows_x % mmq_y == 0) {
  378. const bool need_check = false;
  379. mul_mat_q3_K<need_check><<<block_nums, block_dims, 0, stream>>>
  380. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  381. } else {
  382. const bool need_check = true;
  383. mul_mat_q3_K<need_check><<<block_nums, block_dims, 0, stream>>>
  384. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  385. }
  386. }
  387. #if defined(USE_ROCM)
  388. #define MMQ_X_Q4_K 64
  389. #define MMQ_Y_Q4_K 128
  390. #define NWARPS_Q4_K 8
  391. #else
  392. #define MMQ_X_Q4_K 4
  393. #define MMQ_Y_Q4_K 32
  394. #define NWARPS_Q4_K 4
  395. #endif
  396. template <bool need_check> static __global__ void
  397. #if defined(USE_ROCM)
  398. __launch_bounds__(WARP_SIZE*NWARPS_Q4_K, 2)
  399. #endif
  400. mul_mat_q4_K(
  401. const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
  402. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
  403. const int mmq_x = MMQ_X_Q4_K;
  404. const int mmq_y = MMQ_Y_Q4_K;
  405. const int nwarps = NWARPS_Q4_K;
  406. mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, allocate_tiles_q4_K<mmq_y>,
  407. load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
  408. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  409. }
  410. static void ggml_mul_mat_q4_K_q8_1_cuda(
  411. const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
  412. const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
  413. const int mmq_x = MMQ_X_Q4_K;
  414. const int mmq_y = MMQ_Y_Q4_K;
  415. const int nwarps = NWARPS_Q4_K;
  416. const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
  417. const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
  418. const dim3 block_nums(block_num_x, block_num_y, 1);
  419. const dim3 block_dims(WARP_SIZE, nwarps, 1);
  420. if (nrows_x % mmq_y == 0) {
  421. const bool need_check = false;
  422. mul_mat_q4_K<need_check><<<block_nums, block_dims, 0, stream>>>
  423. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  424. } else {
  425. const bool need_check = true;
  426. mul_mat_q4_K<need_check><<<block_nums, block_dims, 0, stream>>>
  427. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  428. }
  429. }
  430. #if defined(USE_ROCM)
  431. #define MMQ_X_Q5_K 64
  432. #define MMQ_Y_Q5_K 128
  433. #define NWARPS_Q5_K 8
  434. #else
  435. #define MMQ_X_Q5_K 4
  436. #define MMQ_Y_Q5_K 32
  437. #define NWARPS_Q5_K 4
  438. #endif
  439. template <bool need_check> static __global__ void
  440. #if defined(USE_ROCM)
  441. __launch_bounds__(WARP_SIZE*NWARPS_Q5_K, 2)
  442. #endif
  443. mul_mat_q5_K(
  444. const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
  445. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
  446. const int mmq_x = MMQ_X_Q5_K;
  447. const int mmq_y = MMQ_Y_Q5_K;
  448. const int nwarps = NWARPS_Q5_K;
  449. mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, allocate_tiles_q5_K<mmq_y>,
  450. load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
  451. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  452. }
  453. static void ggml_mul_mat_q5_K_q8_1_cuda(
  454. const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
  455. const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
  456. const int mmq_x = MMQ_X_Q5_K;
  457. const int mmq_y = MMQ_Y_Q5_K;
  458. const int nwarps = NWARPS_Q5_K;
  459. const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
  460. const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
  461. const dim3 block_nums(block_num_x, block_num_y, 1);
  462. const dim3 block_dims(WARP_SIZE, nwarps, 1);
  463. if (nrows_x % mmq_y == 0) {
  464. const bool need_check = false;
  465. mul_mat_q5_K<need_check><<<block_nums, block_dims, 0, stream>>>
  466. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  467. } else {
  468. const bool need_check = true;
  469. mul_mat_q5_K<need_check><<<block_nums, block_dims, 0, stream>>>
  470. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  471. }
  472. }
  473. #if defined(USE_ROCM)
  474. #define MMQ_X_Q6_K 64
  475. #define MMQ_Y_Q6_K 128
  476. #define NWARPS_Q6_K 8
  477. #else
  478. #define MMQ_X_Q6_K 4
  479. #define MMQ_Y_Q6_K 32
  480. #define NWARPS_Q6_K 4
  481. #endif
  482. template <bool need_check> static __global__ void
  483. #if defined(USE_ROCM)
  484. __launch_bounds__(WARP_SIZE*NWARPS_Q6_K, 2)
  485. #endif
  486. mul_mat_q6_K(
  487. const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
  488. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
  489. const int mmq_x = MMQ_X_Q6_K;
  490. const int mmq_y = MMQ_Y_Q6_K;
  491. const int nwarps = NWARPS_Q6_K;
  492. mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, allocate_tiles_q6_K<mmq_y>,
  493. load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
  494. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  495. }
  496. static void ggml_mul_mat_q6_K_q8_1_cuda(
  497. const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
  498. const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
  499. const int mmq_x = MMQ_X_Q6_K;
  500. const int mmq_y = MMQ_Y_Q6_K;
  501. const int nwarps = NWARPS_Q6_K;
  502. const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
  503. const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
  504. const dim3 block_nums(block_num_x, block_num_y, 1);
  505. const dim3 block_dims(WARP_SIZE, nwarps, 1);
  506. if (nrows_x % mmq_y == 0) {
  507. const bool need_check = false;
  508. mul_mat_q6_K<need_check><<<block_nums, block_dims, 0, stream>>>
  509. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  510. } else {
  511. const bool need_check = true;
  512. mul_mat_q6_K<need_check><<<block_nums, block_dims, 0, stream>>>
  513. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  514. }
  515. }