// Copyright 2024 FP6-LLM authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // This file is modified from // https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/utils_parallel_dequant.cuh // To support MSVC, all instances of u_int32_t are changed to uint32_t. #ifndef UTILS_PARALLELDEQUANT_CUH #define UTILS_PARALLELDEQUANT_CUH #include #include #include /* * Input: R1 * Outputs: R1, R2 * Note: Simplified Exponent calculation is applied. */ template __device__ __forceinline__ void FPx_FP16_Cast_4Way(uint32_t* In, uint32_t* Out1, uint32_t* Out2) { // constexpr int RIGHT_SHIFT = 5 - EXPONENT; constexpr int MASK1 = 0x80000000; constexpr int MASK2 = MASK1 >> EXPONENT + MANTISSA; constexpr int MASK3 = MASK2 & 0x7fffffff; constexpr int MASK = MASK3 | MASK3 >> 16; // *Out1 = *In & 0x80008000; *Out1 |= ((*In) & MASK) >> RIGHT_SHIFT; // *In = (*In) << 8; *Out2 = *In & 0x80008000; *Out2 |= ((*In) & MASK) >> RIGHT_SHIFT; } template __device__ __forceinline__ uint32_t MultScale(uint32_t PackedFP16Pair, half Scale) { constexpr int BIAS_OFFSET = (int(1) << (5 - 1)) - (int(1) << (EXPONENT - 1)); constexpr int BIAS = int(1) << BIAS_OFFSET; // half* FP16_1 = reinterpret_cast(&PackedFP16Pair); half* FP16_2 = FP16_1 + 1; uint32_t output; half* output_half_ptr = reinterpret_cast(&output); output_half_ptr[0] = __hmul(__hmul(*FP16_1, __float2half(1.0f * BIAS)), Scale); output_half_ptr[1] = __hmul(__hmul(*FP16_2, __float2half(1.0f * BIAS)), Scale); return output; } // MODIFICATION NOTE: to support MSVC // - u_int32_t __restrict__ Reg[][4] is changed to below. // - u_int32_t __restrict__ *read_RPTR_1bit is changed to below. similarly for // read_RPTR_2bit and read_RPTR_4bit template __device__ __forceinline__ void Dequant_32FP6_4Way( uint32_t (*__restrict__ Reg)[4], uint32_t* __restrict__ read_RPTR_1bit, uint32_t* __restrict__ read_RPTR_2bit, uint32_t* __restrict__ read_RPTR_4bit, uint32_t* Scales) { // 1+2+4 weight split constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA; constexpr int USE_SEG_1BIT = BIT_WIDTH & 1; constexpr int USE_SEG_2BIT = BIT_WIDTH & 2; constexpr int USE_SEG_4BIT = BIT_WIDTH & 4; // uint32_t* OutputRegs = reinterpret_cast(Reg); uint32_t* Frag_PTR_1bit = read_RPTR_1bit; uint32_t* Frag_PTR_2bit = read_RPTR_2bit; uint32_t* Frag_PTR_4bit = read_RPTR_4bit; half* Scale_RPTR = reinterpret_cast(Scales); // Dequantizing 32 FP6, each Loop dequantizing 4 FP6 #pragma unroll(8) for (int i = 0; i < 8; i++) { uint32_t Packed_FP6 = 0; uint32_t tmp = 0; // 1bit Frag if (USE_SEG_1BIT) { tmp = (*Frag_PTR_1bit) & 0x80808080; Packed_FP6 |= tmp >> (BIT_WIDTH & 0); if (i % 8 == 7) Frag_PTR_1bit++; else (*Frag_PTR_1bit) = (*Frag_PTR_1bit) << 1; } // 2bit Frag if (USE_SEG_2BIT) { tmp = (*Frag_PTR_2bit) & 0xc0c0c0c0; Packed_FP6 |= tmp >> (BIT_WIDTH & 1); if (i % 4 == 3) Frag_PTR_2bit++; else (*Frag_PTR_2bit) = (*Frag_PTR_2bit) << 2; } // 4bit Frag2 if (USE_SEG_4BIT) { tmp = (*Frag_PTR_4bit) & 0xf0f0f0f0; Packed_FP6 |= tmp >> (BIT_WIDTH & 3); if (i % 2 == 1) Frag_PTR_4bit++; else (*Frag_PTR_4bit) = (*Frag_PTR_4bit) << 4; } // uint32_t out1, out2; FPx_FP16_Cast_4Way(&Packed_FP6, &out1, &out2); // *OutputRegs = MultScale( out1, Scale_RPTR[0]); // Multiply FP16 scales OutputRegs += 1; *OutputRegs = MultScale( out2, Scale_RPTR[1]); // Multiply FP16 scales OutputRegs += 1; // Updating offset for FP16 scales for every two iterations if (i % 2 == 1) Scale_RPTR += 2; } } /* * */ __device__ __forceinline__ void ExtractFromSharedToReg_Scales( uint32_t* Scales, half* WARP_SPTR_Scales) { int lane_id = threadIdx.x % WARP_SIZE; uint32_t* SPTR_uint = reinterpret_cast(WARP_SPTR_Scales); uint32_t tmpReg = SPTR_uint[lane_id]; #pragma unroll for (int i = 0; i < 4; i++) { // T __shfl_sync(unsigned mask, T var, int srcLane, int width=warpSize); Scales[i] = __shfl_sync(0xffffffff, tmpReg, i, 4); } } #endif