qdq_5.cuh 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. /*
  2. * Adapted from https://github.com/turboderp/exllamav2
  3. * Copyright (c) 2024 turboderp
  4. *
  5. * Permission is hereby granted, free of charge, to any person obtaining a copy
  6. * of this software and associated documentation files (the "Software"), to deal
  7. * in the Software without restriction, including without limitation the rights
  8. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  9. * copies of the Software, and to permit persons to whom the Software is
  10. * furnished to do so, subject to the following conditions:
  11. *
  12. * The above copyright notice and this permission notice shall be included in all
  13. * copies or substantial portions of the Software.
  14. *
  15. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  16. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  17. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  18. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  19. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  20. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  21. * SOFTWARE.
  22. */
  23. #ifndef _qdq_5_cuh
  24. #define _qdq_5_cuh
  25. #include "qdq_util.cuh"
  26. namespace aphrodite {
  27. namespace exl2 {
  28. // Permutation:
  29. //
  30. // v5555533 33311111 u4444422 22200000 (u, v lsb)
  31. // vbbbbb99 99977777 uaaaaa88 88866666
  32. // vhhhhhff fffddddd ugggggee eeeccccc
  33. // vnnnnnll llljjjjj ummmmmkk kkkiiiii
  34. // vtttttrr rrrppppp usssssqq qqqooooo
  35. __forceinline__ __device__ void shuffle_5bit_32
  36. (
  37. uint32_t* q,
  38. int stride
  39. )
  40. {
  41. uint32_t qa = q[0 * stride];
  42. uint32_t qb = q[1 * stride];
  43. uint32_t qc = q[2 * stride];
  44. uint32_t qd = q[3 * stride];
  45. uint32_t qe = q[4 * stride];
  46. // qa: 66555554 44443333 32222211 11100000
  47. // qb: ccccbbbb baaaaa99 99988888 77777666
  48. // qc: jiiiiihh hhhggggg fffffeee eedddddc
  49. // qd: pppooooo nnnnnmmm mmlllllk kkkkjjjj
  50. // qe: vvvvvuuu uuttttts ssssrrrr rqqqqqpp
  51. uint32_t qf = qe >> 22;
  52. qe <<= 8;
  53. qe |= qd >> 24;
  54. qd <<= 6;
  55. qd |= qc >> 26;
  56. qc <<= 4;
  57. qc |= qb >> 28;
  58. qb <<= 2;
  59. qb |= qa >> 30;
  60. // qa: 555554 44443333 32222211 11100000
  61. // qb: bbbbba aaaa9999 98888877 77766666
  62. // qc: hhhhhg ggggffff feeeeedd dddccccc
  63. // qd: nnnnnm mmmmllll lkkkkkjj jjjiiiii
  64. // qe: ttttts ssssrrrr rqqqqqpp pppooooo
  65. // qf: vv vvvuuuuu
  66. uint32_t za = 0;
  67. uint32_t zb = 0;
  68. uint32_t zc = 0;
  69. uint32_t zd = 0;
  70. uint32_t ze = 0;
  71. for (int i = 0; i < 3; i++) { uint32_t t0 = qa & 0x1f; uint32_t t1 = (qa & 0x3e0) >> 5; qa >>= 10; za |= (t0 << (i * 5)); za |= (t1 << (i * 5 + 16)); }
  72. for (int i = 0; i < 3; i++) { uint32_t t0 = qb & 0x1f; uint32_t t1 = (qb & 0x3e0) >> 5; qb >>= 10; zb |= (t0 << (i * 5)); zb |= (t1 << (i * 5 + 16)); }
  73. for (int i = 0; i < 3; i++) { uint32_t t0 = qc & 0x1f; uint32_t t1 = (qc & 0x3e0) >> 5; qc >>= 10; zc |= (t0 << (i * 5)); zc |= (t1 << (i * 5 + 16)); }
  74. for (int i = 0; i < 3; i++) { uint32_t t0 = qd & 0x1f; uint32_t t1 = (qd & 0x3e0) >> 5; qd >>= 10; zd |= (t0 << (i * 5)); zd |= (t1 << (i * 5 + 16)); }
  75. for (int i = 0; i < 3; i++) { uint32_t t0 = qe & 0x1f; uint32_t t1 = (qe & 0x3e0) >> 5; qe >>= 10; ze |= (t0 << (i * 5)); ze |= (t1 << (i * 5 + 16)); }
  76. // za: 5555533 33311111 4444422 22200000
  77. // zb: bbbbb99 99977777 aaaaa88 88866666
  78. // zc: hhhhhff fffddddd gggggee eeeccccc
  79. // zd: nnnnnll llljjjjj mmmmmkk kkkiiiii
  80. // ze: tttttrr rrrppppp sssssqq qqqooooo
  81. // qf: vv vvvuuuuu
  82. za |= ((qf & 0x001) >> 0) << 15;
  83. zb |= ((qf & 0x002) >> 1) << 15;
  84. zc |= ((qf & 0x004) >> 2) << 15;
  85. zd |= ((qf & 0x008) >> 3) << 15;
  86. ze |= ((qf & 0x010) >> 4) << 15;
  87. za |= ((qf & 0x020) >> 5) << 31;
  88. zb |= ((qf & 0x040) >> 6) << 31;
  89. zc |= ((qf & 0x080) >> 7) << 31;
  90. zd |= ((qf & 0x100) >> 8) << 31;
  91. ze |= ((qf & 0x200) >> 9) << 31;
  92. // za: v5555533 33311111 u4444422 22200000 (u, v lsb)
  93. // zb: vbbbbb99 99977777 uaaaaa88 88866666
  94. // zc: vhhhhhff fffddddd ugggggee eeeccccc
  95. // zd: vnnnnnll llljjjjj ummmmmkk kkkiiiii
  96. // ze: vtttttrr rrrppppp usssssqq qqqooooo
  97. q[0 * stride] = za;
  98. q[1 * stride] = zb;
  99. q[2 * stride] = zc;
  100. q[3 * stride] = zd;
  101. q[4 * stride] = ze;
  102. }
  103. __forceinline__ __device__ void dequant_5bit_32
  104. (
  105. const uint32_t q_0,
  106. const uint32_t q_1,
  107. const uint32_t q_2,
  108. const uint32_t q_3,
  109. const uint32_t q_4,
  110. half2 (&dq)[16],
  111. int stride
  112. )
  113. {
  114. const uint32_t c0 = 0x64006400;
  115. const half y32_ = __float2half_rn(1.0f / 32.0f);
  116. const half2 y32 = __halves2half2(y32_, y32_);
  117. const half z1_ = __float2half_rn(-1024.0f - 16.0f);
  118. const half z32_ = __float2half_rn(-1024.0f / 32.0f - 16.0f);
  119. const half2 z1 = __halves2half2(z1_, z1_);
  120. const half2 z32 = __halves2half2(z32_, z32_);
  121. uint32_t qa = q_0;
  122. uint32_t qb = q_1;
  123. uint32_t qc = q_2;
  124. uint32_t qd = q_3;
  125. uint32_t qe = q_4;
  126. half2_uint32 q0 ((qa & 0x001f001f) | c0); // half2(q[ 0], q[ 1]) + 1024
  127. half2_uint32 q1 ((qa & 0x03e003e0) | c0); // half2(q[ 2], q[ 3]) * 32 + 1024
  128. qa >>= 10;
  129. half2_uint32 q2 ((qa & 0x001f001f) | c0); // half2(q[ 4], q[ 5]) + 1024
  130. qa >>= 5;
  131. qa &= 0x00010001;
  132. half2_uint32 q3 ((qb & 0x001f001f) | c0); // half2(q[ 6], q[ 7]) + 1024
  133. half2_uint32 q4 ((qb & 0x03e003e0) | c0); // half2(q[ 8], q[ 9]) * 32 + 1024
  134. qb >>= 10;
  135. half2_uint32 q5 ((qb & 0x001f001f) | c0); // half2(q[10], q[11]) + 1024
  136. qb >>= 4;
  137. qb &= 0x00020002;
  138. half2_uint32 q6 ((qc & 0x001f001f) | c0); // half2(q[12], q[13]) + 1024
  139. half2_uint32 q7 ((qc & 0x03e003e0) | c0); // half2(q[14], q[15]) * 32 + 1024
  140. qc >>= 10;
  141. half2_uint32 q8 ((qc & 0x001f001f) | c0); // half2(q[16], q[17]) + 1024
  142. qc >>= 3;
  143. qc &= 0x00040004;
  144. half2_uint32 q9 ((qd & 0x001f001f) | c0); // half2(q[18], q[19]) + 1024
  145. half2_uint32 q10((qd & 0x03e003e0) | c0); // half2(q[20], q[21]) * 32 + 1024
  146. qd >>= 10;
  147. half2_uint32 q11((qd & 0x001f001f) | c0); // half2(q[22], q[23]) + 1024
  148. qd >>= 2;
  149. qd &= 0x00080008;
  150. half2_uint32 q12((qe & 0x001f001f) | c0); // half2(q[24], q[25]) + 1024
  151. half2_uint32 q13((qe & 0x03e003e0) | c0); // half2(q[26], q[27]) * 32 + 1024
  152. qe >>= 10;
  153. half2_uint32 q14((qe & 0x001f001f) | c0); // half2(q[28], q[29]) + 1024
  154. qe >>= 1;
  155. qe &= 0x00100010;
  156. half2_uint32 q15((qa | qb | qc | qd | qe) | c0);
  157. dq[ 0] = __hadd2( q0.as_half2, z1);
  158. dq[ 1] = __hfma2( q1.as_half2, y32, z32);
  159. dq[ 2] = __hadd2( q2.as_half2, z1);
  160. dq[ 3] = __hadd2( q3.as_half2, z1);
  161. dq[ 4] = __hfma2( q4.as_half2, y32, z32);
  162. dq[ 5] = __hadd2( q5.as_half2, z1);
  163. dq[ 6] = __hadd2( q6.as_half2, z1);
  164. dq[ 7] = __hfma2( q7.as_half2, y32, z32);
  165. dq[ 8] = __hadd2( q8.as_half2, z1);
  166. dq[ 9] = __hadd2( q9.as_half2, z1);
  167. dq[10] = __hfma2(q10.as_half2, y32, z32);
  168. dq[11] = __hadd2(q11.as_half2, z1);
  169. dq[12] = __hadd2(q12.as_half2, z1);
  170. dq[13] = __hfma2(q13.as_half2, y32, z32);
  171. dq[14] = __hadd2(q14.as_half2, z1);
  172. dq[15] = __hadd2(q15.as_half2, z1);
  173. }
  174. } // namespace exl2
  175. } // namespace aphrodite
  176. #endif