1
0

marlin_cuda_kernel_zero.cu 42 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098
  1. /*
  2. * Copyright (C) Marlin.2024 Elias Frantar (elias.frantar@ist.ac.at)
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <torch/extension.h>
  17. #include <c10/cuda/CUDAStream.h>
  18. #include <ATen/cuda/CUDAContext.h>
  19. #include <cuda.h>
  20. #include <cuda_fp16.h>
  21. #include <cuda_runtime.h>
  22. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
  23. #else
  24. namespace aphrodite {
  25. namespace marlin {
  26. constexpr int ceildiv(int a, int b) {
  27. return (a + b - 1) / b;
  28. }
  29. // Instances of `Vec` are used to organize groups of >>registers<<, as needed for instance as inputs to tensor core
  30. // operations. Consequently, all corresponding index accesses must be compile-time constants, which is why we
  31. // extensively use `#pragma unroll` throughout the kernel code to guarantee this.
  32. template <typename T, int n>
  33. struct Vec {
  34. T elems[n];
  35. __device__ T& operator[](int i) {
  36. return elems[i];
  37. }
  38. };
  39. using I4 = Vec<int, 4>;
  40. // Matrix fragments for tensor core instructions; their precise layout is documented here:
  41. // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
  42. using FragA = Vec<half2, 4>;
  43. using FragB = Vec<half2, 2>;
  44. using FragC = Vec<float, 4>;
  45. using FragS = Vec<half2, 1>; // quantization scales
  46. // Predicated asynchronous global->shared copy; used for inputs A where we apply predication to handle batchsizes that
  47. // are not multiples of 16.
  48. __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) {
  49. const int BYTES = 16;
  50. uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
  51. asm volatile(
  52. "{\n"
  53. " .reg .pred p;\n"
  54. " setp.ne.b32 p, %0, 0;\n"
  55. " @p cp.async.cg.shared.global [%1], [%2], %3;\n"
  56. "}\n" :: "r"((int) pred), "r"(smem), "l"(glob_ptr), "n"(BYTES)
  57. );
  58. }
  59. // Asynchronous global->shared copy with a cache hint indicating that the values may be evicted immediately; used for
  60. // quantized weights B, which are only accessed precisely once and should thus not pollute the L2 cache which we need
  61. // for inputs A and outputs C.
  62. __device__ inline void cp_async4_stream(void* smem_ptr, const void* glob_ptr) {
  63. const int BYTES = 16;
  64. uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
  65. asm volatile(
  66. "{\n"
  67. " .reg .b64 p;\n"
  68. " createpolicy.fractional.L2::evict_first.b64 p, 1.0;"
  69. " cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n"
  70. "}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES)
  71. );
  72. }
  73. // Async copy fence.
  74. __device__ inline void cp_async_fence() {
  75. asm volatile("cp.async.commit_group;\n" ::);
  76. }
  77. // Wait until at most `n` async copy stages are still pending.
  78. template <int n>
  79. __device__ inline void cp_async_wait() {
  80. asm volatile("cp.async.wait_group %0;\n" :: "n"(n));
  81. }
  82. // m16n8k16 tensor core mma instruction with fp16 inputs and fp32 output/accumulation.
  83. __device__ inline void mma(const FragA& a_frag, const FragB& frag_b, FragC& frag_c) {
  84. const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
  85. const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
  86. float* c = reinterpret_cast<float*>(&frag_c);
  87. asm volatile(
  88. "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
  89. "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
  90. : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
  91. : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
  92. "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])
  93. );
  94. }
  95. // Instruction for loading a full 16x16 matrix fragment of operand A from shared memory, directly in tensor core layout.
  96. __device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
  97. uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
  98. uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
  99. asm volatile(
  100. "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
  101. : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) : "r"(smem)
  102. );
  103. }
  104. // Lookup-table based 3-input logical operation; explicitly used for dequantization as the compiler does not seem to
  105. // automatically recognize it in all cases.
  106. template <int lut>
  107. __device__ inline int lop3(int a, int b, int c) {
  108. int res;
  109. asm volatile(
  110. "lop3.b32 %0, %1, %2, %3, %4;\n"
  111. : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut)
  112. );
  113. return res;
  114. }
  115. // Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 values.
  116. // We mostly follow the strategy in the link below, with some small changes:
  117. // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
  118. __device__ inline FragB dequant(int q) {
  119. const int LO = 0x000f000f;
  120. const int HI = 0x00f000f0;
  121. const int EX = 0x64006400;
  122. // Guarantee that the `(a & b) | c` operations are LOP3s.
  123. int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
  124. int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
  125. // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point directly into `SUB` and `ADD`.
  126. const int SUB = 0x64006400;
  127. const int MUL = 0x2c002c00;
  128. const int ADD = 0xd400d400;
  129. FragB frag_b;
  130. frag_b[0] = __hsub2(
  131. *reinterpret_cast<half2*>(&lo),
  132. *reinterpret_cast<const half2*>(&SUB)
  133. );
  134. frag_b[1] = __hfma2(
  135. *reinterpret_cast<half2*>(&hi),
  136. *reinterpret_cast<const half2*>(&MUL), *reinterpret_cast<const half2*>(&ADD)
  137. );
  138. return frag_b;
  139. }
  140. // Multiply dequantized values by the corresponding quantization scale; used only for grouped quantization.
  141. __device__ inline void scale(FragB& frag_b, FragS& frag_s, FragS& frag_z, int i) {
  142. half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]);
  143. half2 z = __half2half2(reinterpret_cast<__half*>(&frag_z)[i]);
  144. frag_b[0] = __hfma2(frag_b[0], s, z);
  145. frag_b[1] = __hfma2(frag_b[1], s, z);
  146. }
  147. // Wait until barrier reaches `count`, then lock for current threadblock.
  148. __device__ inline void barrier_acquire(int* lock, int count) {
  149. if (threadIdx.x == 0) {
  150. int state = -1;
  151. do
  152. // Guarantee that subsequent writes by this threadblock will be visible globally.
  153. asm volatile ("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
  154. while (state != count);
  155. }
  156. __syncthreads();
  157. }
  158. // Release barrier and increment visitation count.
  159. __device__ inline void barrier_release(int* lock, bool reset = false) {
  160. __syncthreads();
  161. if (threadIdx.x == 0) {
  162. if (reset) {
  163. lock[0] = 0;
  164. return;
  165. }
  166. int val = 1;
  167. // Make sure that all writes since acquiring this barrier are visible globally, while releasing the barrier.
  168. asm volatile ("fence.acq_rel.gpu;\n");
  169. asm volatile ("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(lock), "r"(val));
  170. }
  171. }
  172. template <
  173. const int threads, // number of threads in a threadblock
  174. const int thread_m_blocks, // number of 16x16 blocks in the m dimension (batchsize) of the threadblock
  175. const int thread_n_blocks, // same for n dimension (output)
  176. const int thread_k_blocks, // same for k dimension (reduction)
  177. const int stages, // number of stages for the async global->shared fetch pipeline
  178. const int group_blocks = -1 // number of consecutive 16x16 blocks with a separate quantization scale
  179. >
  180. __global__ void Marlin(
  181. const int4* __restrict__ A, // fp16 input matrix of shape mxk
  182. const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
  183. int4* __restrict__ C, // fp16 output buffer of shape mxn
  184. const int4* __restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn
  185. const int4* __restrict__ z, // fp16 quantization zeros of shape (k/groupsize)xn
  186. int prob_m, // batch dimension m
  187. int prob_n, // output dimension n
  188. int prob_k, // reduction dimension k
  189. int* locks // extra global storage for barrier synchronization
  190. ) {
  191. // Each threadblock processes one "stripe" of the B matrix with (roughly) the same size, which might involve multiple
  192. // column "slices" (of width 16 * `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM example:
  193. // 0 1 3
  194. // 0 2 3
  195. // 1 2 4
  196. // While this kind of partitioning makes things somewhat more complicated, it ensures good utilization of all SMs
  197. // for many kinds of shape and GPU configurations, while requiring as few slow global cross-threadblock reductions as
  198. // possible.
  199. int k_tiles = prob_k / 16 / thread_k_blocks;
  200. int n_tiles = prob_n / 16 / thread_n_blocks;
  201. int iters = ceildiv(k_tiles * n_tiles, gridDim.x);
  202. // Ensure that the number of tiles in each stripe is a multiple of the groupsize; this avoids an annoying special case
  203. // where a stripe starts in the middle of group.
  204. if (group_blocks != -1)
  205. iters = (group_blocks / thread_k_blocks) * ceildiv(iters, (group_blocks / thread_k_blocks));
  206. int slice_row = (iters * blockIdx.x) % k_tiles;
  207. int slice_col = (iters * blockIdx.x) / k_tiles;
  208. int slice_iters; // number of threadblock tiles in the current slice
  209. int slice_count = 0; // total number of active threadblocks in the current slice
  210. int slice_idx; // index of threadblock in current slice; numbered bottom to top
  211. // Compute all information about the current slice which is required for synchronization.
  212. auto init_slice = [&] () {
  213. slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col + slice_row);
  214. if (slice_iters < 0 || slice_col >= n_tiles)
  215. slice_iters = 0;
  216. if (slice_iters == 0)
  217. return;
  218. if (slice_row + slice_iters > k_tiles)
  219. slice_iters = k_tiles - slice_row;
  220. slice_count = 1;
  221. slice_idx = 0;
  222. int col_first = iters * ceildiv(k_tiles * slice_col, iters);
  223. if (col_first <= k_tiles * (slice_col + 1)) {
  224. int col_off = col_first - k_tiles * slice_col;
  225. slice_count = ceildiv(k_tiles - col_off, iters);
  226. if (col_off > 0)
  227. slice_count++;
  228. int delta_first = iters * blockIdx.x - col_first;
  229. if (delta_first < 0 || (col_off == 0 && delta_first == 0))
  230. slice_idx = slice_count - 1;
  231. else {
  232. slice_idx = slice_count - 1 - delta_first / iters;
  233. if (col_off > 0)
  234. slice_idx--;
  235. }
  236. }
  237. };
  238. init_slice();
  239. int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory
  240. // We typically use `constexpr` to indicate that this value is a compile-time constant
  241. constexpr int a_sh_stride = 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory
  242. constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; // delta between subsequent A tiles in global memory
  243. int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile
  244. constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); // between shared memory writes
  245. constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); // between shared memory tile reads
  246. constexpr int a_sh_rd_delta_i = a_sh_stride * 16; // within a shared memory tile
  247. constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); // overall size of a tile
  248. constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); // number of shared write iterations for a tile
  249. int b_gl_stride = 16 * prob_n / 32;
  250. constexpr int b_sh_stride = 32 * thread_n_blocks / 4;
  251. int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
  252. int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride);
  253. constexpr int b_sh_wr_delta = threads;
  254. constexpr int b_sh_rd_delta = threads;
  255. constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
  256. constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
  257. int s_gl_stride = prob_n / 8;
  258. constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
  259. constexpr int s_sh_stage = s_sh_stride;
  260. int s_gl_rd_delta = s_gl_stride;
  261. // Global A read index of current thread.
  262. int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o);
  263. a_gl_rd += a_gl_rd_delta_o * slice_row;
  264. // Shared write index of current thread.
  265. int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o);
  266. // Shared read index.
  267. int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;
  268. a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
  269. int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
  270. b_gl_rd += b_sh_stride * slice_col;
  271. b_gl_rd += b_gl_rd_delta_o * slice_row;
  272. int b_sh_wr = threadIdx.x;
  273. int b_sh_rd = threadIdx.x;
  274. int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x;
  275. int s_sh_wr = threadIdx.x;
  276. int s_sh_rd;
  277. // We use a different scale layout for grouped and column-wise quantization as we scale a `half2` tile in column-major
  278. // layout in the former and in row-major in the latter case.
  279. if (group_blocks != -1)
  280. s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4;
  281. else
  282. s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4;
  283. // Precompute which thread should not read memory in which iterations; this is needed if there are more threads than
  284. // required for a certain tilesize or when the batchsize is not a multiple of 16.
  285. bool a_sh_wr_pred[a_sh_wr_iters];
  286. #pragma unroll
  287. for (int i = 0; i < a_sh_wr_iters; i++)
  288. a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
  289. bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
  290. // To ensure that writing and reading A tiles to/from shared memory, the latter in fragment format, is fully bank
  291. // conflict free, we need to use a rather fancy XOR-based layout. The key here is that neither reads nor writes of
  292. // the 16-byte `int4` blocks of 8 consecutive threads involve the same shared memory banks. Further, it seems (based
  293. // on NSight-Compute) that each warp must also write a consecutive memory segment?
  294. auto transform_a = [&] (int i) {
  295. int row = i / a_gl_rd_delta_o;
  296. return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
  297. };
  298. // Since the computation of this remapping is non-trivial and, due to our main loop unrolls, all shared memory
  299. // accesses are static, we simply precompute both transformed reads and writes.
  300. int a_sh_wr_trans[a_sh_wr_iters];
  301. #pragma unroll
  302. for (int i = 0; i < a_sh_wr_iters; i++)
  303. a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
  304. int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
  305. #pragma unroll
  306. for (int i = 0; i < b_sh_wr_iters; i++) {
  307. #pragma unroll
  308. for (int j = 0; j < thread_m_blocks; j++)
  309. a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
  310. }
  311. // Since B-accesses have non-constant stride they have to be computed at runtime; we break dependencies between
  312. // subsequent accesses with a tile by maintining multiple pointers (we have enough registers), a tiny optimization.
  313. const int4* B_ptr[b_sh_wr_iters];
  314. #pragma unroll
  315. for (int i = 0; i < b_sh_wr_iters; i++)
  316. B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
  317. extern __shared__ int4 sh[];
  318. // Shared memory storage for global fetch pipelines.
  319. int4* sh_a = sh;
  320. int4* sh_b = sh_a + (stages * a_sh_stage);
  321. int4* sh_s = sh_b + (stages * b_sh_stage);
  322. int4* sh_z = sh_s + (stages * s_sh_stage);
  323. // Register storage for double buffer of shared memory reads.
  324. FragA frag_a[2][thread_m_blocks];
  325. I4 frag_b_quant[2];
  326. FragC frag_c[thread_m_blocks][4][2];
  327. FragS frag_s[2][4];
  328. FragS frag_z[2][4];
  329. // Zero accumulators.
  330. auto zero_accums = [&] () {
  331. #pragma unroll
  332. for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
  333. reinterpret_cast<float*>(frag_c)[i] = 0;
  334. };
  335. // Asynchronously fetch the next A, B and s tile from global to the next shared memory pipeline location.
  336. auto fetch_to_shared = [&] (int pipe, int a_off, bool pred = true) {
  337. if (pred) {
  338. int4* sh_a_stage = sh_a + a_sh_stage * pipe;
  339. #pragma unroll
  340. for (int i = 0; i < a_sh_wr_iters; i++) {
  341. cp_async4_pred(
  342. &sh_a_stage[a_sh_wr_trans[i]],
  343. &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
  344. a_sh_wr_pred[i]
  345. );
  346. }
  347. int4* sh_b_stage = sh_b + b_sh_stage * pipe;
  348. #pragma unroll
  349. for (int i = 0; i < b_sh_wr_iters; i++) {
  350. cp_async4_stream(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]);
  351. B_ptr[i] += b_gl_rd_delta_o;
  352. }
  353. // Only fetch scales if this tile starts a new group
  354. if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {
  355. int4* sh_s_stage = sh_s + s_sh_stage * pipe;
  356. int4* sh_z_stage = sh_z + s_sh_stage * pipe;
  357. if (s_sh_wr_pred) {
  358. cp_async4_stream(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
  359. cp_async4_stream(&sh_z_stage[s_sh_wr], &z[s_gl_rd]);
  360. }
  361. s_gl_rd += s_gl_rd_delta;
  362. }
  363. }
  364. // Insert a fence even when we are winding down the pipeline to ensure that waiting is also correct at this point.
  365. cp_async_fence();
  366. };
  367. // Wait until the next thread tile has been loaded to shared memory.
  368. auto wait_for_stage = [&] () {
  369. // We only have `stages - 2` active fetches since we are double buffering and can only issue the next fetch when
  370. // it is guaranteed that the previous shared memory load is fully complete (as it may otherwise be overwritten).
  371. cp_async_wait<stages - 2>();
  372. __syncthreads();
  373. };
  374. // Load the next sub-tile from the current location in the shared memory pipe into the current register buffer.
  375. auto fetch_to_registers = [&] (int k, int pipe) {
  376. // It may seem inefficient that we reload the groups for every sub-tile; however, this does not seem to be a
  377. // significant bottleneck, while some theoretically better attempts have lead to bad instruction ordering by the
  378. // compiler and correspondingly a noticeable drop in performance.
  379. if (group_blocks != -1) {
  380. int4* sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks)));
  381. int4* sh_z_stage = sh_z + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks)));
  382. reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
  383. reinterpret_cast<int4*>(&frag_z[k % 2])[0] = sh_z_stage[s_sh_rd];
  384. }
  385. int4* sh_a_stage = sh_a + a_sh_stage * pipe;
  386. #pragma unroll
  387. for (int i = 0; i < thread_m_blocks; i++)
  388. ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
  389. int4* sh_b_stage = sh_b + b_sh_stage * pipe;
  390. frag_b_quant[k % 2] = *reinterpret_cast<I4*>(&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]);
  391. };
  392. // Execute the actual tensor core matmul of a sub-tile.
  393. auto matmul = [&] (int k) {
  394. // We have the m dimension as the inner loop in order to encourage overlapping dequantization and matmul operations.
  395. #pragma unroll
  396. for (int j = 0; j < 4; j++) {
  397. int b_quant = frag_b_quant[k % 2][j];
  398. int b_quant_shift = b_quant >> 8;
  399. FragB frag_b0 = dequant(b_quant);
  400. // If there are no groups, we can just scale the final output once and can avoid doing so for each weight.
  401. if (group_blocks != -1)
  402. scale(frag_b0, frag_s[k % 2][j], frag_z[k % 2][j], 0);
  403. FragB frag_b1 = dequant(b_quant_shift);
  404. if (group_blocks != -1)
  405. scale(frag_b1, frag_s[k % 2][j], frag_z[k % 2][j], 1);
  406. #pragma unroll
  407. for (int i = 0; i < thread_m_blocks; i++) {
  408. mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
  409. mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
  410. }
  411. }
  412. };
  413. // Since we slice across the k dimension of a tile in order to increase the number of warps while keeping the n
  414. // dimension of a tile reasonable, we have multiple warps that accumulate their partial sums of the same output
  415. // location; which we have to reduce over in the end. We do in shared memory.
  416. auto thread_block_reduce = [&] () {
  417. constexpr int red_off = threads / b_sh_stride / 2;
  418. if (red_off >= 1) {
  419. int red_idx = threadIdx.x / b_sh_stride;
  420. constexpr int red_sh_stride = b_sh_stride * 4 * 2;
  421. constexpr int red_sh_delta = b_sh_stride;
  422. int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
  423. // Parallel logarithmic shared memory reduction. We make sure to avoid any unnecessary read or write iterations,
  424. // e.g., for two warps we write only once by warp 1 and read only once by warp 0.
  425. #pragma unroll
  426. for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
  427. #pragma unroll
  428. for (int i = red_off; i > 0; i /= 2) {
  429. if (i <= red_idx && red_idx < 2 * i) {
  430. #pragma unroll
  431. for (int j = 0; j < 4 * 2; j++) {
  432. int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
  433. if (i < red_off) {
  434. float* c_rd = reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
  435. float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
  436. #pragma unroll
  437. for (int k = 0; k < 4; k++)
  438. reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k];
  439. }
  440. sh[red_sh_wr] = reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
  441. }
  442. }
  443. __syncthreads();
  444. }
  445. if (red_idx == 0) {
  446. #pragma unroll
  447. for (int i = 0; i < 4 * 2; i++) {
  448. float* c_rd = reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
  449. #pragma unroll
  450. for (int j = 0; j < 4; j++)
  451. reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] += c_rd[j];
  452. }
  453. }
  454. __syncthreads();
  455. }
  456. }
  457. };
  458. // Since multiple threadblocks may process parts of the same column slice, we finally have to globally reduce over
  459. // the results. As the striped partitioning minimizes the number of such reductions and our outputs are usually rather
  460. // small, we perform this reduction serially in L2 cache.
  461. auto global_reduce = [&] (bool first = false, bool last = false) {
  462. // We are very careful here to reduce directly in the output buffer to maximize L2 cache utilization in this step.
  463. // To do this, we write out results in FP16 (but still reduce with FP32 compute).
  464. constexpr int active_threads = 32 * thread_n_blocks / 4;
  465. if (threadIdx.x < active_threads) {
  466. int c_gl_stride = prob_n / 8;
  467. int c_gl_wr_delta_o = 8 * c_gl_stride;
  468. int c_gl_wr_delta_i = 4 * (active_threads / 32);
  469. int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + 4 * (threadIdx.x / 32) + threadIdx.x % 4;
  470. c_gl_wr += (2 * thread_n_blocks) * slice_col;
  471. constexpr int c_sh_wr_delta = active_threads;
  472. int c_sh_wr = threadIdx.x;
  473. int row = (threadIdx.x % 32) / 4;
  474. if (!first) {
  475. // Interestingly, doing direct global accesses here really seems to mess up the compiler and lead to slowdowns,
  476. // hence we also use async-copies even though these fetches are not actually asynchronous.
  477. #pragma unroll
  478. for (int i = 0; i < thread_m_blocks * 4; i++) {
  479. cp_async4_pred(
  480. &sh[c_sh_wr + c_sh_wr_delta * i],
  481. &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)],
  482. i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m
  483. );
  484. }
  485. cp_async_fence();
  486. cp_async_wait<0>();
  487. }
  488. #pragma unroll
  489. for (int i = 0; i < thread_m_blocks * 4; i++) {
  490. if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
  491. if (!first) {
  492. int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
  493. #pragma unroll
  494. for (int j = 0; j < 2 * 4; j++) {
  495. reinterpret_cast<float*>(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += __half2float(
  496. reinterpret_cast<__half*>(&c_red)[j]
  497. );
  498. }
  499. }
  500. if (!last) {
  501. int4 c;
  502. #pragma unroll
  503. for (int j = 0; j < 2 * 4; j++) {
  504. reinterpret_cast<__half*>(&c)[j] = __float2half(
  505. reinterpret_cast<float*>(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]
  506. );
  507. }
  508. C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = c;
  509. }
  510. }
  511. }
  512. }
  513. };
  514. // Write out the reduce final result in the correct layout. We only actually reshuffle matrix fragments in this step,
  515. // the reduction above is performed in fragment layout.
  516. auto write_result = [&] () {
  517. int c_gl_stride = prob_n / 8;
  518. constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
  519. int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
  520. constexpr int c_sh_rd_delta = c_sh_stride * (threads / (2 * thread_n_blocks));
  521. int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks));
  522. c_gl_wr += (2 * thread_n_blocks) * slice_col;
  523. int c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;
  524. c_sh_wr += 32 * (threadIdx.x / 32);
  525. int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks));
  526. int c_gl_wr_end = c_gl_stride * prob_m;
  527. // We first reorder in shared memory to guarantee the most efficient final global write patterns
  528. auto write = [&] (int idx, float c0, float c1, FragS& s, FragS& z) {
  529. half2 res = __halves2half2(__float2half(c0), __float2half(c1));
  530. if (group_blocks == -1) // for per-column quantization we finally apply the scale here
  531. res = __hfma2(res, s[0], z[0]);
  532. ((half2*) sh)[idx] = res;
  533. };
  534. if (threadIdx.x / 32 < thread_n_blocks / 4) {
  535. #pragma unroll
  536. for (int i = 0; i < thread_m_blocks; i++) {
  537. #pragma unroll
  538. for (int j = 0; j < 4; j++) {
  539. int wr = c_sh_wr + 8 * j;
  540. write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0], frag_z[j / 2][2 * (j % 2) + 0]);
  541. write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0], frag_z[j / 2][2 * (j % 2) + 0]);
  542. write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1], frag_z[j / 2][2 * (j % 2) + 1]);
  543. write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1], frag_z[j / 2][2 * (j % 2) + 1]);
  544. }
  545. c_sh_wr += 16 * (4 * c_sh_stride);
  546. }
  547. }
  548. __syncthreads();
  549. #pragma unroll
  550. for (int i = 0; i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) {
  551. if (c_gl_wr < c_gl_wr_end) {
  552. C[c_gl_wr] = sh[c_sh_rd];
  553. c_gl_wr += c_gl_wr_delta;
  554. c_sh_rd += c_sh_rd_delta;
  555. }
  556. }
  557. };
  558. // Start global fetch and register load pipelines.
  559. auto start_pipes = [&] () {
  560. #pragma unroll
  561. for (int i = 0; i < stages - 1; i++)
  562. fetch_to_shared(i, i, i < slice_iters);
  563. zero_accums();
  564. wait_for_stage();
  565. fetch_to_registers(0, 0);
  566. a_gl_rd += a_gl_rd_delta_o * (stages - 1);
  567. };
  568. start_pipes();
  569. // Main loop.
  570. while (slice_iters) {
  571. // We unroll over both the global fetch and the register load pipeline to ensure all shared memory accesses are
  572. // static. Note that both pipelines have even length meaning that the next iteration will always start at index 0.
  573. #pragma unroll
  574. for (int pipe = 0; pipe < stages;) {
  575. #pragma unroll
  576. for (int k = 0; k < b_sh_wr_iters; k++) {
  577. fetch_to_registers(k + 1, pipe % stages);
  578. if (k == b_sh_wr_iters - 2) {
  579. fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages);
  580. pipe++;
  581. wait_for_stage();
  582. }
  583. matmul(k);
  584. }
  585. slice_iters--;
  586. if (slice_iters == 0)
  587. break;
  588. }
  589. a_gl_rd += a_gl_rd_delta_o * stages;
  590. // Process results and, if necessary, proceed to the next column slice. While this pattern may not be the most
  591. // readable, other ways of writing the loop seemed to noticeably worse performance after compilation.
  592. if (slice_iters == 0) {
  593. cp_async_wait<0>();
  594. bool last = slice_idx == slice_count - 1;
  595. // For per-column scales, we only fetch them here in the final step before write-out
  596. if (group_blocks == -1 && last) {
  597. if (s_sh_wr_pred) {
  598. cp_async4_stream(&sh_s[s_sh_wr], &s[s_gl_rd]);
  599. cp_async4_stream(&sh_z[s_sh_wr], &z[s_gl_rd]);
  600. }
  601. cp_async_fence();
  602. }
  603. thread_block_reduce();
  604. if (group_blocks == -1 && last) {
  605. cp_async_wait<0>();
  606. __syncthreads();
  607. if (threadIdx.x / 32 < thread_n_blocks / 4) {
  608. reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
  609. reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
  610. reinterpret_cast<int4*>(&frag_z)[0] = sh_z[s_sh_rd + 0];
  611. reinterpret_cast<int4*>(&frag_z)[1] = sh_z[s_sh_rd + 4];
  612. }
  613. }
  614. if (slice_count > 1) { // only globally reduce if there is more than one block in a slice
  615. barrier_acquire(&locks[slice_col], slice_idx);
  616. global_reduce(slice_idx == 0, last);
  617. barrier_release(&locks[slice_col], last);
  618. }
  619. if (last) // only the last block in a slice actually writes the result
  620. write_result();
  621. slice_row = 0;
  622. slice_col++;
  623. init_slice();
  624. if (slice_iters) {
  625. a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o);
  626. #pragma unroll
  627. for (int i = 0; i < b_sh_wr_iters; i++)
  628. B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
  629. s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
  630. start_pipes();
  631. }
  632. }
  633. }
  634. }
  635. // 8 warps are a good choice since every SM has 4 schedulers and having more than 1 warp per schedule allows some more
  636. // latency hiding. At the same time, we want relatively few warps to have many registers per warp and small tiles.
  637. const int THREADS = 256;
  638. const int STAGES = 4; // 4 pipeline stages fit into shared memory
  639. const int SHARED_MEM = 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0)
  640. #define CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, GROUP_BLOCKS) \
  641. else if ( \
  642. thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS && thread_k_blocks == THREAD_K_BLOCKS && \
  643. group_blocks == GROUP_BLOCKS \
  644. ) { \
  645. cudaFuncSetAttribute( \
  646. Marlin<THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>, \
  647. cudaFuncAttributeMaxDynamicSharedMemorySize, \
  648. SHARED_MEM \
  649. ); \
  650. Marlin<THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS><<<blocks, THREADS, SHARED_MEM, stream>>>( \
  651. A_ptr, B_ptr, C_ptr, s_ptr, z_ptr,\
  652. prob_m, prob_n, prob_k, \
  653. locks \
  654. ); \
  655. }
  656. const int ERR_PROB_SHAPE = 1;
  657. const int ERR_KERN_SHAPE = 2;
  658. int marlin_cuda(
  659. const void* A,
  660. const void* B,
  661. void* C,
  662. void* s,
  663. void* z,
  664. int prob_m,
  665. int prob_n,
  666. int prob_k,
  667. void* workspace,
  668. int groupsize = -1,
  669. int dev = 0,
  670. cudaStream_t stream = 0,
  671. int thread_k = -1,
  672. int thread_n = -1,
  673. int sms = -1
  674. ) {
  675. int tot_m = prob_m;
  676. int tot_m_blocks = ceildiv(tot_m, 16);
  677. if (sms == -1)
  678. cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
  679. if (thread_k == -1 || thread_n == -1) {
  680. if (prob_m <= 16) {
  681. // For small batchizes, better partitioning is slightly more important than better compute utilization
  682. thread_k = 128;
  683. thread_n = 128;
  684. } else {
  685. thread_k = 64;
  686. thread_n = 256;
  687. }
  688. }
  689. int thread_k_blocks = thread_k / 16;
  690. int thread_n_blocks = thread_n / 16;
  691. int group_blocks = (groupsize == -1) ? -1 : groupsize / 16;
  692. int blocks = sms;
  693. if (prob_n % thread_n != 0 || prob_k % thread_k != 0 || (group_blocks != -1 && prob_k % group_blocks != 0))
  694. return ERR_PROB_SHAPE;
  695. if (prob_m == 0 || prob_n == 0 || prob_k == 0)
  696. return 0;
  697. const int4* A_ptr = (const int4*) A;
  698. const int4* B_ptr = (const int4*) B;
  699. int4* C_ptr = (int4*) C;
  700. const int4* s_ptr = (const int4*) s;
  701. const int4* z_ptr = (const int4*) z;
  702. int cols = prob_n / thread_n;
  703. int* locks = (int*) workspace;
  704. int ret = 0;
  705. for (int i = 0; i < tot_m_blocks; i += 4) {
  706. int thread_m_blocks = tot_m_blocks - i;
  707. prob_m = tot_m - 16 * i;
  708. if (thread_m_blocks > 4) {
  709. thread_m_blocks = 4;
  710. prob_m = 64;
  711. }
  712. // For compilation speed, we only define the kernel configurations that have seemed useful (in terms of performance)
  713. // in our testing, however many more are, in principle, possible.
  714. if (false) {}
  715. CALL_IF(1, 8, 8, -1)
  716. CALL_IF(1, 8, 8, 8)
  717. CALL_IF(1, 16, 4, -1)
  718. CALL_IF(1, 16, 4, 8)
  719. CALL_IF(2, 16, 4, -1)
  720. CALL_IF(2, 16, 4, 8)
  721. CALL_IF(3, 16, 4, -1)
  722. CALL_IF(3, 16, 4, 8)
  723. CALL_IF(4, 16, 4, -1)
  724. CALL_IF(4, 16, 4, 8)
  725. else
  726. ret = ERR_KERN_SHAPE;
  727. A_ptr += 16 * thread_m_blocks * (prob_k / 8);
  728. C_ptr += 16 * thread_m_blocks * (prob_n / 8);
  729. }
  730. return ret;
  731. }
  732. __global__ void dequant_marlin(
  733. const uint32_t* qweight,
  734. const half* scales,
  735. const half* zeros,
  736. half* out,
  737. int m,
  738. int n,
  739. int groupsize
  740. ) {
  741. int t = threadIdx.x;
  742. int group = blockIdx.x * 16 / groupsize;
  743. int4 pack = *reinterpret_cast<const int4*>(qweight + blockIdx.x * n * 2 + blockIdx.y * 128 + t * 4);
  744. const FragS* sscale = reinterpret_cast<const FragS*>(scales + group * n + blockIdx.y * 64 + t / 4 * 8);
  745. const FragS* szero = reinterpret_cast<const FragS*>(zeros + group * n + blockIdx.y * 64 + t / 4 * 8);
  746. out = out + (blockIdx.x * 16 + (t % 4) * 2) * n + blockIdx.y * 64 + t / 4;
  747. uint32_t* unpack = reinterpret_cast<uint32_t*>(&pack);
  748. for (int i = 0; i < 4; i += 1) {
  749. FragB frag_b0 = dequant(unpack[i]);
  750. FragB frag_b1 = dequant(unpack[i] >> 8);
  751. FragS frag_s = sscale[i];
  752. FragS frag_z = szero[i];
  753. scale(frag_b0, frag_s, frag_z, 0);
  754. scale(frag_b1, frag_s, frag_z, 1);
  755. *out = frag_b0[0].x;
  756. *(out + n) = frag_b0[0].y;
  757. *(out + 8 * n) = frag_b0[1].x;
  758. *(out + 9 * n) = frag_b0[1].y;
  759. *(out + 8) = frag_b1[0].x;
  760. *(out + n + 8) = frag_b1[0].y;
  761. *(out + 8 * n + 8) = frag_b1[1].x;
  762. *(out + 9 * n + 8) = frag_b1[1].y;
  763. out = out + 16;
  764. }
  765. }
  766. __global__ void awq_to_marlin(
  767. uint32_t* in,
  768. uint32_t* out,
  769. int m,
  770. int n
  771. ) {
  772. uint32_t row = blockIdx.x * 16;
  773. uint32_t col = blockIdx.y * 8;
  774. uint32_t t = threadIdx.x;
  775. // marlin packs 4 16x16 blocks one time;
  776. const int pad_len = 18;
  777. __shared__ uint8_t block[4][16][pad_len];
  778. // unpack
  779. int row_offset = t / 8;
  780. int col_offset = t % 8;
  781. int order_map[8] = {0, 2, 4, 6, 1, 3, 5, 7};
  782. for (int offset = row_offset; offset < 16; offset += 4) {
  783. uint32_t v = in[(row + offset) * n + col + col_offset];
  784. #pragma unroll
  785. for (int i = 0; i < 8; i += 1) {
  786. block[col_offset / 2][offset][8 * (col_offset % 2) + order_map[i]] = v & 0xf;
  787. v >>= 4;
  788. }
  789. }
  790. // repack
  791. // ref: _get_perms @ https://github.com/IST-DASLab/marlin/blob/master/marlin/__init__.py
  792. uint32_t srow = (t % 4) * 2;
  793. uint32_t scol = t / 4;
  794. uint32_t idx[8][2];
  795. idx[0][0] = srow; idx[0][1] = scol;
  796. idx[1][0] = srow + 8; idx[1][1] = scol;
  797. idx[2][0] = srow; idx[2][1] = scol + 8;
  798. idx[3][0] = srow + 8; idx[3][1] = scol + 8;
  799. idx[4][0] = srow + 1; idx[4][1] = scol;
  800. idx[5][0] = srow + 9; idx[5][1] = scol;
  801. idx[6][0] = srow + 1; idx[6][1] = scol + 8;
  802. idx[7][0] = srow + 9; idx[7][1] = scol + 8;
  803. #pragma unroll
  804. for (int i = 0; i < 4; i += 1) {
  805. uint32_t v[8];
  806. #pragma unroll
  807. for (int j = 0; j < 8; ++j) {
  808. v[j] = block[i][idx[j][0]][idx[j][1]];
  809. }
  810. uint32_t pack = (v[7] << 28) | (v[6] << 24) | (v[5] << 20) | (v[4] << 16) |
  811. (v[3] << 12) | (v[2] << 8) | (v[1] << 4) | v[0];
  812. out[blockIdx.x * n * 16 + blockIdx.y * 128 + t * 4 + i] = pack;
  813. }
  814. }
  815. __global__ void gptq_to_marlin(
  816. uint32_t* in,
  817. uint32_t* out,
  818. int* g_idx,
  819. int m,
  820. int n
  821. ) {
  822. uint32_t col = blockIdx.y * 64;
  823. uint32_t t = threadIdx.x;
  824. // marlin packs 4 16x16 blocks one time;
  825. const int pad_len = 18;
  826. __shared__ uint8_t block[4][16][pad_len];
  827. // unpack
  828. #pragma unroll
  829. for (int i = 0; i < 16; i += 1) {
  830. uint32_t source_row = g_idx? g_idx[blockIdx.x * 16 + i] : (blockIdx.x * 16 + i);
  831. int in_row = source_row >> 3;
  832. int in_subrow = source_row & 0x07;
  833. int in_row_shift = in_subrow << 2;
  834. for (int offset = t; offset < 64; offset += 32) {
  835. //printf("in_row: %d, n: %d, col: %d, offset: %d\n", in_row, n, col, offset);
  836. uint32_t v = in[in_row * n + col + offset];
  837. block[offset / 16][i][offset % 16] = (v >> in_row_shift) & 0xf;
  838. }
  839. }
  840. // repack
  841. // ref: _get_perms @ https://github.com/IST-DASLab/marlin/blob/master/marlin/__init__.py
  842. uint32_t srow = (t % 4) * 2;
  843. uint32_t scol = t / 4;
  844. uint32_t idx[8][2];
  845. idx[0][0] = srow; idx[0][1] = scol;
  846. idx[1][0] = srow + 8; idx[1][1] = scol;
  847. idx[2][0] = srow; idx[2][1] = scol + 8;
  848. idx[3][0] = srow + 8; idx[3][1] = scol + 8;
  849. idx[4][0] = srow + 1; idx[4][1] = scol;
  850. idx[5][0] = srow + 9; idx[5][1] = scol;
  851. idx[6][0] = srow + 1; idx[6][1] = scol + 8;
  852. idx[7][0] = srow + 9; idx[7][1] = scol + 8;
  853. #pragma unroll
  854. for (int i = 0; i < 4; i += 1) {
  855. uint32_t v[8];
  856. #pragma unroll
  857. for (int j = 0; j < 8; ++j) {
  858. v[j] = block[i][idx[j][0]][idx[j][1]];
  859. }
  860. uint32_t pack = (v[7] << 28) | (v[6] << 24) | (v[5] << 20) | (v[4] << 16) |
  861. (v[3] << 12) | (v[2] << 8) | (v[1] << 4) | v[0];
  862. out[blockIdx.x * n * 2 + blockIdx.y * 128 + t * 4 + i] = pack;
  863. }
  864. }
  865. } // namespace marlin
  866. } // namespace aphrodite
  867. const int ERR_PROB_SHAPE = 1;
  868. const int ERR_KERN_SHAPE = 2;
  869. // input: `torch.half` input matrix of shape `(m, k)` in standard row-major layout
  870. // weights: `torch.int` weight matrix of original shape `(k, n)` in Marlin format; see `Layer.pack()`
  871. // output: `torch.half` out matrix of shape `(m, n)` in standard row-major layout
  872. // scales: `torch.half` scales of shape `(m / groupsize, n)`
  873. // workspace: `torch.int` tensor with at least `n / 128` entries that are all zero
  874. void marlin_gemm_zero(
  875. const torch::Tensor& input,
  876. const torch::Tensor& weights,
  877. torch::Tensor& output,
  878. const torch::Tensor& scales,
  879. const torch::Tensor& zeros,
  880. torch::Tensor& workspace
  881. ) {
  882. // thread_k: `k` size of a thread_tile in `weights` (can usually be left as auto -1)
  883. int thread_k = -1;
  884. // thread_n: `n` size of a thread_tile in `weights` (can usually be left as auto -1)
  885. int thread_n = -1;
  886. // sms: number of SMs to use for the kernel (can usually be left as auto -1)
  887. int sms = -1;
  888. int prob_m = input.size(0);
  889. int prob_n = output.size(1);
  890. int prob_k = input.size(1);
  891. int groupsize = (scales.size(0) == 1) ? -1 : prob_k / scales.size(0);
  892. if (groupsize != -1 && groupsize * scales.size(0) != prob_k)
  893. AT_ERROR("k=", prob_k, " not compatible with ", scales.size(0), " groups.");
  894. int dev = input.get_device();
  895. if (prob_m >= 256) {
  896. auto options = torch::TensorOptions().dtype(input.dtype()).device(input.device());
  897. at::Tensor temp_dq = torch::empty({prob_k, prob_n}, options);
  898. dim3 blockDim, gridDim;
  899. gridDim.x = weights.size(0);
  900. gridDim.y = weights.size(1) / 128;
  901. blockDim.x = 32;
  902. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  903. aphrodite::marlin::dequant_marlin<<<gridDim, blockDim, 0, stream>>>
  904. (
  905. (const uint32_t*) weights.data_ptr(),
  906. (const half*)scales.data_ptr(),
  907. (const half*)zeros.data_ptr(),
  908. (half*)temp_dq.data_ptr(),
  909. prob_k,
  910. prob_n,
  911. groupsize
  912. );
  913. const half alpha = __float2half(1.0f);
  914. const half beta = __float2half(0.0f);
  915. cublasHgemm(at::cuda::getCurrentCUDABlasHandle(),
  916. CUBLAS_OP_N,
  917. CUBLAS_OP_N,
  918. prob_n, prob_m, prob_k,
  919. &alpha, (half*)temp_dq.data_ptr(), prob_n,
  920. (const half*)input.data_ptr(), prob_k,
  921. &beta, (half*)output.data_ptr(), prob_n);
  922. } else {
  923. int err = aphrodite::marlin::marlin_cuda(
  924. input.data_ptr(),
  925. weights.data_ptr(),
  926. output.data_ptr(),
  927. scales.data_ptr(),
  928. zeros.data_ptr(),
  929. prob_m, prob_n, prob_k,
  930. workspace.data_ptr(),
  931. groupsize,
  932. dev,
  933. at::cuda::getCurrentCUDAStream(dev),
  934. thread_k,
  935. thread_n,
  936. sms
  937. );
  938. if (err == ERR_PROB_SHAPE) {
  939. AT_ERROR(
  940. "Problem (m=", prob_m, ", n=", prob_n, ", k=", prob_k, ")",
  941. " not compatible with thread_k=", thread_k, ", thread_n=", thread_n, "."
  942. );
  943. } else if (err == ERR_KERN_SHAPE) {
  944. AT_ERROR(
  945. "No kernel implementation for thread_k=", thread_k, ", thread_n=", thread_n, ", groupsize=", groupsize, "."
  946. );
  947. }
  948. }
  949. }
  950. torch::Tensor gptq_to_marlin(
  951. torch::Tensor W,
  952. torch::Tensor g_idx
  953. ){
  954. int m = W.sizes()[0];
  955. int n = W.sizes()[1];
  956. assert(W.is_contiguous());
  957. assert(W.dtype() == at::kInt);
  958. assert(m % 2 == 0);
  959. assert(n % 64 == 0);
  960. auto result = at::empty(
  961. {m / 2, n * 2}, at::TensorOptions().dtype(at::kInt).device(W.device()));
  962. const dim3 threads(32);
  963. // marlin packs 16 x 64 block and gptq packs 8 x 1
  964. const dim3 blocks(m / 2, n / 64);
  965. cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
  966. aphrodite::marlin::gptq_to_marlin<<<blocks, threads, 0, stream>>>(
  967. (uint32_t*)W.data_ptr(),
  968. (uint32_t*)result.data_ptr(),
  969. g_idx.device().is_meta() ? NULL : (int*)g_idx.data_ptr(),
  970. m,
  971. n
  972. );
  973. return result;
  974. }
  975. torch::Tensor awq_to_marlin(
  976. torch::Tensor W
  977. ){
  978. int m = W.sizes()[0];
  979. int n = W.sizes()[1];
  980. assert(W.dtype() == at::kInt);
  981. assert(m % 16 == 0);
  982. auto result = at::empty(
  983. {m / 16, n * 16}, at::TensorOptions().dtype(at::kInt).device(W.device()));
  984. const dim3 threads(32);
  985. // marlin packs 16 x 64 block and awq packs 1 x 8
  986. const dim3 blocks(m / 16, n / 8);
  987. cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
  988. aphrodite::marlin::awq_to_marlin<<<blocks, threads, 0, stream>>>(
  989. (uint32_t*)W.data_ptr(),
  990. (uint32_t*)result.data_ptr(),
  991. m,
  992. n
  993. );
  994. return result;
  995. }
  996. #endif