marlin_moe_ops.cu 66 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546
  1. /*
  2. * Modified by Neural Magic
  3. * Copyright (C) Marlin.2024 Elias Frantar
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. #include <torch/all.h>
  18. #include <ATen/cuda/CUDAContext.h>
  19. #include <c10/cuda/CUDAGuard.h>
  20. #include <cuda.h>
  21. #include <cuda_fp16.h>
  22. #include <cuda_runtime.h>
  23. #include <iostream>
  24. template <typename T>
  25. inline std::string str(T x) {
  26. return std::to_string(x);
  27. }
  28. namespace marlin_moe {
  29. constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
  30. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  31. // Instances of `Vec` are used to organize groups of >>registers<<, as needed
  32. // for instance as inputs to tensor core operations. Consequently, all
  33. // corresponding index accesses must be compile-time constants, which is why we
  34. // extensively use `#pragma unroll` throughout the kernel code to guarantee
  35. // this.
  36. template <typename T, int n>
  37. struct Vec {
  38. T elems[n];
  39. __device__ T& operator[](int i) { return elems[i]; }
  40. };
  41. using I4 = Vec<int, 4>;
  42. // Matrix fragments for tensor core instructions; their precise layout is
  43. // documented here:
  44. // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
  45. using FragA = Vec<half2, 4>;
  46. using FragB = Vec<half2, 2>;
  47. using FragC = Vec<float, 4>;
  48. using FragS = Vec<half2, 1>; // quantization scales
  49. // Predicated asynchronous global->shared copy; used for inputs A where we apply
  50. // predication to handle batchsizes that are not multiples of 16.
  51. __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
  52. bool pred = true) {
  53. const int BYTES = 16;
  54. uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
  55. asm volatile(
  56. "{\n"
  57. " .reg .pred p;\n"
  58. " setp.ne.b32 p, %0, 0;\n"
  59. " @p cp.async.cg.shared.global [%1], [%2], %3;\n"
  60. "}\n" ::"r"((int)pred),
  61. "r"(smem), "l"(glob_ptr), "n"(BYTES));
  62. }
  63. // Asynchronous global->shared copy
  64. __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
  65. const int BYTES = 16;
  66. uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
  67. asm volatile(
  68. "{\n"
  69. " cp.async.cg.shared.global [%0], [%1], %2;\n"
  70. "}\n" ::"r"(smem),
  71. "l"(glob_ptr), "n"(BYTES));
  72. }
  73. // Async copy fence.
  74. __device__ inline void cp_async_fence() {
  75. asm volatile("cp.async.commit_group;\n" ::);
  76. }
  77. // Wait until at most `n` async copy stages are still pending.
  78. template <int n>
  79. __device__ inline void cp_async_wait() {
  80. asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
  81. }
  82. // m16n8k16 tensor core mma instruction with fp16 inputs and fp32
  83. // output/accumulation.
  84. __device__ inline void mma(const FragA& a_frag, const FragB& frag_b,
  85. FragC& frag_c) {
  86. const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
  87. const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
  88. float* c = reinterpret_cast<float*>(&frag_c);
  89. asm volatile(
  90. "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
  91. "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
  92. : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
  93. : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
  94. "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
  95. }
  96. // Instruction for loading a full 16x16 matrix fragment of operand A from shared
  97. // memory, directly in tensor core layout.
  98. __device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
  99. uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
  100. uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
  101. asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
  102. : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
  103. : "r"(smem));
  104. }
  105. // Lookup-table based 3-input logical operation; explicitly used for
  106. // dequantization as the compiler does not seem to automatically recognize it in
  107. // all cases.
  108. template <int lut>
  109. __device__ inline int lop3(int a, int b, int c) {
  110. int res;
  111. asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
  112. : "=r"(res)
  113. : "r"(a), "r"(b), "r"(c), "n"(lut));
  114. return res;
  115. }
  116. // Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
  117. // values. We mostly follow the strategy in the link below, with some small
  118. // changes:
  119. // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
  120. __device__ inline FragB dequant(int q) {
  121. const int LO = 0x000f000f;
  122. const int HI = 0x00f000f0;
  123. const int EX = 0x64006400;
  124. // Guarantee that the `(a & b) | c` operations are LOP3s.
  125. int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
  126. int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
  127. // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
  128. // directly into `SUB` and `ADD`.
  129. const int SUB = 0x64086408;
  130. const int MUL = 0x2c002c00;
  131. const int ADD = 0xd480d480;
  132. FragB frag_b;
  133. frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
  134. *reinterpret_cast<const half2*>(&SUB));
  135. frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
  136. *reinterpret_cast<const half2*>(&MUL),
  137. *reinterpret_cast<const half2*>(&ADD));
  138. return frag_b;
  139. }
  140. // Multiply dequantized values by the corresponding quantization scale; used
  141. // only for grouped quantization.
  142. __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
  143. half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]);
  144. frag_b[0] = __hmul2(frag_b[0], s);
  145. frag_b[1] = __hmul2(frag_b[1], s);
  146. }
  147. // Given 2 floats multiply by 2 scales (halves)
  148. __device__ inline void scale_float(float* c, FragS& s) {
  149. __half* s_ptr = reinterpret_cast<__half*>(&s);
  150. c[0] = __fmul_rn(c[0], __half2float(s_ptr[0]));
  151. c[1] = __fmul_rn(c[1], __half2float(s_ptr[1]));
  152. }
  153. // Same as above, but for act_order (each K is multiplied individually)
  154. __device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& frag_s_2,
  155. FragS& frag_s_3, FragS& frag_s_4, int i) {
  156. __half2 s_val_1_2;
  157. s_val_1_2.x = reinterpret_cast<__half*>(&frag_s_1)[i];
  158. s_val_1_2.y = reinterpret_cast<__half*>(&frag_s_2)[i];
  159. __half2 s_val_3_4;
  160. s_val_3_4.x = reinterpret_cast<__half*>(&frag_s_3)[i];
  161. s_val_3_4.y = reinterpret_cast<__half*>(&frag_s_4)[i];
  162. frag_b[0] = __hmul2(frag_b[0], s_val_1_2);
  163. frag_b[1] = __hmul2(frag_b[1], s_val_3_4);
  164. }
  165. // Wait until barrier reaches `count`, then lock for current threadblock.
  166. __device__ inline void barrier_acquire(int* lock, int count) {
  167. if (threadIdx.x == 0) {
  168. int state = -1;
  169. do
  170. // Guarantee that subsequent writes by this threadblock will be visible
  171. // globally.
  172. asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
  173. : "=r"(state)
  174. : "l"(lock));
  175. while (state != count);
  176. }
  177. __syncthreads();
  178. }
  179. // Release barrier and increment visitation count.
  180. __device__ inline void barrier_release(int* lock, bool reset = false) {
  181. __syncthreads();
  182. if (threadIdx.x == 0) {
  183. if (reset) {
  184. lock[0] = 0;
  185. return;
  186. }
  187. int val = 1;
  188. // Make sure that all writes since acquiring this barrier are visible
  189. // globally, while releasing the barrier.
  190. asm volatile("fence.acq_rel.gpu;\n");
  191. asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
  192. :
  193. : "l"(lock), "r"(val));
  194. }
  195. }
  196. // For a given "a" of size [M,K] performs a permutation of the K columns based
  197. // on the given "perm" indices.
  198. __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
  199. int const* __restrict__ perm_int_ptr,
  200. int4* __restrict__ out_int4_ptr, int size_m,
  201. int size_k, int block_rows) {
  202. int start_row = block_rows * blockIdx.x;
  203. int finish_row = start_row + block_rows;
  204. if (finish_row > size_m) {
  205. finish_row = size_m;
  206. }
  207. int cur_block_rows = finish_row - start_row;
  208. int row_stride = size_k * sizeof(half) / 16;
  209. auto permute_row = [&](int row) {
  210. int iters = size_k / blockDim.x;
  211. int rest = size_k % blockDim.x;
  212. int offset = row * row_stride;
  213. half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + offset);
  214. half* out_half = reinterpret_cast<half*>(out_int4_ptr + offset);
  215. int base_k = 0;
  216. for (int i = 0; i < iters; i++) {
  217. int cur_k = base_k + threadIdx.x;
  218. int src_pos = perm_int_ptr[cur_k];
  219. out_half[cur_k] = a_row_half[src_pos];
  220. base_k += blockDim.x;
  221. }
  222. if (rest) {
  223. if (threadIdx.x < rest) {
  224. int cur_k = base_k + threadIdx.x;
  225. int src_pos = perm_int_ptr[cur_k];
  226. out_half[cur_k] = a_row_half[src_pos];
  227. }
  228. }
  229. };
  230. for (int i = 0; i < cur_block_rows; i++) {
  231. int cur_row = start_row + i;
  232. if (cur_row < size_m) {
  233. permute_row(cur_row);
  234. }
  235. }
  236. }
  237. __global__ void compute_expert_offsets(int const* __restrict__ topk_ids,
  238. int* __restrict__ expert_offsets,
  239. int topk_length, int block_size) {
  240. int expert_id = threadIdx.x;
  241. int num_experts = blockDim.x;
  242. int occurrences = 0;
  243. for (int i = 0; i < topk_length; ++i) {
  244. occurrences += (topk_ids[i] == expert_id);
  245. }
  246. expert_offsets[expert_id + 1] = occurrences;
  247. __syncthreads();
  248. if (threadIdx.x == 0) {
  249. int tot_offset = 0;
  250. expert_offsets[0] = 0;
  251. for (int i = 0; i < num_experts; ++i) {
  252. tot_offset += ceildiv(expert_offsets[i + 1], block_size) * block_size;
  253. expert_offsets[i + 1] = tot_offset;
  254. }
  255. }
  256. __syncthreads();
  257. }
  258. template <const int threads, // number of threads in a threadblock
  259. const int thread_m_blocks, // number of 16x16 blocks in the m
  260. // dimension (batchsize) of the
  261. // threadblock
  262. const int thread_n_blocks, // same for n dimension (output)
  263. const int thread_k_blocks, // same for k dimension (reduction)
  264. const int stages, // number of stages for the async global->shared
  265. // fetch pipeline
  266. const bool has_act_order, // whether act_order is enabled
  267. const int group_blocks = -1 // number of consecutive 16x16 blocks
  268. // with a separate quantization scale
  269. >
  270. __device__ inline void MarlinMoESingle(
  271. const int4* __restrict__ A, // fp16 input matrix of shape mxk
  272. const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
  273. int4* __restrict__ C, // fp16 output buffer of shape mxn
  274. const int* __restrict__ sorted_ids, // int32 sorted ids of experts
  275. const float* __restrict__ topk_weights, // float topk weights
  276. const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
  277. // (k/groupsize)xn
  278. const int* __restrict__ g_idx, // int32 group indices of shape k
  279. const int* __restrict__ expert_offsets,
  280. int num_groups, // number of scale groups per output channel
  281. int expert_idx, // idx of current expert
  282. int num_experts, // number of experts
  283. int topk, // topk parameter of moe
  284. int prob_m, // batch dimension m
  285. int prob_n, // output dimension n
  286. int prob_k, // reduction dimension k
  287. int tot_m, // total number of rows in A and C
  288. int* locks, // extra global storage for barrier synchronization
  289. bool replicate_input, // do we use the same input for each expert?
  290. bool apply_weights, // apply weights to output
  291. int current_m_block // current m block to start kernel computation from
  292. ) {
  293. // For larger GEMMs we run multiple batchsize 64 versions in parallel for a
  294. // better partitioning with less reductions
  295. int parallel = 1;
  296. if (prob_m > 16 * thread_m_blocks) {
  297. parallel = prob_m / (16 * thread_m_blocks);
  298. prob_m = 16 * thread_m_blocks;
  299. }
  300. int k_tiles = prob_k / 16 / thread_k_blocks;
  301. int n_tiles = prob_n / 16 / thread_n_blocks;
  302. int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x);
  303. if constexpr (!has_act_order && group_blocks != -1) {
  304. if (group_blocks >= thread_k_blocks) {
  305. // Ensure that the number of tiles in each stripe is a multiple of the
  306. // groupsize; this avoids an annoying special case where a stripe starts
  307. // in the middle of group.
  308. iters = (group_blocks / thread_k_blocks) *
  309. ceildiv(iters, (group_blocks / thread_k_blocks));
  310. }
  311. }
  312. int slice_row = (iters * blockIdx.x) % k_tiles;
  313. int slice_col_par = (iters * blockIdx.x) / k_tiles;
  314. int slice_col = slice_col_par;
  315. int slice_iters; // number of threadblock tiles in the current slice
  316. int slice_count =
  317. 0; // total number of active threadblocks in the current slice
  318. int slice_idx; // index of threadblock in current slice; numbered bottom to
  319. // top
  320. // We can easily implement parallel problem execution by just remapping
  321. // indices and advancing global pointers
  322. if (slice_col_par >= n_tiles) {
  323. locks += (slice_col_par / n_tiles) * n_tiles;
  324. slice_col = slice_col_par % n_tiles;
  325. sorted_ids += (slice_col_par / n_tiles) * 16 * thread_m_blocks;
  326. }
  327. // Compute all information about the current slice which is required for
  328. // synchronization.
  329. auto init_slice = [&]() {
  330. slice_iters =
  331. iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
  332. if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
  333. if (slice_iters == 0) return;
  334. if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
  335. slice_count = 1;
  336. slice_idx = 0;
  337. int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);
  338. if (col_first <= k_tiles * (slice_col_par + 1)) {
  339. int col_off = col_first - k_tiles * slice_col_par;
  340. slice_count = ceildiv(k_tiles - col_off, iters);
  341. if (col_off > 0) slice_count++;
  342. int delta_first = iters * blockIdx.x - col_first;
  343. if (delta_first < 0 || (col_off == 0 && delta_first == 0))
  344. slice_idx = slice_count - 1;
  345. else {
  346. slice_idx = slice_count - 1 - delta_first / iters;
  347. if (col_off > 0) slice_idx--;
  348. }
  349. }
  350. if (slice_col == n_tiles) {
  351. sorted_ids += 16 * thread_m_blocks;
  352. locks += n_tiles;
  353. slice_col = 0;
  354. }
  355. };
  356. init_slice();
  357. // A sizes/strides
  358. // stride of the A matrix in global memory
  359. int a_gl_stride = prob_k / 8;
  360. // stride of an A matrix tile in shared memory
  361. constexpr int a_sh_stride = 16 * thread_k_blocks / 8;
  362. // delta between subsequent A tiles in global memory
  363. constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8;
  364. // between subsequent accesses within a tile
  365. int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o);
  366. // between shared memory writes
  367. constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o);
  368. // between shared memory tile reads
  369. constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4));
  370. // within a shared memory tile
  371. constexpr int a_sh_rd_delta_i = a_sh_stride * 16;
  372. // overall size of a tile
  373. constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks);
  374. // number of shared write iterations for a tile
  375. constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta);
  376. // B sizes/strides
  377. int b_gl_stride = 16 * prob_n / 32;
  378. constexpr int b_sh_stride = 32 * thread_n_blocks / 4;
  379. int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
  380. int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride);
  381. constexpr int b_sh_wr_delta = threads;
  382. constexpr int b_sh_rd_delta = threads;
  383. constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
  384. constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
  385. // Scale sizes/strides without act_order
  386. int s_gl_stride = prob_n / 8;
  387. constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
  388. constexpr int s_tb_groups = !has_act_order && group_blocks < thread_k_blocks
  389. ? thread_k_blocks / group_blocks
  390. : 1;
  391. constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
  392. int s_gl_rd_delta = s_gl_stride;
  393. // Scale size/strides with act_order
  394. constexpr int tb_k = 16 * thread_k_blocks;
  395. constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0;
  396. // constexpr int act_s_row_stride = 1;
  397. // int act_s_col_stride = act_s_row_stride * num_groups;
  398. int act_s_col_stride = 1;
  399. int act_s_col_warp_stride = act_s_col_stride * 8;
  400. int tb_n_warps = thread_n_blocks / 4;
  401. int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;
  402. constexpr int sorted_sh_stride = threads;
  403. constexpr int sorted_gl_stride = threads;
  404. // Global A read index of current thread.
  405. int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
  406. (threadIdx.x % a_gl_rd_delta_o);
  407. a_gl_rd += a_gl_rd_delta_o * slice_row;
  408. // Shared write index of current thread.
  409. int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
  410. (threadIdx.x % a_gl_rd_delta_o);
  411. // Shared read index.
  412. int a_sh_rd =
  413. a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;
  414. a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
  415. int b_gl_rd =
  416. b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
  417. b_gl_rd += b_sh_stride * slice_col;
  418. b_gl_rd += b_gl_rd_delta_o * slice_row;
  419. int b_sh_wr = threadIdx.x;
  420. int b_sh_rd = threadIdx.x;
  421. // For act_order
  422. constexpr int k_iter_size = tb_k / b_sh_wr_iters;
  423. int slice_k_start = tb_k * slice_row;
  424. int slice_k_finish = slice_k_start + tb_k * slice_iters;
  425. int slice_k_start_shared_fetch = slice_k_start;
  426. int slice_n_offset = act_s_col_tb_stride * slice_col;
  427. // No act_order
  428. int s_gl_rd;
  429. if constexpr (group_blocks == -1 || group_blocks == 0) {
  430. s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
  431. } else {
  432. s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
  433. s_sh_stride * slice_col + threadIdx.x;
  434. }
  435. int s_sh_wr = threadIdx.x;
  436. bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
  437. // We use a different scale layout for grouped and column-wise quantization as
  438. // we scale a `half2` tile in column-major layout in the former and in
  439. // row-major in the latter case.
  440. int s_sh_rd;
  441. if constexpr (group_blocks != -1)
  442. s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
  443. (threadIdx.x % 32) / 4;
  444. else
  445. s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
  446. (threadIdx.x % 32) % 4;
  447. int sh_first_group_id = -1;
  448. int sh_num_groups = -1;
  449. constexpr int sh_max_num_groups = 32;
  450. int shs_size;
  451. if constexpr (has_act_order)
  452. shs_size = sh_max_num_groups * s_sh_stride + threads;
  453. else
  454. shs_size = group_blocks > 0 ? stages * s_sh_stage : threads;
  455. extern __shared__ int4 sh[];
  456. // Shared memory storage for global fetch pipelines.
  457. int4* sh_a = sh;
  458. int4* sh_b = sh_a + (stages * a_sh_stage);
  459. int4* sh_g_idx = sh_b + (stages * b_sh_stage);
  460. int4* sh_s = sh_g_idx + (stages * g_idx_stage);
  461. int* sh_sorted = (int*)(sh_s + shs_size);
  462. // Precompute which thread should not read memory in which iterations; this is
  463. // needed if there are more threads than required for a certain tilesize or
  464. // when the batchsize is not a multiple of 16.
  465. bool a_sh_wr_pred[a_sh_wr_iters];
  466. #pragma unroll
  467. for (int i = 0; i < a_sh_wr_iters; i++) {
  468. int a_idx = a_sh_wr_delta * i + a_sh_wr;
  469. int row = a_idx / a_gl_rd_delta_o;
  470. if (row >= prob_m) {
  471. a_sh_wr_pred[i] = false;
  472. } else {
  473. a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
  474. }
  475. }
  476. // To ensure that writing and reading A tiles to/from shared memory, the
  477. // latter in fragment format, is fully bank conflict free, we need to use a
  478. // rather fancy XOR-based layout. The key here is that neither reads nor
  479. // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
  480. // same shared memory banks. Further, it seems (based on NSight-Compute) that
  481. // each warp must also write a consecutive memory segment?
  482. auto transform_a = [&](int i) {
  483. int row = i / a_gl_rd_delta_o;
  484. return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
  485. };
  486. // Since the computation of this remapping is non-trivial and, due to our main
  487. // loop unrolls, all shared memory accesses are static, we simply precompute
  488. // both transformed reads and writes.
  489. int a_sh_wr_trans[a_sh_wr_iters];
  490. #pragma unroll
  491. for (int i = 0; i < a_sh_wr_iters; i++)
  492. a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
  493. int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
  494. #pragma unroll
  495. for (int i = 0; i < b_sh_wr_iters; i++) {
  496. #pragma unroll
  497. for (int j = 0; j < thread_m_blocks; j++)
  498. a_sh_rd_trans[i][j] =
  499. transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
  500. }
  501. // Since B-accesses have non-constant stride they have to be computed at
  502. // runtime; we break dependencies between subsequent accesses with a tile by
  503. // maintining multiple pointers (we have enough registers), a tiny
  504. // optimization.
  505. const int4* B_ptr[b_sh_wr_iters];
  506. #pragma unroll
  507. for (int i = 0; i < b_sh_wr_iters; i++)
  508. B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
  509. // Register storage for double buffer of shared memory reads.
  510. FragA frag_a[2][thread_m_blocks];
  511. I4 frag_b_quant[2];
  512. FragC frag_c[thread_m_blocks][4][2];
  513. FragS frag_s[2][4]; // No act-order
  514. FragS act_frag_s[2][4][4]; // For act-order
  515. // Zero accumulators.
  516. auto zero_accums = [&]() {
  517. #pragma unroll
  518. for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
  519. reinterpret_cast<float*>(frag_c)[i] = 0;
  520. };
  521. auto fetch_scales_to_shared = [&](bool is_async, int first_group_id,
  522. int last_group_id) {
  523. sh_first_group_id = first_group_id;
  524. sh_num_groups = last_group_id - first_group_id + 1;
  525. if (sh_num_groups < sh_max_num_groups) {
  526. sh_num_groups = sh_max_num_groups;
  527. }
  528. if (sh_first_group_id + sh_num_groups > num_groups) {
  529. sh_num_groups = num_groups - sh_first_group_id;
  530. }
  531. int row_offset = first_group_id * s_gl_stride;
  532. if (is_async) {
  533. for (int i = 0; i < sh_num_groups; i++) {
  534. if (threadIdx.x < s_sh_stride) {
  535. cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x],
  536. &scales_ptr[row_offset + (i * s_gl_stride) +
  537. slice_n_offset + threadIdx.x]);
  538. }
  539. }
  540. } else {
  541. for (int i = 0; i < sh_num_groups; i++) {
  542. if (threadIdx.x < s_sh_stride) {
  543. sh_s[(i * s_sh_stride) + threadIdx.x] =
  544. scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset +
  545. threadIdx.x];
  546. }
  547. }
  548. }
  549. };
  550. // Asynchronously fetch the next A, B and s tile from global to the next
  551. // shared memory pipeline location.
  552. auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
  553. if (pred) {
  554. int4* sh_a_stage = sh_a + a_sh_stage * pipe;
  555. #pragma unroll
  556. for (int i = 0; i < a_sh_wr_iters; i++) {
  557. int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off;
  558. int row = a_idx / a_gl_stride;
  559. int sorted_row =
  560. replicate_input ? sorted_ids[row] / topk : sorted_ids[row];
  561. int new_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride;
  562. if (sorted_row < tot_m * (replicate_input ? 1 : topk) &&
  563. new_idx < a_gl_stride * tot_m * (replicate_input ? 1 : topk)) {
  564. cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[new_idx],
  565. a_sh_wr_pred[i]);
  566. }
  567. }
  568. int4* sh_b_stage = sh_b + b_sh_stage * pipe;
  569. #pragma unroll
  570. for (int i = 0; i < b_sh_wr_iters; i++) {
  571. cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]);
  572. B_ptr[i] += b_gl_rd_delta_o;
  573. }
  574. if constexpr (has_act_order) {
  575. // Fetch g_idx thread-block portion
  576. int full_pipe = a_off;
  577. int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe;
  578. if (cur_k < prob_k && cur_k < slice_k_finish) {
  579. int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
  580. int4 const* cur_g_idx_stage_ptr =
  581. reinterpret_cast<int4 const*>(&g_idx[cur_k]);
  582. if (threadIdx.x < g_idx_stage) {
  583. cp_async4_pred(&sh_g_idx_stage[threadIdx.x],
  584. &cur_g_idx_stage_ptr[threadIdx.x]);
  585. }
  586. }
  587. } else {
  588. if constexpr (group_blocks != -1) {
  589. int4* sh_s_stage = sh_s + s_sh_stage * pipe;
  590. if constexpr (group_blocks >= thread_k_blocks) {
  591. // Only fetch scales if this tile starts a new group
  592. if (pipe % (group_blocks / thread_k_blocks) == 0) {
  593. if (s_sh_wr_pred) {
  594. cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
  595. }
  596. s_gl_rd += s_gl_rd_delta;
  597. }
  598. } else {
  599. for (int i = 0; i < s_tb_groups; i++) {
  600. if (s_sh_wr_pred) {
  601. cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr],
  602. &scales_ptr[s_gl_rd]);
  603. }
  604. s_gl_rd += s_gl_rd_delta;
  605. }
  606. }
  607. }
  608. }
  609. }
  610. // Insert a fence even when we are winding down the pipeline to ensure that
  611. // waiting is also correct at this point.
  612. cp_async_fence();
  613. };
  614. // TODO we are currently hitting illegal memory accesses when fetching
  615. // sorted_ids to shared data: fix this
  616. auto fetch_sorted_ids_to_shared = [&]() {
  617. const int mpt = ceildiv(prob_m, threads);
  618. for (int i = 0; i < mpt; i++) {
  619. if ((i * sorted_gl_stride) + threadIdx.x < prob_m) {
  620. sh_sorted[(i * sorted_sh_stride) + threadIdx.x] =
  621. sorted_ids[(i * sorted_gl_stride) + threadIdx.x];
  622. }
  623. }
  624. };
  625. // Wait until the next thread tile has been loaded to shared memory.
  626. auto wait_for_stage = [&]() {
  627. // We only have `stages - 2` active fetches since we are double buffering
  628. // and can only issue the next fetch when it is guaranteed that the previous
  629. // shared memory load is fully complete (as it may otherwise be
  630. // overwritten).
  631. cp_async_wait<stages - 2>();
  632. __syncthreads();
  633. };
  634. // Load the next sub-tile from the current location in the shared memory pipe
  635. // into the current register buffer.
  636. auto fetch_to_registers = [&](int k, int pipe) {
  637. int4* sh_a_stage = sh_a + a_sh_stage * pipe;
  638. #pragma unroll
  639. for (int i = 0; i < thread_m_blocks; i++)
  640. ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
  641. int4* sh_b_stage = sh_b + b_sh_stage * pipe;
  642. frag_b_quant[k % 2] = *reinterpret_cast<I4*>(
  643. &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]);
  644. };
  645. bool is_same_group[stages];
  646. int same_group_id[stages];
  647. auto init_same_group = [&](int pipe) {
  648. int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
  649. int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);
  650. int group_id_1 = sh_g_idx_int_ptr[0];
  651. int group_id_2 = sh_g_idx_int_ptr[tb_k - 1];
  652. is_same_group[pipe] = group_id_1 == group_id_2;
  653. same_group_id[pipe] = group_id_1;
  654. };
  655. auto fetch_scales_to_registers = [&](int k, int full_pipe) {
  656. int pipe = full_pipe % stages;
  657. if constexpr (!has_act_order) {
  658. // No act-order case
  659. if constexpr (group_blocks != -1) {
  660. if constexpr (group_blocks >= thread_k_blocks) {
  661. int4* sh_s_stage =
  662. sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
  663. (pipe / (group_blocks / thread_k_blocks)));
  664. reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
  665. } else {
  666. int warp_id = threadIdx.x / 32;
  667. int n_warps = thread_n_blocks / 4;
  668. int warp_row = warp_id / n_warps;
  669. int cur_k = warp_row * 16;
  670. cur_k += k_iter_size * (k % b_sh_wr_iters);
  671. int k_blocks = cur_k / 16;
  672. int cur_group_id = k_blocks / group_blocks;
  673. int4* sh_s_stage = sh_s + s_sh_stage * pipe;
  674. reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
  675. sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
  676. }
  677. }
  678. return;
  679. }
  680. // Act-order case
  681. // Determine K of the "current" thread-block
  682. int cur_k = slice_k_start + tb_k * full_pipe;
  683. if (cur_k >= prob_k || cur_k >= slice_k_finish) {
  684. return;
  685. }
  686. // Reset (to current thread-block) since we read g_idx portion from the
  687. // shared memory
  688. cur_k = 0;
  689. // Progress to current iteration
  690. cur_k += k_iter_size * (k % b_sh_wr_iters);
  691. // Determine "position" inside the thread-block (based on warp and
  692. // thread-id)
  693. int warp_id = threadIdx.x / 32;
  694. int n_warps =
  695. thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N
  696. int warp_row = warp_id / n_warps;
  697. int warp_col = warp_id % n_warps;
  698. cur_k += warp_row * 16;
  699. int th_id = threadIdx.x % 32;
  700. cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix
  701. int s_col_shift =
  702. /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) +
  703. (th_id / 4) * act_s_col_stride;
  704. if (is_same_group[pipe]) {
  705. if (k % 2 == 0) {
  706. *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) =
  707. sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride +
  708. s_col_shift];
  709. } else {
  710. *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) =
  711. *(reinterpret_cast<int4*>(&(act_frag_s[(k - 1) % 2][0][0])));
  712. }
  713. for (int i = 1; i < 4; i++) {
  714. *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) =
  715. *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0])));
  716. }
  717. return;
  718. }
  719. int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
  720. int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);
  721. constexpr int k_frag_offsets[4] = {0, 1, 8,
  722. 9}; // Tensor core offsets per thread
  723. #pragma unroll
  724. for (int i = 0; i < 4; i++) {
  725. int actual_k = cur_k + k_frag_offsets[i];
  726. int group_id = sh_g_idx_int_ptr[actual_k];
  727. int rel_group_id = group_id - sh_first_group_id;
  728. *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) =
  729. sh_s[rel_group_id * s_sh_stride + s_col_shift];
  730. }
  731. };
  732. // Execute the actual tensor core matmul of a sub-tile.
  733. auto matmul = [&](int k) {
  734. // We have the m dimension as the inner loop in order to encourage overlapping
  735. // dequantization and matmul operations.
  736. #pragma unroll
  737. for (int j = 0; j < 4; j++) {
  738. int b_quant = frag_b_quant[k % 2][j];
  739. int b_quant_shift = b_quant >> 8;
  740. FragB frag_b0 = dequant(b_quant);
  741. // Apply scale to frag_b0
  742. if constexpr (has_act_order) {
  743. scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
  744. act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0);
  745. } else {
  746. if constexpr (group_blocks != -1) {
  747. scale(frag_b0, frag_s[k % 2][j], 0);
  748. }
  749. }
  750. FragB frag_b1 = dequant(b_quant_shift);
  751. // Apply scale to frag_b1
  752. if constexpr (has_act_order) {
  753. scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
  754. act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1);
  755. } else {
  756. if constexpr (group_blocks != -1) {
  757. scale(frag_b1, frag_s[k % 2][j], 1);
  758. }
  759. }
  760. #pragma unroll
  761. for (int i = 0; i < thread_m_blocks; i++) {
  762. mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
  763. mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
  764. }
  765. }
  766. };
  767. // Since we slice across the k dimension of a tile in order to increase the
  768. // number of warps while keeping the n dimension of a tile reasonable, we have
  769. // multiple warps that accumulate their partial sums of the same output
  770. // location; which we have to reduce over in the end. We do in shared memory.
  771. auto thread_block_reduce = [&]() {
  772. constexpr int red_off = threads / b_sh_stride / 2;
  773. if (red_off >= 1) {
  774. int red_idx = threadIdx.x / b_sh_stride;
  775. constexpr int red_sh_stride = b_sh_stride * 4 * 2;
  776. constexpr int red_sh_delta = b_sh_stride;
  777. int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) +
  778. (threadIdx.x % b_sh_stride);
  779. // Parallel logarithmic shared memory reduction. We make sure to avoid any
  780. // unnecessary read or write iterations, e.g., for two warps we write only
  781. // once by warp 1 and read only once by warp 0.
  782. #pragma unroll
  783. for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
  784. #pragma unroll
  785. for (int i = red_off; i > 0; i /= 2) {
  786. if (i <= red_idx && red_idx < 2 * i) {
  787. #pragma unroll
  788. for (int j = 0; j < 4 * 2; j++) {
  789. int red_sh_wr =
  790. red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
  791. if (i < red_off) {
  792. float* c_rd =
  793. reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
  794. float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
  795. #pragma unroll
  796. for (int k = 0; k < 4; k++)
  797. reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
  798. c_rd[k] + c_wr[k];
  799. }
  800. sh[red_sh_wr] =
  801. reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
  802. }
  803. }
  804. __syncthreads();
  805. }
  806. if (red_idx == 0) {
  807. #pragma unroll
  808. for (int i = 0; i < 4 * 2; i++) {
  809. float* c_rd =
  810. reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
  811. #pragma unroll
  812. for (int j = 0; j < 4; j++)
  813. reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
  814. c_rd[j];
  815. }
  816. }
  817. __syncthreads();
  818. }
  819. }
  820. };
  821. // Since multiple threadblocks may process parts of the same column slice, we
  822. // finally have to globally reduce over the results. As the striped
  823. // partitioning minimizes the number of such reductions and our outputs are
  824. // usually rather small, we perform this reduction serially in L2 cache.
  825. auto global_reduce = [&](bool first = false, bool last = false) {
  826. // We are very careful here to reduce directly in the output buffer to
  827. // maximize L2 cache utilization in this step. To do this, we write out
  828. // results in FP16 (but still reduce with FP32 compute).
  829. constexpr int active_threads = 32 * thread_n_blocks / 4;
  830. if (threadIdx.x < active_threads) {
  831. int c_gl_stride = prob_n / 8;
  832. int c_gl_wr_delta_o = 8 * c_gl_stride;
  833. int c_gl_wr_delta_i = 4 * (active_threads / 32);
  834. int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) +
  835. 4 * (threadIdx.x / 32) + threadIdx.x % 4;
  836. c_gl_wr += (2 * thread_n_blocks) * slice_col;
  837. constexpr int c_sh_wr_delta = active_threads;
  838. int c_sh_wr = threadIdx.x;
  839. int row = (threadIdx.x % 32) / 4;
  840. if (!first) {
  841. // Interestingly, doing direct global accesses here really seems to mess up
  842. // the compiler and lead to slowdowns, hence we also use async-copies even
  843. // though these fetches are not actually asynchronous.
  844. #pragma unroll
  845. for (int i = 0; i < thread_m_blocks * 4; i++) {
  846. int c_idx =
  847. c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2);
  848. int sorted_row = sorted_ids[c_idx / c_gl_stride];
  849. int new_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride;
  850. cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], &C[new_idx],
  851. sorted_row < tot_m * topk &&
  852. (8 * (i / 2) + row < prob_m &&
  853. (i < (thread_m_blocks - 1) * 4 ||
  854. sorted_ids[8 * (i / 2) + row] < tot_m * topk)));
  855. }
  856. cp_async_fence();
  857. cp_async_wait<0>();
  858. }
  859. #pragma unroll
  860. for (int i = 0; i < thread_m_blocks * 4; i++) {
  861. if (8 * (i / 2) + row < prob_m &&
  862. (i < (thread_m_blocks - 1) * 4 ||
  863. sorted_ids[8 * (i / 2) + row] < tot_m * topk)) {
  864. if (!first) {
  865. int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
  866. #pragma unroll
  867. for (int j = 0; j < 2 * 4; j++) {
  868. reinterpret_cast<float*>(
  869. &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=
  870. __half2float(reinterpret_cast<__half*>(&c_red)[j]);
  871. }
  872. }
  873. if (!last) {
  874. int4 c;
  875. #pragma unroll
  876. for (int j = 0; j < 2 * 4; j++) {
  877. reinterpret_cast<__half*>(&c)[j] =
  878. __float2half(reinterpret_cast<float*>(
  879. &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]);
  880. }
  881. int c_idx =
  882. c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2);
  883. int row = sorted_ids[c_idx / c_gl_stride];
  884. if (row < tot_m * topk) {
  885. int new_idx = row * c_gl_stride + c_idx % c_gl_stride;
  886. C[new_idx] = c;
  887. }
  888. }
  889. }
  890. }
  891. }
  892. };
  893. // Write out the reduce final result in the correct layout. We only actually
  894. // reshuffle matrix fragments in this step, the reduction above is performed
  895. // in fragment layout.
  896. auto write_result = [&]() {
  897. int c_gl_stride = prob_n / 8;
  898. constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
  899. int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
  900. constexpr int c_sh_rd_delta =
  901. c_sh_stride * (threads / (2 * thread_n_blocks));
  902. int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) +
  903. (threadIdx.x % (2 * thread_n_blocks));
  904. c_gl_wr += (2 * thread_n_blocks) * slice_col;
  905. int c_sh_wr =
  906. (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;
  907. c_sh_wr += 32 * (threadIdx.x / 32);
  908. int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) +
  909. (threadIdx.x % (2 * thread_n_blocks));
  910. int c_gl_wr_end = c_gl_stride * prob_m;
  911. // We first reorder in shared memory to guarantee the most efficient final
  912. // global write patterns
  913. auto write = [&](int idx, float c0, float c1, FragS& s) {
  914. half2 res = __halves2half2(__float2half(c0), __float2half(c1));
  915. // For per-column quantization we finally apply the scale here
  916. if constexpr (!has_act_order && group_blocks == -1) {
  917. res = __hmul2(res, s[0]);
  918. }
  919. ((half2*)sh)[idx] = res;
  920. };
  921. if (threadIdx.x / 32 < thread_n_blocks / 4) {
  922. #pragma unroll
  923. for (int i = 0; i < thread_m_blocks; i++) {
  924. #pragma unroll
  925. for (int j = 0; j < 4; j++) {
  926. int wr = c_sh_wr + 8 * j;
  927. write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
  928. frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
  929. write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],
  930. frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
  931. write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],
  932. frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
  933. write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],
  934. frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
  935. }
  936. c_sh_wr += 16 * (4 * c_sh_stride);
  937. }
  938. }
  939. __syncthreads();
  940. #pragma unroll
  941. for (int i = 0;
  942. i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
  943. i++) {
  944. if (c_gl_wr < c_gl_wr_end) {
  945. int row = sorted_ids[c_gl_wr / c_gl_stride];
  946. if (row < tot_m * topk) {
  947. int off = row * c_gl_stride + c_gl_wr % c_gl_stride;
  948. if (!apply_weights) {
  949. C[off] = sh[c_sh_rd];
  950. } else {
  951. __half* ctrg = reinterpret_cast<__half*>(&C[off]);
  952. __half* csrc = reinterpret_cast<__half*>(&sh[c_sh_rd]);
  953. for (int j = 0; j < 8; ++j) {
  954. ctrg[j] = __float2half(topk_weights[row] * __half2float(csrc[j]));
  955. }
  956. }
  957. c_gl_wr += c_gl_wr_delta;
  958. c_sh_rd += c_sh_rd_delta;
  959. }
  960. }
  961. }
  962. };
  963. // Start global fetch and register load pipelines.
  964. auto start_pipes = [&]() {
  965. // TODO re-enable after fixing this function
  966. // fetch_sorted_ids_to_shared();
  967. __syncthreads();
  968. #pragma unroll
  969. for (int i = 0; i < stages - 1; i++) {
  970. if (has_act_order && i == 0) {
  971. int last_g_idx = slice_k_start + stages * tb_k * 2;
  972. if (last_g_idx >= prob_k) {
  973. last_g_idx = prob_k - 1;
  974. }
  975. fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]);
  976. }
  977. fetch_to_shared(i, i, i < slice_iters);
  978. }
  979. zero_accums();
  980. wait_for_stage();
  981. init_same_group(0);
  982. fetch_to_registers(0, 0);
  983. fetch_scales_to_registers(0, 0);
  984. a_gl_rd += a_gl_rd_delta_o * (stages - 1);
  985. slice_k_start_shared_fetch += tb_k * (stages - 1);
  986. };
  987. if (slice_iters) {
  988. start_pipes();
  989. }
  990. // Main loop.
  991. while (slice_iters) {
  992. // We unroll over both the global fetch and the register load pipeline to
  993. // ensure all shared memory accesses are static. Note that both pipelines
  994. // have even length meaning that the next iteration will always start at
  995. // index 0.
  996. #pragma unroll
  997. for (int pipe = 0; pipe < stages;) {
  998. #pragma unroll
  999. for (int k = 0; k < b_sh_wr_iters; k++) {
  1000. fetch_to_registers(k + 1, pipe % stages);
  1001. fetch_scales_to_registers(k + 1, pipe);
  1002. if (k == b_sh_wr_iters - 2) {
  1003. fetch_to_shared((pipe + stages - 1) % stages, pipe,
  1004. slice_iters >= stages);
  1005. pipe++;
  1006. wait_for_stage();
  1007. init_same_group(pipe % stages);
  1008. }
  1009. matmul(k);
  1010. }
  1011. slice_iters--;
  1012. if (slice_iters == 0) {
  1013. break;
  1014. }
  1015. }
  1016. a_gl_rd += a_gl_rd_delta_o * stages;
  1017. slice_k_start += tb_k * stages;
  1018. slice_k_start_shared_fetch += tb_k * stages;
  1019. if constexpr (has_act_order) {
  1020. int first_group_id = g_idx[slice_k_start];
  1021. int last_g_idx = slice_k_start + stages * tb_k * 2;
  1022. if (last_g_idx >= prob_k) {
  1023. last_g_idx = prob_k - 1;
  1024. }
  1025. int last_group_id = g_idx[last_g_idx];
  1026. if (last_group_id >= sh_first_group_id + sh_num_groups) {
  1027. fetch_scales_to_shared(false, first_group_id, last_group_id);
  1028. __syncthreads();
  1029. }
  1030. }
  1031. // Process results and, if necessary, proceed to the next column slice.
  1032. // While this pattern may not be the most readable, other ways of writing
  1033. // the loop seemed to noticeably worse performance after compilation.
  1034. if (slice_iters == 0) {
  1035. cp_async_wait<0>();
  1036. bool last = slice_idx == slice_count - 1;
  1037. // For per-column scales, we only fetch them here in the final step before
  1038. // write-out
  1039. if constexpr (!has_act_order && group_blocks == -1) {
  1040. if (last) {
  1041. if (s_sh_wr_pred) {
  1042. cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
  1043. }
  1044. cp_async_fence();
  1045. }
  1046. }
  1047. thread_block_reduce();
  1048. if constexpr (!has_act_order && group_blocks == -1) {
  1049. if (last) {
  1050. cp_async_wait<0>();
  1051. __syncthreads();
  1052. if (threadIdx.x / 32 < thread_n_blocks / 4) {
  1053. reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
  1054. reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
  1055. }
  1056. }
  1057. }
  1058. if (slice_count > 1) { // only globally reduce if there is more than one
  1059. // block in a slice
  1060. barrier_acquire(&locks[slice_col], slice_idx);
  1061. global_reduce(slice_idx == 0, last);
  1062. barrier_release(&locks[slice_col], last);
  1063. }
  1064. if (last) // only the last block in a slice actually writes the result
  1065. write_result();
  1066. slice_row = 0;
  1067. slice_col_par++;
  1068. slice_col++;
  1069. init_slice();
  1070. if (slice_iters) {
  1071. a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
  1072. (threadIdx.x % a_gl_rd_delta_o);
  1073. #pragma unroll
  1074. for (int i = 0; i < b_sh_wr_iters; i++)
  1075. B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
  1076. if (slice_col == 0) {
  1077. #pragma unroll
  1078. for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
  1079. }
  1080. // Update slice k/n for scales loading
  1081. if constexpr (has_act_order) {
  1082. slice_k_start = tb_k * slice_row;
  1083. slice_k_finish = slice_k_start + tb_k * slice_iters;
  1084. slice_k_start_shared_fetch = slice_k_start;
  1085. slice_n_offset = act_s_col_tb_stride * slice_col;
  1086. } else {
  1087. s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
  1088. }
  1089. start_pipes();
  1090. }
  1091. }
  1092. }
  1093. }
  1094. template <const int threads, // number of threads in a threadblock
  1095. const int thread_m_blocks, // number of 16x16 blocks in the m
  1096. // dimension (batchsize) of the
  1097. // threadblock
  1098. const int thread_n_blocks, // same for n dimension (output)
  1099. const int thread_k_blocks, // same for k dimension (reduction)
  1100. const int stages, // number of stages for the async global->shared
  1101. // fetch pipeline
  1102. const bool has_act_order, // whether act_order is enabled
  1103. const int group_blocks = -1 // number of consecutive 16x16 blocks
  1104. // with a separate quantization scale
  1105. >
  1106. __global__ void MarlinMoE(
  1107. const int4* __restrict__ A, // fp16 input matrix of shape mxk
  1108. const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
  1109. int4* __restrict__ C, // fp16 output buffer of shape mxn
  1110. const int* __restrict__ sorted_ids_base, // int32 sorted ids of experts
  1111. const float* __restrict__ topk_weights, // float topk weights
  1112. const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
  1113. // (k/groupsize)xn
  1114. const int* __restrict__ g_idx, // int32 group indices of shape k
  1115. const int* __restrict__ expert_offsets,
  1116. int num_groups, // number of scale groups per output channel
  1117. int expert_idx, // idx of current expert
  1118. int num_experts, // number of experts
  1119. int topk, // topk parameter of moe
  1120. int prob_m, // batch dimension m
  1121. int prob_n, // output dimension n
  1122. int prob_k, // reduction dimension k
  1123. int tot_m, // total number of rows in A and C
  1124. int* locks, // extra global storage for barrier synchronization
  1125. bool replicate_input, // do we use the same input for each expert?
  1126. bool apply_weights, // apply weights to output
  1127. int current_m_block, // current m block to start kernel computation from
  1128. int max_par // maximum parallelism
  1129. ) {
  1130. int m_block_ctr = current_m_block;
  1131. const int* sorted_ids_expert =
  1132. sorted_ids_base + expert_offsets[expert_idx] + m_block_ctr * 4 * max_par;
  1133. int tot_its = expert_offsets[expert_idx + 1] - expert_offsets[expert_idx];
  1134. if (tot_its == 0) {
  1135. return;
  1136. }
  1137. int tot_m_blocks = ceildiv(tot_its, 16);
  1138. int pad = 16 * tot_m_blocks - tot_its;
  1139. if (m_block_ctr >= tot_m_blocks) {
  1140. return;
  1141. }
  1142. int max_block = tot_m_blocks - m_block_ctr;
  1143. prob_m = tot_its - 16 * m_block_ctr;
  1144. int par = 1;
  1145. if (max_block > 4) {
  1146. // Note that parallel > 1 currently only works for inputs without any
  1147. // padding
  1148. par = (16 * max_block - pad) / 64;
  1149. par = min((16 * max_block - pad) / 64, max_par);
  1150. prob_m = 64 * par;
  1151. m_block_ctr += 4 * (par - 1);
  1152. max_block = 4;
  1153. }
  1154. if (max_block == 1) {
  1155. MarlinMoESingle<threads, 1, thread_n_blocks, thread_k_blocks, stages,
  1156. has_act_order, group_blocks>(
  1157. A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
  1158. expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
  1159. prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
  1160. current_m_block);
  1161. } else if (max_block == 2) {
  1162. MarlinMoESingle<threads, 2, thread_n_blocks, thread_k_blocks, stages,
  1163. has_act_order, group_blocks>(
  1164. A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
  1165. expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
  1166. prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
  1167. current_m_block);
  1168. } else if (max_block == 3) {
  1169. MarlinMoESingle<threads, 3, thread_n_blocks, thread_k_blocks, stages,
  1170. has_act_order, group_blocks>(
  1171. A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
  1172. expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
  1173. prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
  1174. current_m_block);
  1175. } else {
  1176. MarlinMoESingle<threads, 4, thread_n_blocks, thread_k_blocks, stages,
  1177. has_act_order, group_blocks>(
  1178. A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
  1179. expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
  1180. prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
  1181. current_m_block);
  1182. }
  1183. }
  1184. #else
  1185. __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
  1186. int const* __restrict__ perm_int_ptr,
  1187. int4* __restrict__ out_int4_ptr, int size_m,
  1188. int size_k, int block_rows) {
  1189. // Marlin is not implemented yet for SM < 8.0
  1190. assert(false);
  1191. return;
  1192. }
  1193. __global__ void compute_expert_offsets(int const* __restrict__ topk_ids,
  1194. int* __restrict__ expert_offsets,
  1195. int topk_length, int block_size) {
  1196. // Marlin is not implemented yet for SM < 8.0
  1197. assert(false);
  1198. return;
  1199. }
  1200. template <const int threads, // number of threads in a threadblock
  1201. const int thread_m_blocks, // number of 16x16 blocks in the m
  1202. // dimension (batchsize) of the
  1203. // threadblock
  1204. const int thread_n_blocks, // same for n dimension (output)
  1205. const int thread_k_blocks, // same for k dimension (reduction)
  1206. const int stages, // number of stages for the async global->shared
  1207. // fetch pipeline
  1208. const bool has_act_order, // whether act_order is enabled
  1209. const int group_blocks = -1 // number of consecutive 16x16 blocks
  1210. // with a separate quantization scale
  1211. >
  1212. __global__ void MarlinMoE(
  1213. const int4* __restrict__ A, // fp16 input matrix of shape mxk
  1214. const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
  1215. int4* __restrict__ C, // fp16 output buffer of shape mxn
  1216. const int* __restrict__ sorted_ids, // int32 sorted ids of experts
  1217. const float* __restrict__ topk_weights, // float topk weights
  1218. const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
  1219. // (k/groupsize)xn
  1220. const int* __restrict__ g_idx, // int32 group indices of shape k
  1221. const int* __restrict__ expert_offsets,
  1222. int num_groups, // number of scale groups per output channel
  1223. int expert_idx, // idx of current expert
  1224. int num_experts, // number of experts
  1225. int topk, // topk parameter of moe
  1226. int prob_m, // batch dimension m
  1227. int prob_n, // output dimension n
  1228. int prob_k, // reduction dimension k
  1229. int tot_m, // total number of rows in A and C
  1230. int* locks, // extra global storage for barrier synchronization
  1231. bool replicate_input, // do we use the same input for each expert?
  1232. bool apply_weights, // apply weights to output
  1233. int current_m_block, // current m block to start kernel computation from
  1234. int max_par // maximum parallelism
  1235. ) {
  1236. // Marlin is not implemented yet for SM < 8.0
  1237. assert(false);
  1238. return;
  1239. }
  1240. #endif
  1241. // 8 warps are a good choice since every SM has 4 schedulers and having more
  1242. // than 1 warp per schedule allows some more latency hiding. At the same time,
  1243. // we want relatively few warps to have many registers per warp and small tiles.
  1244. const int USER_THREADS =
  1245. 256; // Note: This is only used with user-provided thread_k/n
  1246. const int STAGES = 4; // 4 pipeline stages fit into shared memory
  1247. // const int SHARED_MEM =
  1248. // 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0)
  1249. static constexpr int min_thread_n = 64;
  1250. static constexpr int min_thread_k = 64;
  1251. #define __CALL_IF_MOE(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
  1252. HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \
  1253. else if (thread_m_blocks == THREAD_M_BLOCKS && \
  1254. thread_n_blocks == THREAD_N_BLOCKS && \
  1255. thread_k_blocks == THREAD_K_BLOCKS && \
  1256. has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \
  1257. num_threads == NUM_THREADS) { \
  1258. cudaFuncSetAttribute( \
  1259. MarlinMoE<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
  1260. THREAD_K_BLOCKS, STAGES, HAS_ACT_ORDER, GROUP_BLOCKS>, \
  1261. cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
  1262. MarlinMoE<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
  1263. STAGES, HAS_ACT_ORDER, GROUP_BLOCKS> \
  1264. <<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
  1265. A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \
  1266. g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \
  1267. num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \
  1268. replicate_input, apply_weights, m_block, max_par); \
  1269. }
  1270. typedef struct {
  1271. int thread_k;
  1272. int thread_n;
  1273. int num_threads;
  1274. } thread_config_t;
  1275. thread_config_t small_batch_thread_configs[] = {
  1276. // Ordered by priority
  1277. // thread_k, thread_n, num_threads
  1278. {128, 128, 256}, // Default
  1279. {128, 64, 128}, // Reduce N 2X, same K
  1280. {64, 256, 256}, // Reduce K 2X, increase N 2X
  1281. {64, 128, 128}, // Reduce K 2X, same N
  1282. };
  1283. thread_config_t large_batch_thread_configs[] = {
  1284. // Ordered by priority
  1285. // thread_k, thread_n, num_threads
  1286. {64, 256, 256}, // Default
  1287. {128, 128, 256}, // Reduce N 2X, increase K 2X
  1288. {64, 128, 128}, // Reduce N 2X, same K
  1289. {128, 64, 128}, // Reduce N 4X, increase K 2X
  1290. };
  1291. bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n,
  1292. int prob_k) {
  1293. // Sanity
  1294. if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
  1295. th_config.num_threads == -1) {
  1296. return false;
  1297. }
  1298. // Verify K/N are divisible by thread K/N
  1299. if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
  1300. return false;
  1301. }
  1302. // thread_k can be only 128 or 64 (because it must be less than groupsize
  1303. // which is 128)
  1304. if (th_config.thread_k != 128 && th_config.thread_k != 64) {
  1305. return false;
  1306. }
  1307. // Verify min for thread K/N
  1308. if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
  1309. return false;
  1310. }
  1311. // num_threads must be at least 128 (= 4 warps)
  1312. if (th_config.num_threads < 128) {
  1313. return false;
  1314. }
  1315. return true;
  1316. }
  1317. thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) {
  1318. if (prob_m <= 16) {
  1319. for (auto th_config : small_batch_thread_configs) {
  1320. if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
  1321. return th_config;
  1322. }
  1323. }
  1324. } else {
  1325. for (auto th_config : large_batch_thread_configs) {
  1326. if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
  1327. return th_config;
  1328. }
  1329. }
  1330. }
  1331. return thread_config_t{-1, -1, -1};
  1332. }
  1333. #define CALL_IF_MOE(N_BLOCKS, K_BLOCKS, NUM_THREADS) \
  1334. __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
  1335. __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
  1336. __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
  1337. __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
  1338. \
  1339. __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
  1340. __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
  1341. __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
  1342. __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
  1343. \
  1344. __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
  1345. __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
  1346. __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
  1347. __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
  1348. \
  1349. __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
  1350. __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
  1351. __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
  1352. __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
  1353. \
  1354. __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
  1355. __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
  1356. __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
  1357. __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
  1358. void marlin_mm_moe_f16i4(const void* A, const void* B, void* C,
  1359. const void* sorted_ids, const void* topk_weights,
  1360. const void* topk_ids, const void* s, const void* g_idx,
  1361. const void* perm, void* a_tmp, void* expert_offsets,
  1362. int prob_m, int prob_n, int prob_k, void* workspace,
  1363. bool has_act_order, bool is_k_full, int num_groups,
  1364. int group_size, int num_experts, int topk,
  1365. int moe_block_size, int dev, cudaStream_t stream,
  1366. int thread_k, int thread_n, int sms, int max_par,
  1367. bool replicate_input, bool apply_weights) {
  1368. TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
  1369. ", ", prob_n, ", ", prob_k, "]");
  1370. if (sms == -1) {
  1371. cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
  1372. }
  1373. // Set thread config
  1374. thread_config_t th_config;
  1375. if (thread_k != -1 && thread_n != -1) {
  1376. // User-defined config
  1377. th_config = thread_config_t{thread_k, thread_n, USER_THREADS};
  1378. } else {
  1379. // Auto config
  1380. th_config = determine_thread_config(prob_m, prob_n, prob_k);
  1381. }
  1382. TORCH_CHECK(is_valid_config(th_config, prob_m, prob_n, prob_k),
  1383. "Invalid thread config: thread_k = " + str(th_config.thread_k) +
  1384. ", thread_n = " + str(th_config.thread_n) +
  1385. ", num_threads = " + str(th_config.num_threads) +
  1386. " for MKN = [" + str(prob_m) + ", " + str(prob_k) + ", " +
  1387. str(prob_n) + "]");
  1388. int num_threads = th_config.num_threads;
  1389. thread_k = th_config.thread_k;
  1390. thread_n = th_config.thread_n;
  1391. int thread_k_blocks = thread_k / 16;
  1392. int thread_n_blocks = thread_n / 16;
  1393. int blocks = sms;
  1394. TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
  1395. " is not divisible by thread_n = ", thread_n);
  1396. TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
  1397. " is not divisible by thread_k = ", thread_k);
  1398. int group_blocks = 0;
  1399. if (has_act_order) {
  1400. if (is_k_full) {
  1401. TORCH_CHECK(group_size != -1);
  1402. group_blocks = group_size / 16;
  1403. TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
  1404. " is not divisible by group_blocks = ", group_blocks);
  1405. } else {
  1406. TORCH_CHECK(group_size == 0);
  1407. group_blocks = 0;
  1408. }
  1409. } else {
  1410. if (group_size == -1) {
  1411. group_blocks = -1;
  1412. } else {
  1413. group_blocks = group_size / 16;
  1414. TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
  1415. " is not divisible by group_blocks = ", group_blocks);
  1416. }
  1417. }
  1418. int max_shared_mem = 0;
  1419. cudaDeviceGetAttribute(&max_shared_mem,
  1420. cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
  1421. TORCH_CHECK(max_shared_mem > 0);
  1422. int tot_m = prob_m;
  1423. const int* topk_ids_ptr = (const int*)topk_ids;
  1424. int* expert_offsets_ptr = (int*)expert_offsets;
  1425. compute_expert_offsets<<<1, num_experts, 0, stream>>>(
  1426. topk_ids_ptr, expert_offsets_ptr, tot_m * topk, moe_block_size);
  1427. bool do_permute_a = has_act_order;
  1428. // If we have a full K, then we can run the non-act-order version of Marlin
  1429. // (since the weight rows are reordered by increasing group ids, and by
  1430. // having a full K, we have full original groups)
  1431. if (is_k_full) {
  1432. has_act_order = false;
  1433. }
  1434. for (int expert_idx = 0; expert_idx < num_experts; ++expert_idx) {
  1435. const int4* A_ptr = (const int4*)A;
  1436. int4* a_tmp_ptr = (int4*)a_tmp;
  1437. const int4* B_ptr = (const int4*)B + (prob_n * prob_k / 32) * expert_idx;
  1438. int4* C_ptr = (int4*)C;
  1439. const float* topk_weights_ptr = (const float*)topk_weights;
  1440. const int* sorted_ids_ptr = (const int*)sorted_ids;
  1441. const int4* s_ptr =
  1442. (const int4*)s +
  1443. (((group_size == -1 || group_size == 0) ? 1 : prob_k / group_size) *
  1444. prob_n / 8) *
  1445. expert_idx;
  1446. const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx;
  1447. const int* perm_ptr = (const int*)perm + prob_k * expert_idx;
  1448. int* locks = (int*)workspace;
  1449. if (do_permute_a) {
  1450. // Permute A columns
  1451. int topk_rows = replicate_input ? tot_m : tot_m * topk;
  1452. int block_rows = ceildiv(topk_rows, blocks);
  1453. permute_cols_kernel<<<blocks, num_threads, 0, stream>>>(
  1454. A_ptr, perm_ptr, a_tmp_ptr, topk_rows, prob_k, block_rows);
  1455. A_ptr = a_tmp_ptr;
  1456. }
  1457. int max_m_blocks = ceildiv(tot_m, 16);
  1458. for (int m_block = 0; m_block < max_m_blocks; m_block += 16) {
  1459. // Define kernel configurations
  1460. // make it max possible value
  1461. int thread_m_blocks = 4;
  1462. if (false) {
  1463. }
  1464. CALL_IF_MOE(16, 4, 256)
  1465. CALL_IF_MOE(8, 8, 256)
  1466. CALL_IF_MOE(8, 4, 128)
  1467. CALL_IF_MOE(4, 8, 128)
  1468. else {
  1469. TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " +
  1470. str(prob_n) + ", " + str(prob_k) + "]" +
  1471. ", has_act_order = " + str(has_act_order) +
  1472. ", num_groups = " + str(num_groups) +
  1473. ", group_size = " + str(group_size) +
  1474. ", thread_m_blocks = " + str(thread_m_blocks) +
  1475. ", thread_n_blocks = " + str(thread_n_blocks) +
  1476. ", thread_k_blocks = " + str(thread_k_blocks));
  1477. }
  1478. }
  1479. }
  1480. }
  1481. } // namespace marlin_moe
  1482. torch::Tensor marlin_gemm_moe(
  1483. const torch::Tensor& a, const torch::Tensor& b_q_weights,
  1484. const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights,
  1485. const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
  1486. const torch::Tensor& g_idx, const torch::Tensor& perm,
  1487. torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k,
  1488. bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size,
  1489. bool replicate_input, bool apply_weights) {
  1490. int max_par = 4;
  1491. int dev = a.get_device();
  1492. auto options_dtype =
  1493. torch::TensorOptions().dtype(a.dtype()).device(a.device());
  1494. auto options_int =
  1495. torch::TensorOptions().dtype(torch::kInt).device(a.device());
  1496. torch::Tensor c = torch::zeros({size_m, topk, size_n}, options_dtype);
  1497. torch::Tensor a_tmp =
  1498. replicate_input ? torch::zeros({size_m, size_k}, options_dtype)
  1499. : torch::zeros({size_m, topk, size_k}, options_dtype);
  1500. torch::Tensor expert_offsets = torch::empty({num_experts + 1}, options_int);
  1501. // thread_k: `k` size of a thread_tile in `weights` (can usually be left as
  1502. // auto -1)
  1503. int thread_k = -1;
  1504. // thread_n: `n` size of a thread_tile in `weights` (can usually be left as
  1505. // auto -1)
  1506. int thread_n = -1;
  1507. // sms: number of SMs to use for the kernel (can usually be left as auto -1)
  1508. int sms = -1;
  1509. // Detect groupsize and act_order
  1510. int num_groups = -1;
  1511. int group_size = -1;
  1512. bool has_act_order = g_idx.size(1) != 0;
  1513. int b_rank = b_scales.sizes().size();
  1514. TORCH_CHECK(b_rank == 3, "b_scales rank = ", b_rank, " is not 3");
  1515. TORCH_CHECK(b_scales.size(2) == size_n, "b_scales dim 2 = ", b_scales.size(2),
  1516. " is not size_n = ", size_n);
  1517. num_groups = b_scales.size(1);
  1518. if (has_act_order) {
  1519. if (is_k_full) {
  1520. TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1");
  1521. TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k,
  1522. ", is not divisible by num_groups = ", num_groups);
  1523. group_size = size_k / num_groups;
  1524. } else {
  1525. group_size = 0;
  1526. }
  1527. } else {
  1528. if (num_groups > 1) {
  1529. TORCH_CHECK(
  1530. size_k % num_groups == 0, "size_k = ", size_k,
  1531. ", is not divisible by b_scales.size(0) = ", b_scales.size(0));
  1532. group_size = size_k / num_groups;
  1533. } else {
  1534. group_size = -1;
  1535. }
  1536. }
  1537. marlin_moe::marlin_mm_moe_f16i4(
  1538. a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(),
  1539. topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(),
  1540. g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(),
  1541. expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(),
  1542. has_act_order, is_k_full, num_groups, group_size, num_experts, topk,
  1543. moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k,
  1544. thread_n, sms, max_par, replicate_input, apply_weights);
  1545. return c;
  1546. }