marlin_cuda_kernel.cu 44 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138
  1. /*
  2. * Modified by Neural Magic
  3. * Copyright (C) Marlin.2024 Elias Frantar
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. #include <torch/all.h>
  18. #include <ATen/cuda/CUDAContext.h>
  19. #include <c10/cuda/CUDAGuard.h>
  20. #include <cuda.h>
  21. #include <cuda_fp16.h>
  22. #include <cuda_runtime.h>
  23. #include <iostream>
  24. template <typename T>
  25. inline std::string str(T x) {
  26. return std::to_string(x);
  27. }
  28. namespace marlin_dense {
  29. constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
  30. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  31. // Instances of `Vec` are used to organize groups of >>registers<<, as needed
  32. // for instance as inputs to tensor core operations. Consequently, all
  33. // corresponding index accesses must be compile-time constants, which is why we
  34. // extensively use `#pragma unroll` throughout the kernel code to guarantee
  35. // this.
  36. template <typename T, int n>
  37. struct Vec {
  38. T elems[n];
  39. __device__ T& operator[](int i) { return elems[i]; }
  40. };
  41. using I4 = Vec<int, 4>;
  42. // Matrix fragments for tensor core instructions; their precise layout is
  43. // documented here:
  44. // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
  45. using FragA = Vec<half2, 4>;
  46. using FragB = Vec<half2, 2>;
  47. using FragC = Vec<float, 4>;
  48. using FragS = Vec<half2, 1>; // quantization scales
  49. // Predicated asynchronous global->shared copy; used for inputs A where we apply
  50. // predication to handle batchsizes that are not multiples of 16.
  51. __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
  52. bool pred = true) {
  53. const int BYTES = 16;
  54. uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
  55. asm volatile(
  56. "{\n"
  57. " .reg .pred p;\n"
  58. " setp.ne.b32 p, %0, 0;\n"
  59. " @p cp.async.cg.shared.global [%1], [%2], %3;\n"
  60. "}\n" ::"r"((int)pred),
  61. "r"(smem), "l"(glob_ptr), "n"(BYTES));
  62. }
  63. // Asynchronous global->shared copy
  64. __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
  65. const int BYTES = 16;
  66. uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
  67. asm volatile(
  68. "{\n"
  69. " cp.async.cg.shared.global [%0], [%1], %2;\n"
  70. "}\n" ::"r"(smem),
  71. "l"(glob_ptr), "n"(BYTES));
  72. }
  73. // Async copy fence.
  74. __device__ inline void cp_async_fence() {
  75. asm volatile("cp.async.commit_group;\n" ::);
  76. }
  77. // Wait until at most `n` async copy stages are still pending.
  78. template <int n>
  79. __device__ inline void cp_async_wait() {
  80. asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
  81. }
  82. // m16n8k16 tensor core mma instruction with fp16 inputs and fp32
  83. // output/accumulation.
  84. __device__ inline void mma(const FragA& a_frag, const FragB& frag_b,
  85. FragC& frag_c) {
  86. const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
  87. const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
  88. float* c = reinterpret_cast<float*>(&frag_c);
  89. asm volatile(
  90. "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
  91. "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
  92. : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
  93. : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
  94. "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
  95. }
  96. // Instruction for loading a full 16x16 matrix fragment of operand A from shared
  97. // memory, directly in tensor core layout.
  98. __device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
  99. uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
  100. uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
  101. asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
  102. : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
  103. : "r"(smem));
  104. }
  105. // Lookup-table based 3-input logical operation; explicitly used for
  106. // dequantization as the compiler does not seem to automatically recognize it in
  107. // all cases.
  108. template <int lut>
  109. __device__ inline int lop3(int a, int b, int c) {
  110. int res;
  111. asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
  112. : "=r"(res)
  113. : "r"(a), "r"(b), "r"(c), "n"(lut));
  114. return res;
  115. }
  116. // Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
  117. // values. We mostly follow the strategy in the link below, with some small
  118. // changes:
  119. // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
  120. __device__ inline FragB dequant(int q) {
  121. const int LO = 0x000f000f;
  122. const int HI = 0x00f000f0;
  123. const int EX = 0x64006400;
  124. // Guarantee that the `(a & b) | c` operations are LOP3s.
  125. int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
  126. int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
  127. // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
  128. // directly into `SUB` and `ADD`.
  129. const int SUB = 0x64086408;
  130. const int MUL = 0x2c002c00;
  131. const int ADD = 0xd480d480;
  132. FragB frag_b;
  133. frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
  134. *reinterpret_cast<const half2*>(&SUB));
  135. frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
  136. *reinterpret_cast<const half2*>(&MUL),
  137. *reinterpret_cast<const half2*>(&ADD));
  138. return frag_b;
  139. }
  140. // Multiply dequantized values by the corresponding quantization scale; used
  141. // only for grouped quantization.
  142. __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
  143. half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]);
  144. frag_b[0] = __hmul2(frag_b[0], s);
  145. frag_b[1] = __hmul2(frag_b[1], s);
  146. }
  147. // Wait until barrier reaches `count`, then lock for current threadblock.
  148. __device__ inline void barrier_acquire(int* lock, int count) {
  149. if (threadIdx.x == 0) {
  150. int state = -1;
  151. do
  152. // Guarantee that subsequent writes by this threadblock will be visible
  153. // globally.
  154. asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
  155. : "=r"(state)
  156. : "l"(lock));
  157. while (state != count);
  158. }
  159. __syncthreads();
  160. }
  161. // Release barrier and increment visitation count.
  162. __device__ inline void barrier_release(int* lock, bool reset = false) {
  163. __syncthreads();
  164. if (threadIdx.x == 0) {
  165. if (reset) {
  166. lock[0] = 0;
  167. return;
  168. }
  169. int val = 1;
  170. // Make sure that all writes since acquiring this barrier are visible
  171. // globally, while releasing the barrier.
  172. asm volatile("fence.acq_rel.gpu;\n");
  173. asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
  174. :
  175. : "l"(lock), "r"(val));
  176. }
  177. }
  178. template <const int threads, // number of threads in a threadblock
  179. const int thread_m_blocks, // number of 16x16 blocks in the m
  180. // dimension (batchsize) of the
  181. // threadblock
  182. const int thread_n_blocks, // same for n dimension (output)
  183. const int thread_k_blocks, // same for k dimension (reduction)
  184. const int stages, // number of stages for the async global->shared
  185. // fetch pipeline
  186. const int group_blocks = -1 // number of consecutive 16x16 blocks
  187. // with a separate quantization scale
  188. >
  189. __global__ void Marlin(
  190. const int4* __restrict__ A, // fp16 input matrix of shape mxk
  191. const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
  192. int4* __restrict__ C, // fp16 output buffer of shape mxn
  193. const int4* __restrict__ s, // fp16 quantization scales of shape
  194. // (k/groupsize)xn
  195. int prob_m, // batch dimension m
  196. int prob_n, // output dimension n
  197. int prob_k, // reduction dimension k
  198. int* locks // extra global storage for barrier synchronization
  199. ) {
  200. // Each threadblock processes one "stripe" of the B matrix with (roughly) the
  201. // same size, which might involve multiple column "slices" (of width 16 *
  202. // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
  203. // example:
  204. // 0 1 3
  205. // 0 2 3
  206. // 1 2 4
  207. // While this kind of partitioning makes things somewhat more complicated, it
  208. // ensures good utilization of all SMs for many kinds of shape and GPU
  209. // configurations, while requiring as few slow global cross-threadblock
  210. // reductions as possible.
  211. // For larger GEMMs we run multiple batchsize 64 versions in parallel for a
  212. // better partitioning with less reductions
  213. int parallel = 1;
  214. if (prob_m > 16 * thread_m_blocks) {
  215. parallel = prob_m / (16 * thread_m_blocks);
  216. prob_m = 16 * thread_m_blocks;
  217. }
  218. int k_tiles = prob_k / 16 / thread_k_blocks;
  219. int n_tiles = prob_n / 16 / thread_n_blocks;
  220. int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x);
  221. // Ensure that the number of tiles in each stripe is a multiple of the
  222. // groupsize; this avoids an annoying special case where a stripe starts in
  223. // the middle of group.
  224. if (group_blocks != -1)
  225. iters = (group_blocks / thread_k_blocks) *
  226. ceildiv(iters, (group_blocks / thread_k_blocks));
  227. int slice_row = (iters * blockIdx.x) % k_tiles;
  228. int slice_col_par = (iters * blockIdx.x) / k_tiles;
  229. int slice_col = slice_col_par;
  230. int slice_iters; // number of threadblock tiles in the current slice
  231. int slice_count =
  232. 0; // total number of active threadblocks in the current slice
  233. int slice_idx; // index of threadblock in current slice; numbered bottom to
  234. // top
  235. // We can easily implement parallel problem execution by just remapping
  236. // indices and advancing global pointers
  237. if (slice_col_par >= n_tiles) {
  238. A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8;
  239. C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
  240. locks += (slice_col_par / n_tiles) * n_tiles;
  241. slice_col = slice_col_par % n_tiles;
  242. }
  243. // Compute all information about the current slice which is required for
  244. // synchronization.
  245. auto init_slice = [&]() {
  246. slice_iters =
  247. iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
  248. if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
  249. if (slice_iters == 0) return;
  250. if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
  251. slice_count = 1;
  252. slice_idx = 0;
  253. int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);
  254. if (col_first <= k_tiles * (slice_col_par + 1)) {
  255. int col_off = col_first - k_tiles * slice_col_par;
  256. slice_count = ceildiv(k_tiles - col_off, iters);
  257. if (col_off > 0) slice_count++;
  258. int delta_first = iters * blockIdx.x - col_first;
  259. if (delta_first < 0 || (col_off == 0 && delta_first == 0))
  260. slice_idx = slice_count - 1;
  261. else {
  262. slice_idx = slice_count - 1 - delta_first / iters;
  263. if (col_off > 0) slice_idx--;
  264. }
  265. }
  266. if (slice_col == n_tiles) {
  267. A += 16 * thread_m_blocks * prob_k / 8;
  268. C += 16 * thread_m_blocks * prob_n / 8;
  269. locks += n_tiles;
  270. slice_col = 0;
  271. }
  272. };
  273. init_slice();
  274. int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory
  275. // We typically use `constexpr` to indicate that this value is a compile-time
  276. // constant
  277. constexpr int a_sh_stride =
  278. 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory
  279. constexpr int a_gl_rd_delta_o =
  280. 16 * thread_k_blocks /
  281. 8; // delta between subsequent A tiles in global memory
  282. int a_gl_rd_delta_i =
  283. a_gl_stride *
  284. (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile
  285. constexpr int a_sh_wr_delta =
  286. a_sh_stride *
  287. (threads / a_gl_rd_delta_o); // between shared memory writes
  288. constexpr int a_sh_rd_delta_o =
  289. 2 * ((threads / 32) /
  290. (thread_n_blocks / 4)); // between shared memory tile reads
  291. constexpr int a_sh_rd_delta_i =
  292. a_sh_stride * 16; // within a shared memory tile
  293. constexpr int a_sh_stage =
  294. a_sh_stride * (16 * thread_m_blocks); // overall size of a tile
  295. constexpr int a_sh_wr_iters =
  296. ceildiv(a_sh_stage,
  297. a_sh_wr_delta); // number of shared write iterations for a tile
  298. int b_gl_stride = 16 * prob_n / 32;
  299. constexpr int b_sh_stride = 32 * thread_n_blocks / 4;
  300. int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
  301. int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride);
  302. constexpr int b_sh_wr_delta = threads;
  303. constexpr int b_sh_rd_delta = threads;
  304. constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
  305. constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
  306. int s_gl_stride = prob_n / 8;
  307. constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
  308. constexpr int s_sh_stage = s_sh_stride;
  309. int s_gl_rd_delta = s_gl_stride;
  310. // Global A read index of current thread.
  311. int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
  312. (threadIdx.x % a_gl_rd_delta_o);
  313. a_gl_rd += a_gl_rd_delta_o * slice_row;
  314. // Shared write index of current thread.
  315. int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
  316. (threadIdx.x % a_gl_rd_delta_o);
  317. // Shared read index.
  318. int a_sh_rd =
  319. a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;
  320. a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
  321. int b_gl_rd =
  322. b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
  323. b_gl_rd += b_sh_stride * slice_col;
  324. b_gl_rd += b_gl_rd_delta_o * slice_row;
  325. int b_sh_wr = threadIdx.x;
  326. int b_sh_rd = threadIdx.x;
  327. int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
  328. s_sh_stride * slice_col + threadIdx.x;
  329. int s_sh_wr = threadIdx.x;
  330. int s_sh_rd;
  331. // We use a different scale layout for grouped and column-wise quantization as
  332. // we scale a `half2` tile in column-major layout in the former and in
  333. // row-major in the latter case.
  334. if (group_blocks != -1)
  335. s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
  336. (threadIdx.x % 32) / 4;
  337. else
  338. s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
  339. (threadIdx.x % 32) % 4;
  340. // Precompute which thread should not read memory in which iterations; this is
  341. // needed if there are more threads than required for a certain tilesize or
  342. // when the batchsize is not a multiple of 16.
  343. bool a_sh_wr_pred[a_sh_wr_iters];
  344. #pragma unroll
  345. for (int i = 0; i < a_sh_wr_iters; i++)
  346. a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
  347. bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
  348. // To ensure that writing and reading A tiles to/from shared memory, the
  349. // latter in fragment format, is fully bank conflict free, we need to use a
  350. // rather fancy XOR-based layout. The key here is that neither reads nor
  351. // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
  352. // same shared memory banks. Further, it seems (based on NSight-Compute) that
  353. // each warp must also write a consecutive memory segment?
  354. auto transform_a = [&](int i) {
  355. int row = i / a_gl_rd_delta_o;
  356. return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
  357. };
  358. // Since the computation of this remapping is non-trivial and, due to our main
  359. // loop unrolls, all shared memory accesses are static, we simply precompute
  360. // both transformed reads and writes.
  361. int a_sh_wr_trans[a_sh_wr_iters];
  362. #pragma unroll
  363. for (int i = 0; i < a_sh_wr_iters; i++)
  364. a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
  365. int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
  366. #pragma unroll
  367. for (int i = 0; i < b_sh_wr_iters; i++) {
  368. #pragma unroll
  369. for (int j = 0; j < thread_m_blocks; j++)
  370. a_sh_rd_trans[i][j] =
  371. transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
  372. }
  373. // Since B-accesses have non-constant stride they have to be computed at
  374. // runtime; we break dependencies between subsequent accesses with a tile by
  375. // maintining multiple pointers (we have enough registers), a tiny
  376. // optimization.
  377. const int4* B_ptr[b_sh_wr_iters];
  378. #pragma unroll
  379. for (int i = 0; i < b_sh_wr_iters; i++)
  380. B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
  381. extern __shared__ int4 sh[];
  382. // Shared memory storage for global fetch pipelines.
  383. int4* sh_a = sh;
  384. int4* sh_b = sh_a + (stages * a_sh_stage);
  385. int4* sh_s = sh_b + (stages * b_sh_stage);
  386. // Register storage for double buffer of shared memory reads.
  387. FragA frag_a[2][thread_m_blocks];
  388. I4 frag_b_quant[2];
  389. FragC frag_c[thread_m_blocks][4][2];
  390. FragS frag_s[2][4];
  391. // Zero accumulators.
  392. auto zero_accums = [&]() {
  393. #pragma unroll
  394. for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
  395. reinterpret_cast<float*>(frag_c)[i] = 0;
  396. };
  397. // Asynchronously fetch the next A, B and s tile from global to the next
  398. // shared memory pipeline location.
  399. auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
  400. if (pred) {
  401. int4* sh_a_stage = sh_a + a_sh_stage * pipe;
  402. #pragma unroll
  403. for (int i = 0; i < a_sh_wr_iters; i++) {
  404. cp_async4_pred(
  405. &sh_a_stage[a_sh_wr_trans[i]],
  406. &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
  407. a_sh_wr_pred[i]);
  408. }
  409. int4* sh_b_stage = sh_b + b_sh_stage * pipe;
  410. #pragma unroll
  411. for (int i = 0; i < b_sh_wr_iters; i++) {
  412. cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]);
  413. B_ptr[i] += b_gl_rd_delta_o;
  414. }
  415. // Only fetch scales if this tile starts a new group
  416. if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {
  417. int4* sh_s_stage = sh_s + s_sh_stage * pipe;
  418. if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
  419. s_gl_rd += s_gl_rd_delta;
  420. }
  421. }
  422. // Insert a fence even when we are winding down the pipeline to ensure that
  423. // waiting is also correct at this point.
  424. cp_async_fence();
  425. };
  426. // Wait until the next thread tile has been loaded to shared memory.
  427. auto wait_for_stage = [&]() {
  428. // We only have `stages - 2` active fetches since we are double buffering
  429. // and can only issue the next fetch when it is guaranteed that the previous
  430. // shared memory load is fully complete (as it may otherwise be
  431. // overwritten).
  432. cp_async_wait<stages - 2>();
  433. __syncthreads();
  434. };
  435. // Load the next sub-tile from the current location in the shared memory pipe
  436. // into the current register buffer.
  437. auto fetch_to_registers = [&](int k, int pipe) {
  438. // It may seem inefficient that we reload the groups for every sub-tile;
  439. // however, this does not seem to be a significant bottleneck, while some
  440. // theoretically better attempts have lead to bad instruction ordering by
  441. // the compiler and correspondingly a noticeable drop in performance.
  442. if (group_blocks != -1) {
  443. int4* sh_s_stage =
  444. sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
  445. (pipe / (group_blocks / thread_k_blocks)));
  446. reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
  447. }
  448. int4* sh_a_stage = sh_a + a_sh_stage * pipe;
  449. #pragma unroll
  450. for (int i = 0; i < thread_m_blocks; i++)
  451. ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
  452. int4* sh_b_stage = sh_b + b_sh_stage * pipe;
  453. frag_b_quant[k % 2] = *reinterpret_cast<I4*>(
  454. &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]);
  455. };
  456. // Execute the actual tensor core matmul of a sub-tile.
  457. auto matmul = [&](int k) {
  458. // We have the m dimension as the inner loop in order to encourage overlapping
  459. // dequantization and matmul operations.
  460. #pragma unroll
  461. for (int j = 0; j < 4; j++) {
  462. int b_quant = frag_b_quant[k % 2][j];
  463. int b_quant_shift = b_quant >> 8;
  464. FragB frag_b0 = dequant(b_quant);
  465. // If there are no groups, we can just scale the final output once and can
  466. // avoid doing so for each weight.
  467. if (group_blocks != -1) scale(frag_b0, frag_s[k % 2][j], 0);
  468. FragB frag_b1 = dequant(b_quant_shift);
  469. if (group_blocks != -1) scale(frag_b1, frag_s[k % 2][j], 1);
  470. #pragma unroll
  471. for (int i = 0; i < thread_m_blocks; i++) {
  472. mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
  473. mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
  474. }
  475. }
  476. };
  477. // Since we slice across the k dimension of a tile in order to increase the
  478. // number of warps while keeping the n dimension of a tile reasonable, we have
  479. // multiple warps that accumulate their partial sums of the same output
  480. // location; which we have to reduce over in the end. We do in shared memory.
  481. auto thread_block_reduce = [&]() {
  482. constexpr int red_off = threads / b_sh_stride / 2;
  483. if (red_off >= 1) {
  484. int red_idx = threadIdx.x / b_sh_stride;
  485. constexpr int red_sh_stride = b_sh_stride * 4 * 2;
  486. constexpr int red_sh_delta = b_sh_stride;
  487. int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) +
  488. (threadIdx.x % b_sh_stride);
  489. // Parallel logarithmic shared memory reduction. We make sure to avoid any
  490. // unnecessary read or write iterations, e.g., for two warps we write only
  491. // once by warp 1 and read only once by warp 0.
  492. #pragma unroll
  493. for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
  494. #pragma unroll
  495. for (int i = red_off; i > 0; i /= 2) {
  496. if (i <= red_idx && red_idx < 2 * i) {
  497. #pragma unroll
  498. for (int j = 0; j < 4 * 2; j++) {
  499. int red_sh_wr =
  500. red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
  501. if (i < red_off) {
  502. float* c_rd =
  503. reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
  504. float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
  505. #pragma unroll
  506. for (int k = 0; k < 4; k++)
  507. reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
  508. c_rd[k] + c_wr[k];
  509. }
  510. sh[red_sh_wr] =
  511. reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
  512. }
  513. }
  514. __syncthreads();
  515. }
  516. if (red_idx == 0) {
  517. #pragma unroll
  518. for (int i = 0; i < 4 * 2; i++) {
  519. float* c_rd =
  520. reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
  521. #pragma unroll
  522. for (int j = 0; j < 4; j++)
  523. reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
  524. c_rd[j];
  525. }
  526. }
  527. __syncthreads();
  528. }
  529. }
  530. };
  531. // Since multiple threadblocks may process parts of the same column slice, we
  532. // finally have to globally reduce over the results. As the striped
  533. // partitioning minimizes the number of such reductions and our outputs are
  534. // usually rather small, we perform this reduction serially in L2 cache.
  535. auto global_reduce = [&](bool first = false, bool last = false) {
  536. // We are very careful here to reduce directly in the output buffer to
  537. // maximize L2 cache utilization in this step. To do this, we write out
  538. // results in FP16 (but still reduce with FP32 compute).
  539. constexpr int active_threads = 32 * thread_n_blocks / 4;
  540. if (threadIdx.x < active_threads) {
  541. int c_gl_stride = prob_n / 8;
  542. int c_gl_wr_delta_o = 8 * c_gl_stride;
  543. int c_gl_wr_delta_i = 4 * (active_threads / 32);
  544. int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) +
  545. 4 * (threadIdx.x / 32) + threadIdx.x % 4;
  546. c_gl_wr += (2 * thread_n_blocks) * slice_col;
  547. constexpr int c_sh_wr_delta = active_threads;
  548. int c_sh_wr = threadIdx.x;
  549. int row = (threadIdx.x % 32) / 4;
  550. if (!first) {
  551. // Interestingly, doing direct global accesses here really seems to mess up
  552. // the compiler and lead to slowdowns, hence we also use async-copies even
  553. // though these fetches are not actually asynchronous.
  554. #pragma unroll
  555. for (int i = 0; i < thread_m_blocks * 4; i++) {
  556. cp_async4_pred(
  557. &sh[c_sh_wr + c_sh_wr_delta * i],
  558. &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
  559. c_gl_wr_delta_i * (i % 2)],
  560. i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
  561. }
  562. cp_async_fence();
  563. cp_async_wait<0>();
  564. }
  565. #pragma unroll
  566. for (int i = 0; i < thread_m_blocks * 4; i++) {
  567. if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
  568. if (!first) {
  569. int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
  570. #pragma unroll
  571. for (int j = 0; j < 2 * 4; j++) {
  572. reinterpret_cast<float*>(
  573. &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=
  574. __half2float(reinterpret_cast<__half*>(&c_red)[j]);
  575. }
  576. }
  577. if (!last) {
  578. int4 c;
  579. #pragma unroll
  580. for (int j = 0; j < 2 * 4; j++) {
  581. reinterpret_cast<__half*>(&c)[j] =
  582. __float2half(reinterpret_cast<float*>(
  583. &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]);
  584. }
  585. C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =
  586. c;
  587. }
  588. }
  589. }
  590. }
  591. };
  592. // Write out the reduce final result in the correct layout. We only actually
  593. // reshuffle matrix fragments in this step, the reduction above is performed
  594. // in fragment layout.
  595. auto write_result = [&]() {
  596. int c_gl_stride = prob_n / 8;
  597. constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
  598. int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
  599. constexpr int c_sh_rd_delta =
  600. c_sh_stride * (threads / (2 * thread_n_blocks));
  601. int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) +
  602. (threadIdx.x % (2 * thread_n_blocks));
  603. c_gl_wr += (2 * thread_n_blocks) * slice_col;
  604. int c_sh_wr =
  605. (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;
  606. c_sh_wr += 32 * (threadIdx.x / 32);
  607. int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) +
  608. (threadIdx.x % (2 * thread_n_blocks));
  609. int c_gl_wr_end = c_gl_stride * prob_m;
  610. // We first reorder in shared memory to guarantee the most efficient final
  611. // global write patterns
  612. auto write = [&](int idx, float c0, float c1, FragS& s) {
  613. half2 res = __halves2half2(__float2half(c0), __float2half(c1));
  614. if (group_blocks ==
  615. -1) // for per-column quantization we finally apply the scale here
  616. res = __hmul2(res, s[0]);
  617. ((half2*)sh)[idx] = res;
  618. };
  619. if (threadIdx.x / 32 < thread_n_blocks / 4) {
  620. #pragma unroll
  621. for (int i = 0; i < thread_m_blocks; i++) {
  622. #pragma unroll
  623. for (int j = 0; j < 4; j++) {
  624. int wr = c_sh_wr + 8 * j;
  625. write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
  626. frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
  627. write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],
  628. frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
  629. write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],
  630. frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
  631. write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],
  632. frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
  633. }
  634. c_sh_wr += 16 * (4 * c_sh_stride);
  635. }
  636. }
  637. __syncthreads();
  638. #pragma unroll
  639. for (int i = 0;
  640. i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
  641. i++) {
  642. if (c_gl_wr < c_gl_wr_end) {
  643. C[c_gl_wr] = sh[c_sh_rd];
  644. c_gl_wr += c_gl_wr_delta;
  645. c_sh_rd += c_sh_rd_delta;
  646. }
  647. }
  648. };
  649. // Start global fetch and register load pipelines.
  650. auto start_pipes = [&]() {
  651. #pragma unroll
  652. for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters);
  653. zero_accums();
  654. wait_for_stage();
  655. fetch_to_registers(0, 0);
  656. a_gl_rd += a_gl_rd_delta_o * (stages - 1);
  657. };
  658. start_pipes();
  659. // Main loop.
  660. while (slice_iters) {
  661. // We unroll over both the global fetch and the register load pipeline to
  662. // ensure all shared memory accesses are static. Note that both pipelines have
  663. // even length meaning that the next iteration will always start at index 0.
  664. #pragma unroll
  665. for (int pipe = 0; pipe < stages;) {
  666. #pragma unroll
  667. for (int k = 0; k < b_sh_wr_iters; k++) {
  668. fetch_to_registers(k + 1, pipe % stages);
  669. if (k == b_sh_wr_iters - 2) {
  670. fetch_to_shared((pipe + stages - 1) % stages, pipe,
  671. slice_iters >= stages);
  672. pipe++;
  673. wait_for_stage();
  674. }
  675. matmul(k);
  676. }
  677. slice_iters--;
  678. if (slice_iters == 0) break;
  679. }
  680. a_gl_rd += a_gl_rd_delta_o * stages;
  681. // Process results and, if necessary, proceed to the next column slice.
  682. // While this pattern may not be the most readable, other ways of writing
  683. // the loop seemed to noticeably worse performance after compilation.
  684. if (slice_iters == 0) {
  685. cp_async_wait<0>();
  686. bool last = slice_idx == slice_count - 1;
  687. // For per-column scales, we only fetch them here in the final step before
  688. // write-out
  689. if (group_blocks == -1 && last) {
  690. if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
  691. cp_async_fence();
  692. }
  693. thread_block_reduce();
  694. if (group_blocks == -1 && last) {
  695. cp_async_wait<0>();
  696. __syncthreads();
  697. if (threadIdx.x / 32 < thread_n_blocks / 4) {
  698. reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
  699. reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
  700. }
  701. }
  702. if (slice_count > 1) { // only globally reduce if there is more than one
  703. // block in a slice
  704. barrier_acquire(&locks[slice_col], slice_idx);
  705. global_reduce(slice_idx == 0, last);
  706. barrier_release(&locks[slice_col], last);
  707. }
  708. if (last) // only the last block in a slice actually writes the result
  709. write_result();
  710. slice_row = 0;
  711. slice_col_par++;
  712. slice_col++;
  713. init_slice();
  714. if (slice_iters) {
  715. a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
  716. (threadIdx.x % a_gl_rd_delta_o);
  717. #pragma unroll
  718. for (int i = 0; i < b_sh_wr_iters; i++)
  719. B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
  720. if (slice_col == 0) {
  721. #pragma unroll
  722. for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
  723. }
  724. s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
  725. start_pipes();
  726. }
  727. }
  728. }
  729. }
  730. #else
  731. template <const int threads, // number of threads in a threadblock
  732. const int thread_m_blocks, // number of 16x16 blocks in the m
  733. // dimension (batchsize) of the
  734. // threadblock
  735. const int thread_n_blocks, // same for n dimension (output)
  736. const int thread_k_blocks, // same for k dimension (reduction)
  737. const int stages, // number of stages for the async global->shared
  738. // fetch pipeline
  739. const int group_blocks = -1 // number of consecutive 16x16 blocks
  740. // with a separate quantization scale
  741. >
  742. __global__ void Marlin(
  743. const int4* __restrict__ A, // fp16 input matrix of shape mxk
  744. const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
  745. int4* __restrict__ C, // fp16 output buffer of shape mxn
  746. const int4* __restrict__ s, // fp16 quantization scales of shape
  747. // (k/groupsize)xn
  748. int prob_m, // batch dimension m
  749. int prob_n, // output dimension n
  750. int prob_k, // reduction dimension k
  751. int* locks // extra global storage for barrier synchronization
  752. ) {
  753. // Marlin is not implemented yet for SM < 8.0
  754. assert(false);
  755. return;
  756. }
  757. #endif
  758. // 8 warps are a good choice since every SM has 4 schedulers and having more
  759. // than 1 warp per schedule allows some more latency hiding. At the same time,
  760. // we want relatively few warps to have many registers per warp and small tiles.
  761. const int USER_THREADS =
  762. 256; // Note: This is only used with user-provided thread_k/n
  763. const int STAGES = 4; // 4 pipeline stages fit into shared memory
  764. const int SHARED_MEM =
  765. 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0)
  766. static constexpr int min_thread_n = 64;
  767. static constexpr int min_thread_k = 64;
  768. static constexpr int tile_size = 16;
  769. static constexpr int max_par = 16;
  770. static constexpr int pack_factor_4bit =
  771. 8; // We have 8 4-bit vals inside a 32 bit
  772. #define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
  773. GROUP_BLOCKS, NUM_THREADS) \
  774. else if (thread_m_blocks == THREAD_M_BLOCKS && \
  775. thread_n_blocks == THREAD_N_BLOCKS && \
  776. thread_k_blocks == THREAD_K_BLOCKS && \
  777. group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
  778. cudaFuncSetAttribute(Marlin<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
  779. THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>, \
  780. cudaFuncAttributeMaxDynamicSharedMemorySize, \
  781. SHARED_MEM); \
  782. Marlin<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
  783. STAGES, GROUP_BLOCKS><<<blocks, NUM_THREADS, SHARED_MEM, stream>>>( \
  784. A_ptr, B_ptr, C_ptr, s_ptr, prob_m, prob_n, prob_k, locks); \
  785. }
  786. typedef struct {
  787. int thread_k;
  788. int thread_n;
  789. int num_threads;
  790. } thread_config_t;
  791. thread_config_t small_batch_thread_configs[] = {
  792. // Ordered by priority
  793. // thread_k, thread_n, num_threads
  794. {128, 128, 256}, // Default
  795. {128, 64, 128}, // Reduce N 2X, same K
  796. {64, 256, 256}, // Reduce K 2X, increase N 2X
  797. {64, 128, 128}, // Reduce K 2X, same N
  798. };
  799. thread_config_t large_batch_thread_configs[] = {
  800. // Ordered by priority
  801. // thread_k, thread_n, num_threads
  802. {64, 256, 256}, // Default
  803. {128, 128, 256}, // Reduce N 2X, increase K 2X
  804. {64, 128, 128}, // Reduce N 2X, same K
  805. {128, 64, 128}, // Reduce N 4X, increase K 2X
  806. };
  807. bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n,
  808. int prob_k) {
  809. // Sanity
  810. if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
  811. th_config.num_threads == -1) {
  812. return false;
  813. }
  814. // Verify K/N are divisible by thread K/N
  815. if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
  816. return false;
  817. }
  818. // thread_k can be only 128 or 64 (because it must be less than groupsize
  819. // which is 128)
  820. if (th_config.thread_k != 128 && th_config.thread_k != 64) {
  821. return false;
  822. }
  823. // Verify min for thread K/N
  824. if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
  825. return false;
  826. }
  827. // num_threads must be at least 128 (= 4 warps)
  828. if (th_config.num_threads < 128) {
  829. return false;
  830. }
  831. return true;
  832. }
  833. thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) {
  834. if (prob_m <= 16) {
  835. for (auto th_config : small_batch_thread_configs) {
  836. if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
  837. return th_config;
  838. }
  839. }
  840. } else {
  841. for (auto th_config : large_batch_thread_configs) {
  842. if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
  843. return th_config;
  844. }
  845. }
  846. }
  847. return thread_config_t{-1, -1, -1};
  848. }
  849. #define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \
  850. __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
  851. __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
  852. __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
  853. __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
  854. __CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
  855. __CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
  856. __CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
  857. __CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
  858. __CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
  859. __CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS)
  860. void marlin_cuda(const void* A, const void* B, void* C, void* s, int prob_m,
  861. int prob_n, int prob_k, void* workspace, int groupsize = -1,
  862. int dev = 0, cudaStream_t stream = 0, int thread_k = -1,
  863. int thread_n = -1, int sms = -1, int max_par = 16) {
  864. int tot_m = prob_m;
  865. int tot_m_blocks = ceildiv(tot_m, 16);
  866. int pad = 16 * tot_m_blocks - tot_m;
  867. if (sms == -1)
  868. cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
  869. // Set thread config
  870. thread_config_t th_config;
  871. if (thread_k != -1 && thread_n != -1) {
  872. // User-defined config
  873. th_config = thread_config_t{thread_k, thread_n, USER_THREADS};
  874. } else {
  875. // Auto config
  876. th_config = determine_thread_config(prob_m, prob_n, prob_k);
  877. }
  878. if (!is_valid_config(th_config, prob_m, prob_n, prob_k)) {
  879. throw std::runtime_error(
  880. "Invalid thread config: thread_k = " + str(th_config.thread_k) +
  881. ", thread_n = " + str(th_config.thread_n) +
  882. ", num_threads = " + str(th_config.num_threads) + " for MKN = [" +
  883. str(prob_m) + ", " + str(prob_k) + ", " + str(prob_n) + "]");
  884. }
  885. // Uncomment for debug
  886. // std::cout << "Using thread_config: thread_k = " + str(th_config.thread_k) +
  887. // ", thread_n = " + str(th_config.thread_n) +
  888. // ", num_threads = " + str(th_config.num_threads) + " for
  889. // MKN = [" + str(prob_m) +
  890. // ", " + str(prob_k) + ", " + str(prob_n) + "]\n";
  891. int num_threads = th_config.num_threads;
  892. thread_k = th_config.thread_k;
  893. thread_n = th_config.thread_n;
  894. int thread_k_blocks = thread_k / 16;
  895. int thread_n_blocks = thread_n / 16;
  896. int group_blocks = (groupsize == -1) ? -1 : groupsize / 16;
  897. int blocks = sms;
  898. if (prob_m == 0 || prob_n == 0 || prob_k == 0) {
  899. return;
  900. }
  901. TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
  902. " is not divisible by thread_n = ", thread_n);
  903. TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
  904. " is not divisible by thread_k = ", thread_k);
  905. if (group_blocks != -1) {
  906. TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
  907. " is not divisible by group_blocks = ", group_blocks);
  908. }
  909. const int4* A_ptr = (const int4*)A;
  910. const int4* B_ptr = (const int4*)B;
  911. int4* C_ptr = (int4*)C;
  912. const int4* s_ptr = (const int4*)s;
  913. int* locks = (int*)workspace;
  914. for (int i = 0; i < tot_m_blocks; i += 4) {
  915. int thread_m_blocks = tot_m_blocks - i;
  916. prob_m = tot_m - 16 * i;
  917. int par = 1;
  918. if (thread_m_blocks > 4) {
  919. // Note that parallel > 1 currently only works for inputs without any
  920. // padding
  921. par = (16 * thread_m_blocks - pad) / 64;
  922. if (par > max_par) par = max_par;
  923. prob_m = 64 * par;
  924. i += 4 * (par - 1);
  925. thread_m_blocks = 4;
  926. }
  927. // For compilation speed, we only define the kernel configurations that have
  928. // seemed useful (in terms of performance) in our testing, however many more
  929. // are, in principle, possible.
  930. if (false) {
  931. }
  932. CALL_IF(8, 8, 256)
  933. CALL_IF(16, 4, 256)
  934. CALL_IF(8, 4, 128)
  935. CALL_IF(4, 8, 128)
  936. else {
  937. throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) +
  938. ", " + str(prob_k) + ", " + str(prob_n) + "]" +
  939. ", groupsize = " + str(groupsize) +
  940. ", thread_m_blocks = " + str(thread_m_blocks) +
  941. ", thread_n_blocks = " + str(thread_n_blocks) +
  942. ", thread_k_blocks = " + str(thread_k_blocks));
  943. }
  944. A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;
  945. C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par;
  946. }
  947. }
  948. } // namespace marlin_dense
  949. torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
  950. torch::Tensor& b_scales, torch::Tensor& workspace,
  951. int64_t size_m, int64_t size_n, int64_t size_k) {
  952. // Verify M
  953. TORCH_CHECK(size_m == a.size(0),
  954. "Shape mismatch: a.size(0) = " + str(a.size(0)) +
  955. ", size_m = " + str(size_m));
  956. // Verify K
  957. TORCH_CHECK(size_k == a.size(1),
  958. "Shape mismatch: a.size(1) = " + str(a.size(1)) +
  959. ", size_k = " + str(size_k));
  960. TORCH_CHECK(size_k % marlin_dense::tile_size == 0,
  961. "size_k = " + str(size_k) + " is not divisible by tile_size = " +
  962. str(marlin_dense::tile_size));
  963. TORCH_CHECK((size_k / marlin_dense::tile_size) == b_q_weight.size(0),
  964. "Shape mismatch: b_q_weight.size(0) = " +
  965. str(b_q_weight.size(0)) + ", size_k = " + str(size_k) +
  966. ", tile_size = " + str(marlin_dense::tile_size));
  967. // Verify N
  968. TORCH_CHECK(b_scales.size(1) == size_n,
  969. "b_scales.size(1) = " + str(b_scales.size(1)) +
  970. ", size_n = " + str(size_n));
  971. TORCH_CHECK(
  972. b_q_weight.size(1) % marlin_dense::tile_size == 0,
  973. "b_q_weight.size(1) = " + str(b_q_weight.size(1)) +
  974. " is not divisible by tile_size = " + str(marlin_dense::tile_size));
  975. int actual_size_n = (b_q_weight.size(1) / marlin_dense::tile_size) *
  976. marlin_dense::pack_factor_4bit;
  977. TORCH_CHECK(
  978. size_n == actual_size_n,
  979. "size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n));
  980. // Verify A device and strides
  981. TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
  982. TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
  983. // Verify B device and strides
  984. TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
  985. TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
  986. // Verify scales device and strides
  987. TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
  988. TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
  989. // Alloc C matrix
  990. const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
  991. auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
  992. torch::Tensor c = torch::empty({size_m, size_n}, options);
  993. // thread_k: `k` size of a thread_tile in `weights` (can usually be left as
  994. // auto -1)
  995. int thread_k = -1;
  996. // thread_n: `n` size of a thread_tile in `weights` (can usually be left as
  997. // auto -1)
  998. int thread_n = -1;
  999. // sms: number of SMs to use for the kernel (can usually be left as auto -1)
  1000. int sms = -1;
  1001. // Detect groupsize
  1002. if (b_scales.size(0) != 1) {
  1003. TORCH_CHECK(size_k % b_scales.size(0) == 0,
  1004. "size_k = " + str(size_k) +
  1005. ", is not divisible by b_scales.size(0) = " +
  1006. str(b_scales.size(0)));
  1007. }
  1008. int groupsize = b_scales.size(0) == 1 ? -1 : size_k / b_scales.size(0);
  1009. // Verify groupsize
  1010. TORCH_CHECK(groupsize == -1 || groupsize == 128,
  1011. "Unexpected groupsize = " + str(groupsize));
  1012. // Verify workspace size
  1013. TORCH_CHECK(size_n % marlin_dense::min_thread_n == 0,
  1014. "size_n = " + str(size_n) +
  1015. ", is not divisible by min_thread_n = " +
  1016. str(marlin_dense::min_thread_n));
  1017. int min_workspace_size =
  1018. (size_n / marlin_dense::min_thread_n) * marlin_dense::max_par;
  1019. TORCH_CHECK(workspace.numel() >= min_workspace_size,
  1020. "workspace.numel = " + str(workspace.numel()) +
  1021. " is below min_workspace_size = " + str(min_workspace_size));
  1022. int dev = a.get_device();
  1023. marlin_dense::marlin_cuda(a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(),
  1024. b_scales.data_ptr(), size_m, size_n, size_k,
  1025. workspace.data_ptr(), groupsize, dev,
  1026. at::cuda::getCurrentCUDAStream(dev), thread_k,
  1027. thread_n, sms, marlin_dense::max_par);
  1028. return c;
  1029. }