marlin_cuda_kernel.cu 42 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068
  1. /*
  2. * Modified by Neural Magic
  3. * Copyright (C) Marlin.2024 Elias Frantar
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. #include <torch/all.h>
  18. #include <ATen/cuda/CUDAContext.h>
  19. #include <c10/cuda/CUDAGuard.h>
  20. #include <cuda.h>
  21. #include <cuda_fp16.h>
  22. #include <cuda_runtime.h>
  23. #include <iostream>
  24. #include "common/base.cuh"
  25. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  26. #include "common/mem.cuh"
  27. #endif
  28. template <typename T>
  29. inline std::string str(T x) {
  30. return std::to_string(x);
  31. }
  32. namespace marlin_dense {
  33. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  34. using I4 = Vec<int, 4>;
  35. // Matrix fragments for tensor core instructions; their precise layout is
  36. // documented here:
  37. // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
  38. using FragA = Vec<half2, 4>;
  39. using FragB = Vec<half2, 2>;
  40. using FragC = Vec<float, 4>;
  41. using FragS = Vec<half2, 1>; // quantization scales
  42. // m16n8k16 tensor core mma instruction with fp16 inputs and fp32
  43. // output/accumulation.
  44. __device__ inline void mma(const FragA& a_frag, const FragB& frag_b,
  45. FragC& frag_c) {
  46. const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
  47. const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
  48. float* c = reinterpret_cast<float*>(&frag_c);
  49. asm volatile(
  50. "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
  51. "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
  52. : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
  53. : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
  54. "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
  55. }
  56. // Instruction for loading a full 16x16 matrix fragment of operand A from shared
  57. // memory, directly in tensor core layout.
  58. __device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
  59. uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
  60. uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
  61. asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
  62. : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
  63. : "r"(smem));
  64. }
  65. // Lookup-table based 3-input logical operation; explicitly used for
  66. // dequantization as the compiler does not seem to automatically recognize it in
  67. // all cases.
  68. template <int lut>
  69. __device__ inline int lop3(int a, int b, int c) {
  70. int res;
  71. asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
  72. : "=r"(res)
  73. : "r"(a), "r"(b), "r"(c), "n"(lut));
  74. return res;
  75. }
  76. // Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
  77. // values. We mostly follow the strategy in the link below, with some small
  78. // changes:
  79. // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
  80. __device__ inline FragB dequant(int q) {
  81. const int LO = 0x000f000f;
  82. const int HI = 0x00f000f0;
  83. const int EX = 0x64006400;
  84. // Guarantee that the `(a & b) | c` operations are LOP3s.
  85. int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
  86. int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
  87. // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
  88. // directly into `SUB` and `ADD`.
  89. const int SUB = 0x64086408;
  90. const int MUL = 0x2c002c00;
  91. const int ADD = 0xd480d480;
  92. FragB frag_b;
  93. frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
  94. *reinterpret_cast<const half2*>(&SUB));
  95. frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
  96. *reinterpret_cast<const half2*>(&MUL),
  97. *reinterpret_cast<const half2*>(&ADD));
  98. return frag_b;
  99. }
  100. // Multiply dequantized values by the corresponding quantization scale; used
  101. // only for grouped quantization.
  102. __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
  103. half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]);
  104. frag_b[0] = __hmul2(frag_b[0], s);
  105. frag_b[1] = __hmul2(frag_b[1], s);
  106. }
  107. template <const int threads, // number of threads in a threadblock
  108. const int thread_m_blocks, // number of 16x16 blocks in the m
  109. // dimension (batchsize) of the
  110. // threadblock
  111. const int thread_n_blocks, // same for n dimension (output)
  112. const int thread_k_blocks, // same for k dimension (reduction)
  113. const int stages, // number of stages for the async global->shared
  114. // fetch pipeline
  115. const int group_blocks = -1 // number of consecutive 16x16 blocks
  116. // with a separate quantization scale
  117. >
  118. __global__ void Marlin(
  119. const int4* __restrict__ A, // fp16 input matrix of shape mxk
  120. const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
  121. int4* __restrict__ C, // fp16 output buffer of shape mxn
  122. const int4* __restrict__ s, // fp16 quantization scales of shape
  123. // (k/groupsize)xn
  124. int prob_m, // batch dimension m
  125. int prob_n, // output dimension n
  126. int prob_k, // reduction dimension k
  127. int* locks // extra global storage for barrier synchronization
  128. ) {
  129. // Each threadblock processes one "stripe" of the B matrix with (roughly) the
  130. // same size, which might involve multiple column "slices" (of width 16 *
  131. // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
  132. // example:
  133. // 0 1 3
  134. // 0 2 3
  135. // 1 2 4
  136. // While this kind of partitioning makes things somewhat more complicated, it
  137. // ensures good utilization of all SMs for many kinds of shape and GPU
  138. // configurations, while requiring as few slow global cross-threadblock
  139. // reductions as possible.
  140. // For larger GEMMs we run multiple batchsize 64 versions in parallel for a
  141. // better partitioning with less reductions
  142. int parallel = 1;
  143. if (prob_m > 16 * thread_m_blocks) {
  144. parallel = prob_m / (16 * thread_m_blocks);
  145. prob_m = 16 * thread_m_blocks;
  146. }
  147. int k_tiles = prob_k / 16 / thread_k_blocks;
  148. int n_tiles = prob_n / 16 / thread_n_blocks;
  149. int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x);
  150. // Ensure that the number of tiles in each stripe is a multiple of the
  151. // groupsize; this avoids an annoying special case where a stripe starts in
  152. // the middle of group.
  153. if (group_blocks != -1)
  154. iters = (group_blocks / thread_k_blocks) *
  155. ceildiv(iters, (group_blocks / thread_k_blocks));
  156. int slice_row = (iters * blockIdx.x) % k_tiles;
  157. int slice_col_par = (iters * blockIdx.x) / k_tiles;
  158. int slice_col = slice_col_par;
  159. int slice_iters; // number of threadblock tiles in the current slice
  160. int slice_count =
  161. 0; // total number of active threadblocks in the current slice
  162. int slice_idx; // index of threadblock in current slice; numbered bottom to
  163. // top
  164. // We can easily implement parallel problem execution by just remapping
  165. // indices and advancing global pointers
  166. if (slice_col_par >= n_tiles) {
  167. A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8;
  168. C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
  169. locks += (slice_col_par / n_tiles) * n_tiles;
  170. slice_col = slice_col_par % n_tiles;
  171. }
  172. // Compute all information about the current slice which is required for
  173. // synchronization.
  174. auto init_slice = [&]() {
  175. slice_iters =
  176. iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
  177. if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
  178. if (slice_iters == 0) return;
  179. if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
  180. slice_count = 1;
  181. slice_idx = 0;
  182. int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);
  183. if (col_first <= k_tiles * (slice_col_par + 1)) {
  184. int col_off = col_first - k_tiles * slice_col_par;
  185. slice_count = ceildiv(k_tiles - col_off, iters);
  186. if (col_off > 0) slice_count++;
  187. int delta_first = iters * blockIdx.x - col_first;
  188. if (delta_first < 0 || (col_off == 0 && delta_first == 0))
  189. slice_idx = slice_count - 1;
  190. else {
  191. slice_idx = slice_count - 1 - delta_first / iters;
  192. if (col_off > 0) slice_idx--;
  193. }
  194. }
  195. if (slice_col == n_tiles) {
  196. A += 16 * thread_m_blocks * prob_k / 8;
  197. C += 16 * thread_m_blocks * prob_n / 8;
  198. locks += n_tiles;
  199. slice_col = 0;
  200. }
  201. };
  202. init_slice();
  203. int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory
  204. // We typically use `constexpr` to indicate that this value is a compile-time
  205. // constant
  206. constexpr int a_sh_stride =
  207. 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory
  208. constexpr int a_gl_rd_delta_o =
  209. 16 * thread_k_blocks /
  210. 8; // delta between subsequent A tiles in global memory
  211. int a_gl_rd_delta_i =
  212. a_gl_stride *
  213. (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile
  214. constexpr int a_sh_wr_delta =
  215. a_sh_stride *
  216. (threads / a_gl_rd_delta_o); // between shared memory writes
  217. constexpr int a_sh_rd_delta_o =
  218. 2 * ((threads / 32) /
  219. (thread_n_blocks / 4)); // between shared memory tile reads
  220. constexpr int a_sh_rd_delta_i =
  221. a_sh_stride * 16; // within a shared memory tile
  222. constexpr int a_sh_stage =
  223. a_sh_stride * (16 * thread_m_blocks); // overall size of a tile
  224. constexpr int a_sh_wr_iters =
  225. ceildiv(a_sh_stage,
  226. a_sh_wr_delta); // number of shared write iterations for a tile
  227. int b_gl_stride = 16 * prob_n / 32;
  228. constexpr int b_sh_stride = 32 * thread_n_blocks / 4;
  229. int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
  230. int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride);
  231. constexpr int b_sh_wr_delta = threads;
  232. constexpr int b_sh_rd_delta = threads;
  233. constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
  234. constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
  235. int s_gl_stride = prob_n / 8;
  236. constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
  237. constexpr int s_sh_stage = s_sh_stride;
  238. int s_gl_rd_delta = s_gl_stride;
  239. // Global A read index of current thread.
  240. int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
  241. (threadIdx.x % a_gl_rd_delta_o);
  242. a_gl_rd += a_gl_rd_delta_o * slice_row;
  243. // Shared write index of current thread.
  244. int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
  245. (threadIdx.x % a_gl_rd_delta_o);
  246. // Shared read index.
  247. int a_sh_rd =
  248. a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;
  249. a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
  250. int b_gl_rd =
  251. b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
  252. b_gl_rd += b_sh_stride * slice_col;
  253. b_gl_rd += b_gl_rd_delta_o * slice_row;
  254. int b_sh_wr = threadIdx.x;
  255. int b_sh_rd = threadIdx.x;
  256. int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
  257. s_sh_stride * slice_col + threadIdx.x;
  258. int s_sh_wr = threadIdx.x;
  259. int s_sh_rd;
  260. // We use a different scale layout for grouped and column-wise quantization as
  261. // we scale a `half2` tile in column-major layout in the former and in
  262. // row-major in the latter case.
  263. if (group_blocks != -1)
  264. s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
  265. (threadIdx.x % 32) / 4;
  266. else
  267. s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
  268. (threadIdx.x % 32) % 4;
  269. // Precompute which thread should not read memory in which iterations; this is
  270. // needed if there are more threads than required for a certain tilesize or
  271. // when the batchsize is not a multiple of 16.
  272. bool a_sh_wr_pred[a_sh_wr_iters];
  273. #pragma unroll
  274. for (int i = 0; i < a_sh_wr_iters; i++)
  275. a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
  276. bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
  277. // To ensure that writing and reading A tiles to/from shared memory, the
  278. // latter in fragment format, is fully bank conflict free, we need to use a
  279. // rather fancy XOR-based layout. The key here is that neither reads nor
  280. // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
  281. // same shared memory banks. Further, it seems (based on NSight-Compute) that
  282. // each warp must also write a consecutive memory segment?
  283. auto transform_a = [&](int i) {
  284. int row = i / a_gl_rd_delta_o;
  285. return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
  286. };
  287. // Since the computation of this remapping is non-trivial and, due to our main
  288. // loop unrolls, all shared memory accesses are static, we simply precompute
  289. // both transformed reads and writes.
  290. int a_sh_wr_trans[a_sh_wr_iters];
  291. #pragma unroll
  292. for (int i = 0; i < a_sh_wr_iters; i++)
  293. a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
  294. int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
  295. #pragma unroll
  296. for (int i = 0; i < b_sh_wr_iters; i++) {
  297. #pragma unroll
  298. for (int j = 0; j < thread_m_blocks; j++)
  299. a_sh_rd_trans[i][j] =
  300. transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
  301. }
  302. // Since B-accesses have non-constant stride they have to be computed at
  303. // runtime; we break dependencies between subsequent accesses with a tile by
  304. // maintining multiple pointers (we have enough registers), a tiny
  305. // optimization.
  306. const int4* B_ptr[b_sh_wr_iters];
  307. #pragma unroll
  308. for (int i = 0; i < b_sh_wr_iters; i++)
  309. B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
  310. extern __shared__ int4 sh[];
  311. // Shared memory storage for global fetch pipelines.
  312. int4* sh_a = sh;
  313. int4* sh_b = sh_a + (stages * a_sh_stage);
  314. int4* sh_s = sh_b + (stages * b_sh_stage);
  315. // Register storage for double buffer of shared memory reads.
  316. FragA frag_a[2][thread_m_blocks];
  317. I4 frag_b_quant[2];
  318. FragC frag_c[thread_m_blocks][4][2];
  319. FragS frag_s[2][4];
  320. // Zero accumulators.
  321. auto zero_accums = [&]() {
  322. #pragma unroll
  323. for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
  324. reinterpret_cast<float*>(frag_c)[i] = 0;
  325. };
  326. // Asynchronously fetch the next A, B and s tile from global to the next
  327. // shared memory pipeline location.
  328. auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
  329. if (pred) {
  330. int4* sh_a_stage = sh_a + a_sh_stage * pipe;
  331. #pragma unroll
  332. for (int i = 0; i < a_sh_wr_iters; i++) {
  333. cp_async4_pred(
  334. &sh_a_stage[a_sh_wr_trans[i]],
  335. &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
  336. a_sh_wr_pred[i]);
  337. }
  338. int4* sh_b_stage = sh_b + b_sh_stage * pipe;
  339. #pragma unroll
  340. for (int i = 0; i < b_sh_wr_iters; i++) {
  341. cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]);
  342. B_ptr[i] += b_gl_rd_delta_o;
  343. }
  344. // Only fetch scales if this tile starts a new group
  345. if constexpr (group_blocks != -1) {
  346. // This assumes group_blocks >= thread_k_blocks
  347. // and would need to be modified to support smaller groups.
  348. static_assert(group_blocks >= thread_k_blocks);
  349. if (pipe % (group_blocks / thread_k_blocks) == 0) {
  350. int4* sh_s_stage = sh_s + s_sh_stage * pipe;
  351. if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
  352. s_gl_rd += s_gl_rd_delta;
  353. }
  354. }
  355. }
  356. // Insert a fence even when we are winding down the pipeline to ensure that
  357. // waiting is also correct at this point.
  358. cp_async_fence();
  359. };
  360. // Wait until the next thread tile has been loaded to shared memory.
  361. auto wait_for_stage = [&]() {
  362. // We only have `stages - 2` active fetches since we are double buffering
  363. // and can only issue the next fetch when it is guaranteed that the previous
  364. // shared memory load is fully complete (as it may otherwise be
  365. // overwritten).
  366. cp_async_wait<stages - 2>();
  367. __syncthreads();
  368. };
  369. // Load the next sub-tile from the current location in the shared memory pipe
  370. // into the current register buffer.
  371. auto fetch_to_registers = [&](int k, int pipe) {
  372. // It may seem inefficient that we reload the groups for every sub-tile;
  373. // however, this does not seem to be a significant bottleneck, while some
  374. // theoretically better attempts have lead to bad instruction ordering by
  375. // the compiler and correspondingly a noticeable drop in performance.
  376. if constexpr (group_blocks != -1) {
  377. // This assumes group_blocks >= thread_k_blocks
  378. // and would need to be modified to support smaller groups.
  379. static_assert(group_blocks >= thread_k_blocks);
  380. int4* sh_s_stage =
  381. sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
  382. (pipe / (group_blocks / thread_k_blocks)));
  383. reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
  384. }
  385. int4* sh_a_stage = sh_a + a_sh_stage * pipe;
  386. #pragma unroll
  387. for (int i = 0; i < thread_m_blocks; i++)
  388. ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
  389. int4* sh_b_stage = sh_b + b_sh_stage * pipe;
  390. frag_b_quant[k % 2] = *reinterpret_cast<I4*>(
  391. &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]);
  392. };
  393. // Execute the actual tensor core matmul of a sub-tile.
  394. auto matmul = [&](int k) {
  395. // We have the m dimension as the inner loop in order to encourage overlapping
  396. // dequantization and matmul operations.
  397. #pragma unroll
  398. for (int j = 0; j < 4; j++) {
  399. int b_quant = frag_b_quant[k % 2][j];
  400. int b_quant_shift = b_quant >> 8;
  401. FragB frag_b0 = dequant(b_quant);
  402. // If there are no groups, we can just scale the final output once and can
  403. // avoid doing so for each weight.
  404. if (group_blocks != -1) scale(frag_b0, frag_s[k % 2][j], 0);
  405. FragB frag_b1 = dequant(b_quant_shift);
  406. if (group_blocks != -1) scale(frag_b1, frag_s[k % 2][j], 1);
  407. #pragma unroll
  408. for (int i = 0; i < thread_m_blocks; i++) {
  409. mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
  410. mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
  411. }
  412. }
  413. };
  414. // Since we slice across the k dimension of a tile in order to increase the
  415. // number of warps while keeping the n dimension of a tile reasonable, we have
  416. // multiple warps that accumulate their partial sums of the same output
  417. // location; which we have to reduce over in the end. We do in shared memory.
  418. auto thread_block_reduce = [&]() {
  419. constexpr int red_off = threads / b_sh_stride / 2;
  420. if (red_off >= 1) {
  421. int red_idx = threadIdx.x / b_sh_stride;
  422. constexpr int red_sh_stride = b_sh_stride * 4 * 2;
  423. constexpr int red_sh_delta = b_sh_stride;
  424. int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) +
  425. (threadIdx.x % b_sh_stride);
  426. // Parallel logarithmic shared memory reduction. We make sure to avoid any
  427. // unnecessary read or write iterations, e.g., for two warps we write only
  428. // once by warp 1 and read only once by warp 0.
  429. #pragma unroll
  430. for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
  431. #pragma unroll
  432. for (int i = red_off; i > 0; i /= 2) {
  433. if (i <= red_idx && red_idx < 2 * i) {
  434. #pragma unroll
  435. for (int j = 0; j < 4 * 2; j++) {
  436. int red_sh_wr =
  437. red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
  438. if (i < red_off) {
  439. float* c_rd =
  440. reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
  441. float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
  442. #pragma unroll
  443. for (int k = 0; k < 4; k++)
  444. reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
  445. c_rd[k] + c_wr[k];
  446. }
  447. sh[red_sh_wr] =
  448. reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
  449. }
  450. }
  451. __syncthreads();
  452. }
  453. if (red_idx == 0) {
  454. #pragma unroll
  455. for (int i = 0; i < 4 * 2; i++) {
  456. float* c_rd =
  457. reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
  458. #pragma unroll
  459. for (int j = 0; j < 4; j++)
  460. reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
  461. c_rd[j];
  462. }
  463. }
  464. __syncthreads();
  465. }
  466. }
  467. };
  468. // Since multiple threadblocks may process parts of the same column slice, we
  469. // finally have to globally reduce over the results. As the striped
  470. // partitioning minimizes the number of such reductions and our outputs are
  471. // usually rather small, we perform this reduction serially in L2 cache.
  472. auto global_reduce = [&](bool first = false, bool last = false) {
  473. // We are very careful here to reduce directly in the output buffer to
  474. // maximize L2 cache utilization in this step. To do this, we write out
  475. // results in FP16 (but still reduce with FP32 compute).
  476. constexpr int active_threads = 32 * thread_n_blocks / 4;
  477. if (threadIdx.x < active_threads) {
  478. int c_gl_stride = prob_n / 8;
  479. int c_gl_wr_delta_o = 8 * c_gl_stride;
  480. int c_gl_wr_delta_i = 4 * (active_threads / 32);
  481. int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) +
  482. 4 * (threadIdx.x / 32) + threadIdx.x % 4;
  483. c_gl_wr += (2 * thread_n_blocks) * slice_col;
  484. constexpr int c_sh_wr_delta = active_threads;
  485. int c_sh_wr = threadIdx.x;
  486. int row = (threadIdx.x % 32) / 4;
  487. if (!first) {
  488. // Interestingly, doing direct global accesses here really seems to mess up
  489. // the compiler and lead to slowdowns, hence we also use async-copies even
  490. // though these fetches are not actually asynchronous.
  491. #pragma unroll
  492. for (int i = 0; i < thread_m_blocks * 4; i++) {
  493. cp_async4_pred(
  494. &sh[c_sh_wr + c_sh_wr_delta * i],
  495. &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
  496. c_gl_wr_delta_i * (i % 2)],
  497. i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
  498. }
  499. cp_async_fence();
  500. cp_async_wait<0>();
  501. }
  502. #pragma unroll
  503. for (int i = 0; i < thread_m_blocks * 4; i++) {
  504. if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
  505. if (!first) {
  506. int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
  507. #pragma unroll
  508. for (int j = 0; j < 2 * 4; j++) {
  509. reinterpret_cast<float*>(
  510. &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=
  511. __half2float(reinterpret_cast<__half*>(&c_red)[j]);
  512. }
  513. }
  514. if (!last) {
  515. int4 c;
  516. #pragma unroll
  517. for (int j = 0; j < 2 * 4; j++) {
  518. reinterpret_cast<__half*>(&c)[j] =
  519. __float2half(reinterpret_cast<float*>(
  520. &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]);
  521. }
  522. C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =
  523. c;
  524. }
  525. }
  526. }
  527. }
  528. };
  529. // Write out the reduce final result in the correct layout. We only actually
  530. // reshuffle matrix fragments in this step, the reduction above is performed
  531. // in fragment layout.
  532. auto write_result = [&]() {
  533. int c_gl_stride = prob_n / 8;
  534. constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
  535. int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
  536. constexpr int c_sh_rd_delta =
  537. c_sh_stride * (threads / (2 * thread_n_blocks));
  538. int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) +
  539. (threadIdx.x % (2 * thread_n_blocks));
  540. c_gl_wr += (2 * thread_n_blocks) * slice_col;
  541. int c_sh_wr =
  542. (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;
  543. c_sh_wr += 32 * (threadIdx.x / 32);
  544. int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) +
  545. (threadIdx.x % (2 * thread_n_blocks));
  546. int c_gl_wr_end = c_gl_stride * prob_m;
  547. // We first reorder in shared memory to guarantee the most efficient final
  548. // global write patterns
  549. auto write = [&](int idx, float c0, float c1, FragS& s) {
  550. half2 res = __halves2half2(__float2half(c0), __float2half(c1));
  551. if (group_blocks ==
  552. -1) // for per-column quantization we finally apply the scale here
  553. res = __hmul2(res, s[0]);
  554. ((half2*)sh)[idx] = res;
  555. };
  556. if (threadIdx.x / 32 < thread_n_blocks / 4) {
  557. #pragma unroll
  558. for (int i = 0; i < thread_m_blocks; i++) {
  559. #pragma unroll
  560. for (int j = 0; j < 4; j++) {
  561. int wr = c_sh_wr + 8 * j;
  562. write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
  563. frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
  564. write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],
  565. frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
  566. write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],
  567. frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
  568. write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],
  569. frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
  570. }
  571. c_sh_wr += 16 * (4 * c_sh_stride);
  572. }
  573. }
  574. __syncthreads();
  575. #pragma unroll
  576. for (int i = 0;
  577. i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
  578. i++) {
  579. if (c_gl_wr < c_gl_wr_end) {
  580. C[c_gl_wr] = sh[c_sh_rd];
  581. c_gl_wr += c_gl_wr_delta;
  582. c_sh_rd += c_sh_rd_delta;
  583. }
  584. }
  585. };
  586. // Start global fetch and register load pipelines.
  587. auto start_pipes = [&]() {
  588. #pragma unroll
  589. for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters);
  590. zero_accums();
  591. wait_for_stage();
  592. fetch_to_registers(0, 0);
  593. a_gl_rd += a_gl_rd_delta_o * (stages - 1);
  594. };
  595. start_pipes();
  596. // Main loop.
  597. while (slice_iters) {
  598. // We unroll over both the global fetch and the register load pipeline to
  599. // ensure all shared memory accesses are static. Note that both pipelines have
  600. // even length meaning that the next iteration will always start at index 0.
  601. #pragma unroll
  602. for (int pipe = 0; pipe < stages;) {
  603. #pragma unroll
  604. for (int k = 0; k < b_sh_wr_iters; k++) {
  605. fetch_to_registers(k + 1, pipe % stages);
  606. if (k == b_sh_wr_iters - 2) {
  607. fetch_to_shared((pipe + stages - 1) % stages, pipe,
  608. slice_iters >= stages);
  609. pipe++;
  610. wait_for_stage();
  611. }
  612. matmul(k);
  613. }
  614. slice_iters--;
  615. if (slice_iters == 0) break;
  616. }
  617. a_gl_rd += a_gl_rd_delta_o * stages;
  618. // Process results and, if necessary, proceed to the next column slice.
  619. // While this pattern may not be the most readable, other ways of writing
  620. // the loop seemed to noticeably worse performance after compilation.
  621. if (slice_iters == 0) {
  622. cp_async_wait<0>();
  623. bool last = slice_idx == slice_count - 1;
  624. // For per-column scales, we only fetch them here in the final step before
  625. // write-out
  626. if (group_blocks == -1 && last) {
  627. if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
  628. cp_async_fence();
  629. }
  630. thread_block_reduce();
  631. if (group_blocks == -1 && last) {
  632. cp_async_wait<0>();
  633. __syncthreads();
  634. if (threadIdx.x / 32 < thread_n_blocks / 4) {
  635. reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
  636. reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
  637. }
  638. }
  639. if (slice_count > 1) { // only globally reduce if there is more than one
  640. // block in a slice
  641. barrier_acquire(&locks[slice_col], slice_idx);
  642. global_reduce(slice_idx == 0, last);
  643. barrier_release(&locks[slice_col], last);
  644. }
  645. if (last) // only the last block in a slice actually writes the result
  646. write_result();
  647. slice_row = 0;
  648. slice_col_par++;
  649. slice_col++;
  650. init_slice();
  651. if (slice_iters) {
  652. a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
  653. (threadIdx.x % a_gl_rd_delta_o);
  654. #pragma unroll
  655. for (int i = 0; i < b_sh_wr_iters; i++)
  656. B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
  657. if (slice_col == 0) {
  658. #pragma unroll
  659. for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
  660. }
  661. s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
  662. start_pipes();
  663. }
  664. }
  665. }
  666. }
  667. #else
  668. template <const int threads, // number of threads in a threadblock
  669. const int thread_m_blocks, // number of 16x16 blocks in the m
  670. // dimension (batchsize) of the
  671. // threadblock
  672. const int thread_n_blocks, // same for n dimension (output)
  673. const int thread_k_blocks, // same for k dimension (reduction)
  674. const int stages, // number of stages for the async global->shared
  675. // fetch pipeline
  676. const int group_blocks = -1 // number of consecutive 16x16 blocks
  677. // with a separate quantization scale
  678. >
  679. __global__ void Marlin(
  680. const int4* __restrict__ A, // fp16 input matrix of shape mxk
  681. const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
  682. int4* __restrict__ C, // fp16 output buffer of shape mxn
  683. const int4* __restrict__ s, // fp16 quantization scales of shape
  684. // (k/groupsize)xn
  685. int prob_m, // batch dimension m
  686. int prob_n, // output dimension n
  687. int prob_k, // reduction dimension k
  688. int* locks // extra global storage for barrier synchronization
  689. ) {
  690. // Marlin is not implemented yet for SM < 8.0
  691. assert(false);
  692. return;
  693. }
  694. #endif
  695. // 8 warps are a good choice since every SM has 4 schedulers and having more
  696. // than 1 warp per schedule allows some more latency hiding. At the same time,
  697. // we want relatively few warps to have many registers per warp and small tiles.
  698. const int USER_THREADS =
  699. 256; // Note: This is only used with user-provided thread_k/n
  700. const int STAGES = 4; // 4 pipeline stages fit into shared memory
  701. const int SHARED_MEM =
  702. 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0)
  703. static constexpr int min_thread_n = 64;
  704. static constexpr int min_thread_k = 64;
  705. static constexpr int tile_size = 16;
  706. static constexpr int max_par = 16;
  707. static constexpr int pack_factor_4bit =
  708. 8; // We have 8 4-bit vals inside a 32 bit
  709. #define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
  710. GROUP_BLOCKS, NUM_THREADS) \
  711. else if (thread_m_blocks == THREAD_M_BLOCKS && \
  712. thread_n_blocks == THREAD_N_BLOCKS && \
  713. thread_k_blocks == THREAD_K_BLOCKS && \
  714. group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
  715. cudaFuncSetAttribute(Marlin<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
  716. THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>, \
  717. cudaFuncAttributeMaxDynamicSharedMemorySize, \
  718. SHARED_MEM); \
  719. Marlin<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
  720. STAGES, GROUP_BLOCKS><<<blocks, NUM_THREADS, SHARED_MEM, stream>>>( \
  721. A_ptr, B_ptr, C_ptr, s_ptr, prob_m, prob_n, prob_k, locks); \
  722. }
  723. typedef struct {
  724. int thread_k;
  725. int thread_n;
  726. int num_threads;
  727. } thread_config_t;
  728. thread_config_t small_batch_thread_configs[] = {
  729. // Ordered by priority
  730. // thread_k, thread_n, num_threads
  731. {128, 128, 256}, // Default
  732. {128, 64, 128}, // Reduce N 2X, same K
  733. {64, 256, 256}, // Reduce K 2X, increase N 2X
  734. {64, 128, 128}, // Reduce K 2X, same N
  735. };
  736. thread_config_t large_batch_thread_configs[] = {
  737. // Ordered by priority
  738. // thread_k, thread_n, num_threads
  739. {64, 256, 256}, // Default
  740. {128, 128, 256}, // Reduce N 2X, increase K 2X
  741. {64, 128, 128}, // Reduce N 2X, same K
  742. {128, 64, 128}, // Reduce N 4X, increase K 2X
  743. };
  744. bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n,
  745. int prob_k) {
  746. // Sanity
  747. if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
  748. th_config.num_threads == -1) {
  749. return false;
  750. }
  751. // Verify K/N are divisible by thread K/N
  752. if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
  753. return false;
  754. }
  755. // thread_k can be only 128 or 64 (because it must be less than groupsize
  756. // which is 128)
  757. if (th_config.thread_k != 128 && th_config.thread_k != 64) {
  758. return false;
  759. }
  760. // Verify min for thread K/N
  761. if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
  762. return false;
  763. }
  764. // num_threads must be at least 128 (= 4 warps)
  765. if (th_config.num_threads < 128) {
  766. return false;
  767. }
  768. return true;
  769. }
  770. thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) {
  771. if (prob_m <= 16) {
  772. for (auto th_config : small_batch_thread_configs) {
  773. if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
  774. return th_config;
  775. }
  776. }
  777. } else {
  778. for (auto th_config : large_batch_thread_configs) {
  779. if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
  780. return th_config;
  781. }
  782. }
  783. }
  784. return thread_config_t{-1, -1, -1};
  785. }
  786. #define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \
  787. __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
  788. __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
  789. __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
  790. __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
  791. __CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
  792. __CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
  793. __CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
  794. __CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
  795. __CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
  796. __CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS)
  797. void marlin_cuda(const void* A, const void* B, void* C, void* s, int prob_m,
  798. int prob_n, int prob_k, void* workspace, int groupsize = -1,
  799. int dev = 0, cudaStream_t stream = 0, int thread_k = -1,
  800. int thread_n = -1, int sms = -1, int max_par = 16) {
  801. int tot_m = prob_m;
  802. int tot_m_blocks = ceildiv(tot_m, 16);
  803. int pad = 16 * tot_m_blocks - tot_m;
  804. if (sms == -1)
  805. cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
  806. // Set thread config
  807. thread_config_t th_config;
  808. if (thread_k != -1 && thread_n != -1) {
  809. // User-defined config
  810. th_config = thread_config_t{thread_k, thread_n, USER_THREADS};
  811. } else {
  812. // Auto config
  813. th_config = determine_thread_config(prob_m, prob_n, prob_k);
  814. }
  815. if (!is_valid_config(th_config, prob_m, prob_n, prob_k)) {
  816. throw std::runtime_error(
  817. "Invalid thread config: thread_k = " + str(th_config.thread_k) +
  818. ", thread_n = " + str(th_config.thread_n) +
  819. ", num_threads = " + str(th_config.num_threads) + " for MKN = [" +
  820. str(prob_m) + ", " + str(prob_k) + ", " + str(prob_n) + "]");
  821. }
  822. // Uncomment for debug
  823. // std::cout << "Using thread_config: thread_k = " + str(th_config.thread_k) +
  824. // ", thread_n = " + str(th_config.thread_n) +
  825. // ", num_threads = " + str(th_config.num_threads) + " for
  826. // MKN = [" + str(prob_m) +
  827. // ", " + str(prob_k) + ", " + str(prob_n) + "]\n";
  828. int num_threads = th_config.num_threads;
  829. thread_k = th_config.thread_k;
  830. thread_n = th_config.thread_n;
  831. int thread_k_blocks = thread_k / 16;
  832. int thread_n_blocks = thread_n / 16;
  833. int group_blocks = (groupsize == -1) ? -1 : groupsize / 16;
  834. int blocks = sms;
  835. if (prob_m == 0 || prob_n == 0 || prob_k == 0) {
  836. return;
  837. }
  838. TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
  839. " is not divisible by thread_n = ", thread_n);
  840. TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
  841. " is not divisible by thread_k = ", thread_k);
  842. if (group_blocks != -1) {
  843. TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
  844. " is not divisible by group_blocks = ", group_blocks);
  845. }
  846. const int4* A_ptr = (const int4*)A;
  847. const int4* B_ptr = (const int4*)B;
  848. int4* C_ptr = (int4*)C;
  849. const int4* s_ptr = (const int4*)s;
  850. int* locks = (int*)workspace;
  851. for (int i = 0; i < tot_m_blocks; i += 4) {
  852. int thread_m_blocks = tot_m_blocks - i;
  853. prob_m = tot_m - 16 * i;
  854. int par = 1;
  855. if (thread_m_blocks > 4) {
  856. // Note that parallel > 1 currently only works for inputs without any
  857. // padding
  858. par = (16 * thread_m_blocks - pad) / 64;
  859. if (par > max_par) par = max_par;
  860. prob_m = 64 * par;
  861. i += 4 * (par - 1);
  862. thread_m_blocks = 4;
  863. }
  864. // For compilation speed, we only define the kernel configurations that have
  865. // seemed useful (in terms of performance) in our testing, however many more
  866. // are, in principle, possible.
  867. if (false) {
  868. }
  869. CALL_IF(8, 8, 256)
  870. CALL_IF(16, 4, 256)
  871. CALL_IF(8, 4, 128)
  872. CALL_IF(4, 8, 128)
  873. else {
  874. throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) +
  875. ", " + str(prob_k) + ", " + str(prob_n) + "]" +
  876. ", groupsize = " + str(groupsize) +
  877. ", thread_m_blocks = " + str(thread_m_blocks) +
  878. ", thread_n_blocks = " + str(thread_n_blocks) +
  879. ", thread_k_blocks = " + str(thread_k_blocks));
  880. }
  881. A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;
  882. C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par;
  883. }
  884. }
  885. } // namespace marlin_dense
  886. torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
  887. torch::Tensor& b_scales, torch::Tensor& workspace,
  888. int64_t size_m, int64_t size_n, int64_t size_k) {
  889. // Verify M
  890. TORCH_CHECK(size_m == a.size(0),
  891. "Shape mismatch: a.size(0) = " + str(a.size(0)) +
  892. ", size_m = " + str(size_m));
  893. // Verify K
  894. TORCH_CHECK(size_k == a.size(1),
  895. "Shape mismatch: a.size(1) = " + str(a.size(1)) +
  896. ", size_k = " + str(size_k));
  897. TORCH_CHECK(size_k % marlin_dense::tile_size == 0,
  898. "size_k = " + str(size_k) + " is not divisible by tile_size = " +
  899. str(marlin_dense::tile_size));
  900. TORCH_CHECK((size_k / marlin_dense::tile_size) == b_q_weight.size(0),
  901. "Shape mismatch: b_q_weight.size(0) = " +
  902. str(b_q_weight.size(0)) + ", size_k = " + str(size_k) +
  903. ", tile_size = " + str(marlin_dense::tile_size));
  904. // Verify N
  905. TORCH_CHECK(b_scales.size(1) == size_n,
  906. "b_scales.size(1) = " + str(b_scales.size(1)) +
  907. ", size_n = " + str(size_n));
  908. TORCH_CHECK(
  909. b_q_weight.size(1) % marlin_dense::tile_size == 0,
  910. "b_q_weight.size(1) = " + str(b_q_weight.size(1)) +
  911. " is not divisible by tile_size = " + str(marlin_dense::tile_size));
  912. int actual_size_n = (b_q_weight.size(1) / marlin_dense::tile_size) *
  913. marlin_dense::pack_factor_4bit;
  914. TORCH_CHECK(
  915. size_n == actual_size_n,
  916. "size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n));
  917. // Verify A device and strides
  918. TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
  919. TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
  920. // Verify B device and strides
  921. TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
  922. TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
  923. // Verify scales device and strides
  924. TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
  925. TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
  926. // Alloc C matrix
  927. const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
  928. auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
  929. torch::Tensor c = torch::empty({size_m, size_n}, options);
  930. // thread_k: `k` size of a thread_tile in `weights` (can usually be left as
  931. // auto -1)
  932. int thread_k = -1;
  933. // thread_n: `n` size of a thread_tile in `weights` (can usually be left as
  934. // auto -1)
  935. int thread_n = -1;
  936. // sms: number of SMs to use for the kernel (can usually be left as auto -1)
  937. int sms = -1;
  938. // Detect groupsize
  939. if (b_scales.size(0) != 1) {
  940. TORCH_CHECK(size_k % b_scales.size(0) == 0,
  941. "size_k = " + str(size_k) +
  942. ", is not divisible by b_scales.size(0) = " +
  943. str(b_scales.size(0)));
  944. }
  945. int groupsize = b_scales.size(0) == 1 ? -1 : size_k / b_scales.size(0);
  946. // Verify groupsize
  947. TORCH_CHECK(groupsize == -1 || groupsize == 128,
  948. "Unexpected groupsize = " + str(groupsize));
  949. // Verify workspace size
  950. TORCH_CHECK(size_n % marlin_dense::min_thread_n == 0,
  951. "size_n = " + str(size_n) +
  952. ", is not divisible by min_thread_n = " +
  953. str(marlin_dense::min_thread_n));
  954. int min_workspace_size =
  955. (size_n / marlin_dense::min_thread_n) * marlin_dense::max_par;
  956. TORCH_CHECK(workspace.numel() >= min_workspace_size,
  957. "workspace.numel = " + str(workspace.numel()) +
  958. " is below min_workspace_size = " + str(min_workspace_size));
  959. int dev = a.get_device();
  960. marlin_dense::marlin_cuda(a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(),
  961. b_scales.data_ptr(), size_m, size_n, size_k,
  962. workspace.data_ptr(), groupsize, dev,
  963. at::cuda::getCurrentCUDAStream(dev), thread_k,
  964. thread_n, sms, marlin_dense::max_par);
  965. return c;
  966. }