12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394 |
- // Copyright 2024 FP6-LLM authors
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- //
- // This file is modified from
- // https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_gmem.cuh
- #ifndef UTILS_GMEM_CUH
- #define UTILS_GMEM_CUH
- #include <assert.h>
- #include "configs.h"
- #include "ptx_cp.async.cuh"
- /*
- * Copying A1/A2 from global memory to shared memory.
- * Usually 1024 or 2048 Bytes
- */
- template <int SMEM_SIZE_IN_BYTES_PER_WARP>
- __device__ __forceinline__ void CopyFromGlobalToShared_A(
- uint32_t* SPTR, const uint4* GPTR, bool pred_guard = true) {
- #ifdef DEBUG_MODE
- static_assert(SMEM_SIZE_IN_BYTES_PER_WARP / WARP_SIZE % 16 == 0);
- #endif
- int lane_id = threadIdx.x % WARP_SIZE;
- half* SPTR_HALF = reinterpret_cast<half*>(SPTR);
- const half* GPTR_HALF = reinterpret_cast<const half*>(GPTR);
- SPTR_HALF += lane_id * 8;
- GPTR_HALF += lane_id * 8;
- #pragma unroll
- for (int i = 0; i < SMEM_SIZE_IN_BYTES_PER_WARP / WARP_SIZE / 16; i++) {
- cp_async<16>(SPTR_HALF, GPTR_HALF, pred_guard);
- SPTR_HALF += 256; // Forward 512 Bytes
- GPTR_HALF += 256; // Forward 512 Bytes
- }
- }
- /*
- * Copying 64 Quant Scales (FP16) from global memory to shared memory.
- */
- __device__ __forceinline__ void CopyFromGlobalToShared_Scales(
- half* SPTR_QuantScales, const half* GPTR_A_Scales) {
- int lane_id = threadIdx.x % WARP_SIZE;
- int Offset_Shared = lane_id * 2;
- int Offset_Global = lane_id / 4 + (lane_id % 4) * 16;
- for (int i = 0; i < 2; i++)
- SPTR_QuantScales[Offset_Shared + i] = GPTR_A_Scales[Offset_Global + i * 8];
- }
- // MODIFICATION NOTE: to support MSVC, half __restrict__
- // (*SharedPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below.
- /*
- * (1) Copying X rows * 64 columns of FP16 values, originally in row major
- * (2) Copying 64 rows * X columns of FP16 values, originally in column major
- * 16 Bytes per thread -> 512 Bytes per WARP = 4 line per WARP = 1 line per 8
- * Threads
- */
- template <int MaxNumOfLinesToCopy, int BLOCK_WARPS>
- __device__ __forceinline__ void CopyFromGlobalToShared(
- half (*__restrict__ SharedPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8],
- const half* GlobalPTR, const int GlobalStride,
- const int NumOfLinesLeft, // To support arbitrary N dimensions.
- bool Pred = true) {
- // static parameters: 1 Group (8 Threads) can copy 1 line (64 FP16) each time
- const int NumOfThreads = BLOCK_WARPS * WARP_SIZE;
- const int NumOfGroups = NumOfThreads / 8;
- const int MaxIteration = (MaxNumOfLinesToCopy - 1) / NumOfGroups + 1;
- // runtime variables
- const int line_id = threadIdx.x / 8;
- const int line_offset = (threadIdx.x % 8) * 8;
- // PTR for source global memory and target shared memory
- GlobalPTR += line_id * GlobalStride + line_offset;
- SharedPTR += line_id;
- #pragma unroll
- for (int i = 0; i < MaxIteration; i++) {
- bool AsyncCopyPred = (line_id + i * NumOfGroups) < NumOfLinesLeft && Pred;
- cp_async<16>(&(*SharedPTR)[line_offset], GlobalPTR, AsyncCopyPred);
- //
- GlobalPTR += NumOfGroups * GlobalStride;
- SharedPTR += NumOfGroups;
- }
- }
- #endif
|