// Copyright 2024 FP6-LLM authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // This file is modified from // https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_gmem.cuh #ifndef UTILS_GMEM_CUH #define UTILS_GMEM_CUH #include #include "configs.h" #include "ptx_cp.async.cuh" /* * Copying A1/A2 from global memory to shared memory. * Usually 1024 or 2048 Bytes */ template __device__ __forceinline__ void CopyFromGlobalToShared_A( uint32_t* SPTR, const uint4* GPTR, bool pred_guard = true) { #ifdef DEBUG_MODE static_assert(SMEM_SIZE_IN_BYTES_PER_WARP / WARP_SIZE % 16 == 0); #endif int lane_id = threadIdx.x % WARP_SIZE; half* SPTR_HALF = reinterpret_cast(SPTR); const half* GPTR_HALF = reinterpret_cast(GPTR); SPTR_HALF += lane_id * 8; GPTR_HALF += lane_id * 8; #pragma unroll for (int i = 0; i < SMEM_SIZE_IN_BYTES_PER_WARP / WARP_SIZE / 16; i++) { cp_async<16>(SPTR_HALF, GPTR_HALF, pred_guard); SPTR_HALF += 256; // Forward 512 Bytes GPTR_HALF += 256; // Forward 512 Bytes } } /* * Copying 64 Quant Scales (FP16) from global memory to shared memory. */ __device__ __forceinline__ void CopyFromGlobalToShared_Scales( half* SPTR_QuantScales, const half* GPTR_A_Scales) { int lane_id = threadIdx.x % WARP_SIZE; int Offset_Shared = lane_id * 2; int Offset_Global = lane_id / 4 + (lane_id % 4) * 16; for (int i = 0; i < 2; i++) SPTR_QuantScales[Offset_Shared + i] = GPTR_A_Scales[Offset_Global + i * 8]; } // MODIFICATION NOTE: to support MSVC, half __restrict__ // (*SharedPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below. /* * (1) Copying X rows * 64 columns of FP16 values, originally in row major * (2) Copying 64 rows * X columns of FP16 values, originally in column major * 16 Bytes per thread -> 512 Bytes per WARP = 4 line per WARP = 1 line per 8 * Threads */ template __device__ __forceinline__ void CopyFromGlobalToShared( half (*__restrict__ SharedPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8], const half* GlobalPTR, const int GlobalStride, const int NumOfLinesLeft, // To support arbitrary N dimensions. bool Pred = true) { // static parameters: 1 Group (8 Threads) can copy 1 line (64 FP16) each time const int NumOfThreads = BLOCK_WARPS * WARP_SIZE; const int NumOfGroups = NumOfThreads / 8; const int MaxIteration = (MaxNumOfLinesToCopy - 1) / NumOfGroups + 1; // runtime variables const int line_id = threadIdx.x / 8; const int line_offset = (threadIdx.x % 8) * 8; // PTR for source global memory and target shared memory GlobalPTR += line_id * GlobalStride + line_offset; SharedPTR += line_id; #pragma unroll for (int i = 0; i < MaxIteration; i++) { bool AsyncCopyPred = (line_id + i * NumOfGroups) < NumOfLinesLeft && Pred; cp_async<16>(&(*SharedPTR)[line_offset], GlobalPTR, AsyncCopyPred); // GlobalPTR += NumOfGroups * GlobalStride; SharedPTR += NumOfGroups; } } #endif