qdq_3.cuh 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. /*
  2. * Adapted from https://github.com/turboderp/exllamav2
  3. * Copyright (c) 2024 turboderp
  4. *
  5. * Permission is hereby granted, free of charge, to any person obtaining a copy
  6. * of this software and associated documentation files (the "Software"), to deal
  7. * in the Software without restriction, including without limitation the rights
  8. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  9. * copies of the Software, and to permit persons to whom the Software is
  10. * furnished to do so, subject to the following conditions:
  11. *
  12. * The above copyright notice and this permission notice shall be included in all
  13. * copies or substantial portions of the Software.
  14. *
  15. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  16. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  17. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  18. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  19. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  20. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  21. * SOFTWARE.
  22. */
  23. #ifndef _qdq_3_cuh
  24. #define _qdq_3_cuh
  25. #include "qdq_util.cuh"
  26. namespace aphrodite {
  27. namespace exl2 {
  28. // Permutation:
  29. //
  30. // v9997775 55333111 u8886664 44222000 (u, v lsb)
  31. // vjjjhhhf ffdddbbb uiiiggge eecccaaa
  32. // vtttrrrp ppnnnlll usssqqqo oommmkkk
  33. __forceinline__ __device__ void shuffle_3bit_32
  34. (
  35. uint32_t* q,
  36. int stride
  37. )
  38. {
  39. uint32_t qa = q[0 * stride];
  40. uint32_t qb = q[1 * stride];
  41. uint32_t qc = q[2 * stride];
  42. // qa: aa999888 77766655 54443332 22111000
  43. // qb: lkkkjjji iihhhggg fffeeedd dcccbbba
  44. // qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll
  45. uint32_t qd = qc >> 26;
  46. qc <<= 4;
  47. qc |= qb >> 28;
  48. qb <<= 2;
  49. qb |= qa >> 30;
  50. // qa: ..999888 77766655 54443332 22111000
  51. // qb: ..jjjiii hhhgggff feeedddc ccbbbaaa
  52. // qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk
  53. // qd: vvvuuu
  54. uint32_t za = 0;
  55. uint32_t zb = 0;
  56. uint32_t zc = 0;
  57. for (int i = 0; i < 5; i++) { uint32_t t0 = qa & 0x07; uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; za |= (t0 << (i * 3)); za |= (t1 << (i * 3 + 16)); }
  58. for (int i = 0; i < 5; i++) { uint32_t t0 = qb & 0x07; uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; zb |= (t0 << (i * 3)); zb |= (t1 << (i * 3 + 16)); }
  59. for (int i = 0; i < 5; i++) { uint32_t t0 = qc & 0x07; uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; zc |= (t0 << (i * 3)); zc |= (t1 << (i * 3 + 16)); }
  60. // za: 9997775 55333111 8886664 44222000
  61. // zb: jjjhhhf ffdddbbb iiiggge eecccaaa
  62. // zc: tttrrrp ppnnnlll sssqqqo oommmkkk
  63. // qd: vvvuuu
  64. za |= ((qd & 0x01) >> 0) << 15;
  65. zb |= ((qd & 0x02) >> 1) << 15;
  66. zc |= ((qd & 0x04) >> 2) << 15;
  67. za |= ((qd & 0x08) >> 3) << 31;
  68. zb |= ((qd & 0x10) >> 4) << 31;
  69. zc |= ((qd & 0x20) >> 5) << 31;
  70. // za: v9997775 55333111 u8886664 44222000 (u, v lsb)
  71. // zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa
  72. // zc: vtttrrrp ppnnnlll usssqqqo oommmkkk
  73. q[0 * stride] = za;
  74. q[1 * stride] = zb;
  75. q[2 * stride] = zc;
  76. }
  77. __forceinline__ __device__ void dequant_3bit_32
  78. (
  79. const uint32_t q_0,
  80. const uint32_t q_1,
  81. const uint32_t q_2,
  82. half2 (&dq)[16],
  83. int stride
  84. )
  85. {
  86. const uint32_t c0 = 0x64006400;
  87. const half y8_ = __float2half_rn(1.0f / 8.0f);
  88. const half y64_ = __float2half_rn(1.0f / 64.0f);
  89. const half2 y8 = __halves2half2(y8_, y8_);
  90. const half2 y64 = __halves2half2(y64_, y64_);
  91. const half z1_ = __float2half_rn(-1024.0f - 4.0f);
  92. const half z8_ = __float2half_rn(-1024.0f / 8.0f - 4.0f);
  93. const half z64_ = __float2half_rn(-1024.0f / 64.0f - 4.0f);
  94. const half2 z1 = __halves2half2(z1_, z1_);
  95. const half2 z8 = __halves2half2(z8_, z8_);
  96. const half2 z64 = __halves2half2(z64_, z64_);
  97. uint32_t qa = q_0;
  98. uint32_t qb = q_1;
  99. uint32_t qc = q_2;
  100. half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024
  101. half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024
  102. qa >>= 6;
  103. half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024
  104. half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024
  105. half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024
  106. qa >>= 9;
  107. qa &= 0x00010001;
  108. half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024
  109. half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024
  110. qb >>= 6;
  111. half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024
  112. half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024
  113. half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024
  114. qb >>= 8;
  115. qb &= 0x00020002;
  116. half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024
  117. half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024
  118. qc >>= 6;
  119. half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024
  120. half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024
  121. half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024
  122. qc >>= 7;
  123. qc &= 0x00040004;
  124. half2_uint32 q15((qa | qb | qc) | c0);
  125. dq[ 0] = __hadd2( q0.as_half2, z1);
  126. dq[ 1] = __hfma2( q1.as_half2, y8, z8);
  127. dq[ 2] = __hadd2( q2.as_half2, z1);
  128. dq[ 3] = __hfma2( q3.as_half2, y8, z8);
  129. dq[ 4] = __hfma2( q4.as_half2, y64, z64);
  130. dq[ 5] = __hadd2( q5.as_half2, z1);
  131. dq[ 6] = __hfma2( q6.as_half2, y8, z8);
  132. dq[ 7] = __hadd2( q7.as_half2, z1);
  133. dq[ 8] = __hfma2( q8.as_half2, y8, z8);
  134. dq[ 9] = __hfma2( q9.as_half2, y64, z64);
  135. dq[10] = __hadd2(q10.as_half2, z1);
  136. dq[11] = __hfma2(q11.as_half2, y8, z8);
  137. dq[12] = __hadd2(q12.as_half2, z1);
  138. dq[13] = __hfma2(q13.as_half2, y8, z8);
  139. dq[14] = __hfma2(q14.as_half2, y64, z64);
  140. dq[15] = __hadd2(q15.as_half2, z1);
  141. }
  142. } // namespace exl2
  143. } // namespace aphrodite
  144. #endif