utils_gmem.cuh 3.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. // Copyright 2024 FP6-LLM authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // This file is modified from
  16. // https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_gmem.cuh
  17. #ifndef UTILS_GMEM_CUH
  18. #define UTILS_GMEM_CUH
  19. #include <assert.h>
  20. #include "configs.h"
  21. #include "ptx_cp.async.cuh"
  22. /*
  23. * Copying A1/A2 from global memory to shared memory.
  24. * Usually 1024 or 2048 Bytes
  25. */
  26. template <int SMEM_SIZE_IN_BYTES_PER_WARP>
  27. __device__ __forceinline__ void CopyFromGlobalToShared_A(
  28. uint32_t* SPTR, const uint4* GPTR, bool pred_guard = true) {
  29. #ifdef DEBUG_MODE
  30. static_assert(SMEM_SIZE_IN_BYTES_PER_WARP / WARP_SIZE % 16 == 0);
  31. #endif
  32. int lane_id = threadIdx.x % WARP_SIZE;
  33. half* SPTR_HALF = reinterpret_cast<half*>(SPTR);
  34. const half* GPTR_HALF = reinterpret_cast<const half*>(GPTR);
  35. SPTR_HALF += lane_id * 8;
  36. GPTR_HALF += lane_id * 8;
  37. #pragma unroll
  38. for (int i = 0; i < SMEM_SIZE_IN_BYTES_PER_WARP / WARP_SIZE / 16; i++) {
  39. cp_async<16>(SPTR_HALF, GPTR_HALF, pred_guard);
  40. SPTR_HALF += 256; // Forward 512 Bytes
  41. GPTR_HALF += 256; // Forward 512 Bytes
  42. }
  43. }
  44. /*
  45. * Copying 64 Quant Scales (FP16) from global memory to shared memory.
  46. */
  47. __device__ __forceinline__ void CopyFromGlobalToShared_Scales(
  48. half* SPTR_QuantScales, const half* GPTR_A_Scales) {
  49. int lane_id = threadIdx.x % WARP_SIZE;
  50. int Offset_Shared = lane_id * 2;
  51. int Offset_Global = lane_id / 4 + (lane_id % 4) * 16;
  52. for (int i = 0; i < 2; i++)
  53. SPTR_QuantScales[Offset_Shared + i] = GPTR_A_Scales[Offset_Global + i * 8];
  54. }
  55. // MODIFICATION NOTE: to support MSVC, half __restrict__
  56. // (*SharedPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below.
  57. /*
  58. * (1) Copying X rows * 64 columns of FP16 values, originally in row major
  59. * (2) Copying 64 rows * X columns of FP16 values, originally in column major
  60. * 16 Bytes per thread -> 512 Bytes per WARP = 4 line per WARP = 1 line per 8
  61. * Threads
  62. */
  63. template <int MaxNumOfLinesToCopy, int BLOCK_WARPS>
  64. __device__ __forceinline__ void CopyFromGlobalToShared(
  65. half (*__restrict__ SharedPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8],
  66. const half* GlobalPTR, const int GlobalStride,
  67. const int NumOfLinesLeft, // To support arbitrary N dimensions.
  68. bool Pred = true) {
  69. // static parameters: 1 Group (8 Threads) can copy 1 line (64 FP16) each time
  70. const int NumOfThreads = BLOCK_WARPS * WARP_SIZE;
  71. const int NumOfGroups = NumOfThreads / 8;
  72. const int MaxIteration = (MaxNumOfLinesToCopy - 1) / NumOfGroups + 1;
  73. // runtime variables
  74. const int line_id = threadIdx.x / 8;
  75. const int line_offset = (threadIdx.x % 8) * 8;
  76. // PTR for source global memory and target shared memory
  77. GlobalPTR += line_id * GlobalStride + line_offset;
  78. SharedPTR += line_id;
  79. #pragma unroll
  80. for (int i = 0; i < MaxIteration; i++) {
  81. bool AsyncCopyPred = (line_id + i * NumOfGroups) < NumOfLinesLeft && Pred;
  82. cp_async<16>(&(*SharedPTR)[line_offset], GlobalPTR, AsyncCopyPred);
  83. //
  84. GlobalPTR += NumOfGroups * GlobalStride;
  85. SharedPTR += NumOfGroups;
  86. }
  87. }
  88. #endif