mmvq.cuh 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. // copied and adapted from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmvq.cu
  2. template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
  3. static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, const int ncols, const int nrows) {
  4. const int row = blockIdx.x*blockDim.y + threadIdx.y;
  5. if (row >= nrows) {
  6. return;
  7. }
  8. const int blocks_per_row = ncols / qk;
  9. const int blocks_per_warp = vdr * WARP_SIZE / qi;
  10. // partial sum for each thread
  11. float tmp = 0.0f;
  12. const block_q_t * x = (const block_q_t *) vx;
  13. const block_q8_1 * y = (const block_q8_1 *) vy;
  14. for (int i = threadIdx.x / (qi/vdr); i < blocks_per_row; i += blocks_per_warp) {
  15. const int ibx = row*blocks_per_row + i; // x block index
  16. const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
  17. const int iqs = vdr * (threadIdx.x % (qi/vdr)); // x block quant index when casting the quants to int
  18. tmp += vec_dot_q_cuda(&x[ibx], &y[iby], iqs);
  19. }
  20. // sum up partial sums and write back result
  21. #pragma unroll
  22. for (int mask = 16; mask > 0; mask >>= 1) {
  23. tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
  24. }
  25. if (threadIdx.x == 0) {
  26. dst[row] = __float2half(tmp);
  27. }
  28. }
  29. static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
  30. const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
  31. const dim3 block_nums(block_num_y, 1, 1);
  32. const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
  33. mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
  34. <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
  35. }
  36. static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
  37. const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
  38. const dim3 block_nums(block_num_y, 1, 1);
  39. const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
  40. mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
  41. <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
  42. }
  43. static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
  44. const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
  45. const dim3 block_nums(block_num_y, 1, 1);
  46. const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
  47. mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
  48. <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
  49. }
  50. static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
  51. const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
  52. const dim3 block_nums(block_num_y, 1, 1);
  53. const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
  54. mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
  55. <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
  56. }
  57. static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
  58. const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
  59. const dim3 block_nums(block_num_y, 1, 1);
  60. const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
  61. mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
  62. <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
  63. }
  64. static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
  65. const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
  66. const dim3 block_nums(block_num_y, 1, 1);
  67. const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
  68. mul_mat_vec_q<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
  69. <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
  70. }
  71. static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
  72. const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
  73. const dim3 block_nums(block_num_y, 1, 1);
  74. const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
  75. mul_mat_vec_q<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
  76. <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
  77. }
  78. static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
  79. const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
  80. const dim3 block_nums(block_num_y, 1, 1);
  81. const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
  82. mul_mat_vec_q<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
  83. <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
  84. }
  85. static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
  86. const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
  87. const dim3 block_nums(block_num_y, 1, 1);
  88. const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
  89. mul_mat_vec_q<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
  90. <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
  91. }
  92. static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
  93. const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
  94. const dim3 block_nums(block_num_y, 1, 1);
  95. const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
  96. mul_mat_vec_q<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
  97. <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
  98. }
  99. static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
  100. const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
  101. const dim3 block_nums(block_num_y, 1, 1);
  102. const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
  103. mul_mat_vec_q<QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
  104. <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
  105. }
  106. static void mul_mat_vec_iq2_xs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
  107. const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
  108. const dim3 block_nums(block_num_y, 1, 1);
  109. const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
  110. mul_mat_vec_q<QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
  111. <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
  112. }
  113. static void mul_mat_vec_iq2_s_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
  114. const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
  115. const dim3 block_nums(block_num_y, 1, 1);
  116. const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
  117. mul_mat_vec_q<QK_K, QI2_S, block_iq2_s, 1, vec_dot_iq2_s_q8_1>
  118. <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
  119. }
  120. static void mul_mat_vec_iq3_xxs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
  121. const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
  122. const dim3 block_nums(block_num_y, 1, 1);
  123. const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
  124. mul_mat_vec_q<QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
  125. <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
  126. }
  127. static void mul_mat_vec_iq1_s_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
  128. const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
  129. const dim3 block_nums(block_num_y, 1, 1);
  130. const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
  131. mul_mat_vec_q<QK_K, QI1_S, block_iq1_s, 1, vec_dot_iq1_s_q8_1>
  132. <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
  133. }
  134. static void mul_mat_vec_iq4_nl_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
  135. const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
  136. const dim3 block_nums(block_num_y, 1, 1);
  137. const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
  138. mul_mat_vec_q<QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1>
  139. <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
  140. }
  141. static void mul_mat_vec_iq4_xs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
  142. const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
  143. const dim3 block_nums(block_num_y, 1, 1);
  144. const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
  145. mul_mat_vec_q<QK_K, QI4_XS, block_iq4_xs, 1, vec_dot_iq4_xs_q8_1>
  146. <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
  147. }
  148. static void mul_mat_vec_iq3_s_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
  149. const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
  150. const dim3 block_nums(block_num_y, 1, 1);
  151. const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
  152. mul_mat_vec_q<QK_K, QI3_XS, block_iq3_s, 1, vec_dot_iq3_s_q8_1>
  153. <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
  154. }