marlin_cuda_kernel.cu 46 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144
  1. /*
  2. * Modified by Neural Magic
  3. * Copyright (C) Marlin.2024 Elias Frantar
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. #include <torch/extension.h>
  18. #include <ATen/cuda/CUDAContext.h>
  19. #include <c10/cuda/CUDAGuard.h>
  20. #include <cuda.h>
  21. #include <cuda_fp16.h>
  22. #include <cuda_runtime.h>
  23. #include <iostream>
  24. template <typename T> inline std::string str(T x) { return std::to_string(x); }
  25. namespace marlin {
  26. constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
  27. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  28. // Instances of `Vec` are used to organize groups of >>registers<<, as needed
  29. // for instance as inputs to tensor core operations. Consequently, all
  30. // corresponding index accesses must be compile-time constants, which is why we
  31. // extensively use `#pragma unroll` throughout the kernel code to guarantee
  32. // this.
  33. template <typename T, int n> struct Vec {
  34. T elems[n];
  35. __device__ T &operator[](int i) { return elems[i]; }
  36. };
  37. using I4 = Vec<int, 4>;
  38. // Matrix fragments for tensor core instructions; their precise layout is
  39. // documented here:
  40. // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
  41. using FragA = Vec<half2, 4>;
  42. using FragB = Vec<half2, 2>;
  43. using FragC = Vec<float, 4>;
  44. using FragS = Vec<half2, 1>; // quantization scales
  45. // Predicated asynchronous global->shared copy; used for inputs A where we apply
  46. // predication to handle batchsizes that are not multiples of 16.
  47. __device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr,
  48. bool pred = true) {
  49. const int BYTES = 16;
  50. uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
  51. asm volatile("{\n"
  52. " .reg .pred p;\n"
  53. " setp.ne.b32 p, %0, 0;\n"
  54. " @p cp.async.cg.shared.global [%1], [%2], %3;\n"
  55. "}\n" ::"r"((int)pred),
  56. "r"(smem), "l"(glob_ptr), "n"(BYTES));
  57. }
  58. // Asynchronous global->shared copy with a cache hint indicating that the values
  59. // may be evicted immediately; used for quantized weights B, which are only
  60. // accessed precisely once and should thus not pollute the L2 cache which we
  61. // need for inputs A and outputs C.
  62. // Async global->shared copy
  63. __device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) {
  64. const int BYTES = 16;
  65. uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
  66. asm volatile("{\n"
  67. " cp.async.cg.shared.global [%0], [%1], %2;\n"
  68. "}\n" ::"r"(smem), "l"(glob_ptr), "n"(BYTES));
  69. }
  70. // Async copy fence.
  71. __device__ inline void cp_async_fence() {
  72. asm volatile("cp.async.commit_group;\n" ::);
  73. }
  74. // Wait until at most `n` async copy stages are still pending.
  75. template <int n> __device__ inline void cp_async_wait() {
  76. asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
  77. }
  78. // m16n8k16 tensor core mma instruction with fp16 inputs and fp32
  79. // output/accumulation.
  80. __device__ inline void mma(const FragA &a_frag, const FragB &frag_b,
  81. FragC &frag_c) {
  82. const uint32_t *a = reinterpret_cast<const uint32_t *>(&a_frag);
  83. const uint32_t *b = reinterpret_cast<const uint32_t *>(&frag_b);
  84. float *c = reinterpret_cast<float *>(&frag_c);
  85. asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
  86. "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
  87. : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
  88. : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]),
  89. "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
  90. }
  91. // Instruction for loading a full 16x16 matrix fragment of operand A from shared
  92. // memory, directly in tensor core layout.
  93. __device__ inline void ldsm4(FragA &frag_a, const void *smem_ptr) {
  94. uint32_t *a = reinterpret_cast<uint32_t *>(&frag_a);
  95. uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
  96. asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
  97. : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
  98. : "r"(smem));
  99. }
  100. // Lookup-table based 3-input logical operation; explicitly used for
  101. // dequantization as the compiler does not seem to automatically recognize it in
  102. // all cases.
  103. template <int lut> __device__ inline int lop3(int a, int b, int c) {
  104. int res;
  105. asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
  106. : "=r"(res)
  107. : "r"(a), "r"(b), "r"(c), "n"(lut));
  108. return res;
  109. }
  110. // Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
  111. // values. We mostly follow the strategy in the link below, with some small
  112. // changes:
  113. // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
  114. __device__ inline FragB dequant(int q) {
  115. const int LO = 0x000f000f;
  116. const int HI = 0x00f000f0;
  117. const int EX = 0x64006400;
  118. // Guarantee that the `(a & b) | c` operations are LOP3s.
  119. int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
  120. int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
  121. // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
  122. // directly into `SUB` and `ADD`.
  123. const int SUB = 0x64086408;
  124. const int MUL = 0x2c002c00;
  125. const int ADD = 0xd480d480;
  126. FragB frag_b;
  127. frag_b[0] = __hsub2(*reinterpret_cast<half2 *>(&lo),
  128. *reinterpret_cast<const half2 *>(&SUB));
  129. frag_b[1] = __hfma2(*reinterpret_cast<half2 *>(&hi),
  130. *reinterpret_cast<const half2 *>(&MUL),
  131. *reinterpret_cast<const half2 *>(&ADD));
  132. return frag_b;
  133. }
  134. // Multiply dequantized values by the corresponding quantization scale; used
  135. // only for grouped quantization.
  136. __device__ inline void scale(FragB &frag_b, FragS &frag_s, int i) {
  137. half2 s = __half2half2(reinterpret_cast<__half *>(&frag_s)[i]);
  138. frag_b[0] = __hmul2(frag_b[0], s);
  139. frag_b[1] = __hmul2(frag_b[1], s);
  140. }
  141. // Wait until barrier reaches `count`, then lock for current threadblock.
  142. __device__ inline void barrier_acquire(int *lock, int count) {
  143. if (threadIdx.x == 0) {
  144. int state = -1;
  145. do
  146. // Guarantee that subsequent writes by this threadblock will be visible
  147. // globally.
  148. asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
  149. : "=r"(state)
  150. : "l"(lock));
  151. while (state != count);
  152. }
  153. __syncthreads();
  154. }
  155. // Release barrier and increment visitation count.
  156. __device__ inline void barrier_release(int *lock, bool reset = false) {
  157. __syncthreads();
  158. if (threadIdx.x == 0) {
  159. if (reset) {
  160. lock[0] = 0;
  161. return;
  162. }
  163. int val = 1;
  164. // Make sure that all writes since acquiring this barrier are visible
  165. // globally, while releasing the barrier.
  166. asm volatile("fence.acq_rel.gpu;\n");
  167. asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
  168. :
  169. : "l"(lock), "r"(val));
  170. }
  171. }
  172. template <const int threads, // number of threads in a threadblock
  173. const int thread_m_blocks, // number of 16x16 blocks in the m
  174. // dimension (batchsize) of the threadblock
  175. const int thread_n_blocks, // same for n dimension (output)
  176. const int thread_k_blocks, // same for k dimension (reduction)
  177. const int stages, // number of stages for the async global->shared
  178. // fetch pipeline
  179. const int group_blocks = -1 // number of consecutive 16x16 blocks with
  180. // a separate quantization scale
  181. >
  182. __global__ void
  183. Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
  184. const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn
  185. int4 *__restrict__ C, // fp16 output buffer of shape mxn
  186. const int4
  187. *__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn
  188. int prob_m, // batch dimension m
  189. int prob_n, // output dimension n
  190. int prob_k, // reduction dimension k
  191. int *locks // extra global storage for barrier synchronization
  192. ) {
  193. // Each threadblock processes one "stripe" of the B matrix with (roughly) the
  194. // same size, which might involve multiple column "slices" (of width 16 *
  195. // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
  196. // example:
  197. // 0 1 3
  198. // 0 2 3
  199. // 1 2 4
  200. // While this kind of partitioning makes things somewhat more complicated, it
  201. // ensures good utilization of all SMs for many kinds of shape and GPU
  202. // configurations, while requiring as few slow global cross-threadblock
  203. // reductions as possible.
  204. // For larger GEMMs we run multiple batchsize 64 versions in parallel for a
  205. // better partitioning with less reductions
  206. int parallel = 1;
  207. if (prob_m > 16 * thread_m_blocks) {
  208. parallel = prob_m / (16 * thread_m_blocks);
  209. prob_m = 16 * thread_m_blocks;
  210. }
  211. int k_tiles = prob_k / 16 / thread_k_blocks;
  212. int n_tiles = prob_n / 16 / thread_n_blocks;
  213. int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x);
  214. // Ensure that the number of tiles in each stripe is a multiple of the
  215. // groupsize; this avoids an annoying special case where a stripe starts in
  216. // the middle of group.
  217. if (group_blocks != -1)
  218. iters = (group_blocks / thread_k_blocks) *
  219. ceildiv(iters, (group_blocks / thread_k_blocks));
  220. int slice_row = (iters * blockIdx.x) % k_tiles;
  221. int slice_col_par = (iters * blockIdx.x) / k_tiles;
  222. int slice_col = slice_col_par;
  223. int slice_iters; // number of threadblock tiles in the current slice
  224. int slice_count =
  225. 0; // total number of active threadblocks in the current slice
  226. int slice_idx; // index of threadblock in current slice; numbered bottom to
  227. // top
  228. // We can easily implement parallel problem execution by just remapping
  229. // indices and advancing global pointers
  230. if (slice_col_par >= n_tiles) {
  231. A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8;
  232. C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
  233. locks += (slice_col_par / n_tiles) * n_tiles;
  234. slice_col = slice_col_par % n_tiles;
  235. }
  236. // Compute all information about the current slice which is required for
  237. // synchronization.
  238. auto init_slice = [&]() {
  239. slice_iters =
  240. iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
  241. if (slice_iters < 0 || slice_col_par >= n_tiles * parallel)
  242. slice_iters = 0;
  243. if (slice_iters == 0)
  244. return;
  245. if (slice_row + slice_iters > k_tiles)
  246. slice_iters = k_tiles - slice_row;
  247. slice_count = 1;
  248. slice_idx = 0;
  249. int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);
  250. if (col_first <= k_tiles * (slice_col_par + 1)) {
  251. int col_off = col_first - k_tiles * slice_col_par;
  252. slice_count = ceildiv(k_tiles - col_off, iters);
  253. if (col_off > 0)
  254. slice_count++;
  255. int delta_first = iters * blockIdx.x - col_first;
  256. if (delta_first < 0 || (col_off == 0 && delta_first == 0))
  257. slice_idx = slice_count - 1;
  258. else {
  259. slice_idx = slice_count - 1 - delta_first / iters;
  260. if (col_off > 0)
  261. slice_idx--;
  262. }
  263. }
  264. if (slice_col == n_tiles) {
  265. A += 16 * thread_m_blocks * prob_k / 8;
  266. C += 16 * thread_m_blocks * prob_n / 8;
  267. locks += n_tiles;
  268. slice_col = 0;
  269. }
  270. };
  271. init_slice();
  272. int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory
  273. // We typically use `constexpr` to indicate that this value is a compile-time
  274. // constant
  275. constexpr int a_sh_stride =
  276. 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory
  277. constexpr int a_gl_rd_delta_o =
  278. 16 * thread_k_blocks /
  279. 8; // delta between subsequent A tiles in global memory
  280. int a_gl_rd_delta_i =
  281. a_gl_stride *
  282. (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile
  283. constexpr int a_sh_wr_delta =
  284. a_sh_stride * (threads / a_gl_rd_delta_o); // between shared memory writes
  285. constexpr int a_sh_rd_delta_o =
  286. 2 * ((threads / 32) /
  287. (thread_n_blocks / 4)); // between shared memory tile reads
  288. constexpr int a_sh_rd_delta_i =
  289. a_sh_stride * 16; // within a shared memory tile
  290. constexpr int a_sh_stage =
  291. a_sh_stride * (16 * thread_m_blocks); // overall size of a tile
  292. constexpr int a_sh_wr_iters =
  293. ceildiv(a_sh_stage,
  294. a_sh_wr_delta); // number of shared write iterations for a tile
  295. int b_gl_stride = 16 * prob_n / 32;
  296. constexpr int b_sh_stride = 32 * thread_n_blocks / 4;
  297. int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
  298. int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride);
  299. constexpr int b_sh_wr_delta = threads;
  300. constexpr int b_sh_rd_delta = threads;
  301. constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
  302. constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
  303. int s_gl_stride = prob_n / 8;
  304. constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
  305. constexpr int s_sh_stage = s_sh_stride;
  306. int s_gl_rd_delta = s_gl_stride;
  307. // Global A read index of current thread.
  308. int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
  309. (threadIdx.x % a_gl_rd_delta_o);
  310. a_gl_rd += a_gl_rd_delta_o * slice_row;
  311. // Shared write index of current thread.
  312. int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
  313. (threadIdx.x % a_gl_rd_delta_o);
  314. // Shared read index.
  315. int a_sh_rd =
  316. a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;
  317. a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
  318. int b_gl_rd =
  319. b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
  320. b_gl_rd += b_sh_stride * slice_col;
  321. b_gl_rd += b_gl_rd_delta_o * slice_row;
  322. int b_sh_wr = threadIdx.x;
  323. int b_sh_rd = threadIdx.x;
  324. int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
  325. s_sh_stride * slice_col + threadIdx.x;
  326. int s_sh_wr = threadIdx.x;
  327. int s_sh_rd;
  328. // We use a different scale layout for grouped and column-wise quantization as
  329. // we scale a `half2` tile in column-major layout in the former and in
  330. // row-major in the latter case.
  331. if (group_blocks != -1)
  332. s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
  333. (threadIdx.x % 32) / 4;
  334. else
  335. s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
  336. (threadIdx.x % 32) % 4;
  337. // Precompute which thread should not read memory in which iterations; this is
  338. // needed if there are more threads than required for a certain tilesize or
  339. // when the batchsize is not a multiple of 16.
  340. bool a_sh_wr_pred[a_sh_wr_iters];
  341. #pragma unroll
  342. for (int i = 0; i < a_sh_wr_iters; i++)
  343. a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
  344. bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
  345. // To ensure that writing and reading A tiles to/from shared memory, the
  346. // latter in fragment format, is fully bank conflict free, we need to use a
  347. // rather fancy XOR-based layout. The key here is that neither reads nor
  348. // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
  349. // same shared memory banks. Further, it seems (based on NSight-Compute) that
  350. // each warp must also write a consecutive memory segment?
  351. auto transform_a = [&](int i) {
  352. int row = i / a_gl_rd_delta_o;
  353. return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
  354. };
  355. // Since the computation of this remapping is non-trivial and, due to our main
  356. // loop unrolls, all shared memory accesses are static, we simply precompute
  357. // both transformed reads and writes.
  358. int a_sh_wr_trans[a_sh_wr_iters];
  359. #pragma unroll
  360. for (int i = 0; i < a_sh_wr_iters; i++)
  361. a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
  362. int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
  363. #pragma unroll
  364. for (int i = 0; i < b_sh_wr_iters; i++) {
  365. #pragma unroll
  366. for (int j = 0; j < thread_m_blocks; j++)
  367. a_sh_rd_trans[i][j] =
  368. transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
  369. }
  370. // Since B-accesses have non-constant stride they have to be computed at
  371. // runtime; we break dependencies between subsequent accesses with a tile by
  372. // maintining multiple pointers (we have enough registers), a tiny
  373. // optimization.
  374. const int4 *B_ptr[b_sh_wr_iters];
  375. #pragma unroll
  376. for (int i = 0; i < b_sh_wr_iters; i++)
  377. B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
  378. extern __shared__ int4 sh[];
  379. // Shared memory storage for global fetch pipelines.
  380. int4 *sh_a = sh;
  381. int4 *sh_b = sh_a + (stages * a_sh_stage);
  382. int4 *sh_s = sh_b + (stages * b_sh_stage);
  383. // Register storage for double buffer of shared memory reads.
  384. FragA frag_a[2][thread_m_blocks];
  385. I4 frag_b_quant[2];
  386. FragC frag_c[thread_m_blocks][4][2];
  387. FragS frag_s[2][4];
  388. // Zero accumulators.
  389. auto zero_accums = [&]() {
  390. #pragma unroll
  391. for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
  392. reinterpret_cast<float *>(frag_c)[i] = 0;
  393. };
  394. // Asynchronously fetch the next A, B and s tile from global to the next
  395. // shared memory pipeline location.
  396. auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
  397. if (pred) {
  398. int4 *sh_a_stage = sh_a + a_sh_stage * pipe;
  399. #pragma unroll
  400. for (int i = 0; i < a_sh_wr_iters; i++) {
  401. cp_async4_pred(
  402. &sh_a_stage[a_sh_wr_trans[i]],
  403. &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
  404. a_sh_wr_pred[i]);
  405. }
  406. int4 *sh_b_stage = sh_b + b_sh_stage * pipe;
  407. #pragma unroll
  408. for (int i = 0; i < b_sh_wr_iters; i++) {
  409. cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]);
  410. B_ptr[i] += b_gl_rd_delta_o;
  411. }
  412. // Only fetch scales if this tile starts a new group
  413. if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {
  414. int4 *sh_s_stage = sh_s + s_sh_stage * pipe;
  415. if (s_sh_wr_pred)
  416. cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
  417. s_gl_rd += s_gl_rd_delta;
  418. }
  419. }
  420. // Insert a fence even when we are winding down the pipeline to ensure that
  421. // waiting is also correct at this point.
  422. cp_async_fence();
  423. };
  424. // Wait until the next thread tile has been loaded to shared memory.
  425. auto wait_for_stage = [&]() {
  426. // We only have `stages - 2` active fetches since we are double buffering
  427. // and can only issue the next fetch when it is guaranteed that the previous
  428. // shared memory load is fully complete (as it may otherwise be
  429. // overwritten).
  430. cp_async_wait<stages - 2>();
  431. __syncthreads();
  432. };
  433. // Load the next sub-tile from the current location in the shared memory pipe
  434. // into the current register buffer.
  435. auto fetch_to_registers = [&](int k, int pipe) {
  436. // It may seem inefficient that we reload the groups for every sub-tile;
  437. // however, this does not seem to be a significant bottleneck, while some
  438. // theoretically better attempts have lead to bad instruction ordering by
  439. // the compiler and correspondingly a noticeable drop in performance.
  440. if (group_blocks != -1) {
  441. int4 *sh_s_stage =
  442. sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
  443. (pipe / (group_blocks / thread_k_blocks)));
  444. reinterpret_cast<int4 *>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
  445. }
  446. int4 *sh_a_stage = sh_a + a_sh_stage * pipe;
  447. #pragma unroll
  448. for (int i = 0; i < thread_m_blocks; i++)
  449. ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
  450. int4 *sh_b_stage = sh_b + b_sh_stage * pipe;
  451. frag_b_quant[k % 2] = *reinterpret_cast<I4 *>(
  452. &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]);
  453. };
  454. // Execute the actual tensor core matmul of a sub-tile.
  455. auto matmul = [&](int k) {
  456. // We have the m dimension as the inner loop in order to encourage overlapping
  457. // dequantization and matmul operations.
  458. #pragma unroll
  459. for (int j = 0; j < 4; j++) {
  460. int b_quant = frag_b_quant[k % 2][j];
  461. int b_quant_shift = b_quant >> 8;
  462. FragB frag_b0 = dequant(b_quant);
  463. // If there are no groups, we can just scale the final output once and can
  464. // avoid doing so for each weight.
  465. if (group_blocks != -1)
  466. scale(frag_b0, frag_s[k % 2][j], 0);
  467. FragB frag_b1 = dequant(b_quant_shift);
  468. if (group_blocks != -1)
  469. scale(frag_b1, frag_s[k % 2][j], 1);
  470. #pragma unroll
  471. for (int i = 0; i < thread_m_blocks; i++) {
  472. mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
  473. mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
  474. }
  475. }
  476. };
  477. // Since we slice across the k dimension of a tile in order to increase the
  478. // number of warps while keeping the n dimension of a tile reasonable, we have
  479. // multiple warps that accumulate their partial sums of the same output
  480. // location; which we have to reduce over in the end. We do in shared memory.
  481. auto thread_block_reduce = [&]() {
  482. constexpr int red_off = threads / b_sh_stride / 2;
  483. if (red_off >= 1) {
  484. int red_idx = threadIdx.x / b_sh_stride;
  485. constexpr int red_sh_stride = b_sh_stride * 4 * 2;
  486. constexpr int red_sh_delta = b_sh_stride;
  487. int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) +
  488. (threadIdx.x % b_sh_stride);
  489. // Parallel logarithmic shared memory reduction. We make sure to avoid any
  490. // unnecessary read or write iterations, e.g., for two warps we write only
  491. // once by warp 1 and read only once by warp 0.
  492. #pragma unroll
  493. for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
  494. #pragma unroll
  495. for (int i = red_off; i > 0; i /= 2) {
  496. if (i <= red_idx && red_idx < 2 * i) {
  497. #pragma unroll
  498. for (int j = 0; j < 4 * 2; j++) {
  499. int red_sh_wr =
  500. red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
  501. if (i < red_off) {
  502. float *c_rd = reinterpret_cast<float *>(
  503. &sh[red_sh_delta * j + red_sh_rd]);
  504. float *c_wr = reinterpret_cast<float *>(&sh[red_sh_wr]);
  505. #pragma unroll
  506. for (int k = 0; k < 4; k++)
  507. reinterpret_cast<FragC *>(frag_c)[4 * 2 * m_block + j][k] +=
  508. c_rd[k] + c_wr[k];
  509. }
  510. sh[red_sh_wr] =
  511. reinterpret_cast<int4 *>(&frag_c)[4 * 2 * m_block + j];
  512. }
  513. }
  514. __syncthreads();
  515. }
  516. if (red_idx == 0) {
  517. #pragma unroll
  518. for (int i = 0; i < 4 * 2; i++) {
  519. float *c_rd =
  520. reinterpret_cast<float *>(&sh[red_sh_delta * i + red_sh_rd]);
  521. #pragma unroll
  522. for (int j = 0; j < 4; j++)
  523. reinterpret_cast<FragC *>(frag_c)[4 * 2 * m_block + i][j] +=
  524. c_rd[j];
  525. }
  526. }
  527. __syncthreads();
  528. }
  529. }
  530. };
  531. // Since multiple threadblocks may process parts of the same column slice, we
  532. // finally have to globally reduce over the results. As the striped partitioning
  533. // minimizes the number of such reductions and our outputs are usually rather
  534. // small, we perform this reduction serially in L2 cache.
  535. auto global_reduce = [&](bool first = false, bool last = false) {
  536. // We are very careful here to reduce directly in the output buffer to
  537. // maximize L2 cache utilization in this step. To do this, we write out
  538. // results in FP16 (but still reduce with FP32 compute).
  539. constexpr int active_threads = 32 * thread_n_blocks / 4;
  540. if (threadIdx.x < active_threads) {
  541. int c_gl_stride = prob_n / 8;
  542. int c_gl_wr_delta_o = 8 * c_gl_stride;
  543. int c_gl_wr_delta_i = 4 * (active_threads / 32);
  544. int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) +
  545. 4 * (threadIdx.x / 32) + threadIdx.x % 4;
  546. c_gl_wr += (2 * thread_n_blocks) * slice_col;
  547. constexpr int c_sh_wr_delta = active_threads;
  548. int c_sh_wr = threadIdx.x;
  549. int row = (threadIdx.x % 32) / 4;
  550. if (!first) {
  551. // Interestingly, doing direct global accesses here really seems to mess up the
  552. // compiler and lead to slowdowns, hence we also use async-copies even though
  553. // these fetches are not actually asynchronous.
  554. #pragma unroll
  555. for (int i = 0; i < thread_m_blocks * 4; i++) {
  556. cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i],
  557. &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
  558. c_gl_wr_delta_i * (i % 2)],
  559. i < (thread_m_blocks - 1) * 4 ||
  560. 8 * (i / 2) + row < prob_m);
  561. }
  562. cp_async_fence();
  563. cp_async_wait<0>();
  564. }
  565. #pragma unroll
  566. for (int i = 0; i < thread_m_blocks * 4; i++) {
  567. if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
  568. if (!first) {
  569. int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
  570. #pragma unroll
  571. for (int j = 0; j < 2 * 4; j++) {
  572. reinterpret_cast<float *>(
  573. &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=
  574. __half2float(reinterpret_cast<__half *>(&c_red)[j]);
  575. }
  576. }
  577. if (!last) {
  578. int4 c;
  579. #pragma unroll
  580. for (int j = 0; j < 2 * 4; j++) {
  581. reinterpret_cast<__half *>(&c)[j] =
  582. __float2half(reinterpret_cast<float *>(
  583. &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]);
  584. }
  585. C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =
  586. c;
  587. }
  588. }
  589. }
  590. }
  591. };
  592. // Write out the reduce final result in the correct layout. We only actually
  593. // reshuffle matrix fragments in this step, the reduction above is performed
  594. // in fragment layout.
  595. auto write_result = [&]() {
  596. int c_gl_stride = prob_n / 8;
  597. constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
  598. int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
  599. constexpr int c_sh_rd_delta =
  600. c_sh_stride * (threads / (2 * thread_n_blocks));
  601. int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) +
  602. (threadIdx.x % (2 * thread_n_blocks));
  603. c_gl_wr += (2 * thread_n_blocks) * slice_col;
  604. int c_sh_wr =
  605. (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;
  606. c_sh_wr += 32 * (threadIdx.x / 32);
  607. int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) +
  608. (threadIdx.x % (2 * thread_n_blocks));
  609. int c_gl_wr_end = c_gl_stride * prob_m;
  610. // We first reorder in shared memory to guarantee the most efficient final
  611. // global write patterns
  612. auto write = [&](int idx, float c0, float c1, FragS &s) {
  613. half2 res = __halves2half2(__float2half(c0), __float2half(c1));
  614. if (group_blocks ==
  615. -1) // for per-column quantization we finally apply the scale here
  616. res = __hmul2(res, s[0]);
  617. ((half2 *)sh)[idx] = res;
  618. };
  619. if (threadIdx.x / 32 < thread_n_blocks / 4) {
  620. #pragma unroll
  621. for (int i = 0; i < thread_m_blocks; i++) {
  622. #pragma unroll
  623. for (int j = 0; j < 4; j++) {
  624. int wr = c_sh_wr + 8 * j;
  625. write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
  626. frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
  627. write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],
  628. frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
  629. write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],
  630. frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
  631. write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],
  632. frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
  633. }
  634. c_sh_wr += 16 * (4 * c_sh_stride);
  635. }
  636. }
  637. __syncthreads();
  638. #pragma unroll
  639. for (int i = 0;
  640. i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
  641. i++) {
  642. if (c_gl_wr < c_gl_wr_end) {
  643. C[c_gl_wr] = sh[c_sh_rd];
  644. c_gl_wr += c_gl_wr_delta;
  645. c_sh_rd += c_sh_rd_delta;
  646. }
  647. }
  648. };
  649. // Start global fetch and register load pipelines.
  650. auto start_pipes = [&]() {
  651. #pragma unroll
  652. for (int i = 0; i < stages - 1; i++)
  653. fetch_to_shared(i, i, i < slice_iters);
  654. zero_accums();
  655. wait_for_stage();
  656. fetch_to_registers(0, 0);
  657. a_gl_rd += a_gl_rd_delta_o * (stages - 1);
  658. };
  659. start_pipes();
  660. // Main loop.
  661. while (slice_iters) {
  662. // We unroll over both the global fetch and the register load pipeline to ensure
  663. // all shared memory accesses are static. Note that both pipelines have even
  664. // length meaning that the next iteration will always start at index 0.
  665. #pragma unroll
  666. for (int pipe = 0; pipe < stages;) {
  667. #pragma unroll
  668. for (int k = 0; k < b_sh_wr_iters; k++) {
  669. fetch_to_registers(k + 1, pipe % stages);
  670. if (k == b_sh_wr_iters - 2) {
  671. fetch_to_shared((pipe + stages - 1) % stages, pipe,
  672. slice_iters >= stages);
  673. pipe++;
  674. wait_for_stage();
  675. }
  676. matmul(k);
  677. }
  678. slice_iters--;
  679. if (slice_iters == 0)
  680. break;
  681. }
  682. a_gl_rd += a_gl_rd_delta_o * stages;
  683. // Process results and, if necessary, proceed to the next column slice.
  684. // While this pattern may not be the most readable, other ways of writing
  685. // the loop seemed to noticeably worse performance after compilation.
  686. if (slice_iters == 0) {
  687. cp_async_wait<0>();
  688. bool last = slice_idx == slice_count - 1;
  689. // For per-column scales, we only fetch them here in the final step before
  690. // write-out
  691. if (group_blocks == -1 && last) {
  692. if (s_sh_wr_pred)
  693. cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
  694. cp_async_fence();
  695. }
  696. thread_block_reduce();
  697. if (group_blocks == -1 && last) {
  698. cp_async_wait<0>();
  699. __syncthreads();
  700. if (threadIdx.x / 32 < thread_n_blocks / 4) {
  701. reinterpret_cast<int4 *>(&frag_s)[0] = sh_s[s_sh_rd + 0];
  702. reinterpret_cast<int4 *>(&frag_s)[1] = sh_s[s_sh_rd + 4];
  703. }
  704. }
  705. if (slice_count > 1) { // only globally reduce if there is more than one
  706. // block in a slice
  707. barrier_acquire(&locks[slice_col], slice_idx);
  708. global_reduce(slice_idx == 0, last);
  709. barrier_release(&locks[slice_col], last);
  710. }
  711. if (last) // only the last block in a slice actually writes the result
  712. write_result();
  713. slice_row = 0;
  714. slice_col_par++;
  715. slice_col++;
  716. init_slice();
  717. if (slice_iters) {
  718. a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
  719. (threadIdx.x % a_gl_rd_delta_o);
  720. #pragma unroll
  721. for (int i = 0; i < b_sh_wr_iters; i++)
  722. B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
  723. if (slice_col == 0) {
  724. #pragma unroll
  725. for (int i = 0; i < b_sh_wr_iters; i++)
  726. B_ptr[i] -= b_gl_stride;
  727. }
  728. s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
  729. start_pipes();
  730. }
  731. }
  732. }
  733. }
  734. #else
  735. template <const int threads, // number of threads in a threadblock
  736. const int thread_m_blocks, // number of 16x16 blocks in the m
  737. // dimension (batchsize) of the threadblock
  738. const int thread_n_blocks, // same for n dimension (output)
  739. const int thread_k_blocks, // same for k dimension (reduction)
  740. const int stages, // number of stages for the async global->shared
  741. // fetch pipeline
  742. const int group_blocks = -1 // number of consecutive 16x16 blocks with
  743. // a separate quantization scale
  744. >
  745. __global__ void
  746. Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
  747. const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn
  748. int4 *__restrict__ C, // fp16 output buffer of shape mxn
  749. const int4
  750. *__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn
  751. int prob_m, // batch dimension m
  752. int prob_n, // output dimension n
  753. int prob_k, // reduction dimension k
  754. int *locks // extra global storage for barrier synchronization
  755. ) {
  756. // Marlin is not implemented yet for SM < 8.0
  757. assert(false);
  758. return;
  759. }
  760. #endif
  761. // 8 warps are a good choice since every SM has 4 schedulers and having more
  762. // than 1 warp per schedule allows some more latency hiding. At the same time,
  763. // we want relatively few warps to have many registers per warp and small tiles.
  764. const int USER_THREADS =
  765. 256; // Note: This is only used with user-provided thread_k/n
  766. const int STAGES = 4; // 4 pipeline stages fit into shared memory
  767. const int SHARED_MEM =
  768. 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0)
  769. static constexpr int min_thread_n = 64;
  770. static constexpr int min_thread_k = 64;
  771. static constexpr int tile_size = 16;
  772. static constexpr int max_par = 16;
  773. static constexpr int pack_factor_4bit =
  774. 8; // We have 8 4-bit vals inside a 32 bit
  775. #define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
  776. GROUP_BLOCKS, NUM_THREADS) \
  777. else if (thread_m_blocks == THREAD_M_BLOCKS && \
  778. thread_n_blocks == THREAD_N_BLOCKS && \
  779. thread_k_blocks == THREAD_K_BLOCKS && \
  780. group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
  781. cudaFuncSetAttribute(Marlin<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
  782. THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>, \
  783. cudaFuncAttributeMaxDynamicSharedMemorySize, \
  784. SHARED_MEM); \
  785. Marlin<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
  786. STAGES, GROUP_BLOCKS><<<blocks, NUM_THREADS, SHARED_MEM, stream>>>( \
  787. A_ptr, B_ptr, C_ptr, s_ptr, prob_m, prob_n, prob_k, locks); \
  788. }
  789. typedef struct {
  790. int thread_k;
  791. int thread_n;
  792. int num_threads;
  793. } thread_config_t;
  794. thread_config_t small_batch_thread_configs[] = {
  795. // Ordered by priority
  796. // thread_k, thread_n, num_threads
  797. {128, 128, 256}, // Default
  798. {128, 64, 128}, // Reduce N 2X, same K
  799. {64, 256, 256}, // Reduce K 2X, increase N 2X
  800. {64, 128, 128}, // Reduce K 2X, same N
  801. };
  802. thread_config_t large_batch_thread_configs[] = {
  803. // Ordered by priority
  804. // thread_k, thread_n, num_threads
  805. {64, 256, 256}, // Default
  806. {128, 128, 256}, // Reduce N 2X, increase K 2X
  807. {64, 128, 128}, // Reduce N 2X, same K
  808. {128, 64, 128}, // Reduce N 4X, increase K 2X
  809. };
  810. bool is_valid_config(thread_config_t const &th_config, int prob_m, int prob_n,
  811. int prob_k) {
  812. // Sanity
  813. if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
  814. th_config.num_threads == -1) {
  815. return false;
  816. }
  817. // Verify K/N are divisible by thread K/N
  818. if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
  819. return false;
  820. }
  821. // thread_k can be only 128 or 64 (because it must be less than groupsize
  822. // which is 128)
  823. if (th_config.thread_k != 128 && th_config.thread_k != 64) {
  824. return false;
  825. }
  826. // Verify min for thread K/N
  827. if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
  828. return false;
  829. }
  830. // num_threads must be at least 128 (= 4 warps)
  831. if (th_config.num_threads < 128) {
  832. return false;
  833. }
  834. return true;
  835. }
  836. thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) {
  837. if (prob_m <= 16) {
  838. for (auto th_config : small_batch_thread_configs) {
  839. if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
  840. return th_config;
  841. }
  842. }
  843. } else {
  844. for (auto th_config : large_batch_thread_configs) {
  845. if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
  846. return th_config;
  847. }
  848. }
  849. }
  850. return thread_config_t{-1, -1, -1};
  851. }
  852. #define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \
  853. __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
  854. __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
  855. __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
  856. __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
  857. __CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
  858. __CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
  859. __CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
  860. __CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
  861. __CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
  862. __CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS)
  863. void marlin_cuda(const void *A, const void *B, void *C, void *s, int prob_m,
  864. int prob_n, int prob_k, void *workspace, int groupsize = -1,
  865. int dev = 0, cudaStream_t stream = 0, int thread_k = -1,
  866. int thread_n = -1, int sms = -1, int max_par = 16) {
  867. int tot_m = prob_m;
  868. int tot_m_blocks = ceildiv(tot_m, 16);
  869. int pad = 16 * tot_m_blocks - tot_m;
  870. if (sms == -1)
  871. cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
  872. // Set thread config
  873. thread_config_t th_config;
  874. if (thread_k != -1 && thread_n != -1) {
  875. // User-defined config
  876. th_config = thread_config_t{thread_k, thread_n, USER_THREADS};
  877. } else {
  878. // Auto config
  879. th_config = determine_thread_config(prob_m, prob_n, prob_k);
  880. }
  881. if (!is_valid_config(th_config, prob_m, prob_n, prob_k)) {
  882. throw std::runtime_error(
  883. "Invalid thread config: thread_k = " + str(th_config.thread_k) +
  884. ", thread_n = " + str(th_config.thread_n) +
  885. ", num_threads = " + str(th_config.num_threads) + " for MKN = [" +
  886. str(prob_m) + ", " + str(prob_k) + ", " + str(prob_n) + "]");
  887. }
  888. // Uncomment for debug
  889. // std::cout << "Using thread_config: thread_k = " + str(th_config.thread_k) +
  890. // ", thread_n = " + str(th_config.thread_n) +
  891. // ", num_threads = " + str(th_config.num_threads) + " for
  892. // MKN = [" + str(prob_m) +
  893. // ", " + str(prob_k) + ", " + str(prob_n) + "]\n";
  894. int num_threads = th_config.num_threads;
  895. thread_k = th_config.thread_k;
  896. thread_n = th_config.thread_n;
  897. int thread_k_blocks = thread_k / 16;
  898. int thread_n_blocks = thread_n / 16;
  899. int group_blocks = (groupsize == -1) ? -1 : groupsize / 16;
  900. int blocks = sms;
  901. if (prob_m == 0 || prob_n == 0 || prob_k == 0) {
  902. return;
  903. }
  904. TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
  905. " is not divisible by thread_n = ", thread_n);
  906. TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
  907. " is not divisible by thread_k = ", thread_k);
  908. if (group_blocks != -1) {
  909. TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
  910. " is not divisible by group_blocks = ", group_blocks);
  911. }
  912. const int4 *A_ptr = (const int4 *)A;
  913. const int4 *B_ptr = (const int4 *)B;
  914. int4 *C_ptr = (int4 *)C;
  915. const int4 *s_ptr = (const int4 *)s;
  916. int *locks = (int *)workspace;
  917. for (int i = 0; i < tot_m_blocks; i += 4) {
  918. int thread_m_blocks = tot_m_blocks - i;
  919. prob_m = tot_m - 16 * i;
  920. int par = 1;
  921. if (thread_m_blocks > 4) {
  922. // Note that parallel > 1 currently only works for inputs without any
  923. // padding
  924. par = (16 * thread_m_blocks - pad) / 64;
  925. if (par > max_par)
  926. par = max_par;
  927. prob_m = 64 * par;
  928. i += 4 * (par - 1);
  929. thread_m_blocks = 4;
  930. }
  931. // For compilation speed, we only define the kernel configurations that have
  932. // seemed useful (in terms of performance) in our testing, however many more
  933. // are, in principle, possible.
  934. if (false) {
  935. }
  936. CALL_IF(8, 8, 256)
  937. CALL_IF(16, 4, 256)
  938. CALL_IF(8, 4, 128)
  939. CALL_IF(4, 8, 128)
  940. else {
  941. throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) +
  942. ", " + str(prob_k) + ", " + str(prob_n) + "]" +
  943. ", groupsize = " + str(groupsize) +
  944. ", thread_m_blocks = " + str(thread_m_blocks) +
  945. ", thread_n_blocks = " + str(thread_n_blocks) +
  946. ", thread_k_blocks = " + str(thread_k_blocks));
  947. }
  948. A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;
  949. C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par;
  950. }
  951. }
  952. } // namespace marlin
  953. torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
  954. torch::Tensor &b_scales, torch::Tensor &workspace,
  955. int64_t size_m, int64_t size_n, int64_t size_k) {
  956. // Verify M
  957. TORCH_CHECK(size_m == a.size(0),
  958. "Shape mismatch: a.size(0) = " + str(a.size(0)) +
  959. ", size_m = " + str(size_m));
  960. // Verify K
  961. TORCH_CHECK(size_k == a.size(1),
  962. "Shape mismatch: a.size(1) = " + str(a.size(1)) +
  963. ", size_k = " + str(size_k));
  964. TORCH_CHECK(size_k % marlin::tile_size == 0,
  965. "size_k = " + str(size_k) +
  966. " is not divisible by tile_size = " + str(marlin::tile_size));
  967. TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0),
  968. "Shape mismatch: b_q_weight.size(0) = " +
  969. str(b_q_weight.size(0)) + ", size_k = " + str(size_k) +
  970. ", tile_size = " + str(marlin::tile_size));
  971. // Verify N
  972. TORCH_CHECK(b_scales.size(1) == size_n,
  973. "b_scales.size(1) = " + str(b_scales.size(1)) +
  974. ", size_n = " + str(size_n));
  975. TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0,
  976. "b_q_weight.size(1) = " + str(b_q_weight.size(1)) +
  977. " is not divisible by tile_size = " + str(marlin::tile_size));
  978. int actual_size_n =
  979. (b_q_weight.size(1) / marlin::tile_size) * marlin::pack_factor_4bit;
  980. TORCH_CHECK(size_n == actual_size_n,
  981. "size_n = " + str(size_n) +
  982. ", actual_size_n = " + str(actual_size_n));
  983. // Verify A device and strides
  984. TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
  985. TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
  986. // Verify B device and strides
  987. TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
  988. TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
  989. // Verify scales device and strides
  990. TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
  991. TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
  992. // Alloc C matrix
  993. const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
  994. auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
  995. torch::Tensor c = torch::empty({size_m, size_n}, options);
  996. // thread_k: `k` size of a thread_tile in `weights` (can usually be left as
  997. // auto -1)
  998. int thread_k = -1;
  999. // thread_n: `n` size of a thread_tile in `weights` (can usually be left as
  1000. // auto -1)
  1001. int thread_n = -1;
  1002. // sms: number of SMs to use for the kernel (can usually be left as auto -1)
  1003. int sms = -1;
  1004. // Detect groupsize
  1005. if (b_scales.size(0) != 1) {
  1006. TORCH_CHECK(size_k % b_scales.size(0) == 0,
  1007. "size_k = " + str(size_k) +
  1008. ", is not divisible by b_scales.size(0) = " +
  1009. str(b_scales.size(0)));
  1010. }
  1011. int groupsize = b_scales.size(0) == 1 ? -1 : size_k / b_scales.size(0);
  1012. // Verify groupsize
  1013. TORCH_CHECK(groupsize == -1 || groupsize == 128,
  1014. "Unexpected groupsize = " + str(groupsize));
  1015. // Verify workspace size
  1016. TORCH_CHECK(
  1017. size_n % marlin::min_thread_n == 0,
  1018. "size_n = " + str(size_n) +
  1019. ", is not divisible by min_thread_n = " + str(marlin::min_thread_n));
  1020. int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par;
  1021. TORCH_CHECK(workspace.numel() >= min_workspace_size,
  1022. "workspace.numel = " + str(workspace.numel()) +
  1023. " is below min_workspace_size = " + str(min_workspace_size));
  1024. int dev = a.get_device();
  1025. marlin::marlin_cuda(a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(),
  1026. b_scales.data_ptr(), size_m, size_n, size_k,
  1027. workspace.data_ptr(), groupsize, dev,
  1028. at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n,
  1029. sms, marlin::max_par);
  1030. return c;
  1031. }