123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188 |
- // Copyright 2024 FP6-LLM authors
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- //
- // This file is modified from
- // https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/utils_core.cuh
- #ifndef UTILS_CORE_CUH
- #define UTILS_CORE_CUH
- #include <assert.h>
- #include "configs.h"
- #include "ptx_mma.cuh"
- #include "utils_parallel_dequant.cuh"
- template <int NUM_INT_PER_THREAD>
- __device__ __forceinline__ void CopyFromSharedToRegister_AFrag(uint32_t Reg[],
- uint32_t* SPTR,
- int slice_id) {
- SPTR += slice_id * (NUM_INT_PER_THREAD * WARP_SIZE);
- int lane_id = threadIdx.x % WARP_SIZE;
- #pragma unroll
- for (int i = 0; i < NUM_INT_PER_THREAD; i++) {
- Reg[i] = SPTR[lane_id + i * WARP_SIZE];
- }
- }
- // MODIFICATION NOTE: to support MSVC, half __restrict__
- // (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below.
- template <typename TilingConfig, int EXPONENT, int MANTISSA>
- __device__ __forceinline__ void initialize_mma_slice(
- uint32_t (*a)[4], uint32_t (*b)[4], uint32_t* __restrict__ A_1BIT_SPTR_read,
- uint32_t* __restrict__ A_2BIT_SPTR_read,
- uint32_t* __restrict__ A_4BIT_SPTR_read,
- half (*__restrict__ B_SPTR_read)[WARP_K + PADDING_SHARED_MEM_FOR_B_8],
- uint32_t* RPTR_Scales) {
- // 1+2+4 weight split
- constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA;
- constexpr int USE_SEG_1BIT = BIT_WIDTH & 1;
- constexpr int USE_SEG_2BIT = BIT_WIDTH & 2;
- constexpr int USE_SEG_4BIT = BIT_WIDTH & 4;
- // Writing registers
- // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6
- // per thread => 6 register per thread;
- uint32_t a_1bit[1]; // NO double buffer
- uint32_t a_2bit[2]; // NO double buffer
- uint32_t a_4bit[4]; // NO double buffer
- if (USE_SEG_1BIT)
- CopyFromSharedToRegister_AFrag<1>(a_1bit, A_1BIT_SPTR_read, 0);
- if (USE_SEG_2BIT)
- CopyFromSharedToRegister_AFrag<2>(a_2bit, A_2BIT_SPTR_read, 0);
- if (USE_SEG_4BIT)
- CopyFromSharedToRegister_AFrag<4>(a_4bit, A_4BIT_SPTR_read, 0);
- Dequant_32FP6_4Way<EXPONENT, MANTISSA>(
- a, a_1bit, a_2bit, a_4bit,
- RPTR_Scales); // SIMT Dequant: dequantizing FPx to FP16 at register
- // level, dequantizing a slice each time
- B_FromSharedToReg<TilingConfig>(b, B_SPTR_read,
- 0); // Loading B from shared to registers
- }
- // MODIFICATION NOTE: to support MSVC, half __restrict__
- // (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below.
- template <typename TilingConfig, int EXPONENT, int MANTISSA>
- __device__ __forceinline__ void core_mma_slice(
- float c[][REG_PER_THREAD_C_TENSOR_16_16], uint32_t (*a)[4],
- uint32_t (*b)[4], uint32_t* __restrict__ A_1bit_SPTR_read,
- uint32_t* __restrict__ A_2bit_SPTR_read,
- uint32_t* __restrict__ A_4bit_SPTR_read,
- half (*__restrict__ B_SPTR_read)[WARP_K + PADDING_SHARED_MEM_FOR_B_8],
- uint32_t* RPTR_Scales,
- int slice_id) // writing slice[slice_id] to registers, k=0 -> slice_id=1
- // for prefetching
- {
- // 1+2+4 weight split
- constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA;
- constexpr int USE_SEG_1BIT = BIT_WIDTH & 1;
- constexpr int USE_SEG_2BIT = BIT_WIDTH & 2;
- constexpr int USE_SEG_4BIT = BIT_WIDTH & 4;
- #ifdef DEBUG_MODE
- assert((TilingConfig::WARP_COL_MMA_TENSORS == 1) ||
- (TilingConfig::WARP_COL_MMA_TENSORS % 2 ==
- 0)); // if WARP_COL_MMA_TENSORS == 1, B tile in registers is padded
- // to a 16*16 MMA block
- #endif
- const int NumRegSets_a =
- WARP_ROW_MMA_TENSORS; // 1 set = 4 registers, containing a 16*16 MMA
- // block
- const int NumRegSets_b =
- (TilingConfig::WARP_COL_MMA_TENSORS == 1)
- ? 1
- : TilingConfig::WARP_COL_MMA_TENSORS /
- 2; // 1 set = 4 registers, containing a 16*16 MMA block
- uint32_t(*c_uint_ptr)[REG_PER_THREAD_C_TENSOR_16_16] =
- reinterpret_cast<uint32_t(*)[REG_PER_THREAD_C_TENSOR_16_16]>(
- c); // GlobalRegisters for accumulated FP32 results
- // Setting RPTRs for double buffers
- uint32_t(*a_read)[4] = a;
- uint32_t(*a_write)[4] = a;
- uint32_t(*b_read)[4] = b;
- uint32_t(*b_write)[4] = b;
- if (slice_id % 2 == 1) {
- b_write += NumRegSets_b;
- a_write += NumRegSets_a;
- } else {
- b_read += NumRegSets_b;
- a_read += NumRegSets_a;
- }
- // Reading registers and issuing core tensor core computations (a slice of A and
- // B tile in shared memory)
- #pragma unroll
- for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) {
- if (TilingConfig::WARP_COL_MMA_TENSORS == 1) {
- MMA_FP16_M16N8K16(c_uint_ptr[i], a_read[i], b_read[0]);
- } else {
- #pragma unroll
- for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS / 2; j++) {
- MMA_FP16_M16N8K16(c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS], a_read[i],
- b_read[j]);
- MMA_FP16_M16N8K16(c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS] + 4,
- a_read[i], b_read[j] + 2); // c+4; b+2
- }
- }
- }
- // Writing registers
- // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6
- // per thread => 6 register per thread;
- uint32_t a_1bit[1]; // NO double buffer
- uint32_t a_2bit[2]; // NO double buffer
- uint32_t a_4bit[4]; // NO double buffer
- if (USE_SEG_1BIT)
- CopyFromSharedToRegister_AFrag<1>(a_1bit, A_1bit_SPTR_read, slice_id);
- if (USE_SEG_2BIT)
- CopyFromSharedToRegister_AFrag<2>(a_2bit, A_2bit_SPTR_read, slice_id);
- if (USE_SEG_4BIT)
- CopyFromSharedToRegister_AFrag<4>(a_4bit, A_4bit_SPTR_read, slice_id);
- Dequant_32FP6_4Way<EXPONENT, MANTISSA>(
- a_write, a_1bit, a_2bit, a_4bit,
- RPTR_Scales); // SIMT Dequant: dequantizing FP6 to FP16 at register
- // level, dequantizing a slice each time
- B_FromSharedToReg<TilingConfig>(
- b_write, B_SPTR_read, slice_id); // Loading B from shared to registers
- }
- template <typename TilingConfig>
- __device__ __forceinline__ void StoreToSharedMemoryFromRegister(
- float (*smem_CFrag)[TilingConfig::TILE_M + PADDING_SHARED_MEM_FOR_C_4],
- float c[][REG_PER_THREAD_C_TENSOR_16_16]) {
- const int lane_id = threadIdx.x % WARP_SIZE;
- const int warpId = threadIdx.x / WARP_SIZE;
- int warp_row_offset = warpId * (MMA_16 * WARP_ROW_MMA_TENSORS);
- #pragma unroll
- for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) {
- #pragma unroll
- for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS;
- j++) { // Dealing with one 16*8 Tensor
- int RegSetID = i + (j / 2) * WARP_ROW_MMA_TENSORS;
- int RegOffset = (j % 2) * (REG_PER_THREAD_C_TENSOR_16_16 / 2);
- int Tensor_row_offset = warp_row_offset + i * MMA_16;
- int Tensor_col_offset = j * MMA_8;
- #pragma unroll
- for (int r = 0; r < REG_PER_THREAD_C_TENSOR_16_16 / 2; r++) {
- int row_offset = lane_id / 4;
- if (r >= 2) row_offset += 8;
- int col_offset = (lane_id % 4) * 2;
- if (r % 2 == 1) col_offset += 1;
- smem_CFrag[Tensor_col_offset + col_offset]
- [Tensor_row_offset + row_offset] = c[RegSetID][r + RegOffset];
- }
- }
- }
- }
- #endif
|