marlin_qqq_gemm_kernel.cu 50 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243
  1. /*
  2. * Adapted from
  3. * https://github.com/IST-DASLab/marlin/blob/master/marlin/marlin_cuda_kernel.cu
  4. * https://github.com/IST-DASLab/marlin/blob/master/marlin/marlin_cuda.cpp
  5. * Modified by HandH1998
  6. * Copyright (C) 2024 HandH1998
  7. * Copyright (C) Marlin.2024 Elias Frantar
  8. *
  9. * Licensed under the Apache License, Version 2.0 (the "License");
  10. * you may not use this file except in compliance with the License.
  11. * You may obtain a copy of the License at
  12. *
  13. * http://www.apache.org/licenses/LICENSE-2.0
  14. *
  15. * Unless required by applicable law or agreed to in writing, software
  16. * distributed under the License is distributed on an "AS IS" BASIS,
  17. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. * See the License for the specific language governing permissions and
  19. * limitations under the License.
  20. */
  21. #include <torch/all.h>
  22. #include <ATen/cuda/CUDAContext.h>
  23. #include <c10/cuda/CUDAGuard.h>
  24. #include <cuda.h>
  25. #include <cuda_fp16.h>
  26. #include <cuda_runtime.h>
  27. #include <iostream>
  28. #include "../dense/common/base.cuh"
  29. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  30. #include "../dense/common/mem.cuh"
  31. #endif
  32. template <typename T>
  33. inline std::string str(T x) {
  34. return std::to_string(x);
  35. }
  36. namespace {
  37. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  38. using I4 = Vec<int, 4>;
  39. // Matrix fragments for tensor core instructions; their precise layout is
  40. // documented here:
  41. // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-integer-type
  42. using FragA = Vec<uint32_t, 2>;
  43. using FragB = Vec<uint32_t, 1>;
  44. using FragC = Vec<int, 4>;
  45. using FragS_GROUP = Vec<half2, 1>; // weight per-group quantization scales
  46. using FragS_CHANNEL =
  47. Vec<float, 2>; // weight per-channel quantization scales or activaton
  48. // per-token quantization scales
  49. // NOTE: cp.async.cg only support BYTES = 16, however,
  50. // cp.async.ca can support BYTES = 4, 8, 16;
  51. // as s_tok's shape is equal to prob_m, we need set s_tok to float type,
  52. // and cp_size = 1 float, i.e., 4 BYTES
  53. // Asynchronous global->shared copy for activation quantizaton scales s_tok
  54. __device__ inline void cp_async1(void* smem_ptr, const void* glob_ptr) {
  55. const int BYTES = 4;
  56. uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
  57. asm volatile(
  58. "{\n"
  59. " cp.async.ca.shared.global [%0], [%1], %2;\n"
  60. "}\n" ::"r"(smem),
  61. "l"(glob_ptr), "n"(BYTES));
  62. }
  63. // m16n8k16 tensor core mma instruction with int8 inputs and int32
  64. // output/accumulation.
  65. __device__ inline void mma(const FragA& a_frag, const FragB& frag_b,
  66. FragC& frag_c) {
  67. const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
  68. const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
  69. int* c = reinterpret_cast<int*>(&frag_c);
  70. asm volatile(
  71. "mma.sync.aligned.m16n8k16.row.col.satfinite.s32.s8.s8.s32 "
  72. "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
  73. : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
  74. : "r"(a[0]), "r"(a[1]), "r"(b[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]),
  75. "r"(c[3]));
  76. }
  77. // Instruction for loading a full 16x16 matrix fragment of operand A from shared
  78. // memory, directly in int8 tensor core layout.
  79. __device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
  80. uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
  81. uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
  82. asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n"
  83. : "=r"(a[0]), "=r"(a[1])
  84. : "r"(smem));
  85. }
  86. inline __device__ half2 float2_to_half2(float2 f) {
  87. uint32_t res;
  88. // NOTE: h0,h1 should be uint16_t, not half
  89. uint16_t h0, h1;
  90. asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(h0) : "f"(f.x));
  91. asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(h1) : "f"(f.y));
  92. asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(res) : "h"(h0), "h"(h1));
  93. return reinterpret_cast<half2&>(res);
  94. }
  95. inline __device__ float int32_to_float(int h) {
  96. float res;
  97. asm volatile("cvt.rn.f32.s32 %0, %1;\n" : "=f"(res) : "r"(h));
  98. return res;
  99. }
  100. // Lookup-table based 3-input logical operation; explicitly used for
  101. // dequantization as the compiler does not seem to automatically recognize it in
  102. // all cases.
  103. template <int lut>
  104. __device__ inline int lop3(int a, int b, int c) {
  105. int res;
  106. asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
  107. : "=r"(res)
  108. : "r"(a), "r"(b), "r"(c), "n"(lut));
  109. return res;
  110. }
  111. // Efficiently dequantize an int32 value into a full B-fragment of 4 int8 values
  112. // for weight per channel dequant.
  113. __device__ inline FragB dequant_per_channel(int q) {
  114. static constexpr int MASK = 0xf0f0f0f0;
  115. FragB frag_b;
  116. frag_b[0] = (q & MASK);
  117. return frag_b;
  118. }
  119. // Efficiently dequantize an int32 value into a full B-fragment of 4 int8 values
  120. // for weight per group dequant.
  121. __device__ inline FragB dequant_per_group(int q, FragS_GROUP& frag_s, int i) {
  122. static constexpr uint32_t LO = 0x000f000f;
  123. static constexpr uint32_t HI = 0x00f000f0;
  124. static constexpr uint32_t EX = 0x64006400;
  125. // Guarantee that the `(a & b) | c` operations are LOP3s.
  126. uint32_t t0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
  127. uint32_t t1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
  128. // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
  129. // directly into `SUB` and `ADD`.
  130. static constexpr uint32_t SUB = 0x64086408;
  131. static constexpr uint32_t MUL = 0x2c002c00;
  132. static constexpr uint32_t ADD = 0xd480d480;
  133. *reinterpret_cast<half2*>(&t0) = __hsub2(
  134. *reinterpret_cast<half2*>(&t0), *reinterpret_cast<const half2*>(&SUB));
  135. *reinterpret_cast<half2*>(&t1) = __hfma2(
  136. *reinterpret_cast<half2*>(&t1), *reinterpret_cast<const half2*>(&MUL),
  137. *reinterpret_cast<const half2*>(&ADD));
  138. uint16_t s = reinterpret_cast<uint16_t*>(&frag_s)[i];
  139. uint32_t double_s;
  140. // pack 2xfp16 to half2
  141. asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(double_s) : "h"(s), "h"(s));
  142. // dequant and convert 4 half to 4 uint8 (be placed at the low 8 bits of 4
  143. // half, respectively)
  144. static constexpr uint32_t MAGIC_NUM = 0x64806480;
  145. *reinterpret_cast<half2*>(&t0) = __hfma2(
  146. *reinterpret_cast<half2*>(&t0), *reinterpret_cast<half2*>(&double_s),
  147. *reinterpret_cast<const half2*>(&MAGIC_NUM));
  148. *reinterpret_cast<half2*>(&t1) = __hfma2(
  149. *reinterpret_cast<half2*>(&t1), *reinterpret_cast<half2*>(&double_s),
  150. *reinterpret_cast<const half2*>(&MAGIC_NUM));
  151. // take out the 4 uint8 from 4 half, then convert them to 4 int8 and pack 4
  152. // int8 into 1 uint32
  153. FragB frag_b;
  154. uint32_t uint8s;
  155. static constexpr uint32_t MASK_0246 = 0x6420;
  156. static constexpr uint32_t UINT8s_TO_INT8s_MASK = 0x80808080;
  157. asm volatile("prmt.b32 %0,%1,%2,%3;\n"
  158. : "=r"(uint8s)
  159. : "r"(t0), "r"(t1), "n"(MASK_0246));
  160. frag_b[0] = (uint8s ^ UINT8s_TO_INT8s_MASK);
  161. return frag_b;
  162. }
  163. template <const int threads, // number of threads in a threadblock
  164. const int thread_m_blocks, // number of 16x16 blocks in the m
  165. // dimension (batchsize) of the
  166. // threadblock
  167. const int thread_n_blocks, // same for n dimension (output)
  168. const int thread_k_blocks, // same for k dimension (reduction)
  169. const int stages, // number of stages for the async global->shared
  170. // fetch pipeline
  171. const int group_blocks = -1 // number of consecutive 16x16 blocks
  172. // with a separate quantization scale
  173. >
  174. __global__ void Marlin(
  175. const int4* __restrict__ A, // int8 input matrix of shape mxk
  176. const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
  177. int4* __restrict__ C, // int32 global_reduce buffer of shape
  178. // (max_par*16*4)xn, as int8 tensor core's output is
  179. // int32 dtype
  180. int4* __restrict__ D, // fp16 output buffer of shape mxn
  181. const float* __restrict__ s_tok, // fp32 activation per-token quantization
  182. // scales of shape mx1
  183. const int4* __restrict__ s_ch, // fp32 weight per-channel quantization
  184. // scales of shape 1xn
  185. const int4* __restrict__ s_group, // fp16 weight per-group quantization
  186. // scales of shape (k/groupsize)xn, when
  187. // group_blocks=-1, it should be nullptr
  188. int prob_m, // batch dimension m
  189. int prob_n, // output dimension n
  190. int prob_k, // reduction dimension k
  191. int* locks // extra global storage for barrier synchronization
  192. ) {
  193. // Each threadblock processes one "stripe" of the B matrix with (roughly) the
  194. // same size, which might involve multiple column "slices" (of width 16 *
  195. // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
  196. // example:
  197. // 0 1 3
  198. // 0 2 3
  199. // 1 2 4
  200. // While this kind of partitioning makes things somewhat more complicated, it
  201. // ensures good utilization of all SMs for many kinds of shape and GPU
  202. // configurations, while requiring as few slow global cross-threadblock
  203. // reductions as possible.
  204. // For larger GEMMs we run multiple batchsize 64 versions in parallel for a
  205. // better partitioning with less reductions
  206. int parallel = 1;
  207. if (prob_m > 16 * thread_m_blocks) {
  208. parallel = prob_m / (16 * thread_m_blocks);
  209. prob_m = 16 * thread_m_blocks;
  210. }
  211. int k_tiles = prob_k / 16 / thread_k_blocks;
  212. int n_tiles = prob_n / 16 / thread_n_blocks;
  213. int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x);
  214. // Ensure that the number of tiles in each stripe is a multiple of the
  215. // groupsize; this avoids an annoying special case where a stripe starts in
  216. // the middle of group.
  217. if constexpr (group_blocks != -1)
  218. iters = (group_blocks / thread_k_blocks) *
  219. ceildiv(iters, (group_blocks / thread_k_blocks));
  220. int slice_row = (iters * blockIdx.x) % k_tiles;
  221. int slice_col_par = (iters * blockIdx.x) / k_tiles;
  222. int slice_col = slice_col_par;
  223. int slice_iters; // number of threadblock tiles in the current slice
  224. int slice_count =
  225. 0; // total number of active threadblocks in the current slice
  226. int slice_idx; // index of threadblock in current slice; numbered bottom to
  227. // top
  228. // We can easily implement parallel problem execution by just remapping
  229. // indices and advancing global pointers
  230. if (slice_col_par >= n_tiles) {
  231. A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 16;
  232. C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 4;
  233. D += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
  234. s_tok += (slice_col_par / n_tiles) * 16 * thread_m_blocks;
  235. locks += (slice_col_par / n_tiles) * n_tiles;
  236. slice_col = slice_col_par % n_tiles;
  237. }
  238. // Compute all information about the current slice which is required for
  239. // synchronization.
  240. auto init_slice = [&]() {
  241. slice_iters =
  242. iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
  243. if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
  244. if (slice_iters == 0) return;
  245. if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
  246. slice_count = 1;
  247. slice_idx = 0;
  248. int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);
  249. if (col_first <= k_tiles * (slice_col_par + 1)) {
  250. int col_off = col_first - k_tiles * slice_col_par;
  251. slice_count = ceildiv(k_tiles - col_off, iters);
  252. if (col_off > 0) slice_count++;
  253. int delta_first = iters * blockIdx.x - col_first;
  254. if (delta_first < 0 || (col_off == 0 && delta_first == 0))
  255. slice_idx = slice_count - 1;
  256. else {
  257. slice_idx = slice_count - 1 - delta_first / iters;
  258. if (col_off > 0) slice_idx--;
  259. }
  260. }
  261. if (slice_col == n_tiles) {
  262. A += 16 * thread_m_blocks * prob_k / 16;
  263. C += 16 * thread_m_blocks * prob_n / 4;
  264. D += 16 * thread_m_blocks * prob_n / 8;
  265. s_tok += 16 * thread_m_blocks;
  266. locks += n_tiles;
  267. slice_col = 0;
  268. }
  269. };
  270. init_slice();
  271. int a_gl_stride = prob_k / 16; // stride of the A matrix in global memory
  272. // We typically use `constexpr` to indicate that this value is a compile-time
  273. // constant
  274. constexpr int a_sh_stride =
  275. 16 * thread_k_blocks / 16; // stride of an A matrix tile in shared memory
  276. constexpr int a_gl_rd_delta_o =
  277. 16 * thread_k_blocks /
  278. 16; // delta between subsequent A tiles in global memory
  279. int a_gl_rd_delta_i =
  280. a_gl_stride *
  281. (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile
  282. constexpr int a_sh_wr_delta =
  283. a_sh_stride *
  284. (threads / a_gl_rd_delta_o); // between shared memory writes
  285. constexpr int a_sh_rd_delta_o =
  286. 1 * ((threads / 32) /
  287. (thread_n_blocks / 4)); // between shared memory tile reads
  288. constexpr int a_sh_rd_delta_i =
  289. a_sh_stride * 16; // within a shared memory tile
  290. constexpr int a_sh_stage =
  291. a_sh_stride * (16 * thread_m_blocks); // overall size of a tile
  292. constexpr int a_sh_wr_iters =
  293. ceildiv(a_sh_stage,
  294. a_sh_wr_delta); // number of shared write iterations for a tile
  295. int b_gl_stride = 16 * prob_n / 32;
  296. constexpr int b_sh_stride = 32 * thread_n_blocks / 4;
  297. int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
  298. int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride);
  299. constexpr int b_sh_wr_delta = threads;
  300. constexpr int b_sh_rd_delta = threads;
  301. constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
  302. constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
  303. constexpr int s_tok_sh_stride = 16 * thread_m_blocks;
  304. constexpr int s_ch_sh_stride = 16 * thread_n_blocks / 4;
  305. int s_group_gl_stride = prob_n / 8;
  306. constexpr int s_group_sh_stride = 16 * thread_n_blocks / 8;
  307. constexpr int s_group_sh_stage = s_group_sh_stride;
  308. int s_group_gl_rd_delta = s_group_gl_stride;
  309. // Global A read index of current thread.
  310. int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
  311. (threadIdx.x % a_gl_rd_delta_o);
  312. a_gl_rd += a_gl_rd_delta_o * slice_row;
  313. // Shared write index of current thread.
  314. int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
  315. (threadIdx.x % a_gl_rd_delta_o);
  316. // Shared read index.
  317. // NOTE: int8 input a only need 16 threads to load 16x16 matrix
  318. int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % 16);
  319. a_sh_rd += 1 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
  320. int b_gl_rd =
  321. b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
  322. b_gl_rd += b_sh_stride * slice_col;
  323. b_gl_rd += b_gl_rd_delta_o * slice_row;
  324. int b_sh_wr = threadIdx.x;
  325. int b_sh_rd = threadIdx.x;
  326. int s_tok_gl_rd = threadIdx.x;
  327. // NOTE: activation scale s_tok need shuffle to [0, 8, 1, 9, 2, 10,
  328. // 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] for example, 0, 8 row scales serve for
  329. // thread 0, 1, 2, 3. For more details, refer to mma operand A layout as
  330. // s_tok's size is not fixed, we can not shuffle before inference we shuffle
  331. // it when fetching s_tok from global memory to shared memory, that's why
  332. // s_tok_sh_wr is like this
  333. int s_tok_sh_wr =
  334. (threadIdx.x / 16) * 16 + (threadIdx.x % 8) * 2 + (threadIdx.x % 16) / 8;
  335. int s_tok_sh_rd = (threadIdx.x % 32) / 4;
  336. bool s_tok_sh_wr_pred = threadIdx.x < prob_m;
  337. int s_ch_gl_rd = s_ch_sh_stride * slice_col + threadIdx.x;
  338. int s_ch_sh_wr = threadIdx.x;
  339. int s_ch_sh_rd = 16 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
  340. 2 * ((threadIdx.x % 32) % 4);
  341. bool s_ch_sh_wr_pred = threadIdx.x < s_ch_sh_stride;
  342. int s_group_gl_rd, s_group_sh_wr, s_group_sh_rd;
  343. bool s_group_sh_wr_pred;
  344. if constexpr (group_blocks != -1) {
  345. s_group_gl_rd =
  346. s_group_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
  347. s_group_sh_stride * slice_col + threadIdx.x;
  348. s_group_sh_wr = threadIdx.x;
  349. // NOTE: s_group_sh_rd is related to mma output C
  350. s_group_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
  351. (threadIdx.x % 32) / 4;
  352. s_group_sh_wr_pred = threadIdx.x < s_group_sh_stride;
  353. }
  354. // Precompute which thread should not read memory in which iterations; this is
  355. // needed if there are more threads than required for a certain tilesize or
  356. // when the batchsize is not a multiple of 16.
  357. bool a_sh_wr_pred[a_sh_wr_iters];
  358. #pragma unroll
  359. for (int i = 0; i < a_sh_wr_iters; i++)
  360. a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
  361. // To ensure that writing and reading A tiles to/from shared memory, the
  362. // latter in fragment format, is fully bank conflict free, we need to use a
  363. // rather fancy XOR-based layout. The key here is that neither reads nor
  364. // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
  365. // same shared memory banks. Further, it seems (based on NSight-Compute) that
  366. // each warp must also write a consecutive memory segment?
  367. auto transform_a = [&](int i) {
  368. int row = i / a_gl_rd_delta_o;
  369. return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
  370. };
  371. // Since the computation of this remapping is non-trivial and, due to our main
  372. // loop unrolls, all shared memory accesses are static, we simply precompute
  373. // both transformed reads and writes.
  374. int a_sh_wr_trans[a_sh_wr_iters];
  375. #pragma unroll
  376. for (int i = 0; i < a_sh_wr_iters; i++)
  377. a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
  378. int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
  379. #pragma unroll
  380. for (int i = 0; i < b_sh_wr_iters; i++) {
  381. #pragma unroll
  382. for (int j = 0; j < thread_m_blocks; j++)
  383. a_sh_rd_trans[i][j] =
  384. transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
  385. }
  386. // Since B-accesses have non-constant stride they have to be computed at
  387. // runtime; we break dependencies between subsequent accesses with a tile by
  388. // maintining multiple pointers (we have enough registers), a tiny
  389. // optimization.
  390. const int4* B_ptr[b_sh_wr_iters];
  391. #pragma unroll
  392. for (int i = 0; i < b_sh_wr_iters; i++)
  393. B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
  394. extern __shared__ int4 sh[];
  395. // Shared memory storage for global fetch pipelines.
  396. // NOTE: stages need >= 4, otherwise, sh_s_tok = sh + max(stages *
  397. // a_sh_stage + stages * b_sh_stage, 4 * stages * a_sh_stage)
  398. int4* sh_a = sh;
  399. int4* sh_b = sh_a + (stages * a_sh_stage);
  400. int4* sh_s_tok = sh_b + (stages * b_sh_stage);
  401. int4* sh_s_ch = sh_s_tok + s_tok_sh_stride;
  402. int4* sh_s_group = sh_s_ch + s_ch_sh_stride;
  403. // Register storage for double buffer of shared memory reads.
  404. FragA frag_a[2][thread_m_blocks];
  405. I4 frag_b_quant[2];
  406. FragC frag_c[thread_m_blocks][4][2];
  407. FragS_GROUP frag_s_group[2][4];
  408. FragS_CHANNEL frag_s_tok[thread_m_blocks];
  409. FragS_CHANNEL frag_s_ch[2][4];
  410. // Zero accumulators.
  411. auto zero_accums = [&]() {
  412. #pragma unroll
  413. for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
  414. reinterpret_cast<int*>(frag_c)[i] = 0;
  415. };
  416. // Asynchronously fetch the next A, B and s tile from global to the next
  417. // shared memory pipeline location.
  418. auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
  419. if (pred) {
  420. int4* sh_a_stage = sh_a + a_sh_stage * pipe;
  421. #pragma unroll
  422. for (int i = 0; i < a_sh_wr_iters; i++) {
  423. cp_async4_pred(
  424. &sh_a_stage[a_sh_wr_trans[i]],
  425. &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
  426. a_sh_wr_pred[i]);
  427. }
  428. int4* sh_b_stage = sh_b + b_sh_stage * pipe;
  429. #pragma unroll
  430. for (int i = 0; i < b_sh_wr_iters; i++) {
  431. cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]);
  432. B_ptr[i] += b_gl_rd_delta_o;
  433. }
  434. // Only fetch scales if this tile starts a new group
  435. if constexpr (group_blocks != -1) {
  436. if (pipe % (group_blocks / thread_k_blocks) == 0) {
  437. int4* sh_s_group_stage = sh_s_group + s_group_sh_stage * pipe;
  438. if (s_group_sh_wr_pred)
  439. cp_async4(&sh_s_group_stage[s_group_sh_wr],
  440. &s_group[s_group_gl_rd]);
  441. s_group_gl_rd += s_group_gl_rd_delta;
  442. }
  443. }
  444. }
  445. // Insert a fence even when we are winding down the pipeline to ensure that
  446. // waiting is also correct at this point.
  447. cp_async_fence();
  448. };
  449. // Wait until the next thread tile has been loaded to shared memory.
  450. auto wait_for_stage = [&]() {
  451. // We only have `stages - 2` active fetches since we are double buffering
  452. // and can only issue the next fetch when it is guaranteed that the previous
  453. // shared memory load is fully complete (as it may otherwise be
  454. // overwritten).
  455. cp_async_wait<stages - 2>();
  456. __syncthreads();
  457. };
  458. // Load the next sub-tile from the current location in the shared memory pipe
  459. // into the current register buffer.
  460. auto fetch_to_registers = [&](int k, int pipe) {
  461. // It may seem inefficient that we reload the groups for every sub-tile;
  462. // however, this does not seem to be a significant bottleneck, while some
  463. // theoretically better attempts have lead to bad instruction ordering by
  464. // the compiler and correspondingly a noticeable drop in performance.
  465. if constexpr (group_blocks != -1) {
  466. int4* sh_s_group_stage =
  467. sh_s_group +
  468. s_group_sh_stage * ((group_blocks / thread_k_blocks) *
  469. (pipe / (group_blocks / thread_k_blocks)));
  470. reinterpret_cast<int4*>(&frag_s_group[k % 2])[0] =
  471. sh_s_group_stage[s_group_sh_rd];
  472. }
  473. int4* sh_a_stage = sh_a + a_sh_stage * pipe;
  474. #pragma unroll
  475. for (int i = 0; i < thread_m_blocks; i++)
  476. ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
  477. int4* sh_b_stage = sh_b + b_sh_stage * pipe;
  478. frag_b_quant[k % 2] = *reinterpret_cast<I4*>(
  479. &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]);
  480. };
  481. // Execute the actual tensor core matmul of a sub-tile.
  482. auto matmul = [&](int k) {
  483. // We have the m dimension as the inner loop in order to encourage overlapping
  484. // dequantization and matmul operations.
  485. #pragma unroll
  486. for (int j = 0; j < 4; j++) {
  487. int b_quant = frag_b_quant[k % 2][j];
  488. // int b_quant_shift = b_quant << 4;
  489. FragB frag_b0, frag_b1;
  490. // If there are no groups, we can just scale the final output once and can
  491. // avoid doing so for each weight.
  492. if constexpr (group_blocks != -1) {
  493. int b_quant_shift = b_quant >> 8;
  494. frag_b0 = dequant_per_group(b_quant, frag_s_group[k % 2][j], 0);
  495. frag_b1 = dequant_per_group(b_quant_shift, frag_s_group[k % 2][j], 1);
  496. } else {
  497. int b_quant_shift = b_quant << 4;
  498. frag_b0 = dequant_per_channel(b_quant);
  499. frag_b1 = dequant_per_channel(b_quant_shift);
  500. }
  501. #pragma unroll
  502. for (int i = 0; i < thread_m_blocks; i++) {
  503. mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
  504. mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
  505. }
  506. }
  507. };
  508. // Since we slice across the k dimension of a tile in order to increase the
  509. // number of warps while keeping the n dimension of a tile reasonable, we have
  510. // multiple warps that accumulate their partial sums of the same output
  511. // location; which we have to reduce over in the end. We do in shared memory.
  512. auto thread_block_reduce = [&]() {
  513. constexpr int red_off = threads / b_sh_stride / 2;
  514. if (red_off >= 1) {
  515. int red_idx = threadIdx.x / b_sh_stride;
  516. constexpr int red_sh_stride = b_sh_stride * 4 * 2;
  517. constexpr int red_sh_delta = b_sh_stride;
  518. int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) +
  519. (threadIdx.x % b_sh_stride);
  520. // Parallel logarithmic shared memory reduction. We make sure to avoid any
  521. // unnecessary read or write iterations, e.g., for two warps we write only
  522. // once by warp 1 and read only once by warp 0.
  523. #pragma unroll
  524. for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
  525. #pragma unroll
  526. for (int i = red_off; i > 0; i /= 2) {
  527. if (i <= red_idx && red_idx < 2 * i) {
  528. #pragma unroll
  529. for (int j = 0; j < 4 * 2; j++) {
  530. int red_sh_wr =
  531. red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
  532. if (i < red_off) {
  533. int* c_rd =
  534. reinterpret_cast<int*>(&sh[red_sh_delta * j + red_sh_rd]);
  535. int* c_wr = reinterpret_cast<int*>(&sh[red_sh_wr]);
  536. #pragma unroll
  537. for (int k = 0; k < 4; k++)
  538. reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
  539. c_rd[k] + c_wr[k];
  540. }
  541. sh[red_sh_wr] =
  542. reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
  543. }
  544. }
  545. __syncthreads();
  546. }
  547. if (red_idx == 0) {
  548. #pragma unroll
  549. for (int i = 0; i < 4 * 2; i++) {
  550. int* c_rd =
  551. reinterpret_cast<int*>(&sh[red_sh_delta * i + red_sh_rd]);
  552. #pragma unroll
  553. for (int j = 0; j < 4; j++)
  554. reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
  555. c_rd[j];
  556. }
  557. }
  558. __syncthreads();
  559. }
  560. }
  561. };
  562. // Since multiple threadblocks may process parts of the same column slice, we
  563. // finally have to globally reduce over the results. As the striped
  564. // partitioning minimizes the number of such reductions and our outputs are
  565. // usually rather small, we perform this reduction serially in L2 cache.
  566. // global_reduce works on INT32 elements, which are the results of INT8 GEMM.
  567. // This is why we need another INT32 maxtrix `C` to reduce instead of the
  568. // original half matrix `D`.
  569. auto global_reduce = [&](bool first = false, bool last = false) {
  570. // We are very careful here to reduce directly in the output buffer to
  571. // maximize L2 cache utilization in this step. To do this, we write out
  572. // results in FP16 (but still reduce with FP32 compute).
  573. constexpr int active_threads = 32 * thread_n_blocks / 4;
  574. if (threadIdx.x < active_threads) {
  575. int c_gl_stride = prob_n / 4;
  576. int c_gl_wr_delta_o = 8 * c_gl_stride;
  577. int c_gl_wr_delta_i = 8 * (active_threads / 32);
  578. int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) +
  579. 8 * (threadIdx.x / 32) + (threadIdx.x % 4) * 2;
  580. c_gl_wr += (4 * thread_n_blocks) * slice_col;
  581. constexpr int c_sh_wr_delta = active_threads * 2;
  582. int c_sh_wr = 2 * threadIdx.x;
  583. int row = (threadIdx.x % 32) / 4;
  584. if (!first) {
  585. // Interestingly, doing direct global accesses here really seems to mess up
  586. // the compiler and lead to slowdowns, hence we also use async-copies even
  587. // though these fetches are not actually asynchronous.
  588. #pragma unroll
  589. for (int i = 0; i < thread_m_blocks * 4; i++) {
  590. cp_async4_pred(
  591. &sh[c_sh_wr + c_sh_wr_delta * i],
  592. &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
  593. c_gl_wr_delta_i * (i % 2)],
  594. i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
  595. cp_async4_pred(
  596. &sh[c_sh_wr + c_sh_wr_delta * i + 1],
  597. &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
  598. c_gl_wr_delta_i * (i % 2) + 1],
  599. i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
  600. }
  601. cp_async_fence();
  602. cp_async_wait<0>();
  603. }
  604. #pragma unroll
  605. for (int i = 0; i < thread_m_blocks * 4; i++) {
  606. if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
  607. if (!first) {
  608. int4 d_red1 = sh[c_sh_wr + i * c_sh_wr_delta];
  609. int4 d_red2 = sh[c_sh_wr + i * c_sh_wr_delta + 1];
  610. #pragma unroll
  611. for (int j = 0; j < 4; j++) {
  612. reinterpret_cast<int*>(
  613. &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=
  614. reinterpret_cast<int*>(&d_red1)[j];
  615. }
  616. #pragma unroll
  617. for (int j = 0; j < 4; j++) {
  618. reinterpret_cast<int*>(
  619. &frag_c)[4 * 2 * 4 * (i / 4) + 4 * (j + 4) + (i % 4)] +=
  620. reinterpret_cast<int*>(&d_red2)[j];
  621. }
  622. }
  623. if (!last) {
  624. int4 d1, d2;
  625. #pragma unroll
  626. for (int j = 0; j < 4; j++) {
  627. reinterpret_cast<int*>(&d1)[j] = reinterpret_cast<int*>(
  628. &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)];
  629. }
  630. #pragma unroll
  631. for (int j = 0; j < 4; j++) {
  632. reinterpret_cast<int*>(&d2)[j] = reinterpret_cast<int*>(
  633. &frag_c)[4 * 2 * 4 * (i / 4) + 4 * (j + 4) + (i % 4)];
  634. }
  635. C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =
  636. d1;
  637. C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2) +
  638. 1] = d2;
  639. }
  640. }
  641. }
  642. }
  643. };
  644. // Write out the reduce final result in the correct layout. We only actually
  645. // reshuffle matrix fragments in this step, the reduction above is performed
  646. // in fragment layout.
  647. auto write_result = [&]() {
  648. int d_gl_stride = prob_n / 8;
  649. constexpr int d_sh_stride = 2 * thread_n_blocks + 1;
  650. int d_gl_wr_delta = d_gl_stride * (threads / (2 * thread_n_blocks));
  651. constexpr int d_sh_rd_delta =
  652. d_sh_stride * (threads / (2 * thread_n_blocks));
  653. int d_gl_wr = d_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) +
  654. (threadIdx.x % (2 * thread_n_blocks));
  655. d_gl_wr += (2 * thread_n_blocks) * slice_col;
  656. int d_sh_wr =
  657. (4 * d_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;
  658. d_sh_wr += 32 * (threadIdx.x / 32);
  659. int d_sh_rd = d_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) +
  660. (threadIdx.x % (2 * thread_n_blocks));
  661. int d_gl_wr_end = d_gl_stride * prob_m;
  662. // We first reorder in shared memory to guarantee the most efficient final
  663. // global write patterns
  664. auto write = [&](int idx, int c0, int c1, float a_s, FragS_CHANNEL& w_s) {
  665. float2 deq_res;
  666. deq_res.x = int32_to_float(c0) * w_s[0] * a_s;
  667. deq_res.y = int32_to_float(c1) * w_s[1] * a_s;
  668. ((half2*)sh)[idx] = float2_to_half2(deq_res);
  669. };
  670. if (threadIdx.x / 32 < thread_n_blocks / 4) {
  671. #pragma unroll
  672. for (int i = 0; i < thread_m_blocks; i++) {
  673. #pragma unroll
  674. for (int j = 0; j < 4; j++) {
  675. int wr = d_sh_wr + 8 * j;
  676. write(wr + (4 * d_sh_stride) * 0 + 0, frag_c[i][j][0][0],
  677. frag_c[i][j][0][1], frag_s_tok[i][0],
  678. frag_s_ch[j / 2][2 * (j % 2) + 0]);
  679. write(wr + (4 * d_sh_stride) * 8 + 0, frag_c[i][j][0][2],
  680. frag_c[i][j][0][3], frag_s_tok[i][1],
  681. frag_s_ch[j / 2][2 * (j % 2) + 0]);
  682. write(wr + (4 * d_sh_stride) * 0 + 4, frag_c[i][j][1][0],
  683. frag_c[i][j][1][1], frag_s_tok[i][0],
  684. frag_s_ch[j / 2][2 * (j % 2) + 1]);
  685. write(wr + (4 * d_sh_stride) * 8 + 4, frag_c[i][j][1][2],
  686. frag_c[i][j][1][3], frag_s_tok[i][1],
  687. frag_s_ch[j / 2][2 * (j % 2) + 1]);
  688. }
  689. d_sh_wr += 16 * (4 * d_sh_stride);
  690. }
  691. }
  692. __syncthreads();
  693. #pragma unroll
  694. for (int i = 0;
  695. i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
  696. i++) {
  697. if (d_gl_wr < d_gl_wr_end) {
  698. D[d_gl_wr] = sh[d_sh_rd];
  699. d_gl_wr += d_gl_wr_delta;
  700. d_sh_rd += d_sh_rd_delta;
  701. }
  702. }
  703. };
  704. // Start global fetch and register load pipelines.
  705. auto start_pipes = [&]() {
  706. #pragma unroll
  707. for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters);
  708. zero_accums();
  709. wait_for_stage();
  710. fetch_to_registers(0, 0);
  711. a_gl_rd += a_gl_rd_delta_o * (stages - 1);
  712. };
  713. start_pipes();
  714. // Main loop.
  715. while (slice_iters) {
  716. // We unroll over both the global fetch and the register load pipeline to
  717. // ensure all shared memory accesses are static. Note that both pipelines have
  718. // even length meaning that the next iteration will always start at index 0.
  719. #pragma unroll
  720. for (int pipe = 0; pipe < stages;) {
  721. #pragma unroll
  722. for (int k = 0; k < b_sh_wr_iters; k++) {
  723. fetch_to_registers(k + 1, pipe % stages);
  724. if (k == b_sh_wr_iters - 2) {
  725. fetch_to_shared((pipe + stages - 1) % stages, pipe,
  726. slice_iters >= stages);
  727. pipe++;
  728. wait_for_stage();
  729. }
  730. matmul(k);
  731. }
  732. slice_iters--;
  733. if (slice_iters == 0) break;
  734. }
  735. a_gl_rd += a_gl_rd_delta_o * stages;
  736. // Process results and, if necessary, proceed to the next column slice.
  737. // While this pattern may not be the most readable, other ways of writing
  738. // the loop seemed to noticeably worse performance after compilation.
  739. if (slice_iters == 0) {
  740. cp_async_wait<0>();
  741. bool last = slice_idx == slice_count - 1;
  742. // For per-column scales, we only fetch them here in the final step before
  743. // write-out
  744. if (last) {
  745. if (s_tok_sh_wr_pred) {
  746. cp_async1(&sh_s_tok[s_tok_sh_wr], &s_tok[s_tok_gl_rd]);
  747. }
  748. if (s_ch_sh_wr_pred) {
  749. cp_async4(&sh_s_ch[s_ch_sh_wr], &s_ch[s_ch_gl_rd]);
  750. }
  751. cp_async_fence();
  752. }
  753. thread_block_reduce();
  754. if (last) {
  755. cp_async_wait<0>();
  756. __syncthreads();
  757. if (threadIdx.x / 32 < thread_n_blocks / 4) {
  758. #pragma unroll
  759. for (int i = 0; i < thread_m_blocks; i++) {
  760. frag_s_tok[i][0] =
  761. *reinterpret_cast<float*>(&sh_s_tok[16 * i + 2 * s_tok_sh_rd]);
  762. frag_s_tok[i][1] = *reinterpret_cast<float*>(
  763. &sh_s_tok[16 * i + 2 * s_tok_sh_rd + 1]);
  764. }
  765. reinterpret_cast<int4*>(&frag_s_ch)[0] = sh_s_ch[s_ch_sh_rd + 0];
  766. reinterpret_cast<int4*>(&frag_s_ch)[1] = sh_s_ch[s_ch_sh_rd + 1];
  767. reinterpret_cast<int4*>(&frag_s_ch)[2] = sh_s_ch[s_ch_sh_rd + 8];
  768. reinterpret_cast<int4*>(&frag_s_ch)[3] = sh_s_ch[s_ch_sh_rd + 9];
  769. }
  770. }
  771. if (slice_count > 1) { // only globally reduce if there is more than one
  772. // block in a slice
  773. barrier_acquire(&locks[slice_col], slice_idx);
  774. global_reduce(slice_idx == 0, last);
  775. barrier_release(&locks[slice_col], last);
  776. }
  777. if (last) // only the last block in a slice actually writes the result
  778. write_result();
  779. slice_row = 0;
  780. slice_col_par++;
  781. slice_col++;
  782. init_slice();
  783. if (slice_iters) {
  784. a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
  785. (threadIdx.x % a_gl_rd_delta_o);
  786. #pragma unroll
  787. for (int i = 0; i < b_sh_wr_iters; i++)
  788. B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
  789. if (slice_col == 0) {
  790. #pragma unroll
  791. for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
  792. }
  793. s_group_gl_rd = s_group_sh_stride * slice_col + threadIdx.x;
  794. s_ch_gl_rd = s_ch_sh_stride * slice_col + threadIdx.x;
  795. start_pipes();
  796. }
  797. }
  798. }
  799. }
  800. #else
  801. template <const int threads, // number of threads in a threadblock
  802. const int thread_m_blocks, // number of 16x16 blocks in the m
  803. // dimension (batchsize) of the
  804. // threadblock
  805. const int thread_n_blocks, // same for n dimension (output)
  806. const int thread_k_blocks, // same for k dimension (reduction)
  807. const int stages, // number of stages for the async global->shared
  808. // fetch pipeline
  809. const int group_blocks = -1 // number of consecutive 16x16 blocks
  810. // with a separate quantization scale
  811. >
  812. __global__ void Marlin(
  813. const int4* __restrict__ A, // int8 input matrix of shape mxk
  814. const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
  815. int4* __restrict__ C, // int32 global_reduce buffer of shape
  816. // (max_par*16*4)xn, as int8 tensor core's output is
  817. // int32 dtype
  818. int4* __restrict__ D, // fp16 output buffer of shape mxn
  819. const float* __restrict__ s_tok, // fp32 activation per-token quantization
  820. // scales of shape mx1
  821. const int4* __restrict__ s_ch, // fp32 weight per-channel quantization
  822. // scales of shape 1xn
  823. const int4* __restrict__ s_group, // fp16 weight per-group quantization
  824. // scales of shape (k/groupsize)xn, when
  825. // group_blocks=-1, it should be nullptr
  826. int prob_m, // batch dimension m
  827. int prob_n, // output dimension n
  828. int prob_k, // reduction dimension k
  829. int* locks // extra global storage for barrier synchronization
  830. ) {
  831. // Marlin is not implemented yet for SM < 8.0
  832. assert(false);
  833. return;
  834. }
  835. #endif
  836. // 8 warps are a good choice since every SM has 4 schedulers and having more
  837. // than 1 warp per schedule allows some more latency hiding. At the same time,
  838. // we want relatively few warps to have many registers per warp and small tiles.
  839. const int USER_THREADS =
  840. 256; // Note: This is only used with user-provided thread_k/n
  841. const int STAGES = 4; // 4 pipeline stages fit into shared memory
  842. static constexpr int min_thread_n = 64;
  843. static constexpr int min_thread_k = 64;
  844. static constexpr int tile_size = 16;
  845. static constexpr int max_par = 16;
  846. static constexpr int pack_factor_4bit =
  847. 8; // We have 8 4-bit vals inside a 32 bit
  848. #define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
  849. GROUP_BLOCKS, NUM_THREADS) \
  850. else if (thread_m_blocks == THREAD_M_BLOCKS && \
  851. thread_n_blocks == THREAD_N_BLOCKS && \
  852. thread_k_blocks == THREAD_K_BLOCKS && \
  853. group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
  854. cudaFuncSetAttribute(Marlin<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
  855. THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>, \
  856. cudaFuncAttributeMaxDynamicSharedMemorySize, \
  857. max_shared_mem); \
  858. Marlin<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
  859. STAGES, GROUP_BLOCKS> \
  860. <<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
  861. A_ptr, B_ptr, C_ptr, D_ptr, s_tok_ptr, s_ch_ptr, s_group_ptr, \
  862. prob_m, prob_n, prob_k, locks); \
  863. }
  864. typedef struct {
  865. int thread_k;
  866. int thread_n;
  867. int num_threads;
  868. } thread_config_t;
  869. thread_config_t small_batch_thread_configs[] = {
  870. // Ordered by priority
  871. // thread_k, thread_n, num_threads
  872. {128, 128, 256}, // Default
  873. {128, 64, 128}, // Reduce N 2X, same K
  874. {64, 256, 256}, // Reduce K 2X, increase N 2X
  875. {64, 128, 128}, // Reduce K 2X, same N
  876. };
  877. thread_config_t large_batch_thread_configs[] = {
  878. // Ordered by priority
  879. // thread_k, thread_n, num_threads
  880. {64, 256, 256}, // Default
  881. {128, 128, 256}, // Reduce N 2X, increase K 2X
  882. {64, 128, 128}, // Reduce N 2X, same K
  883. {128, 64, 128}, // Reduce N 4X, increase K 2X
  884. };
  885. bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n,
  886. int prob_k) {
  887. // Sanity
  888. if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
  889. th_config.num_threads == -1) {
  890. return false;
  891. }
  892. // Verify K/N are divisible by thread K/N
  893. if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
  894. return false;
  895. }
  896. // thread_k can be only 128 or 64 (because it must be less than groupsize
  897. // which is 128)
  898. if (th_config.thread_k != 128 && th_config.thread_k != 64) {
  899. return false;
  900. }
  901. // Verify min for thread K/N
  902. if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
  903. return false;
  904. }
  905. // num_threads must be at least 128 (= 4 warps)
  906. if (th_config.num_threads < 128) {
  907. return false;
  908. }
  909. return true;
  910. }
  911. thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) {
  912. if (prob_m <= 16) {
  913. for (auto th_config : small_batch_thread_configs) {
  914. if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
  915. return th_config;
  916. }
  917. }
  918. } else {
  919. for (auto th_config : large_batch_thread_configs) {
  920. if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
  921. return th_config;
  922. }
  923. }
  924. }
  925. return thread_config_t{-1, -1, -1};
  926. }
  927. #define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \
  928. __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
  929. __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
  930. __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
  931. __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
  932. __CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
  933. __CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
  934. __CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
  935. __CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
  936. __CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
  937. __CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS)
  938. void marlin_qqq_cuda(const void* A, const void* B, void* C, void* D,
  939. void* s_tok, void* s_ch, void* s_group, int prob_m,
  940. int prob_n, int prob_k, void* workspace,
  941. int groupsize = -1, int dev = 0, cudaStream_t stream = 0,
  942. int thread_k = -1, int thread_n = -1, int sms = -1,
  943. int max_par = 16) {
  944. int tot_m = prob_m;
  945. int tot_m_blocks = ceildiv(tot_m, 16);
  946. int pad = 16 * tot_m_blocks - tot_m;
  947. if (sms == -1)
  948. cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
  949. int max_shared_mem = 0;
  950. cudaDeviceGetAttribute(&max_shared_mem,
  951. cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
  952. TORCH_CHECK(max_shared_mem > 0);
  953. // Set thread config
  954. thread_config_t th_config;
  955. if (thread_k != -1 && thread_n != -1) {
  956. // User-defined config
  957. th_config = thread_config_t{thread_k, thread_n, USER_THREADS};
  958. } else {
  959. // Auto config
  960. th_config = determine_thread_config(prob_m, prob_n, prob_k);
  961. }
  962. if (!is_valid_config(th_config, prob_m, prob_n, prob_k)) {
  963. throw std::runtime_error(
  964. "Invalid thread config: thread_k = " + str(th_config.thread_k) +
  965. ", thread_n = " + str(th_config.thread_n) +
  966. ", num_threads = " + str(th_config.num_threads) + " for MKN = [" +
  967. str(prob_m) + ", " + str(prob_k) + ", " + str(prob_n) + "]");
  968. }
  969. int num_threads = th_config.num_threads;
  970. thread_k = th_config.thread_k;
  971. thread_n = th_config.thread_n;
  972. int thread_k_blocks = thread_k / 16;
  973. int thread_n_blocks = thread_n / 16;
  974. int group_blocks = (groupsize == -1) ? -1 : groupsize / 16;
  975. int blocks = sms;
  976. if (prob_m == 0 || prob_n == 0 || prob_k == 0) {
  977. return;
  978. }
  979. TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
  980. " is not divisible by thread_n = ", thread_n);
  981. TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
  982. " is not divisible by thread_k = ", thread_k);
  983. if (group_blocks != -1) {
  984. TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
  985. " is not divisible by group_blocks = ", group_blocks);
  986. }
  987. const int4* A_ptr = (const int4*)A;
  988. const int4* B_ptr = (const int4*)B;
  989. int4* C_ptr = (int4*)C;
  990. int4* D_ptr = (int4*)D;
  991. const float* s_tok_ptr = (const float*)s_tok;
  992. const int4* s_ch_ptr = (const int4*)s_ch;
  993. const int4* s_group_ptr = (const int4*)s_group;
  994. int* locks = (int*)workspace;
  995. for (int i = 0; i < tot_m_blocks; i += 4) {
  996. int thread_m_blocks = tot_m_blocks - i;
  997. prob_m = tot_m - 16 * i;
  998. int par = 1;
  999. if (thread_m_blocks > 4) {
  1000. // Note that parallel > 1 currently only works for inputs without any
  1001. // padding
  1002. par = (16 * thread_m_blocks - pad) / 64;
  1003. if (par > max_par) par = max_par;
  1004. prob_m = 64 * par;
  1005. i += 4 * (par - 1);
  1006. thread_m_blocks = 4;
  1007. }
  1008. // For compilation speed, we only define the kernel configurations that have
  1009. // seemed useful (in terms of performance) in our testing, however many more
  1010. // are, in principle, possible.
  1011. if (false) {
  1012. }
  1013. CALL_IF(8, 8, 256)
  1014. CALL_IF(16, 4, 256)
  1015. CALL_IF(8, 4, 128)
  1016. CALL_IF(4, 8, 128)
  1017. else {
  1018. throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) +
  1019. ", " + str(prob_k) + ", " + str(prob_n) + "]" +
  1020. ", groupsize = " + str(groupsize) +
  1021. ", thread_m_blocks = " + str(thread_m_blocks) +
  1022. ", thread_n_blocks = " + str(thread_n_blocks) +
  1023. ", thread_k_blocks = " + str(thread_k_blocks));
  1024. }
  1025. A_ptr += 16 * thread_m_blocks * (prob_k / 16) * par;
  1026. D_ptr += 16 * thread_m_blocks * (prob_n / 8) * par;
  1027. s_tok_ptr += 16 * thread_m_blocks * par;
  1028. }
  1029. }
  1030. } // anonymous namespace
  1031. torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
  1032. torch::Tensor const& b_q_weight,
  1033. torch::Tensor const& s_tok,
  1034. torch::Tensor const& s_ch,
  1035. torch::Tensor const& s_group,
  1036. torch::Tensor& workspace, int64_t size_m,
  1037. int64_t size_n, int64_t size_k) {
  1038. // Verify M
  1039. TORCH_CHECK(size_m == a.size(0),
  1040. "Shape mismatch: a.size(0) = " + str(a.size(0)) +
  1041. ", size_m = " + str(size_m));
  1042. TORCH_CHECK(size_m == s_tok.numel(),
  1043. "Shape mismatch: s_tok.numel() = " + str(s_tok.numel()) +
  1044. ", size_m = " + str(size_m));
  1045. // Verify K
  1046. TORCH_CHECK(size_k == a.size(1),
  1047. "Shape mismatch: a.size(1) = " + str(a.size(1)) +
  1048. ", size_k = " + str(size_k));
  1049. TORCH_CHECK(size_k % tile_size == 0,
  1050. "size_k = " + str(size_k) +
  1051. " is not divisible by tile_size = " + str(tile_size));
  1052. TORCH_CHECK(
  1053. (size_k / tile_size) == b_q_weight.size(0),
  1054. "Shape mismatch: b_q_weight.size(0) = " + str(b_q_weight.size(0)) +
  1055. ", size_k = " + str(size_k) + ", tile_size = " + str(tile_size));
  1056. int groupsize = (s_group.numel() == 0) ? -1 : size_k / s_group.size(0);
  1057. // Verify groupsize
  1058. TORCH_CHECK(groupsize == -1 || groupsize == 128,
  1059. "Unexpected groupsize = " + str(groupsize));
  1060. // Verify N
  1061. TORCH_CHECK(s_ch.numel() == size_n,
  1062. "Shape mismatch: s_ch.numel() = " + str(s_ch.numel()) +
  1063. ", size_n = " + str(size_n));
  1064. TORCH_CHECK(b_q_weight.size(1) % tile_size == 0,
  1065. "b_q_weight.size(1) = " + str(b_q_weight.size(1)) +
  1066. " is not divisible by tile_size = " + str(tile_size));
  1067. if (groupsize != -1) {
  1068. TORCH_CHECK(s_group.size(1) == size_n,
  1069. "Shape mismatch: s_group.size(1) = " + str(s_group.size(1)) +
  1070. ", size_n = " + str(size_n));
  1071. TORCH_CHECK(
  1072. size_k % s_group.size(0) == 0,
  1073. "size_k = " + str(size_k) +
  1074. ", is not divisible by s_group.size(0) = " + str(s_group.size(0)));
  1075. }
  1076. int actual_size_n = (b_q_weight.size(1) / tile_size) * pack_factor_4bit;
  1077. TORCH_CHECK(size_n == actual_size_n,
  1078. "Shape mismatch: size_n = " + str(size_n) +
  1079. ", actual_size_n = " + str(actual_size_n));
  1080. // Verify A device and strides
  1081. TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
  1082. TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
  1083. // Verify B device and strides
  1084. TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
  1085. TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
  1086. // Verify s_tok device, strides and dtype
  1087. TORCH_CHECK(s_tok.device().is_cuda(), "s_tok is not on GPU");
  1088. TORCH_CHECK(s_tok.is_contiguous(), "s_tok is not contiguous");
  1089. TORCH_CHECK(s_tok.dtype() == torch::kFloat32, "s_tok's dtype is not float32");
  1090. // Verify s_ch device, strides and dtype
  1091. TORCH_CHECK(s_ch.device().is_cuda(), "s_ch is not on GPU");
  1092. TORCH_CHECK(s_ch.is_contiguous(), "s_ch is not contiguous");
  1093. TORCH_CHECK(s_ch.dtype() == torch::kFloat32, "s_ch's dtype is not float32");
  1094. // Verify s_group device, strides and dtype
  1095. TORCH_CHECK(s_group.device().is_cuda(), "s_group is not on GPU");
  1096. TORCH_CHECK(s_group.is_contiguous(), "s_group is not contiguous");
  1097. TORCH_CHECK(s_group.dtype() == torch::kFloat16,
  1098. "s_group's dtype is not float16");
  1099. // Verify workspace size
  1100. TORCH_CHECK(size_n % min_thread_n == 0,
  1101. "size_n = " + str(size_n) +
  1102. ", is not divisible by min_thread_n = " + str(min_thread_n));
  1103. int min_workspace_size = (size_n / min_thread_n) * max_par;
  1104. TORCH_CHECK(workspace.numel() >= min_workspace_size,
  1105. "workspace.numel = " + str(workspace.numel()) +
  1106. " is below min_workspace_size = " + str(min_workspace_size));
  1107. // Alloc C matrix
  1108. const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
  1109. auto options_c = torch::TensorOptions().dtype(torch::kInt).device(a.device());
  1110. torch::Tensor c = torch::empty({max_par * 64, size_n}, options_c);
  1111. // Alloc D matrix
  1112. auto options_d =
  1113. torch::TensorOptions().dtype(torch::kFloat16).device(a.device());
  1114. torch::Tensor d = torch::empty({size_m, size_n}, options_d);
  1115. // thread_k: `k` size of a thread_tile in `weights` (can usually be left as
  1116. // auto -1)
  1117. int thread_k = -1;
  1118. // thread_n: `n` size of a thread_tile in `weights` (can usually be left as
  1119. // auto -1)
  1120. int thread_n = -1;
  1121. // sms: number of SMs to use for the kernel (can usually be left as auto -1)
  1122. int sms = -1;
  1123. int dev = a.get_device();
  1124. marlin_qqq_cuda(
  1125. a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), d.data_ptr(),
  1126. s_tok.data_ptr(), s_ch.data_ptr(), s_group.data_ptr(), size_m, size_n,
  1127. size_k, workspace.data_ptr(), groupsize, dev,
  1128. at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par);
  1129. return d;
  1130. }