marlin_moe_kernel.h 55 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425
  1. #pragma once
  2. #include <torch/all.h>
  3. #include <ATen/cuda/CUDAContext.h>
  4. #include <c10/cuda/CUDAGuard.h>
  5. #include <cuda.h>
  6. #include <cuda_fp16.h>
  7. #include <cuda_runtime.h>
  8. #include <iostream>
  9. #include "core/scalar_type.hpp"
  10. namespace marlin_moe {
  11. constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
  12. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  13. // Instances of `Vec` are used to organize groups of >>registers<<, as needed
  14. // for instance as inputs to tensor core operations. Consequently, all
  15. // corresponding index accesses must be compile-time constants, which is why we
  16. // extensively use `#pragma unroll` throughout the kernel code to guarantee
  17. // this.
  18. template <typename T, int n>
  19. struct Vec {
  20. T elems[n];
  21. __device__ T& operator[](int i) { return elems[i]; }
  22. };
  23. using I4 = Vec<int, 4>;
  24. // Matrix fragments for tensor core instructions; their precise layout is
  25. // documented here:
  26. // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
  27. using FragA = Vec<half2, 4>;
  28. using FragB = Vec<half2, 2>;
  29. using FragC = Vec<float, 4>;
  30. using FragS = Vec<half2, 1>; // quantization scales
  31. // Predicated asynchronous global->shared copy; used for inputs A where we apply
  32. // predication to handle batchsizes that are not multiples of 16.
  33. __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
  34. bool pred = true) {
  35. const int BYTES = 16;
  36. uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
  37. asm volatile(
  38. "{\n"
  39. " .reg .pred p;\n"
  40. " setp.ne.b32 p, %0, 0;\n"
  41. " @p cp.async.cg.shared.global [%1], [%2], %3;\n"
  42. "}\n" ::"r"((int)pred),
  43. "r"(smem), "l"(glob_ptr), "n"(BYTES));
  44. }
  45. // Asynchronous global->shared copy
  46. __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
  47. const int BYTES = 16;
  48. uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
  49. asm volatile(
  50. "{\n"
  51. " cp.async.cg.shared.global [%0], [%1], %2;\n"
  52. "}\n" ::"r"(smem),
  53. "l"(glob_ptr), "n"(BYTES));
  54. }
  55. // Async copy fence.
  56. __device__ inline void cp_async_fence() {
  57. asm volatile("cp.async.commit_group;\n" ::);
  58. }
  59. // Wait until at most `n` async copy stages are still pending.
  60. template <int n>
  61. __device__ inline void cp_async_wait() {
  62. asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
  63. }
  64. // m16n8k16 tensor core mma instruction with fp16 inputs and fp32
  65. // output/accumulation.
  66. __device__ inline void mma(const FragA& a_frag, const FragB& frag_b,
  67. FragC& frag_c) {
  68. const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
  69. const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
  70. float* c = reinterpret_cast<float*>(&frag_c);
  71. asm volatile(
  72. "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
  73. "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
  74. : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
  75. : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
  76. "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
  77. }
  78. // Instruction for loading a full 16x16 matrix fragment of operand A from shared
  79. // memory, directly in tensor core layout.
  80. __device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
  81. uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
  82. uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
  83. asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
  84. : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
  85. : "r"(smem));
  86. }
  87. // Lookup-table based 3-input logical operation; explicitly used for
  88. // dequantization as the compiler does not seem to automatically recognize it in
  89. // all cases.
  90. template <int lut>
  91. __device__ inline int lop3(int a, int b, int c) {
  92. int res;
  93. asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
  94. : "=r"(res)
  95. : "r"(a), "r"(b), "r"(c), "n"(lut));
  96. return res;
  97. }
  98. // Constructs destination register by taking bytes from 2 sources (based on
  99. // mask)
  100. template <int start_byte, int mask>
  101. __device__ inline uint32_t prmt(uint32_t a) {
  102. uint32_t res;
  103. asm volatile("prmt.b32 %0, %1, %2, %3;\n"
  104. : "=r"(res)
  105. : "r"(a), "n"(start_byte), "n"(mask));
  106. return res;
  107. }
  108. template <aphrodite::ScalarTypeId w_type_id>
  109. __device__ inline FragB dequant(int q);
  110. // Efficiently dequantize 4bit values packed in an int32 value into a full
  111. // B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
  112. // with some small changes:
  113. // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
  114. template <>
  115. __device__ inline FragB dequant<aphrodite::kU4B8.id()>(int q) {
  116. const int LO = 0x000f000f;
  117. const int HI = 0x00f000f0;
  118. const int EX = 0x64006400;
  119. // Guarantee that the `(a & b) | c` operations are LOP3s.
  120. int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
  121. int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
  122. // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
  123. // directly into `SUB` and `ADD`.
  124. const int SUB = 0x64086408;
  125. const int MUL = 0x2c002c00;
  126. const int ADD = 0xd480d480;
  127. FragB frag_b;
  128. frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
  129. *reinterpret_cast<const half2*>(&SUB));
  130. frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
  131. *reinterpret_cast<const half2*>(&MUL),
  132. *reinterpret_cast<const half2*>(&ADD));
  133. return frag_b;
  134. }
  135. // Fast Int8ToFp16: Efficiently dequantize 8bit int values to fp16
  136. // Reference:
  137. // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
  138. template <>
  139. __device__ inline FragB dequant<aphrodite::kU8B128.id()>(int q) {
  140. static constexpr uint32_t mask_for_elt_01 = 0x5250;
  141. static constexpr uint32_t mask_for_elt_23 = 0x5351;
  142. static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
  143. uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
  144. uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
  145. static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
  146. FragB frag_b;
  147. frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
  148. *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
  149. frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
  150. *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
  151. return frag_b;
  152. }
  153. // Multiply dequantized values by the corresponding quantization scale; used
  154. // only for grouped quantization.
  155. __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
  156. half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]);
  157. frag_b[0] = __hmul2(frag_b[0], s);
  158. frag_b[1] = __hmul2(frag_b[1], s);
  159. }
  160. // Given 2 floats multiply by 2 scales (halves)
  161. __device__ inline void scale_float(float* c, FragS& s) {
  162. __half* s_ptr = reinterpret_cast<__half*>(&s);
  163. c[0] = __fmul_rn(c[0], __half2float(s_ptr[0]));
  164. c[1] = __fmul_rn(c[1], __half2float(s_ptr[1]));
  165. }
  166. // Same as above, but for act_order (each K is multiplied individually)
  167. __device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& frag_s_2,
  168. FragS& frag_s_3, FragS& frag_s_4, int i) {
  169. __half2 s_val_1_2;
  170. s_val_1_2.x = reinterpret_cast<__half*>(&frag_s_1)[i];
  171. s_val_1_2.y = reinterpret_cast<__half*>(&frag_s_2)[i];
  172. __half2 s_val_3_4;
  173. s_val_3_4.x = reinterpret_cast<__half*>(&frag_s_3)[i];
  174. s_val_3_4.y = reinterpret_cast<__half*>(&frag_s_4)[i];
  175. frag_b[0] = __hmul2(frag_b[0], s_val_1_2);
  176. frag_b[1] = __hmul2(frag_b[1], s_val_3_4);
  177. }
  178. // Wait until barrier reaches `count`, then lock for current threadblock.
  179. __device__ inline void barrier_acquire(int* lock, int count) {
  180. if (threadIdx.x == 0) {
  181. int state = -1;
  182. do
  183. // Guarantee that subsequent writes by this threadblock will be visible
  184. // globally.
  185. asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
  186. : "=r"(state)
  187. : "l"(lock));
  188. while (state != count);
  189. }
  190. __syncthreads();
  191. }
  192. // Release barrier and increment visitation count.
  193. __device__ inline void barrier_release(int* lock, bool reset = false) {
  194. __syncthreads();
  195. if (threadIdx.x == 0) {
  196. if (reset) {
  197. lock[0] = 0;
  198. return;
  199. }
  200. int val = 1;
  201. // Make sure that all writes since acquiring this barrier are visible
  202. // globally, while releasing the barrier.
  203. asm volatile("fence.acq_rel.gpu;\n");
  204. asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
  205. :
  206. : "l"(lock), "r"(val));
  207. }
  208. }
  209. template <const aphrodite::ScalarTypeId w_type_id, // weight ScalarType id
  210. const int threads, // number of threads in a threadblock
  211. const int thread_m_blocks, // number of 16x16 blocks in the m
  212. // dimension (batchsize) of the
  213. // threadblock
  214. const int thread_n_blocks, // same for n dimension (output)
  215. const int thread_k_blocks, // same for k dimension (reduction)
  216. const int stages, // number of stages for the async global->shared
  217. // fetch pipeline
  218. const bool has_act_order, // whether act_order is enabled
  219. const int group_blocks = -1 // number of consecutive 16x16 blocks
  220. // with a separate quantization scale
  221. >
  222. __device__ inline void MarlinMoESingle(
  223. const int4* __restrict__ A, // fp16 input matrix of shape mxk
  224. const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
  225. int4* __restrict__ C, // fp16 output buffer of shape mxn
  226. const int* __restrict__ sorted_ids, // int32 sorted ids of experts
  227. const float* __restrict__ topk_weights, // float topk weights
  228. const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
  229. // (k/groupsize)xn
  230. const int* __restrict__ g_idx, // int32 group indices of shape k
  231. const int* __restrict__ expert_offsets,
  232. int num_groups, // number of scale groups per output channel
  233. int expert_idx, // idx of current expert
  234. int num_experts, // number of experts
  235. int topk, // topk parameter of moe
  236. int prob_m, // batch dimension m
  237. int prob_n, // output dimension n
  238. int prob_k, // reduction dimension k
  239. int tot_m, // total number of rows in A and C
  240. int* locks, // extra global storage for barrier synchronization
  241. bool replicate_input, // do we use the same input for each expert?
  242. bool apply_weights, // apply weights to output
  243. int current_m_block // current m block to start kernel computation from
  244. ) {
  245. static constexpr auto w_type = aphrodite::ScalarType::from_id(w_type_id);
  246. constexpr int pack_factor = 32 / w_type.size_bits();
  247. // For larger GEMMs we run multiple batchsize 64 versions in parallel for a
  248. // better partitioning with less reductions
  249. int parallel = 1;
  250. if (prob_m > 16 * thread_m_blocks) {
  251. parallel = prob_m / (16 * thread_m_blocks);
  252. prob_m = 16 * thread_m_blocks;
  253. }
  254. int k_tiles = prob_k / 16 / thread_k_blocks;
  255. int n_tiles = prob_n / 16 / thread_n_blocks;
  256. int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x);
  257. if constexpr (!has_act_order && group_blocks != -1) {
  258. if (group_blocks >= thread_k_blocks) {
  259. // Ensure that the number of tiles in each stripe is a multiple of the
  260. // groupsize; this avoids an annoying special case where a stripe starts
  261. // in the middle of group.
  262. iters = (group_blocks / thread_k_blocks) *
  263. ceildiv(iters, (group_blocks / thread_k_blocks));
  264. }
  265. }
  266. int slice_row = (iters * blockIdx.x) % k_tiles;
  267. int slice_col_par = (iters * blockIdx.x) / k_tiles;
  268. int slice_col = slice_col_par;
  269. int slice_iters; // number of threadblock tiles in the current slice
  270. int slice_count =
  271. 0; // total number of active threadblocks in the current slice
  272. int slice_idx; // index of threadblock in current slice; numbered bottom to
  273. // top
  274. // We can easily implement parallel problem execution by just remapping
  275. // indices and advancing global pointers
  276. if (slice_col_par >= n_tiles) {
  277. locks += (slice_col_par / n_tiles) * n_tiles;
  278. slice_col = slice_col_par % n_tiles;
  279. sorted_ids += (slice_col_par / n_tiles) * 16 * thread_m_blocks;
  280. }
  281. // Compute all information about the current slice which is required for
  282. // synchronization.
  283. auto init_slice = [&]() {
  284. slice_iters =
  285. iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
  286. if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
  287. if (slice_iters == 0) return;
  288. if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
  289. slice_count = 1;
  290. slice_idx = 0;
  291. int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);
  292. if (col_first <= k_tiles * (slice_col_par + 1)) {
  293. int col_off = col_first - k_tiles * slice_col_par;
  294. slice_count = ceildiv(k_tiles - col_off, iters);
  295. if (col_off > 0) slice_count++;
  296. int delta_first = iters * blockIdx.x - col_first;
  297. if (delta_first < 0 || (col_off == 0 && delta_first == 0))
  298. slice_idx = slice_count - 1;
  299. else {
  300. slice_idx = slice_count - 1 - delta_first / iters;
  301. if (col_off > 0) slice_idx--;
  302. }
  303. }
  304. if (slice_col == n_tiles) {
  305. sorted_ids += 16 * thread_m_blocks;
  306. locks += n_tiles;
  307. slice_col = 0;
  308. }
  309. };
  310. init_slice();
  311. // A sizes/strides
  312. // stride of the A matrix in global memory
  313. int a_gl_stride = prob_k / 8;
  314. // stride of an A matrix tile in shared memory
  315. constexpr int a_sh_stride = 16 * thread_k_blocks / 8;
  316. // delta between subsequent A tiles in global memory
  317. constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8;
  318. // between subsequent accesses within a tile
  319. int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o);
  320. // between shared memory writes
  321. constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o);
  322. // between shared memory tile reads
  323. constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4));
  324. // within a shared memory tile
  325. constexpr int a_sh_rd_delta_i = a_sh_stride * 16;
  326. // overall size of a tile
  327. constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks);
  328. // number of shared write iterations for a tile
  329. constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta);
  330. // B sizes/strides
  331. int b_gl_stride = 16 * prob_n / (pack_factor * 4);
  332. constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4;
  333. constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2;
  334. constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;
  335. int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
  336. int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads);
  337. constexpr int b_sh_wr_delta = threads * b_thread_vecs;
  338. constexpr int b_sh_rd_delta = threads * b_thread_vecs;
  339. constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
  340. constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
  341. // Scale sizes/strides without act_order
  342. int s_gl_stride = prob_n / 8;
  343. constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
  344. constexpr int s_tb_groups =
  345. !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
  346. ? thread_k_blocks / group_blocks
  347. : 1;
  348. constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
  349. int s_gl_rd_delta = s_gl_stride;
  350. // Scale size/strides with act_order
  351. constexpr int tb_k = 16 * thread_k_blocks;
  352. constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0;
  353. // constexpr int act_s_row_stride = 1;
  354. // int act_s_col_stride = act_s_row_stride * num_groups;
  355. int act_s_col_stride = 1;
  356. int act_s_col_warp_stride = act_s_col_stride * 8;
  357. int tb_n_warps = thread_n_blocks / 4;
  358. int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;
  359. constexpr int sorted_sh_stride = threads;
  360. constexpr int sorted_gl_stride = threads;
  361. // Global A read index of current thread.
  362. int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
  363. (threadIdx.x % a_gl_rd_delta_o);
  364. a_gl_rd += a_gl_rd_delta_o * slice_row;
  365. // Shared write index of current thread.
  366. int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
  367. (threadIdx.x % a_gl_rd_delta_o);
  368. // Shared read index.
  369. int a_sh_rd =
  370. a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;
  371. a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
  372. int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) +
  373. (threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
  374. b_gl_rd += b_sh_stride * slice_col;
  375. b_gl_rd += b_gl_rd_delta_o * slice_row;
  376. int b_sh_wr = threadIdx.x * b_thread_vecs;
  377. int b_sh_rd = threadIdx.x * b_thread_vecs;
  378. // For act_order
  379. constexpr int k_iter_size = tb_k / b_sh_wr_iters;
  380. int slice_k_start = tb_k * slice_row;
  381. int slice_k_finish = slice_k_start + tb_k * slice_iters;
  382. int slice_k_start_shared_fetch = slice_k_start;
  383. int slice_n_offset = act_s_col_tb_stride * slice_col;
  384. // No act_order
  385. int s_gl_rd;
  386. if constexpr (!has_act_order) {
  387. if constexpr (group_blocks == -1) {
  388. s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
  389. } else {
  390. s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
  391. s_sh_stride * slice_col + threadIdx.x;
  392. }
  393. }
  394. int s_sh_wr = threadIdx.x;
  395. bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
  396. // We use a different scale layout for grouped and column-wise quantization as
  397. // we scale a `half2` tile in column-major layout in the former and in
  398. // row-major in the latter case.
  399. int s_sh_rd;
  400. if constexpr (group_blocks != -1)
  401. s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
  402. (threadIdx.x % 32) / 4;
  403. else
  404. s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
  405. (threadIdx.x % 32) % 4;
  406. int sh_first_group_id = -1;
  407. int sh_num_groups = -1;
  408. constexpr int sh_max_num_groups = 32;
  409. int shs_size;
  410. if constexpr (has_act_order)
  411. shs_size = sh_max_num_groups * s_sh_stride + threads;
  412. else
  413. shs_size = group_blocks > 0 ? stages * s_sh_stage : threads;
  414. extern __shared__ int4 sh[];
  415. // Shared memory storage for global fetch pipelines.
  416. int4* sh_a = sh;
  417. int4* sh_b = sh_a + (stages * a_sh_stage);
  418. int4* sh_g_idx = sh_b + (stages * b_sh_stage);
  419. int4* sh_s = sh_g_idx + (stages * g_idx_stage);
  420. int* sh_sorted = (int*)(sh_s + shs_size);
  421. // Precompute which thread should not read memory in which iterations; this is
  422. // needed if there are more threads than required for a certain tilesize or
  423. // when the batchsize is not a multiple of 16.
  424. bool a_sh_wr_pred[a_sh_wr_iters];
  425. #pragma unroll
  426. for (int i = 0; i < a_sh_wr_iters; i++) {
  427. int a_idx = a_sh_wr_delta * i + a_sh_wr;
  428. int row = a_idx / a_gl_rd_delta_o;
  429. if (row >= prob_m) {
  430. a_sh_wr_pred[i] = false;
  431. } else {
  432. a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
  433. }
  434. }
  435. // To ensure that writing and reading A tiles to/from shared memory, the
  436. // latter in fragment format, is fully bank conflict free, we need to use a
  437. // rather fancy XOR-based layout. The key here is that neither reads nor
  438. // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
  439. // same shared memory banks. Further, it seems (based on NSight-Compute) that
  440. // each warp must also write a consecutive memory segment?
  441. auto transform_a = [&](int i) {
  442. int row = i / a_gl_rd_delta_o;
  443. return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
  444. };
  445. // Since the computation of this remapping is non-trivial and, due to our main
  446. // loop unrolls, all shared memory accesses are static, we simply precompute
  447. // both transformed reads and writes.
  448. int a_sh_wr_trans[a_sh_wr_iters];
  449. #pragma unroll
  450. for (int i = 0; i < a_sh_wr_iters; i++)
  451. a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
  452. int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
  453. #pragma unroll
  454. for (int i = 0; i < b_sh_wr_iters; i++) {
  455. #pragma unroll
  456. for (int j = 0; j < thread_m_blocks; j++)
  457. a_sh_rd_trans[i][j] =
  458. transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
  459. }
  460. // Since B-accesses have non-constant stride they have to be computed at
  461. // runtime; we break dependencies between subsequent accesses with a tile by
  462. // maintining multiple pointers (we have enough registers), a tiny
  463. // optimization.
  464. const int4* B_ptr[b_sh_wr_iters];
  465. #pragma unroll
  466. for (int i = 0; i < b_sh_wr_iters; i++)
  467. B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
  468. // Register storage for double buffer of shared memory reads.
  469. FragA frag_a[2][thread_m_blocks];
  470. I4 frag_b_quant[2][b_thread_vecs];
  471. FragC frag_c[thread_m_blocks][4][2];
  472. FragS frag_s[2][4]; // No act-order
  473. FragS act_frag_s[2][4][4]; // For act-order
  474. // Zero accumulators.
  475. auto zero_accums = [&]() {
  476. #pragma unroll
  477. for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
  478. reinterpret_cast<float*>(frag_c)[i] = 0;
  479. };
  480. auto fetch_scales_to_shared = [&](bool is_async, int first_group_id,
  481. int last_group_id) {
  482. sh_first_group_id = first_group_id;
  483. sh_num_groups = last_group_id - first_group_id + 1;
  484. if (sh_num_groups < sh_max_num_groups) {
  485. sh_num_groups = sh_max_num_groups;
  486. }
  487. if (sh_first_group_id + sh_num_groups > num_groups) {
  488. sh_num_groups = num_groups - sh_first_group_id;
  489. }
  490. int row_offset = first_group_id * s_gl_stride;
  491. if (is_async) {
  492. for (int i = 0; i < sh_num_groups; i++) {
  493. if (threadIdx.x < s_sh_stride) {
  494. cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x],
  495. &scales_ptr[row_offset + (i * s_gl_stride) +
  496. slice_n_offset + threadIdx.x]);
  497. }
  498. }
  499. } else {
  500. for (int i = 0; i < sh_num_groups; i++) {
  501. if (threadIdx.x < s_sh_stride) {
  502. sh_s[(i * s_sh_stride) + threadIdx.x] =
  503. scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset +
  504. threadIdx.x];
  505. }
  506. }
  507. }
  508. };
  509. // Asynchronously fetch the next A, B and s tile from global to the next
  510. // shared memory pipeline location.
  511. auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
  512. if (pred) {
  513. int4* sh_a_stage = sh_a + a_sh_stage * pipe;
  514. #pragma unroll
  515. for (int i = 0; i < a_sh_wr_iters; i++) {
  516. int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off;
  517. int row = a_idx / a_gl_stride;
  518. int sorted_row =
  519. replicate_input ? sorted_ids[row] / topk : sorted_ids[row];
  520. int new_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride;
  521. if (sorted_row < tot_m * (replicate_input ? 1 : topk) &&
  522. new_idx < a_gl_stride * tot_m * (replicate_input ? 1 : topk)) {
  523. cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[new_idx],
  524. a_sh_wr_pred[i]);
  525. }
  526. }
  527. int4* sh_b_stage = sh_b + b_sh_stage * pipe;
  528. #pragma unroll
  529. for (int i = 0; i < b_sh_wr_iters; i++) {
  530. #pragma unroll
  531. for (int j = 0; j < b_thread_vecs; j++) {
  532. cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j);
  533. }
  534. B_ptr[i] += b_gl_rd_delta_o;
  535. }
  536. if constexpr (has_act_order) {
  537. // Fetch g_idx thread-block portion
  538. int full_pipe = a_off;
  539. int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe;
  540. if (cur_k < prob_k && cur_k < slice_k_finish) {
  541. int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
  542. int4 const* cur_g_idx_stage_ptr =
  543. reinterpret_cast<int4 const*>(&g_idx[cur_k]);
  544. if (threadIdx.x < g_idx_stage) {
  545. cp_async4_pred(&sh_g_idx_stage[threadIdx.x],
  546. &cur_g_idx_stage_ptr[threadIdx.x]);
  547. }
  548. }
  549. } else {
  550. if constexpr (group_blocks != -1) {
  551. int4* sh_s_stage = sh_s + s_sh_stage * pipe;
  552. if constexpr (group_blocks >= thread_k_blocks) {
  553. // Only fetch scales if this tile starts a new group
  554. if (pipe % (group_blocks / thread_k_blocks) == 0) {
  555. if (s_sh_wr_pred) {
  556. cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
  557. }
  558. s_gl_rd += s_gl_rd_delta;
  559. }
  560. } else {
  561. for (int i = 0; i < s_tb_groups; i++) {
  562. if (s_sh_wr_pred) {
  563. cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr],
  564. &scales_ptr[s_gl_rd]);
  565. }
  566. s_gl_rd += s_gl_rd_delta;
  567. }
  568. }
  569. }
  570. }
  571. }
  572. // Insert a fence even when we are winding down the pipeline to ensure that
  573. // waiting is also correct at this point.
  574. cp_async_fence();
  575. };
  576. // TODO we are currently hitting illegal memory accesses when fetching
  577. // sorted_ids to shared data: fix this
  578. auto fetch_sorted_ids_to_shared = [&]() {
  579. const int mpt = ceildiv(prob_m, threads);
  580. for (int i = 0; i < mpt; i++) {
  581. if ((i * sorted_gl_stride) + threadIdx.x < prob_m) {
  582. sh_sorted[(i * sorted_sh_stride) + threadIdx.x] =
  583. sorted_ids[(i * sorted_gl_stride) + threadIdx.x];
  584. }
  585. }
  586. };
  587. // Wait until the next thread tile has been loaded to shared memory.
  588. auto wait_for_stage = [&]() {
  589. // We only have `stages - 2` active fetches since we are double buffering
  590. // and can only issue the next fetch when it is guaranteed that the previous
  591. // shared memory load is fully complete (as it may otherwise be
  592. // overwritten).
  593. cp_async_wait<stages - 2>();
  594. __syncthreads();
  595. };
  596. // Load the next sub-tile from the current location in the shared memory pipe
  597. // into the current register buffer.
  598. auto fetch_to_registers = [&](int k, int pipe) {
  599. int4* sh_a_stage = sh_a + a_sh_stage * pipe;
  600. #pragma unroll
  601. for (int i = 0; i < thread_m_blocks; i++)
  602. ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
  603. int4* sh_b_stage = sh_b + b_sh_stage * pipe;
  604. #pragma unroll
  605. for (int i = 0; i < b_thread_vecs; i++) {
  606. frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>(
  607. &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]);
  608. }
  609. };
  610. bool is_same_group[stages];
  611. int same_group_id[stages];
  612. auto init_same_group = [&](int pipe) {
  613. if constexpr (!has_act_order) {
  614. is_same_group[pipe] = false;
  615. same_group_id[pipe] = 0;
  616. return;
  617. }
  618. int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
  619. int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);
  620. int group_id_1 = sh_g_idx_int_ptr[0];
  621. int group_id_2 = sh_g_idx_int_ptr[tb_k - 1];
  622. is_same_group[pipe] = group_id_1 == group_id_2;
  623. same_group_id[pipe] = group_id_1;
  624. };
  625. auto fetch_scales_to_registers = [&](int k, int full_pipe) {
  626. int pipe = full_pipe % stages;
  627. if constexpr (!has_act_order) {
  628. // No act-order case
  629. if constexpr (group_blocks != -1) {
  630. if constexpr (group_blocks >= thread_k_blocks) {
  631. int4* sh_s_stage =
  632. sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
  633. (pipe / (group_blocks / thread_k_blocks)));
  634. reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
  635. } else {
  636. int warp_id = threadIdx.x / 32;
  637. int n_warps = thread_n_blocks / 4;
  638. int warp_row = warp_id / n_warps;
  639. int cur_k = warp_row * 16;
  640. cur_k += k_iter_size * (k % b_sh_wr_iters);
  641. int k_blocks = cur_k / 16;
  642. int cur_group_id = k_blocks / group_blocks;
  643. int4* sh_s_stage = sh_s + s_sh_stage * pipe;
  644. reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
  645. sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
  646. }
  647. }
  648. return;
  649. }
  650. // Act-order case
  651. // Determine K of the "current" thread-block
  652. int cur_k = slice_k_start + tb_k * full_pipe;
  653. if (cur_k >= prob_k || cur_k >= slice_k_finish) {
  654. return;
  655. }
  656. // Reset (to current thread-block) since we read g_idx portion from the
  657. // shared memory
  658. cur_k = 0;
  659. // Progress to current iteration
  660. cur_k += k_iter_size * (k % b_sh_wr_iters);
  661. // Determine "position" inside the thread-block (based on warp and
  662. // thread-id)
  663. int warp_id = threadIdx.x / 32;
  664. int n_warps =
  665. thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N
  666. int warp_row = warp_id / n_warps;
  667. int warp_col = warp_id % n_warps;
  668. cur_k += warp_row * 16;
  669. int th_id = threadIdx.x % 32;
  670. cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix
  671. int s_col_shift =
  672. /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) +
  673. (th_id / 4) * act_s_col_stride;
  674. if (is_same_group[pipe]) {
  675. if (k % 2 == 0) {
  676. *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) =
  677. sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride +
  678. s_col_shift];
  679. } else {
  680. *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) =
  681. *(reinterpret_cast<int4*>(&(act_frag_s[(k - 1) % 2][0][0])));
  682. }
  683. for (int i = 1; i < 4; i++) {
  684. *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) =
  685. *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0])));
  686. }
  687. return;
  688. }
  689. int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
  690. int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);
  691. constexpr int k_frag_offsets[4] = {0, 1, 8,
  692. 9}; // Tensor core offsets per thread
  693. #pragma unroll
  694. for (int i = 0; i < 4; i++) {
  695. int actual_k = cur_k + k_frag_offsets[i];
  696. int group_id = sh_g_idx_int_ptr[actual_k];
  697. int rel_group_id = group_id - sh_first_group_id;
  698. *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) =
  699. sh_s[rel_group_id * s_sh_stride + s_col_shift];
  700. }
  701. };
  702. // Execute the actual tensor core matmul of a sub-tile.
  703. auto matmul = [&](int k) {
  704. // We have the m dimension as the inner loop in order to encourage overlapping
  705. // dequantization and matmul operations.
  706. #pragma unroll
  707. for (int j = 0; j < 4; j++) {
  708. int b_quant_0, b_quant_1;
  709. if constexpr (w_type.size_bits() == 4) {
  710. b_quant_0 = frag_b_quant[k % 2][0][j];
  711. b_quant_1 = b_quant_0 >> 8;
  712. } else {
  713. static_assert(w_type.size_bits() == 8);
  714. int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);
  715. b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
  716. b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
  717. }
  718. FragB frag_b0 = dequant<w_type_id>(b_quant_0);
  719. FragB frag_b1 = dequant<w_type_id>(b_quant_1);
  720. // Apply scale to frag_b0
  721. if constexpr (has_act_order) {
  722. scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
  723. act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0);
  724. } else {
  725. if constexpr (group_blocks != -1) {
  726. scale(frag_b0, frag_s[k % 2][j], 0);
  727. }
  728. }
  729. // Apply scale to frag_b1
  730. if constexpr (has_act_order) {
  731. scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
  732. act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1);
  733. } else {
  734. if constexpr (group_blocks != -1) {
  735. scale(frag_b1, frag_s[k % 2][j], 1);
  736. }
  737. }
  738. #pragma unroll
  739. for (int i = 0; i < thread_m_blocks; i++) {
  740. mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
  741. mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
  742. }
  743. }
  744. };
  745. // Since we slice across the k dimension of a tile in order to increase the
  746. // number of warps while keeping the n dimension of a tile reasonable, we have
  747. // multiple warps that accumulate their partial sums of the same output
  748. // location; which we have to reduce over in the end. We do in shared memory.
  749. auto thread_block_reduce = [&]() {
  750. constexpr int red_off = threads / b_sh_stride_threads / 2;
  751. if (red_off >= 1) {
  752. int red_idx = threadIdx.x / b_sh_stride_threads;
  753. constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
  754. constexpr int red_sh_delta = b_sh_stride_threads;
  755. int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
  756. (threadIdx.x % b_sh_stride_threads);
  757. // Parallel logarithmic shared memory reduction. We make sure to avoid any
  758. // unnecessary read or write iterations, e.g., for two warps we write only
  759. // once by warp 1 and read only once by warp 0.
  760. #pragma unroll
  761. for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
  762. #pragma unroll
  763. for (int i = red_off; i > 0; i /= 2) {
  764. if (i <= red_idx && red_idx < 2 * i) {
  765. #pragma unroll
  766. for (int j = 0; j < 4 * 2; j++) {
  767. int red_sh_wr =
  768. red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
  769. if (i < red_off) {
  770. float* c_rd =
  771. reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
  772. float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
  773. #pragma unroll
  774. for (int k = 0; k < 4; k++)
  775. reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
  776. c_rd[k] + c_wr[k];
  777. }
  778. sh[red_sh_wr] =
  779. reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
  780. }
  781. }
  782. __syncthreads();
  783. }
  784. if (red_idx == 0) {
  785. #pragma unroll
  786. for (int i = 0; i < 4 * 2; i++) {
  787. float* c_rd =
  788. reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
  789. #pragma unroll
  790. for (int j = 0; j < 4; j++)
  791. reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
  792. c_rd[j];
  793. }
  794. }
  795. __syncthreads();
  796. }
  797. }
  798. };
  799. // Since multiple threadblocks may process parts of the same column slice, we
  800. // finally have to globally reduce over the results. As the striped
  801. // partitioning minimizes the number of such reductions and our outputs are
  802. // usually rather small, we perform this reduction serially in L2 cache.
  803. auto global_reduce = [&](bool first = false, bool last = false) {
  804. // We are very careful here to reduce directly in the output buffer to
  805. // maximize L2 cache utilization in this step. To do this, we write out
  806. // results in FP16 (but still reduce with FP32 compute).
  807. constexpr int active_threads = 32 * thread_n_blocks / 4;
  808. if (threadIdx.x < active_threads) {
  809. int c_gl_stride = prob_n / 8;
  810. int c_gl_wr_delta_o = 8 * c_gl_stride;
  811. int c_gl_wr_delta_i = 4 * (active_threads / 32);
  812. int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) +
  813. 4 * (threadIdx.x / 32) + threadIdx.x % 4;
  814. c_gl_wr += (2 * thread_n_blocks) * slice_col;
  815. constexpr int c_sh_wr_delta = active_threads;
  816. int c_sh_wr = threadIdx.x;
  817. int row = (threadIdx.x % 32) / 4;
  818. if (!first) {
  819. // Interestingly, doing direct global accesses here really seems to mess up
  820. // the compiler and lead to slowdowns, hence we also use async-copies even
  821. // though these fetches are not actually asynchronous.
  822. #pragma unroll
  823. for (int i = 0; i < thread_m_blocks * 4; i++) {
  824. int c_idx =
  825. c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2);
  826. int sorted_row = sorted_ids[c_idx / c_gl_stride];
  827. int new_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride;
  828. cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], &C[new_idx],
  829. sorted_row < tot_m * topk &&
  830. (8 * (i / 2) + row < prob_m &&
  831. (i < (thread_m_blocks - 1) * 4 ||
  832. sorted_ids[8 * (i / 2) + row] < tot_m * topk)));
  833. }
  834. cp_async_fence();
  835. cp_async_wait<0>();
  836. }
  837. #pragma unroll
  838. for (int i = 0; i < thread_m_blocks * 4; i++) {
  839. if (8 * (i / 2) + row < prob_m &&
  840. (i < (thread_m_blocks - 1) * 4 ||
  841. sorted_ids[8 * (i / 2) + row] < tot_m * topk)) {
  842. if (!first) {
  843. int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
  844. #pragma unroll
  845. for (int j = 0; j < 2 * 4; j++) {
  846. reinterpret_cast<float*>(
  847. &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=
  848. __half2float(reinterpret_cast<__half*>(&c_red)[j]);
  849. }
  850. }
  851. if (!last) {
  852. int4 c;
  853. #pragma unroll
  854. for (int j = 0; j < 2 * 4; j++) {
  855. reinterpret_cast<__half*>(&c)[j] =
  856. __float2half(reinterpret_cast<float*>(
  857. &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]);
  858. }
  859. int c_idx =
  860. c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2);
  861. int row = sorted_ids[c_idx / c_gl_stride];
  862. if (row < tot_m * topk) {
  863. int new_idx = row * c_gl_stride + c_idx % c_gl_stride;
  864. C[new_idx] = c;
  865. }
  866. }
  867. }
  868. }
  869. }
  870. };
  871. // Write out the reduce final result in the correct layout. We only actually
  872. // reshuffle matrix fragments in this step, the reduction above is performed
  873. // in fragment layout.
  874. auto write_result = [&]() {
  875. int c_gl_stride = prob_n / 8;
  876. constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
  877. int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
  878. constexpr int c_sh_rd_delta =
  879. c_sh_stride * (threads / (2 * thread_n_blocks));
  880. int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) +
  881. (threadIdx.x % (2 * thread_n_blocks));
  882. c_gl_wr += (2 * thread_n_blocks) * slice_col;
  883. int c_sh_wr =
  884. (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;
  885. c_sh_wr += 32 * (threadIdx.x / 32);
  886. int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) +
  887. (threadIdx.x % (2 * thread_n_blocks));
  888. int c_gl_wr_end = c_gl_stride * prob_m;
  889. // We first reorder in shared memory to guarantee the most efficient final
  890. // global write patterns
  891. auto write = [&](int idx, float c0, float c1, FragS& s) {
  892. half2 res = __halves2half2(__float2half(c0), __float2half(c1));
  893. // For per-column quantization we finally apply the scale here (only for
  894. // 4-bit)
  895. if constexpr (!has_act_order && group_blocks == -1 &&
  896. w_type.size_bits() == 4) {
  897. res = __hmul2(res, s[0]);
  898. }
  899. ((half2*)sh)[idx] = res;
  900. };
  901. if (threadIdx.x / 32 < thread_n_blocks / 4) {
  902. #pragma unroll
  903. for (int i = 0; i < thread_m_blocks; i++) {
  904. #pragma unroll
  905. for (int j = 0; j < 4; j++) {
  906. int wr = c_sh_wr + 8 * j;
  907. write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
  908. frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
  909. write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],
  910. frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
  911. write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],
  912. frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
  913. write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],
  914. frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
  915. }
  916. c_sh_wr += 16 * (4 * c_sh_stride);
  917. }
  918. }
  919. __syncthreads();
  920. #pragma unroll
  921. for (int i = 0;
  922. i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
  923. i++) {
  924. if (c_gl_wr < c_gl_wr_end) {
  925. int row = sorted_ids[c_gl_wr / c_gl_stride];
  926. if (row < tot_m * topk) {
  927. int off = row * c_gl_stride + c_gl_wr % c_gl_stride;
  928. if (!apply_weights) {
  929. C[off] = sh[c_sh_rd];
  930. } else {
  931. __half* ctrg = reinterpret_cast<__half*>(&C[off]);
  932. __half* csrc = reinterpret_cast<__half*>(&sh[c_sh_rd]);
  933. for (int j = 0; j < 8; ++j) {
  934. ctrg[j] = __float2half(topk_weights[row] * __half2float(csrc[j]));
  935. }
  936. }
  937. c_gl_wr += c_gl_wr_delta;
  938. c_sh_rd += c_sh_rd_delta;
  939. }
  940. }
  941. }
  942. };
  943. // Start global fetch and register load pipelines.
  944. auto start_pipes = [&]() {
  945. // TODO re-enable after fixing this function
  946. // fetch_sorted_ids_to_shared();
  947. // __syncthreads();
  948. #pragma unroll
  949. for (int i = 0; i < stages - 1; i++) {
  950. if (has_act_order && i == 0) {
  951. int last_g_idx = slice_k_start + stages * tb_k * 2;
  952. if (last_g_idx >= prob_k) {
  953. last_g_idx = prob_k - 1;
  954. }
  955. fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]);
  956. }
  957. fetch_to_shared(i, i, i < slice_iters);
  958. }
  959. zero_accums();
  960. wait_for_stage();
  961. init_same_group(0);
  962. fetch_to_registers(0, 0);
  963. fetch_scales_to_registers(0, 0);
  964. a_gl_rd += a_gl_rd_delta_o * (stages - 1);
  965. slice_k_start_shared_fetch += tb_k * (stages - 1);
  966. };
  967. if (slice_iters) {
  968. start_pipes();
  969. }
  970. // Main loop.
  971. while (slice_iters) {
  972. // We unroll over both the global fetch and the register load pipeline to
  973. // ensure all shared memory accesses are static. Note that both pipelines
  974. // have even length meaning that the next iteration will always start at
  975. // index 0.
  976. #pragma unroll
  977. for (int pipe = 0; pipe < stages;) {
  978. #pragma unroll
  979. for (int k = 0; k < b_sh_wr_iters; k++) {
  980. fetch_to_registers(k + 1, pipe % stages);
  981. fetch_scales_to_registers(k + 1, pipe);
  982. if (k == b_sh_wr_iters - 2) {
  983. fetch_to_shared((pipe + stages - 1) % stages, pipe,
  984. slice_iters >= stages);
  985. pipe++;
  986. wait_for_stage();
  987. init_same_group(pipe % stages);
  988. }
  989. matmul(k);
  990. }
  991. slice_iters--;
  992. if (slice_iters == 0) {
  993. break;
  994. }
  995. }
  996. a_gl_rd += a_gl_rd_delta_o * stages;
  997. slice_k_start += tb_k * stages;
  998. slice_k_start_shared_fetch += tb_k * stages;
  999. if constexpr (has_act_order) {
  1000. int first_group_id = g_idx[slice_k_start];
  1001. int last_g_idx = slice_k_start + stages * tb_k * 2;
  1002. if (last_g_idx >= prob_k) {
  1003. last_g_idx = prob_k - 1;
  1004. }
  1005. int last_group_id = g_idx[last_g_idx];
  1006. if (last_group_id >= sh_first_group_id + sh_num_groups) {
  1007. fetch_scales_to_shared(false, first_group_id, last_group_id);
  1008. __syncthreads();
  1009. }
  1010. }
  1011. // Process results and, if necessary, proceed to the next column slice.
  1012. // While this pattern may not be the most readable, other ways of writing
  1013. // the loop seemed to noticeably worse performance after compilation.
  1014. if (slice_iters == 0) {
  1015. cp_async_wait<0>();
  1016. bool last = slice_idx == slice_count - 1;
  1017. if constexpr (!has_act_order && group_blocks == -1) {
  1018. if constexpr (w_type.size_bits() == 8) {
  1019. if (s_sh_wr_pred) {
  1020. cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
  1021. }
  1022. cp_async_fence();
  1023. } else {
  1024. // For 4-bit per-column scales, we only fetch them here in the
  1025. // final step before write-out
  1026. if (last) {
  1027. if (s_sh_wr_pred) {
  1028. cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
  1029. }
  1030. cp_async_fence();
  1031. }
  1032. }
  1033. }
  1034. thread_block_reduce();
  1035. if constexpr (!has_act_order && group_blocks == -1) {
  1036. if constexpr (w_type.size_bits() == 8) {
  1037. cp_async_wait<0>();
  1038. __syncthreads();
  1039. if (threadIdx.x / 32 < thread_n_blocks / 4) {
  1040. reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
  1041. reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
  1042. }
  1043. } else {
  1044. if (last) {
  1045. cp_async_wait<0>();
  1046. __syncthreads();
  1047. if (threadIdx.x / 32 < thread_n_blocks / 4) {
  1048. reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
  1049. reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
  1050. }
  1051. }
  1052. }
  1053. }
  1054. // For 8-bit channelwise, we apply the scale before the global reduction
  1055. // that converts the fp32 results to fp16 (so that we avoid possible
  1056. // overflow in fp16)
  1057. if constexpr (!has_act_order && group_blocks == -1 &&
  1058. w_type.size_bits() == 8) {
  1059. if (threadIdx.x / 32 < thread_n_blocks / 4) {
  1060. #pragma unroll
  1061. for (int i = 0; i < thread_m_blocks; i++) {
  1062. #pragma unroll
  1063. for (int j = 0; j < 4; j++) {
  1064. scale_float(reinterpret_cast<float*>(&frag_c[i][j][0][0]),
  1065. frag_s[j / 2][2 * (j % 2) + 0]);
  1066. scale_float(reinterpret_cast<float*>(&frag_c[i][j][0][2]),
  1067. frag_s[j / 2][2 * (j % 2) + 0]);
  1068. scale_float(reinterpret_cast<float*>(&frag_c[i][j][1][0]),
  1069. frag_s[j / 2][2 * (j % 2) + 1]);
  1070. scale_float(reinterpret_cast<float*>(&frag_c[i][j][1][2]),
  1071. frag_s[j / 2][2 * (j % 2) + 1]);
  1072. }
  1073. }
  1074. }
  1075. }
  1076. if (slice_count > 1) { // only globally reduce if there is more than one
  1077. // block in a slice
  1078. barrier_acquire(&locks[slice_col], slice_idx);
  1079. global_reduce(slice_idx == 0, last);
  1080. barrier_release(&locks[slice_col], last);
  1081. }
  1082. if (last) // only the last block in a slice actually writes the result
  1083. write_result();
  1084. slice_row = 0;
  1085. slice_col_par++;
  1086. slice_col++;
  1087. init_slice();
  1088. if (slice_iters) {
  1089. a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
  1090. (threadIdx.x % a_gl_rd_delta_o);
  1091. #pragma unroll
  1092. for (int i = 0; i < b_sh_wr_iters; i++)
  1093. B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
  1094. if (slice_col == 0) {
  1095. #pragma unroll
  1096. for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
  1097. }
  1098. // Update slice k/n for scales loading
  1099. if constexpr (has_act_order) {
  1100. slice_k_start = tb_k * slice_row;
  1101. slice_k_finish = slice_k_start + tb_k * slice_iters;
  1102. slice_k_start_shared_fetch = slice_k_start;
  1103. slice_n_offset = act_s_col_tb_stride * slice_col;
  1104. } else {
  1105. s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
  1106. }
  1107. start_pipes();
  1108. }
  1109. }
  1110. }
  1111. }
  1112. template <const aphrodite::ScalarTypeId w_type_id, // weight ScalarType id
  1113. const int threads, // number of threads in a threadblock
  1114. const int thread_n_blocks, // same for n dimension (output)
  1115. const int thread_k_blocks, // same for k dimension (reduction)
  1116. const int stages, // number of stages for the async global->shared
  1117. // fetch pipeline
  1118. const bool has_act_order, // whether act_order is enabled
  1119. const int group_blocks = -1 // number of consecutive 16x16 blocks
  1120. // with a separate quantization scale
  1121. >
  1122. __global__ void MarlinMoE(
  1123. const int4* __restrict__ A, // fp16 input matrix of shape mxk
  1124. const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
  1125. int4* __restrict__ C, // fp16 output buffer of shape mxn
  1126. const int* __restrict__ sorted_ids_base, // int32 sorted ids of experts
  1127. const float* __restrict__ topk_weights, // float topk weights
  1128. const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
  1129. // (k/groupsize)xn
  1130. const int* __restrict__ g_idx, // int32 group indices of shape k
  1131. const int* __restrict__ expert_offsets,
  1132. int num_groups, // number of scale groups per output channel
  1133. int expert_idx, // idx of current expert
  1134. int num_experts, // number of experts
  1135. int topk, // topk parameter of moe
  1136. int prob_m, // batch dimension m
  1137. int prob_n, // output dimension n
  1138. int prob_k, // reduction dimension k
  1139. int tot_m, // total number of rows in A and C
  1140. int* locks, // extra global storage for barrier synchronization
  1141. bool replicate_input, // do we use the same input for each expert?
  1142. bool apply_weights, // apply weights to output
  1143. int current_m_block, // current m block to start kernel computation from
  1144. int max_par, // maximum parallelism
  1145. int cfg_max_m_blocks // upper bound on m blocks
  1146. ) {
  1147. int m_block_ctr = current_m_block;
  1148. const int* sorted_ids_expert =
  1149. sorted_ids_base + expert_offsets[expert_idx] + m_block_ctr * 4 * max_par;
  1150. int tot_its = expert_offsets[expert_idx + 1] - expert_offsets[expert_idx];
  1151. if (tot_its == 0) {
  1152. return;
  1153. }
  1154. int tot_m_blocks = ceildiv(tot_its, 16);
  1155. int pad = 16 * tot_m_blocks - tot_its;
  1156. if (m_block_ctr >= tot_m_blocks) {
  1157. return;
  1158. }
  1159. int max_block = tot_m_blocks - m_block_ctr;
  1160. prob_m = tot_its - 16 * m_block_ctr;
  1161. int par = 1;
  1162. if (max_block > cfg_max_m_blocks) {
  1163. // Note that parallel > 1 currently only works for inputs without any
  1164. // padding
  1165. par = (16 * max_block - pad) / (16 * cfg_max_m_blocks);
  1166. if (par > max_par) par = max_par;
  1167. prob_m = (16 * cfg_max_m_blocks) * par;
  1168. m_block_ctr += cfg_max_m_blocks * (par - 1);
  1169. max_block = cfg_max_m_blocks;
  1170. }
  1171. if (max_block == 1) {
  1172. MarlinMoESingle<w_type_id, threads, 1, thread_n_blocks, thread_k_blocks,
  1173. stages, has_act_order, group_blocks>(
  1174. A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
  1175. expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
  1176. prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
  1177. current_m_block);
  1178. } else if (max_block == 2) {
  1179. MarlinMoESingle<w_type_id, threads, 2, thread_n_blocks, thread_k_blocks,
  1180. stages, has_act_order, group_blocks>(
  1181. A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
  1182. expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
  1183. prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
  1184. current_m_block);
  1185. } else if (max_block == 3) {
  1186. MarlinMoESingle<w_type_id, threads, 3, thread_n_blocks, thread_k_blocks,
  1187. stages, has_act_order, group_blocks>(
  1188. A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
  1189. expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
  1190. prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
  1191. current_m_block);
  1192. } else {
  1193. MarlinMoESingle<w_type_id, threads, 4, thread_n_blocks, thread_k_blocks,
  1194. stages, has_act_order, group_blocks>(
  1195. A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
  1196. expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
  1197. prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
  1198. current_m_block);
  1199. }
  1200. }
  1201. #else
  1202. template <const aphrodite::ScalarTypeId w_type_id, // weight ScalarType id
  1203. const int threads, // number of threads in a threadblock
  1204. const int thread_n_blocks, // same for n dimension (output)
  1205. const int thread_k_blocks, // same for k dimension (reduction)
  1206. const int stages, // number of stages for the async global->shared
  1207. // fetch pipeline
  1208. const bool has_act_order, // whether act_order is enabled
  1209. const int group_blocks = -1 // number of consecutive 16x16 blocks
  1210. // with a separate quantization scale
  1211. >
  1212. __global__ void MarlinMoE(
  1213. const int4* __restrict__ A, // fp16 input matrix of shape mxk
  1214. const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
  1215. int4* __restrict__ C, // fp16 output buffer of shape mxn
  1216. const int* __restrict__ sorted_ids, // int32 sorted ids of experts
  1217. const float* __restrict__ topk_weights, // float topk weights
  1218. const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
  1219. // (k/groupsize)xn
  1220. const int* __restrict__ g_idx, // int32 group indices of shape k
  1221. const int* __restrict__ expert_offsets,
  1222. int num_groups, // number of scale groups per output channel
  1223. int expert_idx, // idx of current expert
  1224. int num_experts, // number of experts
  1225. int topk, // topk parameter of moe
  1226. int prob_m, // batch dimension m
  1227. int prob_n, // output dimension n
  1228. int prob_k, // reduction dimension k
  1229. int tot_m, // total number of rows in A and C
  1230. int* locks, // extra global storage for barrier synchronization
  1231. bool replicate_input, // do we use the same input for each expert?
  1232. bool apply_weights, // apply weights to output
  1233. int current_m_block, // current m block to start kernel computation from
  1234. int max_par, // maximum parallelism
  1235. int cfg_max_m_blocks // upper bound on m blocks
  1236. ) {
  1237. // Marlin is not implemented yet for SM < 8.0
  1238. assert(false);
  1239. return;
  1240. }
  1241. #endif
  1242. // 8 warps are a good choice since every SM has 4 schedulers and having more
  1243. // than 1 warp per schedule allows some more latency hiding. At the same time,
  1244. // we want relatively few warps to have many registers per warp and small tiles.
  1245. const int USER_THREADS =
  1246. 256; // Note: This is only used with user-provided thread_k/n
  1247. const int STAGES = 4; // 4 pipeline stages fit into shared memory
  1248. // const int SHARED_MEM =
  1249. // 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0)
  1250. static constexpr int min_thread_n = 64;
  1251. static constexpr int min_thread_k = 64;
  1252. #define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, HAS_ACT_ORDER, \
  1253. GROUP_BLOCKS, NUM_THREADS) \
  1254. else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \
  1255. thread_k_blocks == THREAD_K_BLOCKS && \
  1256. has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \
  1257. num_threads == NUM_THREADS) { \
  1258. cudaFuncSetAttribute( \
  1259. MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
  1260. STAGES, HAS_ACT_ORDER, GROUP_BLOCKS>, \
  1261. cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
  1262. MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
  1263. STAGES, HAS_ACT_ORDER, GROUP_BLOCKS> \
  1264. <<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
  1265. A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \
  1266. g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \
  1267. num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \
  1268. replicate_input, apply_weights, m_block, max_par, \
  1269. cfg_max_m_blocks); \
  1270. }
  1271. #define GPTQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
  1272. __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
  1273. __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
  1274. __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
  1275. __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
  1276. __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
  1277. } // namespace marlin_moe