marlin_cuda_kernel.cu 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871
  1. /*
  2. * Copyright (C) Marlin.2024 Elias Frantar (elias.frantar@ist.ac.at)
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MARLIN_CUDA_KERNEL_CUH
  17. #define MARLIN_CUDA_KERNEL_CUH
  18. #include <torch/extension.h>
  19. #include <c10/cuda/CUDAStream.h>
  20. #include <cuda.h>
  21. #include <cuda_fp16.h>
  22. #include <cuda_runtime.h>
  23. namespace aphrodite {
  24. namespace marlin {
  25. constexpr int ceildiv(int a, int b) {
  26. return (a + b - 1) / b;
  27. }
  28. // Instances of `Vec` are used to organize groups of >>registers<<, as needed for instance as inputs to tensor core
  29. // operations. Consequently, all corresponding index accesses must be compile-time constants, which is why we
  30. // extensively use `#pragma unroll` throughout the kernel code to guarantee this.
  31. template <typename T, int n>
  32. struct Vec {
  33. T elems[n];
  34. __device__ T& operator[](int i) {
  35. return elems[i];
  36. }
  37. };
  38. using I4 = Vec<int, 4>;
  39. // Matrix fragments for tensor core instructions; their precise layout is documented here:
  40. // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
  41. using FragA = Vec<half2, 4>;
  42. using FragB = Vec<half2, 2>;
  43. using FragC = Vec<float, 4>;
  44. using FragS = Vec<half2, 1>; // quantization scales
  45. // Predicated asynchronous global->shared copy; used for inputs A where we apply predication to handle batchsizes that
  46. // are not multiples of 16.
  47. __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) {
  48. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  49. const int BYTES = 16;
  50. uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
  51. asm volatile(
  52. "{\n"
  53. " .reg .pred p;\n"
  54. " setp.ne.b32 p, %0, 0;\n"
  55. " @p cp.async.cg.shared.global [%1], [%2], %3;\n"
  56. "}\n" :: "r"((int) pred), "r"(smem), "l"(glob_ptr), "n"(BYTES)
  57. );
  58. #endif
  59. }
  60. // Asynchronous global->shared copy with a chache hint indicating that the values may be evicted immediately; used for
  61. // quantized weights B, which are only accessed precisely once and should thus not pollute the L2 cache which we need
  62. // for inputs A and outputs C.
  63. __device__ inline void cp_async4_stream(void* smem_ptr, const void* glob_ptr) {
  64. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  65. const int BYTES = 16;
  66. uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
  67. asm volatile(
  68. "{\n"
  69. " .reg .b64 p;\n"
  70. " createpolicy.fractional.L2::evict_first.b64 p, 1.0;"
  71. " cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n"
  72. "}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES)
  73. );
  74. #endif
  75. }
  76. // Async copy fence.
  77. __device__ inline void cp_async_fence() {
  78. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  79. asm volatile("cp.async.commit_group;\n" ::);
  80. #endif
  81. }
  82. // Wait until at most `n` async copy stages are still pending.
  83. template <int n>
  84. __device__ inline void cp_async_wait() {
  85. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  86. asm volatile("cp.async.wait_group %0;\n" :: "n"(n));
  87. #endif
  88. }
  89. // m16n8k16 tensor core mma instruction with fp16 inputs and fp32 output/accumulation.
  90. __device__ inline void mma(const FragA& a_frag, const FragB& frag_b, FragC& frag_c) {
  91. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  92. const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
  93. const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
  94. float* c = reinterpret_cast<float*>(&frag_c);
  95. asm volatile(
  96. "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
  97. "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
  98. : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
  99. : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
  100. "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])
  101. );
  102. #endif
  103. }
  104. // Instruction for loading a full 16x16 matrix fragment of operand A from shared memory, directly in tensor core layout.
  105. __device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
  106. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  107. uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
  108. uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
  109. asm volatile(
  110. "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
  111. : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) : "r"(smem)
  112. );
  113. #endif
  114. }
  115. // Lookup-table based 3-input logical operation; explicitly used for dequantization as the compiler does not seem to
  116. // automatically recognize it in all cases.
  117. template <int lut>
  118. __device__ inline int lop3(int a, int b, int c) {
  119. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  120. int res;
  121. asm volatile(
  122. "lop3.b32 %0, %1, %2, %3, %4;\n"
  123. : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut)
  124. );
  125. return res;
  126. #endif
  127. }
  128. // Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 values.
  129. // We mostly follow the strategy in the link below, with some small changes:
  130. // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
  131. __device__ inline FragB dequant(int q) {
  132. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  133. const int LO = 0x000f000f;
  134. const int HI = 0x00f000f0;
  135. const int EX = 0x64006400;
  136. // Guarantee that the `(a & b) | c` operations are LOP3s.
  137. int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
  138. int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
  139. // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point directly into `SUB` and `ADD`.
  140. const int SUB = 0x64086408;
  141. const int MUL = 0x2c002c00;
  142. const int ADD = 0xd480d480;
  143. FragB frag_b;
  144. frag_b[0] = __hsub2(
  145. *reinterpret_cast<half2*>(&lo),
  146. *reinterpret_cast<const half2*>(&SUB)
  147. );
  148. frag_b[1] = __hfma2(
  149. *reinterpret_cast<half2*>(&hi),
  150. *reinterpret_cast<const half2*>(&MUL), *reinterpret_cast<const half2*>(&ADD)
  151. );
  152. return frag_b;
  153. #endif
  154. }
  155. // Multiply dequantized values by the corresponding quantization scale; used only for grouped quantization.
  156. __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
  157. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  158. half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]);
  159. frag_b[0] = __hmul2(frag_b[0], s);
  160. frag_b[1] = __hmul2(frag_b[1], s);
  161. #endif
  162. }
  163. // Wait until barrier reaches `count`, then lock for current threadblock.
  164. __device__ inline void barrier_acquire(int* lock, int count) {
  165. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  166. if (threadIdx.x == 0) {
  167. int state = -1;
  168. do
  169. // Guarantee that subsequent writes by this threadblock will be visible globally.
  170. asm volatile ("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
  171. while (state != count);
  172. }
  173. __syncthreads();
  174. #endif
  175. }
  176. // Release barrier and increment visitation count.
  177. __device__ inline void barrier_release(int* lock, bool reset = false) {
  178. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  179. __syncthreads();
  180. if (threadIdx.x == 0) {
  181. if (reset) {
  182. lock[0] = 0;
  183. return;
  184. }
  185. int val = 1;
  186. // Make sure that all writes since acquiring this barrier are visible globally, while releasing the barrier.
  187. asm volatile ("fence.acq_rel.gpu;\n");
  188. asm volatile ("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(lock), "r"(val));
  189. }
  190. #endif
  191. }
  192. template <
  193. const int threads, // number of threads in a threadblock
  194. const int thread_m_blocks, // number of 16x16 blocks in the m dimension (batchsize) of the threadblock
  195. const int thread_n_blocks, // same for n dimension (output)
  196. const int thread_k_blocks, // same for k dimension (reduction)
  197. const int stages, // number of stages for the async global->shared fetch pipeline
  198. const int group_blocks = -1 // number of consecutive 16x16 blocks with a separate quantization scale
  199. >
  200. __global__ void Marlin(
  201. const int4* __restrict__ A, // fp16 input matrix of shape mxk
  202. const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
  203. int4* __restrict__ C, // fp16 output buffer of shape mxn
  204. const int4* __restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn
  205. int prob_m, // batch dimension m
  206. int prob_n, // output dimension n
  207. int prob_k, // reduction dimension k
  208. int* locks // extra global storage for barrier synchronization
  209. ) {
  210. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  211. // Each threadblock processes one "stripe" of the B matrix with (roughly) the same size, which might involve multiple
  212. // column "slices" (of width 16 * `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM example:
  213. // 0 1 3
  214. // 0 2 3
  215. // 1 2 4
  216. // While this kind of partitioning makes things somewhat more complicated, it ensures good utilization of all SMs
  217. // for many kinds of shape and GPU configurations, while requiring as few slow global cross-threadblock reductions as
  218. // possible.
  219. int k_tiles = prob_k / 16 / thread_k_blocks;
  220. int n_tiles = prob_n / 16 / thread_n_blocks;
  221. int iters = ceildiv(k_tiles * n_tiles, gridDim.x);
  222. // Ensure that the number of tiles in each stripe is a multiple of the groupsize; this avoids an annoying special case
  223. // where a stripe starts in the middle of group.
  224. if (group_blocks != -1)
  225. iters = (group_blocks / thread_k_blocks) * ceildiv(iters, (group_blocks / thread_k_blocks));
  226. int slice_row = (iters * blockIdx.x) % k_tiles;
  227. int slice_col = (iters * blockIdx.x) / k_tiles;
  228. int slice_iters; // number of threadblock tiles in the current slice
  229. int slice_count = 0; // total number of active threadblocks in the current slice
  230. int slice_idx; // index of threadblock in current slice; numbered bottom to top
  231. // Compute all information about the current slice which is required for synchronization.
  232. auto init_slice = [&] () {
  233. slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col + slice_row);
  234. if (slice_iters < 0 || slice_col >= n_tiles)
  235. slice_iters = 0;
  236. if (slice_iters == 0)
  237. return;
  238. if (slice_row + slice_iters > k_tiles)
  239. slice_iters = k_tiles - slice_row;
  240. slice_count = 1;
  241. slice_idx = 0;
  242. int col_first = iters * ceildiv(k_tiles * slice_col, iters);
  243. if (col_first <= k_tiles * (slice_col + 1)) {
  244. int col_off = col_first - k_tiles * slice_col;
  245. slice_count = ceildiv(k_tiles - col_off, iters);
  246. if (col_off > 0)
  247. slice_count++;
  248. int delta_first = iters * blockIdx.x - col_first;
  249. if (delta_first < 0 || (col_off == 0 && delta_first == 0))
  250. slice_idx = slice_count - 1;
  251. else {
  252. slice_idx = slice_count - 1 - delta_first / iters;
  253. if (col_off > 0)
  254. slice_idx--;
  255. }
  256. }
  257. };
  258. init_slice();
  259. int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory
  260. // We typically use `constexpr` to indicate that this value is a compile-time constant
  261. constexpr int a_sh_stride = 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory
  262. constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; // delta between subsequent A tiles in global memory
  263. int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile
  264. constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); // between shared memory writes
  265. constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); // between shared memory tile reads
  266. constexpr int a_sh_rd_delta_i = a_sh_stride * 16; // within a shared memory tile
  267. constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); // overall size of a tile
  268. constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); // number of shared write iterations for a tile
  269. int b_gl_stride = 16 * prob_n / 32;
  270. constexpr int b_sh_stride = 32 * thread_n_blocks / 4;
  271. int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
  272. int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride);
  273. constexpr int b_sh_wr_delta = threads;
  274. constexpr int b_sh_rd_delta = threads;
  275. constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
  276. constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
  277. int s_gl_stride = prob_n / 8;
  278. constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
  279. constexpr int s_sh_stage = s_sh_stride;
  280. int s_gl_rd_delta = s_gl_stride;
  281. // Global A read index of current thread.
  282. int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o);
  283. a_gl_rd += a_gl_rd_delta_o * slice_row;
  284. // Shared write index of current thread.
  285. int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o);
  286. // Shared read index.
  287. int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;
  288. a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
  289. int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
  290. b_gl_rd += b_sh_stride * slice_col;
  291. b_gl_rd += b_gl_rd_delta_o * slice_row;
  292. int b_sh_wr = threadIdx.x;
  293. int b_sh_rd = threadIdx.x;
  294. int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x;
  295. int s_sh_wr = threadIdx.x;
  296. int s_sh_rd;
  297. // We use a different scale layout for grouped and column-wise quantization as we scale a `half2` tile in column-major
  298. // layout in the former and in row-major in the latter case.
  299. if (group_blocks != -1)
  300. s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4;
  301. else
  302. s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4;
  303. // Precompute which thread should not read memory in which iterations; this is needed if there are more threads than
  304. // required for a certain tilesize or when the batchsize is not a multiple of 16.
  305. bool a_sh_wr_pred[a_sh_wr_iters];
  306. #pragma unroll
  307. for (int i = 0; i < a_sh_wr_iters; i++)
  308. a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
  309. bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
  310. // To ensure that writing and reading A tiles to/from shared memory, the latter in fragment format, is fully bank
  311. // conflict free, we need to use a rather fancy XOR-based layout. The key here is that neither reads nor writes of
  312. // the 16-byte `int4` blocks of 8 consecutive threads involve the same shared memory banks. Further, it seems (based
  313. // on NSight-Compute) that each warp must also write a consecutive memory segment?
  314. auto transform_a = [&] (int i) {
  315. int row = i / a_gl_rd_delta_o;
  316. return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
  317. };
  318. // Since the computation of this remapping is non-trivial and, due to our main loop unrolls, all shared memory
  319. // accesses are static, we simply precompute both transformed reads and writes.
  320. int a_sh_wr_trans[a_sh_wr_iters];
  321. #pragma unroll
  322. for (int i = 0; i < a_sh_wr_iters; i++)
  323. a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
  324. int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
  325. #pragma unroll
  326. for (int i = 0; i < b_sh_wr_iters; i++) {
  327. #pragma unroll
  328. for (int j = 0; j < thread_m_blocks; j++)
  329. a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
  330. }
  331. // Since B-accesses have non-constant stride they have to be computed at runtime; we break dependicies between
  332. // subsequent accesses with a tile by maintining multiple pointers (we have enough registers), a tiny optimization.
  333. const int4* B_ptr[b_sh_wr_iters];
  334. #pragma unroll
  335. for (int i = 0; i < b_sh_wr_iters; i++)
  336. B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
  337. extern __shared__ int4 sh[];
  338. // Shared memory storage for global fetch pipelines.
  339. int4* sh_a = sh;
  340. int4* sh_b = sh_a + (stages * a_sh_stage);
  341. int4* sh_s = sh_b + (stages * b_sh_stage);
  342. // Register storage for double buffer of shared memory reads.
  343. FragA frag_a[2][thread_m_blocks];
  344. I4 frag_b_quant[2];
  345. FragC frag_c[thread_m_blocks][4][2];
  346. FragS frag_s[2][4];
  347. // Zero accumulators.
  348. auto zero_accums = [&] () {
  349. #pragma unroll
  350. for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
  351. reinterpret_cast<float*>(frag_c)[i] = 0;
  352. };
  353. // Asynchronously fetch the next A, B and s tile from global to the next shared memory pipeline location.
  354. auto fetch_to_shared = [&] (int pipe, int a_off, bool pred = true) {
  355. if (pred) {
  356. int4* sh_a_stage = sh_a + a_sh_stage * pipe;
  357. #pragma unroll
  358. for (int i = 0; i < a_sh_wr_iters; i++) {
  359. cp_async4_pred(
  360. &sh_a_stage[a_sh_wr_trans[i]],
  361. &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
  362. a_sh_wr_pred[i]
  363. );
  364. }
  365. int4* sh_b_stage = sh_b + b_sh_stage * pipe;
  366. #pragma unroll
  367. for (int i = 0; i < b_sh_wr_iters; i++) {
  368. cp_async4_stream(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]);
  369. B_ptr[i] += b_gl_rd_delta_o;
  370. }
  371. // Only fetch scales if this tile starts a new group
  372. if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {
  373. int4* sh_s_stage = sh_s + s_sh_stage * pipe;
  374. if (s_sh_wr_pred)
  375. cp_async4_stream(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
  376. s_gl_rd += s_gl_rd_delta;
  377. }
  378. }
  379. // Insert a fence even when we are winding down the pipeline to ensure that waiting is also correct at this point.
  380. cp_async_fence();
  381. };
  382. // Wait until the next thread tile has been loaded to shared memory.
  383. auto wait_for_stage = [&] () {
  384. // We only have `stages - 2` active fetches since we are double buffering and can only issue the next fetch when
  385. // it is guaranteed that the previous shared memory load is fully complete (as it may otherwise be overwritten).
  386. cp_async_wait<stages - 2>();
  387. __syncthreads();
  388. };
  389. // Load the next sub-tile from the current location in the shared memory pipe into the current register buffer.
  390. auto fetch_to_registers = [&] (int k, int pipe) {
  391. // It may seem inefficient that we reload the groups for every sub-tile; however, this does not seem to be a
  392. // significant bottleneck, while some theoretically better attempts have lead to bad instruction ordering by the
  393. // compiler and correspondingly a noticable drop in performance.
  394. if (group_blocks != -1) {
  395. int4* sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks)));
  396. reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
  397. }
  398. int4* sh_a_stage = sh_a + a_sh_stage * pipe;
  399. #pragma unroll
  400. for (int i = 0; i < thread_m_blocks; i++)
  401. ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
  402. int4* sh_b_stage = sh_b + b_sh_stage * pipe;
  403. frag_b_quant[k % 2] = *reinterpret_cast<I4*>(&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]);
  404. };
  405. // Execute the actual tensor core matmul of a sub-tile.
  406. auto matmul = [&] (int k) {
  407. // We have the m dimension as the inner loop in order to encourage overlapping dequantization and matmul operations.
  408. #pragma unroll
  409. for (int j = 0; j < 4; j++) {
  410. int b_quant = frag_b_quant[k % 2][j];
  411. int b_quant_shift = b_quant >> 8;
  412. FragB frag_b0 = dequant(b_quant);
  413. // If there are no groups, we can just scale the final output once and can avoid doing so for each weight.
  414. if (group_blocks != -1)
  415. scale(frag_b0, frag_s[k % 2][j], 0);
  416. FragB frag_b1 = dequant(b_quant_shift);
  417. if (group_blocks != -1)
  418. scale(frag_b1, frag_s[k % 2][j], 1);
  419. #pragma unroll
  420. for (int i = 0; i < thread_m_blocks; i++) {
  421. mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
  422. mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
  423. }
  424. }
  425. };
  426. // Since we slice across the k dimension of a tile in order to increase the number of warps while keeping the n
  427. // dimension of a tile reasonable, we have multiple warps that accumulate their partial sums of the same output
  428. // location; which we have to reduce over in the end. We do in shared memory.
  429. auto thread_block_reduce = [&] () {
  430. constexpr int red_off = threads / b_sh_stride / 2;
  431. if (red_off >= 1) {
  432. int red_idx = threadIdx.x / b_sh_stride;
  433. constexpr int red_sh_stride = b_sh_stride * 4 * 2;
  434. constexpr int red_sh_delta = b_sh_stride;
  435. int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
  436. // Parallel logarithmic shared memory reduction. We make sure to avoid any unnecessary read or write iterations,
  437. // e.g., for two warps we write only once by warp 1 and read only once by warp 0.
  438. #pragma unroll
  439. for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
  440. #pragma unroll
  441. for (int i = red_off; i > 0; i /= 2) {
  442. if (i <= red_idx && red_idx < 2 * i) {
  443. #pragma unroll
  444. for (int j = 0; j < 4 * 2; j++) {
  445. int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
  446. if (i < red_off) {
  447. float* c_rd = reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
  448. float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
  449. #pragma unroll
  450. for (int k = 0; k < 4; k++)
  451. reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k];
  452. }
  453. sh[red_sh_wr] = reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
  454. }
  455. }
  456. __syncthreads();
  457. }
  458. if (red_idx == 0) {
  459. #pragma unroll
  460. for (int i = 0; i < 4 * 2; i++) {
  461. float* c_rd = reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
  462. #pragma unroll
  463. for (int j = 0; j < 4; j++)
  464. reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] += c_rd[j];
  465. }
  466. }
  467. __syncthreads();
  468. }
  469. }
  470. };
  471. // Since multiple threadblocks may process parts of the same column slice, we finally have to globally reduce over
  472. // the results. As the striped partioning minimizes the number of such reductions and our outputs are usually rather
  473. // small, we perform this reduction serially in L2 cache.
  474. auto global_reduce = [&] (bool first = false, bool last = false) {
  475. // We are very careful here to reduce directly in the output buffer to maximize L2 cache utilization in this step.
  476. // To do this, we write out results in FP16 (but still reduce with FP32 compute).
  477. constexpr int active_threads = 32 * thread_n_blocks / 4;
  478. if (threadIdx.x < active_threads) {
  479. int c_gl_stride = prob_n / 8;
  480. int c_gl_wr_delta_o = 8 * c_gl_stride;
  481. int c_gl_wr_delta_i = 4 * (active_threads / 32);
  482. int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + 4 * (threadIdx.x / 32) + threadIdx.x % 4;
  483. c_gl_wr += (2 * thread_n_blocks) * slice_col;
  484. constexpr int c_sh_wr_delta = active_threads;
  485. int c_sh_wr = threadIdx.x;
  486. int row = (threadIdx.x % 32) / 4;
  487. if (!first) {
  488. // Interestingly, doing direct global accesses here really seems to mess up the compiler and lead to slowdowns,
  489. // hence we also use async-copies even though these fetches are not actually asynchronous.
  490. #pragma unroll
  491. for (int i = 0; i < thread_m_blocks * 4; i++) {
  492. cp_async4_pred(
  493. &sh[c_sh_wr + c_sh_wr_delta * i],
  494. &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)],
  495. i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m
  496. );
  497. }
  498. cp_async_fence();
  499. cp_async_wait<0>();
  500. }
  501. #pragma unroll
  502. for (int i = 0; i < thread_m_blocks * 4; i++) {
  503. if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
  504. if (!first) {
  505. int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
  506. #pragma unroll
  507. for (int j = 0; j < 2 * 4; j++) {
  508. reinterpret_cast<float*>(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += __half2float(
  509. reinterpret_cast<__half*>(&c_red)[j]
  510. );
  511. }
  512. }
  513. if (!last) {
  514. int4 c;
  515. #pragma unroll
  516. for (int j = 0; j < 2 * 4; j++) {
  517. reinterpret_cast<__half*>(&c)[j] = __float2half(
  518. reinterpret_cast<float*>(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]
  519. );
  520. }
  521. C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = c;
  522. }
  523. }
  524. }
  525. }
  526. };
  527. // Write out the reduce final result in the correct layout. We only actually reshuffle matrix fragments in this step,
  528. // the reduction above is performed in fragment layout.
  529. auto write_result = [&] () {
  530. int c_gl_stride = prob_n / 8;
  531. constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
  532. int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
  533. constexpr int c_sh_rd_delta = c_sh_stride * (threads / (2 * thread_n_blocks));
  534. int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks));
  535. c_gl_wr += (2 * thread_n_blocks) * slice_col;
  536. int c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;
  537. c_sh_wr += 32 * (threadIdx.x / 32);
  538. int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks));
  539. int c_gl_wr_end = c_gl_stride * prob_m;
  540. // We first reorder in shared memory to guarantee the most efficient final global write patterns
  541. auto write = [&] (int idx, float c0, float c1, FragS& s) {
  542. half2 res = __halves2half2(__float2half(c0), __float2half(c1));
  543. if (group_blocks == -1) // for per-column quantization we finally apply the scale here
  544. res = __hmul2(res, s[0]);
  545. ((half2*) sh)[idx] = res;
  546. };
  547. if (threadIdx.x / 32 < thread_n_blocks / 4) {
  548. #pragma unroll
  549. for (int i = 0; i < thread_m_blocks; i++) {
  550. #pragma unroll
  551. for (int j = 0; j < 4; j++) {
  552. int wr = c_sh_wr + 8 * j;
  553. write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
  554. write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
  555. write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
  556. write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
  557. }
  558. c_sh_wr += 16 * (4 * c_sh_stride);
  559. }
  560. }
  561. __syncthreads();
  562. #pragma unroll
  563. for (int i = 0; i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) {
  564. if (c_gl_wr < c_gl_wr_end) {
  565. C[c_gl_wr] = sh[c_sh_rd];
  566. c_gl_wr += c_gl_wr_delta;
  567. c_sh_rd += c_sh_rd_delta;
  568. }
  569. }
  570. };
  571. // Start global fetch and register load pipelines.
  572. auto start_pipes = [&] () {
  573. #pragma unroll
  574. for (int i = 0; i < stages - 1; i++)
  575. fetch_to_shared(i, i, i < slice_iters);
  576. zero_accums();
  577. wait_for_stage();
  578. fetch_to_registers(0, 0);
  579. a_gl_rd += a_gl_rd_delta_o * (stages - 1);
  580. };
  581. start_pipes();
  582. // Main loop.
  583. while (slice_iters) {
  584. // We unroll over both the global fetch and the register load pipeline to ensure all shared memory accesses are
  585. // static. Note that both pipelines have even length meaning that the next iteration will always start at index 0.
  586. #pragma unroll
  587. for (int pipe = 0; pipe < stages;) {
  588. #pragma unroll
  589. for (int k = 0; k < b_sh_wr_iters; k++) {
  590. fetch_to_registers(k + 1, pipe % stages);
  591. if (k == b_sh_wr_iters - 2) {
  592. fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages);
  593. pipe++;
  594. wait_for_stage();
  595. }
  596. matmul(k);
  597. }
  598. slice_iters--;
  599. if (slice_iters == 0)
  600. break;
  601. }
  602. a_gl_rd += a_gl_rd_delta_o * stages;
  603. // Process results and, if necessary, proceed to the next column slice. While this pattern may not be the most
  604. // readable, other ways of writing the loop seemed to noticeably worse performance after compliation.
  605. if (slice_iters == 0) {
  606. cp_async_wait<0>();
  607. bool last = slice_idx == slice_count - 1;
  608. // For per-column scales, we only fetch them here in the final step before write-out
  609. if (group_blocks == -1 && last) {
  610. if (s_sh_wr_pred)
  611. cp_async4_stream(&sh_s[s_sh_wr], &s[s_gl_rd]);
  612. cp_async_fence();
  613. }
  614. thread_block_reduce();
  615. if (group_blocks == -1 && last) {
  616. cp_async_wait<0>();
  617. __syncthreads();
  618. if (threadIdx.x / 32 < thread_n_blocks / 4) {
  619. reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
  620. reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
  621. }
  622. }
  623. if (slice_count > 1) { // only globally reduce if there is more than one block in a slice
  624. barrier_acquire(&locks[slice_col], slice_idx);
  625. global_reduce(slice_idx == 0, last);
  626. barrier_release(&locks[slice_col], last);
  627. }
  628. if (last) // only the last block in a slice actually writes the result
  629. write_result();
  630. slice_row = 0;
  631. slice_col++;
  632. init_slice();
  633. if (slice_iters) {
  634. a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o);
  635. #pragma unroll
  636. for (int i = 0; i < b_sh_wr_iters; i++)
  637. B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
  638. s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
  639. start_pipes();
  640. }
  641. }
  642. }
  643. #endif
  644. }
  645. // 8 warps are a good choice since every SM has 4 schedulers and having more than 1 warp per schedule allows some more
  646. // latency hiding. At the same time, we want relatively few warps to have many registers per warp and small tiles.
  647. const int THREADS = 256;
  648. const int STAGES = 4; // 4 pipeline stages fit into shared memory
  649. const int SHARED_MEM = 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0)
  650. #define CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, GROUP_BLOCKS) \
  651. else if ( \
  652. thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS && thread_k_blocks == THREAD_K_BLOCKS && \
  653. group_blocks == GROUP_BLOCKS \
  654. ) { \
  655. cudaFuncSetAttribute( \
  656. Marlin<THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>, \
  657. cudaFuncAttributeMaxDynamicSharedMemorySize, \
  658. SHARED_MEM \
  659. ); \
  660. Marlin<THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS><<<blocks, THREADS, SHARED_MEM, stream>>>( \
  661. A_ptr, B_ptr, C_ptr, s_ptr, \
  662. prob_m, prob_n, prob_k, \
  663. locks \
  664. ); \
  665. }
  666. const int ERR_PROB_SHAPE = 1;
  667. const int ERR_KERN_SHAPE = 2;
  668. int marlin_cuda(
  669. const void* A,
  670. const void* B,
  671. void* C,
  672. void* s,
  673. int prob_m,
  674. int prob_n,
  675. int prob_k,
  676. void* workspace,
  677. int groupsize = -1,
  678. int dev = 0,
  679. cudaStream_t stream = 0,
  680. int thread_k = -1,
  681. int thread_n = -1,
  682. int sms = -1
  683. ) {
  684. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  685. int tot_m = prob_m;
  686. int tot_m_blocks = ceildiv(tot_m, 16);
  687. if (sms == -1)
  688. cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
  689. if (thread_k == -1 || thread_n == -1) {
  690. if (prob_m <= 16) {
  691. // For small batchizes, better partioning is slightly more important than better compute utilization
  692. thread_k = 128;
  693. thread_n = 128;
  694. } else {
  695. thread_k = 64;
  696. thread_n = 256;
  697. }
  698. }
  699. int thread_k_blocks = thread_k / 16;
  700. int thread_n_blocks = thread_n / 16;
  701. int group_blocks = (groupsize == -1) ? -1 : groupsize / 16;
  702. int blocks = sms;
  703. if (prob_n % thread_n != 0 || prob_k % thread_k != 0 || (group_blocks != -1 && prob_k % group_blocks != 0))
  704. return ERR_PROB_SHAPE;
  705. if (prob_m == 0 || prob_n == 0 || prob_k == 0)
  706. return 0;
  707. const int4* A_ptr = (const int4*) A;
  708. const int4* B_ptr = (const int4*) B;
  709. int4* C_ptr = (int4*) C;
  710. const int4* s_ptr = (const int4*) s;
  711. int cols = prob_n / thread_n;
  712. int* locks = (int*) workspace;
  713. int ret = 0;
  714. for (int i = 0; i < tot_m_blocks; i += 4) {
  715. int thread_m_blocks = tot_m_blocks - i;
  716. prob_m = tot_m - 16 * i;
  717. if (thread_m_blocks > 4) {
  718. thread_m_blocks = 4;
  719. prob_m = 64;
  720. }
  721. // For compilation speed, we only define the kernel configurations that have seemed useful (in terms of performance)
  722. // in our testing, however many more are, in principle, possible.
  723. if (false) {}
  724. CALL_IF(1, 8, 8, -1)
  725. CALL_IF(1, 8, 8, 8)
  726. CALL_IF(1, 16, 4, -1)
  727. CALL_IF(1, 16, 4, 8)
  728. CALL_IF(2, 16, 4, -1)
  729. CALL_IF(2, 16, 4, 8)
  730. CALL_IF(3, 16, 4, -1)
  731. CALL_IF(3, 16, 4, 8)
  732. CALL_IF(4, 16, 4, -1)
  733. CALL_IF(4, 16, 4, 8)
  734. else
  735. ret = ERR_KERN_SHAPE;
  736. A_ptr += 16 * thread_m_blocks * (prob_k / 8);
  737. C_ptr += 16 * thread_m_blocks * (prob_n / 8);
  738. }
  739. return ret;
  740. #endif
  741. }
  742. #endif
  743. } // namespace marlin
  744. } // namespace aphrodite
  745. const int ERR_PROB_SHAPE = 1;
  746. const int ERR_KERN_SHAPE = 2;
  747. // input: `torch.half` input matrix of shape `(m, k)` in standard row-major layout
  748. // weights: `torch.int` weight matrix of original shape `(k, n)` in Marlin format; see `Layer.pack()`
  749. // output: `torch.half` out matrix of shape `(m, n)` in standard row-major layout
  750. // scales: `torch.half` scales of shape `(m / groupsize, n)`
  751. // workspace: `torch.int` tensor with at least `n / 128` entries that are all zero
  752. void marlin_gemm(
  753. const torch::Tensor& input,
  754. const torch::Tensor& weights,
  755. torch::Tensor& output,
  756. const torch::Tensor& scales,
  757. torch::Tensor& workspace
  758. ) {
  759. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  760. // thread_k: `k` size of a thread_tile in `weights` (can usually be left as auto -1)
  761. int thread_k = -1;
  762. // thread_n: `n` size of a thread_tile in `weights` (can usually be left as auto -1)
  763. int thread_n = -1;
  764. // sms: number of SMs to use for the kernel (can usually be left as auto -1)
  765. int sms = -1;
  766. int prob_m = input.size(0);
  767. int prob_n = output.size(1);
  768. int prob_k = input.size(1);
  769. int groupsize = (scales.size(0) == 1) ? -1 : prob_k / scales.size(0);
  770. if (groupsize != -1 && groupsize * scales.size(0) != prob_k)
  771. AT_ERROR("k=", prob_k, " not compatible with ", scales.size(0), " groups.");
  772. int dev = input.get_device();
  773. int err = aphrodite::marlin::marlin_cuda(
  774. input.data_ptr(),
  775. weights.data_ptr(),
  776. output.data_ptr(),
  777. scales.data_ptr(),
  778. prob_m, prob_n, prob_k,
  779. workspace.data_ptr(),
  780. groupsize,
  781. dev,
  782. at::cuda::getCurrentCUDAStream(dev),
  783. thread_k,
  784. thread_n,
  785. sms
  786. );
  787. if (err == ERR_PROB_SHAPE) {
  788. AT_ERROR(
  789. "Problem (m=", prob_m, ", n=", prob_n, ", k=", prob_k, ")",
  790. " not compatible with thread_k=", thread_k, ", thread_n=", thread_n, "."
  791. );
  792. } else if (err == ERR_KERN_SHAPE) {
  793. AT_ERROR(
  794. "No kernel implementation for thread_k=", thread_k, ", thread_n=", thread_n, ", groupsize=", groupsize, "."
  795. );
  796. }
  797. #endif
  798. }