utils_parallel_dequant.cuh 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. // Copyright 2024 FP6-LLM authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // This file is modified from
  16. // https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/utils_parallel_dequant.cuh
  17. // To support MSVC, all instances of u_int32_t are changed to uint32_t.
  18. #ifndef UTILS_PARALLELDEQUANT_CUH
  19. #define UTILS_PARALLELDEQUANT_CUH
  20. #include <cuda.h>
  21. #include <cuda_fp16.h>
  22. #include <cuda_runtime.h>
  23. /*
  24. * Input: R1
  25. * Outputs: R1, R2
  26. * Note: Simplified Exponent calculation is applied.
  27. */
  28. template <int EXPONENT, int MANTISSA>
  29. __device__ __forceinline__ void FPx_FP16_Cast_4Way(uint32_t* In, uint32_t* Out1,
  30. uint32_t* Out2) {
  31. //
  32. constexpr int RIGHT_SHIFT = 5 - EXPONENT;
  33. constexpr int MASK1 = 0x80000000;
  34. constexpr int MASK2 = MASK1 >> EXPONENT + MANTISSA;
  35. constexpr int MASK3 = MASK2 & 0x7fffffff;
  36. constexpr int MASK = MASK3 | MASK3 >> 16;
  37. //
  38. *Out1 = *In & 0x80008000;
  39. *Out1 |= ((*In) & MASK) >> RIGHT_SHIFT;
  40. //
  41. *In = (*In) << 8;
  42. *Out2 = *In & 0x80008000;
  43. *Out2 |= ((*In) & MASK) >> RIGHT_SHIFT;
  44. }
  45. template <int EXPONENT, int MANTISSA>
  46. __device__ __forceinline__ uint32_t MultScale(uint32_t PackedFP16Pair,
  47. half Scale) {
  48. constexpr int BIAS_OFFSET = (int(1) << (5 - 1)) - (int(1) << (EXPONENT - 1));
  49. constexpr int BIAS = int(1) << BIAS_OFFSET;
  50. //
  51. half* FP16_1 = reinterpret_cast<half*>(&PackedFP16Pair);
  52. half* FP16_2 = FP16_1 + 1;
  53. uint32_t output;
  54. half* output_half_ptr = reinterpret_cast<half*>(&output);
  55. output_half_ptr[0] =
  56. __hmul(__hmul(*FP16_1, __float2half(1.0f * BIAS)), Scale);
  57. output_half_ptr[1] =
  58. __hmul(__hmul(*FP16_2, __float2half(1.0f * BIAS)), Scale);
  59. return output;
  60. }
  61. // MODIFICATION NOTE: to support MSVC
  62. // - u_int32_t __restrict__ Reg[][4] is changed to below.
  63. // - u_int32_t __restrict__ *read_RPTR_1bit is changed to below. similarly for
  64. // read_RPTR_2bit and read_RPTR_4bit
  65. template <int EXPONENT, int MANTISSA>
  66. __device__ __forceinline__ void Dequant_32FP6_4Way(
  67. uint32_t (*__restrict__ Reg)[4], uint32_t* __restrict__ read_RPTR_1bit,
  68. uint32_t* __restrict__ read_RPTR_2bit,
  69. uint32_t* __restrict__ read_RPTR_4bit, uint32_t* Scales) {
  70. // 1+2+4 weight split
  71. constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA;
  72. constexpr int USE_SEG_1BIT = BIT_WIDTH & 1;
  73. constexpr int USE_SEG_2BIT = BIT_WIDTH & 2;
  74. constexpr int USE_SEG_4BIT = BIT_WIDTH & 4;
  75. //
  76. uint32_t* OutputRegs = reinterpret_cast<uint32_t*>(Reg);
  77. uint32_t* Frag_PTR_1bit = read_RPTR_1bit;
  78. uint32_t* Frag_PTR_2bit = read_RPTR_2bit;
  79. uint32_t* Frag_PTR_4bit = read_RPTR_4bit;
  80. half* Scale_RPTR = reinterpret_cast<half*>(Scales);
  81. // Dequantizing 32 FP6, each Loop dequantizing 4 FP6
  82. #pragma unroll(8)
  83. for (int i = 0; i < 8; i++) {
  84. uint32_t Packed_FP6 = 0;
  85. uint32_t tmp = 0;
  86. // 1bit Frag
  87. if (USE_SEG_1BIT) {
  88. tmp = (*Frag_PTR_1bit) & 0x80808080;
  89. Packed_FP6 |= tmp >> (BIT_WIDTH & 0);
  90. if (i % 8 == 7)
  91. Frag_PTR_1bit++;
  92. else
  93. (*Frag_PTR_1bit) = (*Frag_PTR_1bit) << 1;
  94. }
  95. // 2bit Frag
  96. if (USE_SEG_2BIT) {
  97. tmp = (*Frag_PTR_2bit) & 0xc0c0c0c0;
  98. Packed_FP6 |= tmp >> (BIT_WIDTH & 1);
  99. if (i % 4 == 3)
  100. Frag_PTR_2bit++;
  101. else
  102. (*Frag_PTR_2bit) = (*Frag_PTR_2bit) << 2;
  103. }
  104. // 4bit Frag2
  105. if (USE_SEG_4BIT) {
  106. tmp = (*Frag_PTR_4bit) & 0xf0f0f0f0;
  107. Packed_FP6 |= tmp >> (BIT_WIDTH & 3);
  108. if (i % 2 == 1)
  109. Frag_PTR_4bit++;
  110. else
  111. (*Frag_PTR_4bit) = (*Frag_PTR_4bit) << 4;
  112. }
  113. //
  114. uint32_t out1, out2;
  115. FPx_FP16_Cast_4Way<EXPONENT, MANTISSA>(&Packed_FP6, &out1, &out2);
  116. //
  117. *OutputRegs = MultScale<EXPONENT, MANTISSA>(
  118. out1, Scale_RPTR[0]); // Multiply FP16 scales
  119. OutputRegs += 1;
  120. *OutputRegs = MultScale<EXPONENT, MANTISSA>(
  121. out2, Scale_RPTR[1]); // Multiply FP16 scales
  122. OutputRegs += 1;
  123. // Updating offset for FP16 scales for every two iterations
  124. if (i % 2 == 1) Scale_RPTR += 2;
  125. }
  126. }
  127. /*
  128. *
  129. */
  130. __device__ __forceinline__ void ExtractFromSharedToReg_Scales(
  131. uint32_t* Scales, half* WARP_SPTR_Scales) {
  132. int lane_id = threadIdx.x % WARP_SIZE;
  133. uint32_t* SPTR_uint = reinterpret_cast<uint32_t*>(WARP_SPTR_Scales);
  134. uint32_t tmpReg = SPTR_uint[lane_id];
  135. #pragma unroll
  136. for (int i = 0; i < 4; i++) {
  137. // T __shfl_sync(unsigned mask, T var, int srcLane, int width=warpSize);
  138. Scales[i] = __shfl_sync(0xffffffff, tmpReg, i, 4);
  139. }
  140. }
  141. #endif