dequantize.cuh 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531
  1. // copied and adapted from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/convert.cu
  2. // Dequant functions
  3. static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
  4. const block_q4_0 * x = (const block_q4_0 *) vx;
  5. const dfloat d = x[ib].d;
  6. const int vui = x[ib].qs[iqs];
  7. v.x = __int2half_rn(vui & 0xF);
  8. v.y = __int2half_rn(vui >> 4);
  9. v = __hsub2(v, __floats2half2_rn(8.0f, 8.0f));
  10. v = __hmul2(v, {d, d});
  11. }
  12. static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
  13. const block_q4_1 * x = (const block_q4_1 *) vx;
  14. const dfloat d = __low2half(x[ib].dm);
  15. const dfloat m = __high2half(x[ib].dm);
  16. const int vui = x[ib].qs[iqs];
  17. v.x = __int2half_rn(vui & 0xF);
  18. v.y = __int2half_rn(vui >> 4);
  19. v = __hmul2(v, {d, d});
  20. v = __hadd2(v, {m, m});
  21. }
  22. static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
  23. const block_q5_0 * x = (const block_q5_0 *) vx;
  24. const dfloat d = x[ib].d;
  25. uint32_t qh;
  26. memcpy(&qh, x[ib].qh, sizeof(qh));
  27. const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
  28. const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
  29. v.x = __int2half_rn((x[ib].qs[iqs] & 0xf) | xh_0);
  30. v.y = __int2half_rn((x[ib].qs[iqs] >> 4) | xh_1);
  31. v = __hsub2(v, __floats2half2_rn(16.0f, 16.0f));
  32. v = __hmul2(v, {d, d});
  33. }
  34. static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
  35. const block_q5_1 * x = (const block_q5_1 *) vx;
  36. const dfloat d = __low2half(x[ib].dm);
  37. const dfloat m = __high2half(x[ib].dm);
  38. uint32_t qh;
  39. memcpy(&qh, x[ib].qh, sizeof(qh));
  40. const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
  41. const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
  42. v.x = __int2half_rn((x[ib].qs[iqs] & 0xf) | xh_0);
  43. v.y = __int2half_rn((x[ib].qs[iqs] >> 4) | xh_1);
  44. v = __hmul2(v, {d, d});
  45. v = __hadd2(v, {m, m});
  46. }
  47. static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
  48. const block_q8_0 * x = (const block_q8_0 *) vx;
  49. const dfloat d = x[ib].d;
  50. v.x = __int2half_rn(x[ib].qs[iqs + 0]);
  51. v.y = __int2half_rn(x[ib].qs[iqs + 1]);
  52. v = __hmul2(v, {d, d});
  53. }
  54. template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
  55. static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
  56. const int i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
  57. if (i >= k) {
  58. return;
  59. }
  60. const int ib = i/qk; // block index
  61. const int iqs = (i%qk)/qr; // quant index
  62. const int iybs = i - i%qk; // y block start index
  63. const int y_offset = qr == 1 ? 1 : qk/2;
  64. // dequantize
  65. dfloat2 v;
  66. dequantize_kernel(vx, ib, iqs, v);
  67. y[iybs + iqs + 0] = v.x;
  68. y[iybs + iqs + y_offset] = v.y;
  69. }
  70. template<typename dst_t>
  71. static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  72. const int i = blockIdx.x;
  73. const block_q2_K * x = (const block_q2_K *) vx;
  74. const int tid = threadIdx.x;
  75. const int n = tid/32;
  76. const int l = tid - 32*n;
  77. const int is = 8*n + l/16;
  78. const uint8_t q = x[i].qs[32*n + l];
  79. dst_t * y = yy + i*QK_K + 128*n;
  80. half dall = __low2half(x[i].dm);
  81. half dmin = __high2half(x[i].dm);
  82. y[l+ 0] = __hsub(__hmul(dall, __int2half_rn((x[i].scales[is+0] & 0xF) * ((q >> 0) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+0] >> 4)));
  83. y[l+32] = __hsub(__hmul(dall, __int2half_rn((x[i].scales[is+2] & 0xF) * ((q >> 2) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+2] >> 4)));
  84. y[l+64] = __hsub(__hmul(dall, __int2half_rn((x[i].scales[is+4] & 0xF) * ((q >> 4) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+4] >> 4)));
  85. y[l+96] = __hsub(__hmul(dall, __int2half_rn((x[i].scales[is+6] & 0xF) * ((q >> 6) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+6] >> 4)));
  86. }
  87. template<typename dst_t>
  88. static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  89. const int i = blockIdx.x;
  90. const block_q3_K * x = (const block_q3_K *) vx;
  91. const int r = threadIdx.x/4;
  92. const int tid = r/2;
  93. const int is0 = r%2;
  94. const int l0 = 16*is0 + 4*(threadIdx.x%4);
  95. const int n = tid / 4;
  96. const int j = tid - 4*n;
  97. uint8_t m = 1 << (4*n + j);
  98. int is = 8*n + 2*j + is0;
  99. int shift = 2*j;
  100. int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
  101. is < 8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) :
  102. is < 12 ? (x[i].scales[is-8] >> 4) | (((x[i].scales[is+0] >> 4) & 3) << 4) :
  103. (x[i].scales[is-8] >> 4) | (((x[i].scales[is-4] >> 6) & 3) << 4);
  104. half d_all = x[i].d;
  105. half dl = __hmul(d_all, __int2half_rn(us - 32));
  106. dst_t * y = yy + i*QK_K + 128*n + 32*j;
  107. const uint8_t * q = x[i].qs + 32*n;
  108. const uint8_t * hm = x[i].hmask;
  109. for (int l = l0; l < l0+4; ++l) y[l] = __hmul(dl, __int2half_rn((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)));
  110. }
  111. static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
  112. if (j < 4) {
  113. d = q[j] & 63; m = q[j + 4] & 63;
  114. } else {
  115. d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
  116. m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
  117. }
  118. }
  119. template<typename dst_t>
  120. static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  121. const block_q4_K * x = (const block_q4_K *) vx;
  122. const int i = blockIdx.x;
  123. // assume 32 threads
  124. const int tid = threadIdx.x;
  125. const int il = tid/8;
  126. const int ir = tid%8;
  127. const int is = 2*il;
  128. const int n = 4;
  129. dst_t * y = yy + i*QK_K + 64*il + n*ir;
  130. const half dall = __low2half(x[i].dm);
  131. const half dmin = __high2half(x[i].dm);
  132. const uint8_t * q = x[i].qs + 32*il + n*ir;
  133. uint8_t sc, m;
  134. get_scale_min_k4(is + 0, x[i].scales, sc, m);
  135. const half d1 = __hmul(dall, __int2half_rn(sc));
  136. const half m1 = __hmul(dmin, __int2half_rn(m));
  137. get_scale_min_k4(is + 1, x[i].scales, sc, m);
  138. const half d2 = __hmul(dall, __int2half_rn(sc));
  139. const half m2 = __hmul(dmin, __int2half_rn(m));
  140. for (int l = 0; l < n; ++l) {
  141. y[l + 0] = __hsub(__hmul(d1, __int2half_rn(q[l] & 0xF)), m1);
  142. y[l +32] = __hsub(__hmul(d2, __int2half_rn(q[l] >> 4)), m2);
  143. }
  144. }
  145. template<typename dst_t>
  146. static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  147. const block_q5_K * x = (const block_q5_K *) vx;
  148. const int i = blockIdx.x;
  149. // assume 64 threads - this is very slightly better than the one below
  150. const int tid = threadIdx.x;
  151. const int il = tid/16; // il is in 0...3
  152. const int ir = tid%16; // ir is in 0...15
  153. const int is = 2*il; // is is in 0...6
  154. dst_t * y = yy + i*QK_K + 64*il + 2*ir;
  155. const half dall = __low2half(x[i].dm);
  156. const half dmin = __high2half(x[i].dm);
  157. const uint8_t * ql = x[i].qs + 32*il + 2*ir;
  158. const uint8_t * qh = x[i].qh + 2*ir;
  159. uint8_t sc, m;
  160. get_scale_min_k4(is + 0, x[i].scales, sc, m);
  161. const half d1 = __hmul(dall, __int2half_rn(sc)); const half m1 = __hmul(dmin, __int2half_rn(m));
  162. get_scale_min_k4(is + 1, x[i].scales, sc, m);
  163. const half d2 = __hmul(dall, __int2half_rn(sc)); const half m2 = __hmul(dmin, __int2half_rn(m));
  164. uint8_t hm = 1 << (2*il);
  165. y[ 0] = __hsub(__hmul(d1, __int2half_rn((ql[0] & 0xF) + (qh[0] & hm ? 16 : 0))), m1);
  166. y[ 1] = __hsub(__hmul(d1, __int2half_rn((ql[1] & 0xF) + (qh[1] & hm ? 16 : 0))), m1);
  167. hm <<= 1;
  168. y[32] = __hsub(__hmul(d2, __int2half_rn((ql[0] >> 4) + (qh[0] & hm ? 16 : 0))), m2);
  169. y[33] = __hsub(__hmul(d2, __int2half_rn((ql[1] >> 4) + (qh[1] & hm ? 16 : 0))), m2);
  170. }
  171. template<typename dst_t>
  172. static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  173. const block_q6_K * x = (const block_q6_K *) vx;
  174. const int i = blockIdx.x;
  175. // assume 64 threads - this is very slightly better than the one below
  176. const int tid = threadIdx.x;
  177. const int ip = tid/32; // ip is 0 or 1
  178. const int il = tid - 32*ip; // 0...32
  179. const int is = 8*ip + il/16;
  180. dst_t * y = yy + i*QK_K + 128*ip + il;
  181. const half d = x[i].d;
  182. const uint8_t * ql = x[i].ql + 64*ip + il;
  183. const uint8_t qh = x[i].qh[32*ip + il];
  184. const int8_t * sc = x[i].scales + is;
  185. y[ 0] = __hmul(d, __int2half_rn(sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32)));
  186. y[32] = __hmul(d, __int2half_rn(sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32)));
  187. y[64] = __hmul(d, __int2half_rn(sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32)));
  188. y[96] = __hmul(d, __int2half_rn(sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32)));
  189. }
  190. template<typename dst_t>
  191. static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  192. const int i = blockIdx.x;
  193. const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
  194. const int tid = threadIdx.x;
  195. const int il = tid/8; // 0...3
  196. const int ib = tid%8; // 0...7
  197. dst_t * y = yy + i*QK_K + 32*ib + 8*il;
  198. const uint16_t * q2 = x[i].qs + 4*ib;
  199. const uint8_t * aux8 = (const uint8_t *)q2;
  200. const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[il]);
  201. const uint32_t aux32 = q2[2] | (q2[3] << 16);
  202. const float d = __half2float(x[i].d) * (0.5f + (aux32 >> 28)) * 0.25f;
  203. const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
  204. for (int j = 0; j < 8; ++j) y[j] = __float2half(d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f));
  205. }
  206. template<typename dst_t>
  207. static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  208. const int i = blockIdx.x;
  209. const block_iq2_xs * x = (const block_iq2_xs *) vx;
  210. const int tid = threadIdx.x;
  211. const int il = tid/8; // 0...3
  212. const int ib = tid%8; // 0...7
  213. dst_t * y = yy + i*QK_K + 32*ib + 8*il;
  214. const uint16_t * q2 = x[i].qs + 4*ib;
  215. const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
  216. const float d = __half2float(x[i].d) * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
  217. const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
  218. for (int j = 0; j < 8; ++j) y[j] = __float2half(d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f));
  219. }
  220. template<typename dst_t>
  221. static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  222. const int i = blockIdx.x;
  223. const block_iq2_s * x = (const block_iq2_s *) vx;
  224. const int tid = threadIdx.x;
  225. const int il = tid/8; // 0...3
  226. const int ib = tid%8; // 0...7
  227. dst_t * y = yy + i*QK_K + 32*ib + 8*il;
  228. const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
  229. const float d = __half2float(x[i].d) * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
  230. const uint8_t signs = x[i].qs[QK_K/8+4*ib+il];
  231. for (int j = 0; j < 8; ++j) y[j] = __float2half(d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f));
  232. }
  233. template<typename dst_t>
  234. static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  235. const int i = blockIdx.x;
  236. const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
  237. const int tid = threadIdx.x;
  238. const int il = tid/8; // 0...3
  239. const int ib = tid%8; // 0...7
  240. dst_t * y = yy + i*QK_K + 32*ib + 8*il;
  241. const uint8_t * q3 = x[i].qs + 8*ib;
  242. const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;
  243. const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*il+0]);
  244. const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*il+1]);
  245. const uint32_t aux32 = gas[0] | (gas[1] << 16);
  246. const float d = __half2float(x[i].d) * (0.5f + (aux32 >> 28)) * 0.5f;
  247. const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
  248. for (int j = 0; j < 4; ++j) {
  249. y[j+0] = __float2half(d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f));
  250. y[j+4] = __float2half(d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f));
  251. }
  252. }
  253. template<typename dst_t>
  254. static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  255. const int i = blockIdx.x;
  256. const block_iq3_s * x = (const block_iq3_s *) vx;
  257. const int tid = threadIdx.x;
  258. const int il = tid/8; // 0...3
  259. const int ib = tid%8; // 0...7
  260. dst_t * y = yy + i*QK_K + 32*ib + 8*il;
  261. const uint8_t * qs = x[i].qs + 8*ib;
  262. const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
  263. const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)));
  264. const float d = __half2float(x[i].d) * (0.5f + ((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf)) * 0.5f;
  265. const uint8_t signs = x[i].signs[4*ib + il];
  266. for (int j = 0; j < 4; ++j) {
  267. y[j+0] = __float2half(d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f));
  268. y[j+4] = __float2half(d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f));
  269. }
  270. }
  271. template<typename dst_t>
  272. static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  273. const int i = blockIdx.x;
  274. const block_iq1_s * x = (const block_iq1_s *) vx;
  275. const int tid = threadIdx.x;
  276. const int il = tid/8; // 0...3
  277. const int ib = tid%8; // 0...7
  278. dst_t * y = yy + i*QK_K + 32*ib + 8*il;
  279. const int i8 = 4*ib+il;
  280. uint8_t h = x[i].scales[i8/2] >> 4*(i8%2);
  281. const int8_t * grid = (const int8_t *)(iq1s_grid + (x[i].qs[i8] | ((h & 8) << 5)));
  282. const float d = __half2float(x[i].d) * (2*(h & 7) + 1);
  283. for (int j = 0; j < 8; ++j) y[j] = __float2half(d * grid[j]);
  284. }
  285. template<typename dst_t>
  286. static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  287. const int i = blockIdx.x;
  288. const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
  289. const int tid = threadIdx.x;
  290. const int il = tid/8; // 0...3
  291. const int ib = tid%8; // 0...7
  292. dst_t * y = yy + i*QK_K + 32*ib + 4*il;
  293. const uint8_t * q4 = x[ib].qs + 4*il;
  294. const float d = __half2float(x[ib].d);
  295. for (int j = 0; j < 4; ++j) {
  296. y[j+ 0] = __float2half(d * kvalues_iq4nl[q4[j] & 0xf]);
  297. y[j+16] = __float2half(d * kvalues_iq4nl[q4[j] >> 4]);
  298. }
  299. }
  300. template<typename dst_t>
  301. static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  302. const int i = blockIdx.x;
  303. const block_iq4_xs * x = (const block_iq4_xs *)vx;
  304. const int tid = threadIdx.x;
  305. const int il = tid/8; // 0...3
  306. const int ib = tid%8; // 0...7
  307. dst_t * y = yy + i*QK_K + 32*ib + 4*il;
  308. const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
  309. const float d = __half2float(x[i].d) * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
  310. for (int j = 0; j < 4; ++j) {
  311. y[j+ 0] = __float2half(d * kvalues_iq4nl[q4[j] & 0xf]);
  312. y[j+16] = __float2half(d * kvalues_iq4nl[q4[j] >> 4]);
  313. }
  314. }
  315. template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
  316. static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
  317. const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
  318. dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
  319. }
  320. template<typename dst_t>
  321. static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
  322. const int nb = k / QK_K;
  323. dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
  324. }
  325. template<typename dst_t>
  326. static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
  327. const int nb = k / QK_K;
  328. dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
  329. }
  330. template<typename dst_t>
  331. static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
  332. const int nb = k / QK_K;
  333. dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y);
  334. }
  335. template<typename dst_t>
  336. static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
  337. const int nb = k / QK_K;
  338. dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
  339. }
  340. template<typename dst_t>
  341. static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
  342. const int nb = k / QK_K;
  343. dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
  344. }
  345. template<typename dst_t>
  346. static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
  347. const int nb = k / QK_K;
  348. dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
  349. }
  350. template<typename dst_t>
  351. static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
  352. const int nb = k / QK_K;
  353. dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);
  354. }
  355. template<typename dst_t>
  356. static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
  357. const int nb = k / QK_K;
  358. dequantize_block_iq2_s<<<nb, 32, 0, stream>>>(vx, y);
  359. }
  360. template<typename dst_t>
  361. static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
  362. const int nb = k / QK_K;
  363. dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
  364. }
  365. template<typename dst_t>
  366. static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
  367. const int nb = k / QK_K;
  368. dequantize_block_iq3_s<<<nb, 32, 0, stream>>>(vx, y);
  369. }
  370. template<typename dst_t>
  371. static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
  372. const int nb = k / QK_K;
  373. dequantize_block_iq1_s<<<nb, 32, 0, stream>>>(vx, y);
  374. }
  375. template<typename dst_t>
  376. static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
  377. const int nb = (k + QK_K - 1) / QK_K;
  378. dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
  379. }
  380. template<typename dst_t>
  381. static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
  382. const int nb = (k + QK_K - 1) / QK_K;
  383. dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
  384. }
  385. static to_fp16_cuda_t ggml_get_to_fp16_cuda(int64_t type) {
  386. switch (type) {
  387. case 2:
  388. return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
  389. case 3:
  390. return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
  391. case 6:
  392. return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
  393. case 7:
  394. return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
  395. case 8:
  396. return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
  397. case 10:
  398. return dequantize_row_q2_K_cuda;
  399. case 11:
  400. return dequantize_row_q3_K_cuda;
  401. case 12:
  402. return dequantize_row_q4_K_cuda;
  403. case 13:
  404. return dequantize_row_q5_K_cuda;
  405. case 14:
  406. return dequantize_row_q6_K_cuda;
  407. case 16:
  408. return dequantize_row_iq2_xxs_cuda;
  409. case 17:
  410. return dequantize_row_iq2_xs_cuda;
  411. case 18:
  412. return dequantize_row_iq3_xxs_cuda;
  413. case 19:
  414. return dequantize_row_iq1_s_cuda;
  415. case 20:
  416. return dequantize_row_iq4_nl_cuda;
  417. case 21:
  418. return dequantize_row_iq3_s_cuda;
  419. case 22:
  420. return dequantize_row_iq2_s_cuda;
  421. case 23:
  422. return dequantize_row_iq4_xs_cuda;
  423. default:
  424. return nullptr;
  425. }
  426. }