1
0

kernel_matmul.cuh 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. // Copyright 2024 FP6-LLM authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // This file is modified from
  16. // https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/kernel_matmul.cuh
  17. #include "configs.h"
  18. #include "utils_gmem.cuh"
  19. #include "utils_core.cuh"
  20. /************************** Bitwidth of Weight Segments
  21. * ************************/
  22. #define BIT_WIDTH_1 1
  23. #define BIT_WIDTH_2 2
  24. #define BIT_WIDTH_4 4
  25. /*************************** 64*64 Weghts of Weight Matrix
  26. * *********************/
  27. #define WEIGHT_PER_WARP (WARP_M * WARP_K) // 64*64 = 4096
  28. #define SMEM_SIZE_PER_WARP_1BIT \
  29. (WEIGHT_PER_WARP * BIT_WIDTH_1 / \
  30. 8) // 512 Bytes, doubleBuffer not taken into consideration
  31. #define SMEM_SIZE_PER_WARP_2BIT \
  32. (WEIGHT_PER_WARP * BIT_WIDTH_2 / \
  33. 8) // 1024 Bytes, doubleBuffer not taken into consideration
  34. #define SMEM_SIZE_PER_WARP_4BIT \
  35. (WEIGHT_PER_WARP * BIT_WIDTH_4 / \
  36. 8) // 2048 Bytes, doubleBuffer not taken into consideration
  37. #define SMEM_SIZE_PER_TB_1BIT \
  38. (SMEM_SIZE_PER_WARP_1BIT * TilingConfig::BLOCK_WARPS * \
  39. PIPELINE_LEVEL_GMEM) // #WARP=4; Trible-Buffer for 3-level pipeline for A
  40. // = 6 KB; double buffer for 2-level pipeline A= 4
  41. // KB.
  42. #define SMEM_SIZE_PER_TB_2BIT \
  43. (SMEM_SIZE_PER_WARP_2BIT * TilingConfig::BLOCK_WARPS * \
  44. PIPELINE_LEVEL_GMEM) // #WARP=4; Trible-Buffer for 3-level pipeline for A
  45. // = 12 KB; double buffer for 2-level pipeline A= 8
  46. // KB.
  47. #define SMEM_SIZE_PER_TB_4BIT \
  48. (SMEM_SIZE_PER_WARP_4BIT * TilingConfig::BLOCK_WARPS * \
  49. PIPELINE_LEVEL_GMEM) // #WARP=4; Trible-Buffer for 3-level pipeline for A
  50. // = 24 KB; double buffer for 2-level pipeline A= 16
  51. // KB.
  52. #define SMEM_SIZE_PER_TB_A_TILE \
  53. (SMEM_SIZE_PER_TB_1BIT + SMEM_SIZE_PER_TB_2BIT + \
  54. SMEM_SIZE_PER_TB_4BIT) // used in fp6_linear.cu, Kernel_Ex().
  55. /******************** Global Memory Layout For QUANTIZED DATA
  56. * *******************/
  57. #define NUM_INT4_PER_WARP_1BIT (WEIGHT_PER_WARP * BIT_WIDTH_1 / 128) // 32
  58. #define NUM_INT4_PER_WARP_2BIT (WEIGHT_PER_WARP * BIT_WIDTH_2 / 128) // 64
  59. #define NUM_INT4_PER_WARP_4BIT (WEIGHT_PER_WARP * BIT_WIDTH_4 / 128) // 128
  60. /*
  61. * C = A*B
  62. * A: row major with ahead-of-time layout transformation, FP6
  63. * B: col major, FP16
  64. * C: col major, FP16
  65. */
  66. template <typename TilingConfig, typename OutputDataType, int EXPONENT,
  67. int MANTISSA>
  68. __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales,
  69. const half* B, OutputDataType* C,
  70. const size_t M_Global, const size_t N_Global,
  71. const size_t K_Global, int Split_K) {
  72. #ifdef DEBUG_MODE
  73. assert(K_Global % TilingConfig::TILE_K == 0);
  74. assert(M_Global % TilingConfig::TILE_M == 0);
  75. assert(gridDim.y == Split_K * (M_Global / TilingConfig::TILE_M));
  76. #endif
  77. // 1+2+4 weight split
  78. constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA;
  79. constexpr int USE_SEG_1BIT = BIT_WIDTH & 1;
  80. constexpr int USE_SEG_2BIT = BIT_WIDTH & 2;
  81. constexpr int USE_SEG_4BIT = BIT_WIDTH & 4;
  82. const uint4* Weight_1bit = Weight;
  83. const uint4* Weight_2bit =
  84. Weight_1bit +
  85. (USE_SEG_1BIT ? M_Global * K_Global * BIT_WIDTH_1 / 128 : 0);
  86. const uint4* Weight_4bit =
  87. Weight_2bit +
  88. (USE_SEG_2BIT ? M_Global * K_Global * BIT_WIDTH_2 / 128 : 0);
  89. // Dynamic shared memory for FP16 A tiles, 128 Bytes aligned
  90. extern __shared__ __align__(128) half smem[];
  91. half(*smem_array)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] =
  92. reinterpret_cast<half(*)[WARP_K + PADDING_SHARED_MEM_FOR_B_8]>(
  93. smem + SMEM_SIZE_PER_TB_A_TILE /
  94. 2); // Dynamic shared memory for FP16 B tiles
  95. __shared__ half
  96. QuantScales[64 *
  97. TilingConfig::BLOCK_WARPS]; // static shared memory for
  98. // quantization scales, 64 row
  99. // per warp * 4 warps = 512 Bytes
  100. // Thread Block Mapping, considering SplitK
  101. const size_t BatchID = blockIdx.y / (M_Global / TilingConfig::TILE_M);
  102. const size_t x =
  103. blockIdx.x; // Output Block ID: (BlockID_Row = y; BlockID_Col = x )
  104. const size_t y =
  105. blockIdx.y %
  106. (M_Global / TilingConfig::TILE_M); // Output Block ID: (BlockID_Row = y;
  107. // BlockID_Col = x )
  108. const size_t Tile_Start_M = y * TilingConfig::TILE_M;
  109. const size_t Tile_Start_N = x * TilingConfig::TILE_N;
  110. const size_t NumColumnToCopy =
  111. (N_Global - Tile_Start_N) < TilingConfig::TILE_N
  112. ? (N_Global - Tile_Start_N)
  113. : TilingConfig::TILE_N;
  114. const size_t NumBlock_K = K_Global / TilingConfig::TILE_K;
  115. const size_t AverageNumBlock_K = NumBlock_K / Split_K;
  116. const size_t ExtraNumBlock_K = NumBlock_K - AverageNumBlock_K * Split_K;
  117. size_t NumIter = AverageNumBlock_K;
  118. size_t StartBlockID_K = AverageNumBlock_K * BatchID;
  119. if (BatchID < ExtraNumBlock_K) {
  120. NumIter++;
  121. StartBlockID_K += BatchID;
  122. } else
  123. StartBlockID_K += ExtraNumBlock_K;
  124. // Warp ID.
  125. const int warpId = threadIdx.x / WARP_SIZE;
  126. int WARP_i = warpId / TilingConfig::BLOCK_COL_WARPS; // WARP_i: row number;
  127. // WARP_j: column number
  128. // int WARP_j = warpId % TilingConfig::BLOCK_COL_WARPS;
  129. // Global Memory Address for Matrix A (Weight)
  130. // /////////////////////////////////////////////////////////////////////////
  131. // StartPTR for each ThreadBlock(TB)
  132. const uint4* TB_StartGPTR_A_1BIT =
  133. Weight_1bit +
  134. (y * TilingConfig::BLOCK_ROW_WARPS) * NumBlock_K * NUM_INT4_PER_WARP_1BIT;
  135. const uint4* TB_StartGPTR_A_2BIT =
  136. Weight_2bit +
  137. (y * TilingConfig::BLOCK_ROW_WARPS) * NumBlock_K * NUM_INT4_PER_WARP_2BIT;
  138. const uint4* TB_StartGPTR_A_4BIT =
  139. Weight_4bit +
  140. (y * TilingConfig::BLOCK_ROW_WARPS) * NumBlock_K * NUM_INT4_PER_WARP_4BIT;
  141. // StartPTR for each WARP.
  142. const uint4* WARP_StartGPTR_A_1BIT =
  143. TB_StartGPTR_A_1BIT + WARP_i * NumBlock_K * NUM_INT4_PER_WARP_1BIT;
  144. const uint4* WARP_StartGPTR_A_2BIT =
  145. TB_StartGPTR_A_2BIT + WARP_i * NumBlock_K * NUM_INT4_PER_WARP_2BIT;
  146. const uint4* WARP_StartGPTR_A_4BIT =
  147. TB_StartGPTR_A_4BIT + WARP_i * NumBlock_K * NUM_INT4_PER_WARP_4BIT;
  148. // StartPTR for each WARP, considering SplitK
  149. const size_t WARP_Start_UnitID_K = StartBlockID_K;
  150. WARP_StartGPTR_A_1BIT += WARP_Start_UnitID_K * NUM_INT4_PER_WARP_1BIT;
  151. WARP_StartGPTR_A_2BIT += WARP_Start_UnitID_K * NUM_INT4_PER_WARP_2BIT;
  152. WARP_StartGPTR_A_4BIT += WARP_Start_UnitID_K * NUM_INT4_PER_WARP_4BIT;
  153. // Copying A tile from Global to Shared, using double-buffer
  154. // ////////////////////////////////////////////////////////// StartSPTR for
  155. // each ThreadBlock
  156. uint32_t* AFrag_1BIT_SPTR = reinterpret_cast<uint32_t*>(smem);
  157. uint32_t* AFrag_2BIT_SPTR = AFrag_1BIT_SPTR + SMEM_SIZE_PER_TB_1BIT / 4;
  158. uint32_t* AFrag_4BIT_SPTR =
  159. AFrag_2BIT_SPTR +
  160. SMEM_SIZE_PER_TB_2BIT /
  161. 4; // 8 buffers including double buffers, 12 for trible buffers
  162. // StartSPTR for each WARP
  163. AFrag_1BIT_SPTR += warpId * SMEM_SIZE_PER_WARP_1BIT / 4;
  164. AFrag_2BIT_SPTR += warpId * SMEM_SIZE_PER_WARP_2BIT / 4;
  165. AFrag_4BIT_SPTR += warpId * SMEM_SIZE_PER_WARP_4BIT / 4;
  166. // Pre-fetch of A tile
  167. for (int i = 0; i < PIPELINE_LEVEL_GMEM - 1; i++) {
  168. if (USE_SEG_1BIT)
  169. CopyFromGlobalToShared_A<SMEM_SIZE_PER_WARP_1BIT>(
  170. AFrag_1BIT_SPTR + i * SMEM_SIZE_PER_WARP_1BIT / 4 * 4,
  171. WARP_StartGPTR_A_1BIT);
  172. if (USE_SEG_2BIT)
  173. CopyFromGlobalToShared_A<SMEM_SIZE_PER_WARP_2BIT>(
  174. AFrag_2BIT_SPTR + i * SMEM_SIZE_PER_WARP_2BIT / 4 * 4,
  175. WARP_StartGPTR_A_2BIT);
  176. if (USE_SEG_4BIT)
  177. CopyFromGlobalToShared_A<SMEM_SIZE_PER_WARP_4BIT>(
  178. AFrag_4BIT_SPTR + i * SMEM_SIZE_PER_WARP_4BIT / 4 * 4,
  179. WARP_StartGPTR_A_4BIT);
  180. WARP_StartGPTR_A_1BIT += SMEM_SIZE_PER_WARP_1BIT / 16;
  181. WARP_StartGPTR_A_2BIT += SMEM_SIZE_PER_WARP_2BIT / 16;
  182. WARP_StartGPTR_A_4BIT += SMEM_SIZE_PER_WARP_4BIT / 16;
  183. }
  184. // Global Memory Address for Matrix A (QuantScale)
  185. // /////////////////////////////////////////////////////////////////////
  186. const half* TB_StartGPTR_A_Scale =
  187. Scales + (y * TilingConfig::BLOCK_ROW_WARPS) * 64;
  188. const half* WARP_StartGPTR_A_Scales = TB_StartGPTR_A_Scale + WARP_i * 64;
  189. CopyFromGlobalToShared_Scales(QuantScales + WARP_i * 64,
  190. WARP_StartGPTR_A_Scales);
  191. // Copying B tile from Global to Shared, considering SplitK
  192. // /////////////////////////////////////////////////////////////
  193. const half* BTile_GPTR =
  194. B + Tile_Start_N * K_Global + StartBlockID_K * TilingConfig::TILE_K;
  195. for (int i = 0; i < PIPELINE_LEVEL_GMEM - 1; i++) {
  196. CopyFromGlobalToShared<TilingConfig::TILE_N, TilingConfig::BLOCK_WARPS>(
  197. smem_array + i * TilingConfig::TILE_N, BTile_GPTR, K_Global,
  198. NumColumnToCopy);
  199. BTile_GPTR += TilingConfig::TILE_K;
  200. }
  201. // Register Allocation for A,B, and C, Initilazed to Zeros
  202. // /////////////////////////////////////////////////////////////////////
  203. constexpr int NumRegSets_a =
  204. WARP_ROW_MMA_TENSORS; // 1 set = 4 registers, containing a 16*16 MMA
  205. // block
  206. constexpr int NumRegSets_b =
  207. (TilingConfig::WARP_COL_MMA_TENSORS == 1)
  208. ? 1
  209. : TilingConfig::WARP_COL_MMA_TENSORS /
  210. 2; // 1 set = 4 registers, containing a 16*16 MMA block
  211. uint32_t a[NumRegSets_a * PIPELINE_LEVEL_SMEM]
  212. [4]; // double/Trible buffer is used // Registers to store
  213. // decompressed FP6
  214. uint32_t b[NumRegSets_b * PIPELINE_LEVEL_SMEM]
  215. [4]; // double/Triple buffer is used // Register to store FP16 B
  216. // matrix (a slice)
  217. float c[NumRegSets_a * NumRegSets_b][REG_PER_THREAD_C_TENSOR_16_16];
  218. for (int i = 0; i < NumRegSets_a * NumRegSets_b; i++)
  219. for (int j = 0; j < REG_PER_THREAD_C_TENSOR_16_16; j++) c[i][j] = 0.0f;
  220. //
  221. cp_async_wait_all();
  222. __syncthreads();
  223. /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
  224. uint32_t Scales_RPTR[4]; // 4 Registers per thread for Quantization Scales
  225. ExtractFromSharedToReg_Scales(Scales_RPTR, QuantScales + WARP_i * 64);
  226. // Initializing the Software Pipeline: writing registers.
  227. // ////////////////////////////////////////////////////////////////////////////////////////////////
  228. initialize_mma_slice<TilingConfig, EXPONENT, MANTISSA>(
  229. a, b, AFrag_1BIT_SPTR, AFrag_2BIT_SPTR, AFrag_4BIT_SPTR, smem_array,
  230. Scales_RPTR);
  231. // The outer loop.
  232. // /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
  233. #pragma unroll(1)
  234. for (size_t tile_id_k = 0; tile_id_k < NumIter; tile_id_k++) {
  235. // Trible-Buffer for A Tile
  236. uint32_t* __restrict__ read_SPTR_Frag_1bit =
  237. AFrag_1BIT_SPTR +
  238. ((tile_id_k + 0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_1BIT / 4 *
  239. 4; // 512 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16
  240. uint32_t* __restrict__ read_SPTR_Frag_2bit =
  241. AFrag_2BIT_SPTR +
  242. ((tile_id_k + 0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_2BIT / 4 *
  243. 4; // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16
  244. uint32_t* __restrict__ read_SPTR_Frag_4bit =
  245. AFrag_4BIT_SPTR +
  246. ((tile_id_k + 0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_4BIT / 4 *
  247. 4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16
  248. uint32_t* __restrict__ read2_SPTR_Frag_1bit =
  249. AFrag_1BIT_SPTR + ((tile_id_k + 1) % PIPELINE_LEVEL_GMEM) *
  250. SMEM_SIZE_PER_WARP_1BIT / 4 * 4;
  251. uint32_t* __restrict__ read2_SPTR_Frag_2bit =
  252. AFrag_2BIT_SPTR + ((tile_id_k + 1) % PIPELINE_LEVEL_GMEM) *
  253. SMEM_SIZE_PER_WARP_2BIT / 4 * 4;
  254. uint32_t* __restrict__ read2_SPTR_Frag_4bit =
  255. AFrag_4BIT_SPTR + ((tile_id_k + 1) % PIPELINE_LEVEL_GMEM) *
  256. SMEM_SIZE_PER_WARP_4BIT / 4 * 4;
  257. uint32_t* __restrict__ write_SPTR_Frag_1bit =
  258. AFrag_1BIT_SPTR +
  259. ((tile_id_k + (PIPELINE_LEVEL_GMEM - 1)) % PIPELINE_LEVEL_GMEM) *
  260. SMEM_SIZE_PER_WARP_1BIT / 4 *
  261. 4; // 512 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16
  262. uint32_t* __restrict__ write_SPTR_Frag_2bit =
  263. AFrag_2BIT_SPTR +
  264. ((tile_id_k + (PIPELINE_LEVEL_GMEM - 1)) % PIPELINE_LEVEL_GMEM) *
  265. SMEM_SIZE_PER_WARP_2BIT / 4 *
  266. 4; // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16
  267. uint32_t* __restrict__ write_SPTR_Frag_4bit =
  268. AFrag_4BIT_SPTR +
  269. ((tile_id_k + (PIPELINE_LEVEL_GMEM - 1)) % PIPELINE_LEVEL_GMEM) *
  270. SMEM_SIZE_PER_WARP_4BIT / 4 *
  271. 4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16
  272. // Trible-Buffer for B Tile
  273. // MODIFICATION NOTE: to support MSVC, half __restrict__ (*read_SPTR ) is
  274. // changed to below. similarly for read2_SPTR and write_SPTR.
  275. half(*__restrict__ read_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] =
  276. smem_array +
  277. ((tile_id_k + 0) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;
  278. half(*__restrict__ read2_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] =
  279. smem_array +
  280. ((tile_id_k + 1) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;
  281. half(*__restrict__ write_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] =
  282. smem_array +
  283. ((tile_id_k + (PIPELINE_LEVEL_GMEM - 1)) % PIPELINE_LEVEL_GMEM) *
  284. TilingConfig::TILE_N;
  285. //
  286. bool GlobalCopy = (tile_id_k + PIPELINE_LEVEL_GMEM - 1) < NumIter;
  287. // Copying A tile from Global to Register, Bypassing L1, using double-buffer
  288. if (USE_SEG_1BIT)
  289. CopyFromGlobalToShared_A<SMEM_SIZE_PER_WARP_1BIT>(
  290. write_SPTR_Frag_1bit, WARP_StartGPTR_A_1BIT, GlobalCopy);
  291. if (USE_SEG_2BIT)
  292. CopyFromGlobalToShared_A<SMEM_SIZE_PER_WARP_2BIT>(
  293. write_SPTR_Frag_2bit, WARP_StartGPTR_A_2BIT, GlobalCopy);
  294. if (USE_SEG_4BIT)
  295. CopyFromGlobalToShared_A<SMEM_SIZE_PER_WARP_4BIT>(
  296. write_SPTR_Frag_4bit, WARP_StartGPTR_A_4BIT, GlobalCopy);
  297. // copying B tile from GlobalMemory to SharedMemory
  298. CopyFromGlobalToShared<TilingConfig::TILE_N, TilingConfig::BLOCK_WARPS>(
  299. write_SPTR, BTile_GPTR, K_Global, NumColumnToCopy, GlobalCopy);
  300. cp_async_group_commit();
  301. core_mma_slice<TilingConfig, EXPONENT, MANTISSA>(
  302. c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit,
  303. read_SPTR, Scales_RPTR,
  304. 1); // read_SPTR_Frag_2bit, read_SPTR_Frag_4bit are different for each
  305. // WARP; read_SPTR is shared among WARPs
  306. core_mma_slice<TilingConfig, EXPONENT, MANTISSA>(
  307. c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit,
  308. read_SPTR, Scales_RPTR, 2);
  309. core_mma_slice<TilingConfig, EXPONENT, MANTISSA>(
  310. c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit,
  311. read_SPTR, Scales_RPTR, 3);
  312. // Barriers and Synchronizations
  313. cp_async_wait_group<PIPELINE_LEVEL_GMEM - 2>();
  314. __syncthreads();
  315. core_mma_slice<TilingConfig, EXPONENT, MANTISSA>(
  316. c, a, b, read2_SPTR_Frag_1bit, read2_SPTR_Frag_2bit,
  317. read2_SPTR_Frag_4bit, read2_SPTR, Scales_RPTR, 0);
  318. // Updating global PTRs
  319. WARP_StartGPTR_A_1BIT +=
  320. SMEM_SIZE_PER_WARP_1BIT / 16; // 2KB/16=128 (1)/16: int4*+1 = char*+16
  321. WARP_StartGPTR_A_2BIT +=
  322. SMEM_SIZE_PER_WARP_2BIT / 16; // 4KB/16=256 (1)/16: int4*+1 = char*+16
  323. WARP_StartGPTR_A_4BIT +=
  324. SMEM_SIZE_PER_WARP_4BIT / 16; // 8KB/16=512 (1)/16: int4*+1 = char*+16
  325. BTile_GPTR += TilingConfig::TILE_K;
  326. }
  327. /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
  328. /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
  329. // Store the C fragments to shared memory.
  330. float(*smem_CFrag)[TilingConfig::TILE_M + PADDING_SHARED_MEM_FOR_C_4] =
  331. reinterpret_cast<
  332. float(*)[TilingConfig::TILE_M + PADDING_SHARED_MEM_FOR_C_4]>(smem);
  333. StoreToSharedMemoryFromRegister<TilingConfig>(smem_CFrag, c);
  334. __syncthreads();
  335. // Now that shared memory contains all the D tiles, stream them to global
  336. // memory.
  337. OutputDataType* BlockGlobalPTR = C + BatchID * (M_Global * N_Global) +
  338. Tile_Start_M + Tile_Start_N * M_Global;
  339. for (size_t i = warpId; i < NumColumnToCopy;
  340. i += TilingConfig::BLOCK_WARPS) // i-th column
  341. #pragma unroll
  342. for (size_t j = threadIdx.x % WARP_SIZE; j < TilingConfig::TILE_M;
  343. j += WARP_SIZE) // j-th row
  344. {
  345. if constexpr (std::is_same<OutputDataType, half>::value)
  346. BlockGlobalPTR[j + i * M_Global] = __float2half_rn(smem_CFrag[i][j]);
  347. else
  348. BlockGlobalPTR[j + i * M_Global] = smem_CFrag[i][j];
  349. }
  350. }