qdq_4.cuh 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. /*
  2. * Adapted from https://github.com/turboderp/exllamav2
  3. * Copyright (c) 2024 turboderp
  4. *
  5. * Permission is hereby granted, free of charge, to any person obtaining a copy
  6. * of this software and associated documentation files (the "Software"), to deal
  7. * in the Software without restriction, including without limitation the rights
  8. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  9. * copies of the Software, and to permit persons to whom the Software is
  10. * furnished to do so, subject to the following conditions:
  11. *
  12. * The above copyright notice and this permission notice shall be included in all
  13. * copies or substantial portions of the Software.
  14. *
  15. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  16. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  17. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  18. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  19. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  20. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  21. * SOFTWARE.
  22. */
  23. #ifndef _qdq_4_cuh
  24. #define _qdq_4_cuh
  25. #include "qdq_util.cuh"
  26. namespace aphrodite {
  27. namespace exl2 {
  28. // Permutation:
  29. //
  30. // 77775555 33331111 66664444 22220000
  31. __forceinline__ __device__ void shuffle_4bit_8
  32. (
  33. uint32_t* q,
  34. int stride
  35. )
  36. {
  37. uint32_t qa = q[0];
  38. uint32_t qb = 0;
  39. #pragma unroll
  40. for (int i = 0; i < 4; i++)
  41. {
  42. uint32_t qa0 = qa & 0x0f;
  43. uint32_t qa1 = (qa & 0xf0) >> 4;
  44. qa >>= 8;
  45. qb |= (qa1 << (i * 4 + 16));
  46. qb |= (qa0 << (i * 4));
  47. }
  48. q[0] = qb;
  49. }
  50. __forceinline__ __device__ void dequant_4bit_8
  51. (
  52. const uint32_t q_0,
  53. half2 (&dq)[4],
  54. int stride
  55. )
  56. {
  57. const uint32_t c0 = 0x64006400;
  58. const half y16_ = __float2half_rn(1.0f / 16.0f);
  59. const half2 y16 = __halves2half2(y16_, y16_);
  60. const half z1_ = __float2half_rn(-1024.0f - 8.0f);
  61. const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f);
  62. const half2 z1 = __halves2half2(z1_, z1_);
  63. const half2 z16 = __halves2half2(z16_, z16_);
  64. uint32_t qa = q_0;
  65. half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024
  66. half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024
  67. qa >>= 8;
  68. half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024
  69. half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024
  70. dq[0] = __hadd2(q0.as_half2, z1);
  71. dq[1] = __hfma2(q1.as_half2, y16, z16);
  72. dq[2] = __hadd2(q2.as_half2, z1);
  73. dq[3] = __hfma2(q3.as_half2, y16, z16);
  74. }
  75. __forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
  76. (
  77. const uint32_t zero,
  78. const half scale,
  79. half2 (&z1z16)[2],
  80. half2 (&y1y16)[2]
  81. )
  82. {
  83. half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
  84. half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
  85. half2 scale2 = __half2half2(scale);
  86. z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half));
  87. z1z16[1] = __hmul2(scale2, __half2half2(z16));
  88. const half y1 = __float2half_rn(1.0f);
  89. const half y16 = __float2half_rn(1.0f / 16.0f);
  90. y1y16[0] = __hmul2(scale2, __half2half2(y1));
  91. y1y16[1] = __hmul2(scale2, __half2half2(y16));
  92. }
  93. __forceinline__ __device__ void dequant_4bit_8_prep_zero
  94. (
  95. const uint32_t zero,
  96. half2(&z1z16)[2],
  97. half2(&y1y16)[2]
  98. )
  99. {
  100. half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
  101. half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
  102. z1z16[0] = __half2half2(z1.as_half);
  103. z1z16[1] = __half2half2(z16);
  104. const half y1 = __float2half_rn(1.0f);
  105. const half y16 = __float2half_rn(1.0f / 16.0f);
  106. y1y16[0] = __half2half2(y1);
  107. y1y16[1] = __half2half2(y16);
  108. }
  109. __forceinline__ __device__ void dequant_4bit_8_gptq
  110. (
  111. const uint32_t q_0,
  112. half2 (&dq)[4],
  113. half2 (&z1z16)[2],
  114. half2 (&y1y16)[2],
  115. int stride,
  116. bool scaled
  117. )
  118. {
  119. const uint32_t c0 = 0x64006400;
  120. uint32_t qa = q_0;
  121. half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 )
  122. half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
  123. qa >>= 8;
  124. half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 )
  125. half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
  126. if (scaled)
  127. {
  128. dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s)
  129. dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s)
  130. dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]);
  131. dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]);
  132. }
  133. else
  134. {
  135. dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z )
  136. dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] - z, q[3] - z )
  137. dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z )
  138. dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z )
  139. }
  140. }
  141. } // namespace exl2
  142. } // namespace aphrodite
  143. #endif