// Copyright 2024 FP6-LLM authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // This file is modified from // https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/utils_core.cuh #ifndef UTILS_CORE_CUH #define UTILS_CORE_CUH #include #include "configs.h" #include "ptx_mma.cuh" #include "utils_parallel_dequant.cuh" template __device__ __forceinline__ void CopyFromSharedToRegister_AFrag(uint32_t Reg[], uint32_t* SPTR, int slice_id) { SPTR += slice_id * (NUM_INT_PER_THREAD * WARP_SIZE); int lane_id = threadIdx.x % WARP_SIZE; #pragma unroll for (int i = 0; i < NUM_INT_PER_THREAD; i++) { Reg[i] = SPTR[lane_id + i * WARP_SIZE]; } } // MODIFICATION NOTE: to support MSVC, half __restrict__ // (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below. template __device__ __forceinline__ void initialize_mma_slice( uint32_t (*a)[4], uint32_t (*b)[4], uint32_t* __restrict__ A_1BIT_SPTR_read, uint32_t* __restrict__ A_2BIT_SPTR_read, uint32_t* __restrict__ A_4BIT_SPTR_read, half (*__restrict__ B_SPTR_read)[WARP_K + PADDING_SHARED_MEM_FOR_B_8], uint32_t* RPTR_Scales) { // 1+2+4 weight split constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA; constexpr int USE_SEG_1BIT = BIT_WIDTH & 1; constexpr int USE_SEG_2BIT = BIT_WIDTH & 2; constexpr int USE_SEG_4BIT = BIT_WIDTH & 4; // Writing registers // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6 // per thread => 6 register per thread; uint32_t a_1bit[1]; // NO double buffer uint32_t a_2bit[2]; // NO double buffer uint32_t a_4bit[4]; // NO double buffer if (USE_SEG_1BIT) CopyFromSharedToRegister_AFrag<1>(a_1bit, A_1BIT_SPTR_read, 0); if (USE_SEG_2BIT) CopyFromSharedToRegister_AFrag<2>(a_2bit, A_2BIT_SPTR_read, 0); if (USE_SEG_4BIT) CopyFromSharedToRegister_AFrag<4>(a_4bit, A_4BIT_SPTR_read, 0); Dequant_32FP6_4Way( a, a_1bit, a_2bit, a_4bit, RPTR_Scales); // SIMT Dequant: dequantizing FPx to FP16 at register // level, dequantizing a slice each time B_FromSharedToReg(b, B_SPTR_read, 0); // Loading B from shared to registers } // MODIFICATION NOTE: to support MSVC, half __restrict__ // (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below. template __device__ __forceinline__ void core_mma_slice( float c[][REG_PER_THREAD_C_TENSOR_16_16], uint32_t (*a)[4], uint32_t (*b)[4], uint32_t* __restrict__ A_1bit_SPTR_read, uint32_t* __restrict__ A_2bit_SPTR_read, uint32_t* __restrict__ A_4bit_SPTR_read, half (*__restrict__ B_SPTR_read)[WARP_K + PADDING_SHARED_MEM_FOR_B_8], uint32_t* RPTR_Scales, int slice_id) // writing slice[slice_id] to registers, k=0 -> slice_id=1 // for prefetching { // 1+2+4 weight split constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA; constexpr int USE_SEG_1BIT = BIT_WIDTH & 1; constexpr int USE_SEG_2BIT = BIT_WIDTH & 2; constexpr int USE_SEG_4BIT = BIT_WIDTH & 4; #ifdef DEBUG_MODE assert((TilingConfig::WARP_COL_MMA_TENSORS == 1) || (TilingConfig::WARP_COL_MMA_TENSORS % 2 == 0)); // if WARP_COL_MMA_TENSORS == 1, B tile in registers is padded // to a 16*16 MMA block #endif const int NumRegSets_a = WARP_ROW_MMA_TENSORS; // 1 set = 4 registers, containing a 16*16 MMA // block const int NumRegSets_b = (TilingConfig::WARP_COL_MMA_TENSORS == 1) ? 1 : TilingConfig::WARP_COL_MMA_TENSORS / 2; // 1 set = 4 registers, containing a 16*16 MMA block uint32_t(*c_uint_ptr)[REG_PER_THREAD_C_TENSOR_16_16] = reinterpret_cast( c); // GlobalRegisters for accumulated FP32 results // Setting RPTRs for double buffers uint32_t(*a_read)[4] = a; uint32_t(*a_write)[4] = a; uint32_t(*b_read)[4] = b; uint32_t(*b_write)[4] = b; if (slice_id % 2 == 1) { b_write += NumRegSets_b; a_write += NumRegSets_a; } else { b_read += NumRegSets_b; a_read += NumRegSets_a; } // Reading registers and issuing core tensor core computations (a slice of A and // B tile in shared memory) #pragma unroll for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) { if (TilingConfig::WARP_COL_MMA_TENSORS == 1) { MMA_FP16_M16N8K16(c_uint_ptr[i], a_read[i], b_read[0]); } else { #pragma unroll for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS / 2; j++) { MMA_FP16_M16N8K16(c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS], a_read[i], b_read[j]); MMA_FP16_M16N8K16(c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS] + 4, a_read[i], b_read[j] + 2); // c+4; b+2 } } } // Writing registers // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6 // per thread => 6 register per thread; uint32_t a_1bit[1]; // NO double buffer uint32_t a_2bit[2]; // NO double buffer uint32_t a_4bit[4]; // NO double buffer if (USE_SEG_1BIT) CopyFromSharedToRegister_AFrag<1>(a_1bit, A_1bit_SPTR_read, slice_id); if (USE_SEG_2BIT) CopyFromSharedToRegister_AFrag<2>(a_2bit, A_2bit_SPTR_read, slice_id); if (USE_SEG_4BIT) CopyFromSharedToRegister_AFrag<4>(a_4bit, A_4bit_SPTR_read, slice_id); Dequant_32FP6_4Way( a_write, a_1bit, a_2bit, a_4bit, RPTR_Scales); // SIMT Dequant: dequantizing FP6 to FP16 at register // level, dequantizing a slice each time B_FromSharedToReg( b_write, B_SPTR_read, slice_id); // Loading B from shared to registers } template __device__ __forceinline__ void StoreToSharedMemoryFromRegister( float (*smem_CFrag)[TilingConfig::TILE_M + PADDING_SHARED_MEM_FOR_C_4], float c[][REG_PER_THREAD_C_TENSOR_16_16]) { const int lane_id = threadIdx.x % WARP_SIZE; const int warpId = threadIdx.x / WARP_SIZE; int warp_row_offset = warpId * (MMA_16 * WARP_ROW_MMA_TENSORS); #pragma unroll for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) { #pragma unroll for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS; j++) { // Dealing with one 16*8 Tensor int RegSetID = i + (j / 2) * WARP_ROW_MMA_TENSORS; int RegOffset = (j % 2) * (REG_PER_THREAD_C_TENSOR_16_16 / 2); int Tensor_row_offset = warp_row_offset + i * MMA_16; int Tensor_col_offset = j * MMA_8; #pragma unroll for (int r = 0; r < REG_PER_THREAD_C_TENSOR_16_16 / 2; r++) { int row_offset = lane_id / 4; if (r >= 2) row_offset += 8; int col_offset = (lane_id % 4) * 2; if (r % 2 == 1) col_offset += 1; smem_CFrag[Tensor_col_offset + col_offset] [Tensor_row_offset + row_offset] = c[RegSetID][r + RegOffset]; } } } } #endif