utils_core.cuh 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. // Copyright 2024 FP6-LLM authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // This file is modified from
  16. // https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/utils_core.cuh
  17. #ifndef UTILS_CORE_CUH
  18. #define UTILS_CORE_CUH
  19. #include <assert.h>
  20. #include "configs.h"
  21. #include "ptx_mma.cuh"
  22. #include "utils_parallel_dequant.cuh"
  23. template <int NUM_INT_PER_THREAD>
  24. __device__ __forceinline__ void CopyFromSharedToRegister_AFrag(uint32_t Reg[],
  25. uint32_t* SPTR,
  26. int slice_id) {
  27. SPTR += slice_id * (NUM_INT_PER_THREAD * WARP_SIZE);
  28. int lane_id = threadIdx.x % WARP_SIZE;
  29. #pragma unroll
  30. for (int i = 0; i < NUM_INT_PER_THREAD; i++) {
  31. Reg[i] = SPTR[lane_id + i * WARP_SIZE];
  32. }
  33. }
  34. // MODIFICATION NOTE: to support MSVC, half __restrict__
  35. // (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below.
  36. template <typename TilingConfig, int EXPONENT, int MANTISSA>
  37. __device__ __forceinline__ void initialize_mma_slice(
  38. uint32_t (*a)[4], uint32_t (*b)[4], uint32_t* __restrict__ A_1BIT_SPTR_read,
  39. uint32_t* __restrict__ A_2BIT_SPTR_read,
  40. uint32_t* __restrict__ A_4BIT_SPTR_read,
  41. half (*__restrict__ B_SPTR_read)[WARP_K + PADDING_SHARED_MEM_FOR_B_8],
  42. uint32_t* RPTR_Scales) {
  43. // 1+2+4 weight split
  44. constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA;
  45. constexpr int USE_SEG_1BIT = BIT_WIDTH & 1;
  46. constexpr int USE_SEG_2BIT = BIT_WIDTH & 2;
  47. constexpr int USE_SEG_4BIT = BIT_WIDTH & 4;
  48. // Writing registers
  49. // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6
  50. // per thread => 6 register per thread;
  51. uint32_t a_1bit[1]; // NO double buffer
  52. uint32_t a_2bit[2]; // NO double buffer
  53. uint32_t a_4bit[4]; // NO double buffer
  54. if (USE_SEG_1BIT)
  55. CopyFromSharedToRegister_AFrag<1>(a_1bit, A_1BIT_SPTR_read, 0);
  56. if (USE_SEG_2BIT)
  57. CopyFromSharedToRegister_AFrag<2>(a_2bit, A_2BIT_SPTR_read, 0);
  58. if (USE_SEG_4BIT)
  59. CopyFromSharedToRegister_AFrag<4>(a_4bit, A_4BIT_SPTR_read, 0);
  60. Dequant_32FP6_4Way<EXPONENT, MANTISSA>(
  61. a, a_1bit, a_2bit, a_4bit,
  62. RPTR_Scales); // SIMT Dequant: dequantizing FPx to FP16 at register
  63. // level, dequantizing a slice each time
  64. B_FromSharedToReg<TilingConfig>(b, B_SPTR_read,
  65. 0); // Loading B from shared to registers
  66. }
  67. // MODIFICATION NOTE: to support MSVC, half __restrict__
  68. // (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below.
  69. template <typename TilingConfig, int EXPONENT, int MANTISSA>
  70. __device__ __forceinline__ void core_mma_slice(
  71. float c[][REG_PER_THREAD_C_TENSOR_16_16], uint32_t (*a)[4],
  72. uint32_t (*b)[4], uint32_t* __restrict__ A_1bit_SPTR_read,
  73. uint32_t* __restrict__ A_2bit_SPTR_read,
  74. uint32_t* __restrict__ A_4bit_SPTR_read,
  75. half (*__restrict__ B_SPTR_read)[WARP_K + PADDING_SHARED_MEM_FOR_B_8],
  76. uint32_t* RPTR_Scales,
  77. int slice_id) // writing slice[slice_id] to registers, k=0 -> slice_id=1
  78. // for prefetching
  79. {
  80. // 1+2+4 weight split
  81. constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA;
  82. constexpr int USE_SEG_1BIT = BIT_WIDTH & 1;
  83. constexpr int USE_SEG_2BIT = BIT_WIDTH & 2;
  84. constexpr int USE_SEG_4BIT = BIT_WIDTH & 4;
  85. #ifdef DEBUG_MODE
  86. assert((TilingConfig::WARP_COL_MMA_TENSORS == 1) ||
  87. (TilingConfig::WARP_COL_MMA_TENSORS % 2 ==
  88. 0)); // if WARP_COL_MMA_TENSORS == 1, B tile in registers is padded
  89. // to a 16*16 MMA block
  90. #endif
  91. const int NumRegSets_a =
  92. WARP_ROW_MMA_TENSORS; // 1 set = 4 registers, containing a 16*16 MMA
  93. // block
  94. const int NumRegSets_b =
  95. (TilingConfig::WARP_COL_MMA_TENSORS == 1)
  96. ? 1
  97. : TilingConfig::WARP_COL_MMA_TENSORS /
  98. 2; // 1 set = 4 registers, containing a 16*16 MMA block
  99. uint32_t(*c_uint_ptr)[REG_PER_THREAD_C_TENSOR_16_16] =
  100. reinterpret_cast<uint32_t(*)[REG_PER_THREAD_C_TENSOR_16_16]>(
  101. c); // GlobalRegisters for accumulated FP32 results
  102. // Setting RPTRs for double buffers
  103. uint32_t(*a_read)[4] = a;
  104. uint32_t(*a_write)[4] = a;
  105. uint32_t(*b_read)[4] = b;
  106. uint32_t(*b_write)[4] = b;
  107. if (slice_id % 2 == 1) {
  108. b_write += NumRegSets_b;
  109. a_write += NumRegSets_a;
  110. } else {
  111. b_read += NumRegSets_b;
  112. a_read += NumRegSets_a;
  113. }
  114. // Reading registers and issuing core tensor core computations (a slice of A and
  115. // B tile in shared memory)
  116. #pragma unroll
  117. for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) {
  118. if (TilingConfig::WARP_COL_MMA_TENSORS == 1) {
  119. MMA_FP16_M16N8K16(c_uint_ptr[i], a_read[i], b_read[0]);
  120. } else {
  121. #pragma unroll
  122. for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS / 2; j++) {
  123. MMA_FP16_M16N8K16(c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS], a_read[i],
  124. b_read[j]);
  125. MMA_FP16_M16N8K16(c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS] + 4,
  126. a_read[i], b_read[j] + 2); // c+4; b+2
  127. }
  128. }
  129. }
  130. // Writing registers
  131. // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6
  132. // per thread => 6 register per thread;
  133. uint32_t a_1bit[1]; // NO double buffer
  134. uint32_t a_2bit[2]; // NO double buffer
  135. uint32_t a_4bit[4]; // NO double buffer
  136. if (USE_SEG_1BIT)
  137. CopyFromSharedToRegister_AFrag<1>(a_1bit, A_1bit_SPTR_read, slice_id);
  138. if (USE_SEG_2BIT)
  139. CopyFromSharedToRegister_AFrag<2>(a_2bit, A_2bit_SPTR_read, slice_id);
  140. if (USE_SEG_4BIT)
  141. CopyFromSharedToRegister_AFrag<4>(a_4bit, A_4bit_SPTR_read, slice_id);
  142. Dequant_32FP6_4Way<EXPONENT, MANTISSA>(
  143. a_write, a_1bit, a_2bit, a_4bit,
  144. RPTR_Scales); // SIMT Dequant: dequantizing FP6 to FP16 at register
  145. // level, dequantizing a slice each time
  146. B_FromSharedToReg<TilingConfig>(
  147. b_write, B_SPTR_read, slice_id); // Loading B from shared to registers
  148. }
  149. template <typename TilingConfig>
  150. __device__ __forceinline__ void StoreToSharedMemoryFromRegister(
  151. float (*smem_CFrag)[TilingConfig::TILE_M + PADDING_SHARED_MEM_FOR_C_4],
  152. float c[][REG_PER_THREAD_C_TENSOR_16_16]) {
  153. const int lane_id = threadIdx.x % WARP_SIZE;
  154. const int warpId = threadIdx.x / WARP_SIZE;
  155. int warp_row_offset = warpId * (MMA_16 * WARP_ROW_MMA_TENSORS);
  156. #pragma unroll
  157. for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) {
  158. #pragma unroll
  159. for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS;
  160. j++) { // Dealing with one 16*8 Tensor
  161. int RegSetID = i + (j / 2) * WARP_ROW_MMA_TENSORS;
  162. int RegOffset = (j % 2) * (REG_PER_THREAD_C_TENSOR_16_16 / 2);
  163. int Tensor_row_offset = warp_row_offset + i * MMA_16;
  164. int Tensor_col_offset = j * MMA_8;
  165. #pragma unroll
  166. for (int r = 0; r < REG_PER_THREAD_C_TENSOR_16_16 / 2; r++) {
  167. int row_offset = lane_id / 4;
  168. if (r >= 2) row_offset += 8;
  169. int col_offset = (lane_id % 4) * 2;
  170. if (r % 2 == 1) col_offset += 1;
  171. smem_CFrag[Tensor_col_offset + col_offset]
  172. [Tensor_row_offset + row_offset] = c[RegSetID][r + RegOffset];
  173. }
  174. }
  175. }
  176. }
  177. #endif