1
0

gemm_template.h 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500
  1. /*
  2. * Adapted from https://github.com/InternLM/lmdeploy
  3. * Copyright (c) OpenMMLab. All rights reserved.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. #pragma once
  18. #include <cuda_pipeline_primitives.h>
  19. #include <cuda_bf16.h>
  20. #include "common.h"
  21. #include "cta_iterator.h"
  22. #include "warp_iterator.h"
  23. namespace aphrodite {
  24. namespace autoquant {
  25. __inline__ __device__ void
  26. mma_m16n8k16_row_col(Array<float, 4>& d, const Array<half, 8>& a, const Array<half, 4>& b, Array<float, 4>& c)
  27. {
  28. #if APHRODITE_ARCH_SM80
  29. uint32_t const* A = reinterpret_cast<uint32_t const*>(&a);
  30. uint32_t const* B = reinterpret_cast<uint32_t const*>(&b);
  31. float const* C = reinterpret_cast<float const*>(&c);
  32. float* D = reinterpret_cast<float*>(&d);
  33. asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, "
  34. "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
  35. : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
  36. : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
  37. #else
  38. assert(APHRODITE_ARCH_SM80);
  39. #endif
  40. }
  41. __inline__ __device__ void
  42. mma_m16n8k16_row_col(Array<float, 4>& d, const Array<__nv_bfloat16, 8>& a, const Array<__nv_bfloat16, 4>& b, Array<float, 4>& c)
  43. {
  44. #if APHRODITE_ARCH_SM80
  45. uint32_t const* A = reinterpret_cast<uint32_t const*>(&a);
  46. uint32_t const* B = reinterpret_cast<uint32_t const*>(&b);
  47. float const* C = reinterpret_cast<float const*>(&c);
  48. float* D = reinterpret_cast<float*>(&d);
  49. asm("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0,%1,%2,%3}, "
  50. "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
  51. : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
  52. : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
  53. #else
  54. assert(APHRODITE_ARCH_SM80);
  55. #endif
  56. }
  57. __inline__ __device__ uint transpose_m8n8_b16_warp_shuffle(uint value, int lane_id)
  58. {
  59. int src_lane = lane_id / 8 + lane_id % 4 * 8;
  60. uint u0 = __shfl_sync(0xffffffff, value, src_lane);
  61. uint u1 = __shfl_sync(0xffffffff, value, src_lane + 4);
  62. short2 r;
  63. if (lane_id % 8 < 4) {
  64. r.x = ((short2&)u0).x;
  65. r.y = ((short2&)u1).x;
  66. }
  67. else {
  68. r.x = ((short2&)u0).y;
  69. r.y = ((short2&)u1).y;
  70. }
  71. return (uint&)r;
  72. }
  73. #if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 8)
  74. __inline__ __device__ uint transpose_m8n8_b16_movmatrix(uint a)
  75. {
  76. #if APHRODITE_ARCH_SM75
  77. uint d;
  78. asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;\n" : "=r"(d) : "r"(a));
  79. return d;
  80. #else
  81. assert(APHRODITE_ARCH_SM75);
  82. return 0;
  83. #endif
  84. }
  85. #endif
  86. __inline__ __device__ uint transpose_m8n8_b16(uint a, int lane_id)
  87. {
  88. #if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 8)
  89. (void)lane_id;
  90. return transpose_m8n8_b16_movmatrix(a);
  91. #else
  92. return transpose_m8n8_b16_warp_shuffle(a, lane_id);
  93. #endif
  94. }
  95. namespace ops {
  96. __inline__ __device__ float4 operator+(const float4& a, const float4& b)
  97. {
  98. return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w};
  99. }
  100. __inline__ __device__ float2 operator+(const float2& a, const float2& b)
  101. {
  102. return {a.x + b.x, a.y + b.y};
  103. }
  104. } // namespace ops
  105. template<int CTA_M,
  106. int CTA_N,
  107. int CTA_K,
  108. int WARP_M,
  109. int WARP_N,
  110. int WARP_K,
  111. int STAGES,
  112. int GROUP_SIZE,
  113. typename OutputOps,
  114. typename T_BC,
  115. typename T_Q>
  116. struct Gemm {
  117. static constexpr int kWarpCountM = CTA_M / WARP_M;
  118. static constexpr int kWarpCountN = CTA_N / WARP_N;
  119. static constexpr int kWarpCountK = CTA_K / WARP_K;
  120. static constexpr int kWarpCountMN = kWarpCountM * kWarpCountN;
  121. static constexpr int kWarpCount = kWarpCountMN * kWarpCountK;
  122. static constexpr int SLICES = kWarpCountK;
  123. static constexpr int SLICE_K = CTA_K / SLICES;
  124. static_assert(SLICE_K % WARP_K == 0, "infeasible sliced-k setting");
  125. using IteratorA = aphrodite::autoquant::IteratorA<kWarpCountMN, CTA_M, CTA_N, CTA_K, STAGES, SLICES>;
  126. using IteratorQ = aphrodite::autoquant::IteratorQ<kWarpCountMN, CTA_M, CTA_N, CTA_K, STAGES, SLICES, GROUP_SIZE, T_Q>;
  127. using IteratorB = aphrodite::autoquant::IteratorB<kWarpCountMN, CTA_M, CTA_N, CTA_K, STAGES, SLICES, T_BC>;
  128. static constexpr int OP_M = 16;
  129. static constexpr int OP_N = 8;
  130. static constexpr int OP_K = 16;
  131. using WarpIterA = aphrodite::autoquant::WarpIteratorA<CTA_M,
  132. CTA_K,
  133. WARP_M,
  134. WARP_K,
  135. OP_M,
  136. OP_K,
  137. GROUP_SIZE,
  138. STAGES,
  139. IteratorA::kSizePerStage,
  140. IteratorQ::kSizePerStage,
  141. T_BC,
  142. T_Q>;
  143. using WarpIterB =
  144. aphrodite::autoquant::WarpIteratorB<CTA_N, CTA_K, WARP_N, WARP_K, OP_N, OP_K, IteratorB::kSmemPadCtaK, STAGES, T_BC>;
  145. __device__ void warp_mma(IteratorA& iter_A,
  146. IteratorQ& iter_Q,
  147. IteratorB& iter_B,
  148. WarpIterA& warp_iter_A,
  149. WarpIterB& warp_iter_B,
  150. float* accum,
  151. int slice_id,
  152. int& gemm_iter)
  153. {
  154. constexpr int ITER_M = WARP_M / OP_M;
  155. constexpr int ITER_N = WARP_N / OP_N;
  156. constexpr int ITER_K = WARP_K / OP_K;
  157. constexpr int kBatchA = (IteratorA::kIterCount + ITER_K - 1) / ITER_K;
  158. constexpr int kBatchQ = (IteratorQ::kIterCount + ITER_K - 1) / ITER_K;
  159. constexpr int kBatchB = (IteratorB::kIterCount + ITER_K - 1) / ITER_K;
  160. auto frag_C_ptr = (Array<float, 4>*)accum; // [ITER_N, ITER_M]
  161. PRAGMA_UNROLL
  162. for (int iter_k = 0; iter_k < ITER_K; ++iter_k) {
  163. warp_iter_A.load(warp_frag_A_[(iter_k + 1) % 2], (iter_k + 1) % ITER_K);
  164. warp_iter_B.load(warp_frag_B_[(iter_k + 1) % 2], (iter_k + 1) % ITER_K);
  165. auto warp_frag_A = warp_frag_A_[iter_k % 2];
  166. auto warp_frag_B = warp_frag_B_[iter_k % 2];
  167. PRAGMA_UNROLL
  168. for (int iter_m = 0; iter_m < ITER_M; ++iter_m) {
  169. PRAGMA_UNROLL
  170. for (int iter_n = 0; iter_n < ITER_N; ++iter_n) {
  171. auto& frag_A = warp_frag_A[iter_m];
  172. auto& frag_B = warp_frag_B[iter_n];
  173. auto& frag_C = frag_C_ptr[iter_n * ITER_M + iter_m];
  174. mma_m16n8k16_row_col(frag_C, frag_A, frag_B, frag_C);
  175. }
  176. }
  177. if (iter_k < ITER_K - 1) {
  178. iter_A.prefetch_batch(iter_k, kBatchA, gemm_iter > 0);
  179. iter_Q.prefetch_batch(iter_k, kBatchQ, gemm_iter > 0);
  180. iter_B.prefetch_batch(iter_k, kBatchB, gemm_iter > 0);
  181. }
  182. if (iter_k == ITER_K - 2) {
  183. iter_A.prefetch_batch(iter_k + 1, kBatchA, gemm_iter > 0);
  184. iter_Q.prefetch_batch(iter_k + 1, kBatchQ, gemm_iter > 0);
  185. iter_B.prefetch_batch(iter_k + 1, kBatchB, gemm_iter > 0);
  186. __pipeline_commit();
  187. __pipeline_wait_prior(STAGES - 2);
  188. sync_slice(slice_id);
  189. iter_A.next_stage();
  190. iter_Q.next_stage();
  191. iter_B.next_stage();
  192. warp_iter_A.next_stage();
  193. warp_iter_B.next_stage();
  194. --gemm_iter;
  195. }
  196. }
  197. }
  198. template<typename T, int N>
  199. __device__ static void copy(T (&dst)[N], const T (&src)[N])
  200. {
  201. PRAGMA_UNROLL
  202. for (int i = 0; i < N; ++i) {
  203. dst[i] = src[i];
  204. }
  205. }
  206. template<typename T, int N>
  207. __device__ static void clear(T (&dst)[N])
  208. {
  209. PRAGMA_UNROLL
  210. for (int i = 0; i < N; ++i) {
  211. dst[i] = T{};
  212. }
  213. }
  214. __device__ void sync_slice(int slice_id)
  215. {
  216. if constexpr (SLICES == 1) {
  217. __syncthreads();
  218. }
  219. else {
  220. constexpr int SLICE_GROUP = (SLICES + 7) / 8;
  221. constexpr uint32_t num_threads = kWarpCountMN * WARP_SIZE;
  222. const uint32_t barrier_id = slice_id / SLICE_GROUP + 1;
  223. asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "n"(num_threads));
  224. }
  225. }
  226. __device__ void load_partial(float* tb_frag_C, const float* partial_C, int cta, int slice_id)
  227. {
  228. if (slice_id == 0) {
  229. PRAGMA_UNROLL
  230. for (int i = 0; i < CTA_N; ++i) {
  231. tb_frag_C[i] += partial_C[cta * CTA_N * CTA_M + i * CTA_M + threadIdx.x];
  232. }
  233. }
  234. }
  235. __device__ void store_partial(float* partial_C, const float* tb_frag_C, int cta, int slice_id)
  236. {
  237. if (slice_id == 0) {
  238. PRAGMA_UNROLL
  239. for (int i = 0; i < CTA_N; ++i) {
  240. partial_C[cta * CTA_N * CTA_M + i * CTA_M + threadIdx.x] = tb_frag_C[i];
  241. }
  242. }
  243. }
  244. template<int Index>
  245. __device__ void store_accum(float* tb_frag_C,
  246. float* tb_smem_C,
  247. T_BC* C,
  248. int m,
  249. int n,
  250. int cta_m,
  251. int cta_n,
  252. int warp_id_m,
  253. int warp_id_n,
  254. int lane_id,
  255. int slice_id)
  256. {
  257. if (slice_id != 0) {
  258. return;
  259. }
  260. // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#mma-16816-c
  261. PRAGMA_UNROLL
  262. for (int i = 0; i < WARP_N / OP_N; ++i) {
  263. const float2* frag_C = (float2*)&tb_frag_C[i * WARP_M / OP_M * 4];
  264. const int nn = cta_n + warp_id_n * WARP_N + i * OP_N + lane_id / 4;
  265. PRAGMA_UNROLL
  266. for (int j = 0; j < WARP_M / OP_M; ++j) {
  267. PRAGMA_UNROLL
  268. for (int x = 0; x < 2; ++x) {
  269. const int mm = cta_m + warp_id_m * WARP_M + j * OP_M + x * 8 + lane_id % 4 * 2;
  270. if(std::is_same<T_BC, half>::value){
  271. // convert to half
  272. float2 frag_c = frag_C[j * 2 + x];
  273. frag_c.x = clamp_inf_for_half(frag_c.x);
  274. frag_c.y = clamp_inf_for_half(frag_c.y);
  275. half2 half_C = __float22half2_rn(frag_c);
  276. // transpose 8x8 accum tile
  277. uint trans_C = transpose_m8n8_b16((uint&)half_C, lane_id);
  278. // store to global memory
  279. OutputOps::template apply<Index>(trans_C, mm, nn, C, m, n);
  280. }
  281. else{
  282. // convert to bfloat16
  283. auto half_C = float22bfloat162_rn(frag_C[j * 2 + x]) ;
  284. // transpose 8x8 accum tile
  285. uint trans_C = transpose_m8n8_b16((uint&)half_C, lane_id);
  286. // store to global memory
  287. OutputOps::template apply<Index>(trans_C, mm, nn, C, m, n);
  288. }
  289. }
  290. }
  291. }
  292. }
  293. __device__ void
  294. sum_slices(float* tb_frag_C, float* tb_smem_C, int warp_id_m, int warp_id_n, int lane_id, int slice_id)
  295. {
  296. int offset_m = warp_id_m * WARP_M / OP_M;
  297. int offset_n = warp_id_n * WARP_N / OP_N;
  298. PRAGMA_UNROLL
  299. for (int z = 0; z < SLICES; ++z) {
  300. if (slice_id == z) {
  301. PRAGMA_UNROLL
  302. for (int i = 0; i < WARP_N / OP_N; ++i) {
  303. PRAGMA_UNROLL
  304. for (int j = 0; j < WARP_M / OP_M; ++j) {
  305. PRAGMA_UNROLL
  306. for (int x = 0; x < 4; ++x) {
  307. int src = (i * WARP_M / OP_M + j) * 4 + x;
  308. int dst = ((i + offset_n) * CTA_M / OP_M + j + offset_m) * 4 + x;
  309. if (z > 0) {
  310. using namespace ops;
  311. tb_frag_C[src] = tb_smem_C[dst * WARP_SIZE + lane_id] + tb_frag_C[src];
  312. }
  313. tb_smem_C[dst * WARP_SIZE + lane_id] = tb_frag_C[src];
  314. }
  315. }
  316. }
  317. }
  318. __syncthreads();
  319. }
  320. if (slice_id == 0) {
  321. PRAGMA_UNROLL
  322. for (int i = 0; i < WARP_N / OP_N; ++i) {
  323. PRAGMA_UNROLL
  324. for (int j = 0; j < WARP_M / OP_M; ++j) {
  325. PRAGMA_UNROLL
  326. for (int x = 0; x < 4; ++x) {
  327. int src = ((i + offset_n) * CTA_M / OP_M + j + offset_m) * 4 + x;
  328. int dst = (i * WARP_M / OP_M + j) * 4 + x;
  329. tb_frag_C[dst] = tb_smem_C[src * WARP_SIZE + lane_id];
  330. }
  331. }
  332. }
  333. }
  334. }
  335. //Array<half, 8> warp_frag_A_[2][WARP_M / OP_M];
  336. //Array<half, 4> warp_frag_B_[2][WARP_N / OP_N];
  337. Array<T_BC, 8> warp_frag_A_[2][WARP_M / OP_M];
  338. Array<T_BC, 4> warp_frag_B_[2][WARP_N / OP_N];
  339. __device__ void run_v2(T_BC* __restrict__ C,
  340. const uint* __restrict__ A,
  341. const T_BC* __restrict__ B,
  342. const T_Q* __restrict__ Q,
  343. int M,
  344. int N,
  345. int K,
  346. int output_op_idx)
  347. {
  348. static_assert(WARP_M % OP_N == 0);
  349. float tb_frag_C[(WARP_N / OP_N) * (WARP_M / OP_M) * 4];
  350. extern __shared__ uint8_t smem[];
  351. const int warp_id = threadIdx.x / WARP_SIZE;
  352. const int lane_id = threadIdx.x % WARP_SIZE;
  353. const int warp_id_m = warp_id % kWarpCountM;
  354. const int warp_id_nk = warp_id / kWarpCountM;
  355. const int warp_id_n = warp_id_nk % kWarpCountN;
  356. const int warp_id_k = warp_id_nk / kWarpCountN;
  357. const int warp_id_mn = warp_id_n * kWarpCountM + warp_id_m;
  358. const int slice_id = warp_id_k;
  359. const int cta_k = slice_id * SLICE_K; // sliced-k offset
  360. const int cta_m = blockIdx.x * CTA_M;
  361. const int cta_n = blockIdx.y * CTA_N;
  362. // each slice has its own partition of smem
  363. uint4* const tb_smem_A = (uint4*)(smem + IteratorA::kSmemByteSize * slice_id);
  364. T_BC* const tb_smem_B = (T_BC*)(smem + IteratorA::kSmemByteSize * SLICES + IteratorB::kSmemByteSize * slice_id);
  365. // [CTA_N / OP_N, CTA_M / OP_M, 4, WARP_SIZE], all mn fragments in CTA
  366. float* const tb_smem_C = (float*)smem;
  367. __shared__ typename IteratorQ::Storage tb_smem_Q_storage;
  368. auto tb_smem_Q = tb_smem_Q_storage.data[slice_id];
  369. IteratorA iter_A{A, tb_smem_A, M, K, cta_m, cta_k, warp_id_mn, lane_id};
  370. IteratorQ iter_Q{Q, tb_smem_Q, M, K, cta_m, cta_k, warp_id_mn, lane_id};
  371. IteratorB iter_B{B, tb_smem_B, K, N, cta_n, cta_k, warp_id_mn, lane_id};
  372. const int offset_m = warp_id_m * WARP_M + lane_id;
  373. WarpIterA warp_iter_A(iter_A.smem_, iter_Q.smem_, warp_id, lane_id, offset_m, cta_k);
  374. WarpIterB warp_iter_B(iter_B.smem_int_ptr_, warp_id_n, lane_id, 0);
  375. int gemm_iter = (K + CTA_K - 1) / CTA_K;
  376. PRAGMA_UNROLL
  377. for (int stage = 0; stage < STAGES - 1; ++stage, --gemm_iter) {
  378. iter_A.prefetch_stage(gemm_iter > 0);
  379. iter_Q.prefetch_stage(gemm_iter > 0);
  380. iter_B.prefetch_stage(gemm_iter > 0);
  381. __pipeline_commit();
  382. }
  383. clear(tb_frag_C);
  384. __pipeline_wait_prior(STAGES - 2);
  385. sync_slice(slice_id);
  386. warp_iter_A.load(warp_frag_A_[0], 0);
  387. warp_iter_B.load(warp_frag_B_[0], 0);
  388. PRAGMA_NO_UNROLL
  389. for (; gemm_iter > -STAGES + 1;) {
  390. warp_mma(iter_A, iter_Q, iter_B, warp_iter_A, warp_iter_B, tb_frag_C, slice_id, gemm_iter);
  391. }
  392. __pipeline_commit();
  393. __pipeline_wait_prior(0);
  394. __syncthreads();
  395. if constexpr (SLICES > 1) {
  396. sum_slices(tb_frag_C, tb_smem_C, warp_id_m, warp_id_n, lane_id, slice_id);
  397. }
  398. switch (output_op_idx) {
  399. case 0:
  400. store_accum<0>(tb_frag_C, tb_smem_C, C, M, N, cta_m, cta_n, warp_id_m, warp_id_n, lane_id, slice_id);
  401. break;
  402. case 1:
  403. store_accum<1>(tb_frag_C, tb_smem_C, C, M, N, cta_m, cta_n, warp_id_m, warp_id_n, lane_id, slice_id);
  404. break;
  405. default:
  406. return;
  407. }
  408. }
  409. };
  410. template<typename Gemm, typename T_BC, typename T_Q>
  411. __global__ void gemm_s4_f16_nn(T_BC* __restrict__ C,
  412. const uint* __restrict__ A,
  413. const T_BC* __restrict__ B,
  414. const T_Q* __restrict__ Q,
  415. int M,
  416. int N,
  417. int K,
  418. int output_op_idx)
  419. {
  420. Gemm{}.run_v2(C, A, B, Q, M, N, K, output_op_idx);
  421. }
  422. } // namespace autoquant
  423. } // namespace aphrodite