marlin_24_cuda_kernel.cu 45 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136
  1. /*
  2. * Notice: This file was modified by Neuralmagic inc to include 8-bit support
  3. *
  4. * Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All
  5. * Rights Reserved.
  6. *
  7. * Licensed under the Apache License, Version 2.0 (the "License");
  8. * you may not use this file except in compliance with the License.
  9. * You may obtain a copy of the License at
  10. *
  11. * http://www.apache.org/licenses/LICENSE-2.0
  12. *
  13. * Unless required by applicable law or agreed to in writing, software
  14. * distributed under the License is distributed on an "AS IS" BASIS,
  15. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. * See the License for the specific language governing permissions and
  17. * limitations under the License.
  18. */
  19. #include <torch/all.h>
  20. #include <ATen/cuda/CUDAContext.h>
  21. #include <c10/cuda/CUDAGuard.h>
  22. #include <cuda.h>
  23. #include <cuda_fp16.h>
  24. #include <cuda_runtime.h>
  25. #include <iostream>
  26. #include "common/base.h"
  27. #include "../../../core/scalar_type.hpp"
  28. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  29. #else
  30. #include "common/mem.h"
  31. #include "common/mma.h"
  32. #endif
  33. template <typename T>
  34. inline std::string str(T x) {
  35. return std::to_string(x);
  36. }
  37. namespace marlin_24 {
  38. // 8 warps are a good choice since every SM has 4 schedulers and having more
  39. // than 1 warp per schedule allows some more latency hiding. At the same time,
  40. // we want relatively few warps to have many registers per warp and small tiles.
  41. static constexpr int THREADS = 256;
  42. static constexpr int STAGES = 4;
  43. static constexpr int min_thread_n = 128;
  44. static constexpr int tile_size = 16;
  45. static constexpr int max_par = 64;
  46. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  47. template <const int num_bits, // weight bits
  48. const int threads, // number of threads in a threadblock
  49. const int thread_m_blocks, // number of 16x16 blocks in the m
  50. // dimension (batchsize) of the
  51. // threadblock
  52. const int thread_n_blocks, // same for n dimension (output)
  53. const int thread_k_blocks, // same for k dimension (reduction)
  54. const int stages, // number of stages for the async global->shared
  55. // fetch pipeline
  56. const int group_blocks = -1 // number of consecutive 16x16 blocks
  57. // with a separate quantization scale
  58. >
  59. __global__ void Marlin_24(
  60. const int4* __restrict__ A, // fp16 input matrix of shape mxk
  61. const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
  62. const int4* __restrict__ meta, // 2bit metadata information about 2:4
  63. // format on B
  64. int4* __restrict__ C, // fp16 output buffer of shape mxn
  65. const int4* __restrict__ s, // fp16 quantization scales of shape
  66. // (k/groupsize)xn
  67. int prob_m, // batch dimension m
  68. int prob_n, // output dimension n
  69. int prob_k, // reduction dimension k
  70. int* locks // extra global storage for barrier synchronization
  71. ) {}
  72. torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
  73. torch::Tensor& b_meta,
  74. torch::Tensor& b_scales,
  75. torch::Tensor& workspace,
  76. aphrodite::ScalarTypeTorchPtr const& b_q_type,
  77. int64_t size_m, int64_t size_n,
  78. int64_t size_k) {
  79. TORCH_CHECK_NOT_IMPLEMENTED(
  80. false, "gptq_marlin_24_gemm(..) requires CUDA_ARCH >= 8.0");
  81. return torch::empty({1, 1});
  82. }
  83. #else
  84. template <const int num_bits, // weight bits
  85. const int threads, // number of threads in a threadblock
  86. const int thread_m_blocks, // number of 16x16 blocks in the m
  87. // dimension (batchsize) of the
  88. // threadblock
  89. const int thread_n_blocks, // same for n dimension (output)
  90. const int thread_k_blocks, // same for k dimension (reduction)
  91. const int stages, // number of stages for the async global->shared
  92. // fetch pipeline
  93. const int group_blocks = -1 // number of consecutive 16x16 blocks
  94. // with a separate quantization scale
  95. >
  96. __global__ void Marlin_24(
  97. const int4* __restrict__ A, // fp16 input matrix of shape mxk
  98. const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
  99. const int4* __restrict__ meta, // 2bit metadata information about 2:4
  100. // format on B
  101. int4* __restrict__ C, // fp16 output buffer of shape mxn
  102. const int4* __restrict__ s, // fp16 quantization scales of shape
  103. // (k/groupsize)xn
  104. int prob_m, // batch dimension m
  105. int prob_n, // output dimension n
  106. int prob_k, // reduction dimension k
  107. int* locks // extra global storage for barrier synchronization
  108. ) {
  109. // Each threadblock processes one "stripe" of the B matrix with (roughly) the
  110. // same size, which might involve multiple column "slices" (of width 16 *
  111. // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
  112. // example:
  113. // 0 1 3
  114. // 0 2 3
  115. // 1 2 4
  116. // While this kind of partitioning makes things somewhat more complicated, it
  117. // ensures good utilization of all SMs for many kinds of shape and GPU
  118. // configurations, while requiring as few slow global cross-threadblock
  119. // reductions as possible.
  120. // For larger GEMMs we run multiple batchsize 64 versions in parallel for a
  121. // better partitioning with less reductions
  122. int parallel = 1;
  123. if (prob_m > 16 * thread_m_blocks) {
  124. parallel = prob_m / (16 * thread_m_blocks);
  125. prob_m = 16 * thread_m_blocks;
  126. }
  127. // number of thread_k_blocks in k-dim
  128. int k_tiles = prob_k / 32 / thread_k_blocks;
  129. // number of thread_n_blocks in n-dim
  130. int n_tiles = prob_n / 16 / thread_n_blocks;
  131. // iters needed to cover all slices
  132. int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x);
  133. // Ensure that the number of tiles in each stripe is a multiple of the
  134. // groupsize; this avoids an annoying special case where a stripe starts in
  135. // the middle of group.
  136. if (group_blocks != -1)
  137. iters = (group_blocks / thread_k_blocks) *
  138. ceildiv(iters, (group_blocks / thread_k_blocks));
  139. int slice_row = (iters * blockIdx.x) % k_tiles;
  140. int slice_col_par = (iters * blockIdx.x) / k_tiles;
  141. int slice_col = slice_col_par;
  142. // number of threadblock tiles in the current slice
  143. int slice_iters;
  144. // total number of active threadblocks in the current slice
  145. int slice_count = 0;
  146. // index of threadblock in current slice; numbered bottom to top
  147. int slice_idx;
  148. // We can easily implement parallel problem execution by just remapping
  149. // indices and advancing global pointers
  150. if (slice_col_par >= n_tiles) {
  151. A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8;
  152. C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
  153. locks += (slice_col_par / n_tiles) * n_tiles;
  154. slice_col = slice_col_par % n_tiles;
  155. }
  156. // Compute all information about the current slice which is required for
  157. // synchronization.
  158. auto init_slice = [&]() {
  159. slice_iters =
  160. iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
  161. if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
  162. if (slice_iters == 0) return;
  163. if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
  164. slice_count = 1;
  165. slice_idx = 0;
  166. int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);
  167. if (col_first <= k_tiles * (slice_col_par + 1)) {
  168. int col_off = col_first - k_tiles * slice_col_par;
  169. slice_count = ceildiv(k_tiles - col_off, iters);
  170. if (col_off > 0) slice_count++;
  171. int delta_first = iters * blockIdx.x - col_first;
  172. if (delta_first < 0 || (col_off == 0 && delta_first == 0))
  173. slice_idx = slice_count - 1;
  174. else {
  175. slice_idx = slice_count - 1 - delta_first / iters;
  176. if (col_off > 0) slice_idx--;
  177. }
  178. }
  179. if (slice_col == n_tiles) {
  180. A += 16 * thread_m_blocks * prob_k / 8;
  181. C += 16 * thread_m_blocks * prob_n / 8;
  182. locks += n_tiles;
  183. slice_col = 0;
  184. }
  185. };
  186. init_slice();
  187. // RLC: 8 is vec_size -> 128-bit instructions, 8 fp16 elements
  188. int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory
  189. // stride of an A matrix tile in shared memory
  190. constexpr int a_sh_stride = 32 * thread_k_blocks / 8;
  191. // delta between subsequent A tiles in global memory
  192. constexpr int a_gl_rd_delta_o = 32 * thread_k_blocks / 8;
  193. // between subsequent accesses within a tile
  194. int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o);
  195. // between shared memory writes
  196. constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o);
  197. // between shared memory tile reads //RLC: 2 * #warps k-dim
  198. constexpr int a_sh_rd_delta_o = 4 * ((threads / 32) / (thread_n_blocks / 4));
  199. // within a shared memory tile
  200. constexpr int a_sh_rd_delta_i = a_sh_stride * 16;
  201. // overall size of a tile
  202. constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks);
  203. // number of shared write iterations for a tile
  204. constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta);
  205. constexpr int pack_factor = 32 / num_bits;
  206. int b_gl_stride = 16 * prob_n / (pack_factor * 4);
  207. constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4;
  208. constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2;
  209. constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;
  210. int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
  211. int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads);
  212. constexpr int b_sh_wr_delta = threads * b_thread_vecs;
  213. constexpr int b_sh_rd_delta = threads * b_thread_vecs;
  214. constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
  215. constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
  216. int m_gl_stride = 2 * prob_n / 8; // (16*2*4 / 8) = 16
  217. constexpr int m_sh_stride =
  218. (16 * thread_n_blocks) / 4; // #warps n-dim * threads/warp
  219. int m_gl_rd_delta_o = m_gl_stride * thread_k_blocks;
  220. int m_gl_rd_delta_i = m_gl_stride * (threads / m_sh_stride);
  221. constexpr int m_sh_wr_delta = threads / 2;
  222. constexpr int m_sh_rd_delta = threads / 2;
  223. constexpr int m_sh_stage = m_sh_stride * thread_k_blocks;
  224. constexpr int m_sh_iters = ceildiv(m_sh_stage, m_sh_wr_delta);
  225. int s_gl_stride = prob_n / 8;
  226. constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
  227. constexpr int s_sh_stage = s_sh_stride;
  228. int s_gl_rd_delta = s_gl_stride;
  229. // Global A read index of current thread.
  230. int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
  231. (threadIdx.x % a_gl_rd_delta_o);
  232. a_gl_rd += a_gl_rd_delta_o * slice_row;
  233. // Shared write index of current thread.
  234. int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
  235. (threadIdx.x % a_gl_rd_delta_o);
  236. // Shared read index.
  237. int a_sh_rd =
  238. a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;
  239. a_sh_rd += 4 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
  240. int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) +
  241. (threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
  242. b_gl_rd += b_sh_stride * slice_col;
  243. b_gl_rd += b_gl_rd_delta_o * slice_row;
  244. int b_sh_wr = threadIdx.x * b_thread_vecs;
  245. int b_sh_rd = threadIdx.x * b_thread_vecs;
  246. int m_gl_rd = m_gl_stride * (threadIdx.x / (m_sh_stride)) +
  247. (threadIdx.x % (m_sh_stride));
  248. m_gl_rd += (m_sh_stride)*slice_col;
  249. m_gl_rd += m_gl_rd_delta_o * slice_row;
  250. int m_sh_wr = threadIdx.x;
  251. int m_sh_rd = threadIdx.x % 16 + (threadIdx.x / 32) * 16;
  252. int s_gl_rd;
  253. if constexpr (group_blocks == -1) {
  254. s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
  255. } else {
  256. s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
  257. s_sh_stride * slice_col + threadIdx.x;
  258. }
  259. int s_sh_wr = threadIdx.x;
  260. int s_sh_rd;
  261. // We use a different scale layout for grouped and column-wise quantization as
  262. // we scale a `half2` tile in column-major layout in the former and in
  263. // row-major in the latter case.
  264. if (group_blocks != -1) {
  265. s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
  266. (threadIdx.x % 32) / 4;
  267. } else {
  268. s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
  269. (threadIdx.x % 32) / 4;
  270. }
  271. // Precompute which thread should not read memory in which iterations; this is
  272. // needed if there are more threads than required for a certain tilesize or
  273. // when the batchsize is not a multiple of 16.
  274. bool a_sh_wr_pred[a_sh_wr_iters];
  275. #pragma unroll
  276. for (int i = 0; i < a_sh_wr_iters; i++) {
  277. a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
  278. }
  279. bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
  280. // To ensure that writing and reading A tiles to/from shared memory, the
  281. // latter in fragment format, is fully bank conflict free, we need to use a
  282. // rather fancy XOR-based layout. The key here is that neither reads nor
  283. // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
  284. // same shared memory banks. Further, it seems (based on NSight-Compute) that
  285. // each warp must also write a consecutive memory segment?
  286. auto transform_a = [&](int i) {
  287. int row = i / a_gl_rd_delta_o;
  288. return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
  289. };
  290. // Since the computation of this remapping is non-trivial and, due to our main
  291. // loop unrolls, all shared memory accesses are static, we simply precompute
  292. // both transformed reads and writes.
  293. int a_sh_wr_trans[a_sh_wr_iters];
  294. #pragma unroll
  295. for (int i = 0; i < a_sh_wr_iters; i++)
  296. a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
  297. int a_sh_rd_trans[2][b_sh_wr_iters][thread_m_blocks];
  298. #pragma unroll
  299. for (int i = 0; i < b_sh_wr_iters; i++) {
  300. #pragma unroll
  301. for (int j = 0; j < thread_m_blocks; j++) {
  302. a_sh_rd_trans[0][i][j] =
  303. transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
  304. a_sh_rd_trans[1][i][j] =
  305. transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd + 2);
  306. }
  307. }
  308. // Since B-accesses have non-constant stride they have to be computed at
  309. // runtime; we break dependencies between subsequent accesses with a tile by
  310. // maintining multiple pointers (we have enough registers), a tiny
  311. // optimization.
  312. const int4* B_ptr[b_sh_wr_iters];
  313. #pragma unroll
  314. for (int i = 0; i < b_sh_wr_iters; i++)
  315. B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
  316. bool m_sh_wr_pred = threadIdx.x < m_sh_wr_delta;
  317. const int4* meta_ptr[m_sh_iters];
  318. #pragma unroll
  319. for (int i = 0; i < m_sh_iters; i++)
  320. meta_ptr[i] = meta + m_gl_rd_delta_i * i + m_gl_rd;
  321. extern __shared__ int4 sh[];
  322. // Shared memory storage for global fetch pipelines.
  323. int4* sh_a = sh;
  324. int4* sh_b = sh_a + (stages * a_sh_stage);
  325. int4* sh_s = sh_b + (stages * b_sh_stage);
  326. int4* sh_m = sh_s + (stages * s_sh_stage);
  327. // Register storage for double buffer of shared memory reads.
  328. FragA frag_a[2][thread_m_blocks][2];
  329. I4 frag_b_quant[2][b_thread_vecs];
  330. FragM frag_m[2][2];
  331. FragC frag_c[thread_m_blocks][4][2];
  332. FragS frag_s[2][4];
  333. // Zero accumulators.
  334. auto zero_accums = [&]() {
  335. #pragma unroll
  336. for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
  337. reinterpret_cast<float*>(frag_c)[i] = 0;
  338. };
  339. // Asynchronously fetch the next A, B and s tile from global to the next
  340. // shared memory pipeline location.
  341. auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
  342. if (pred) {
  343. int4* sh_a_stage = sh_a + a_sh_stage * pipe;
  344. #pragma unroll
  345. for (int i = 0; i < a_sh_wr_iters; i++) {
  346. cp_async4_pred(
  347. &sh_a_stage[a_sh_wr_trans[i]],
  348. &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
  349. a_sh_wr_pred[i]);
  350. }
  351. int4* sh_b_stage = sh_b + b_sh_stage * pipe;
  352. #pragma unroll
  353. for (int i = 0; i < b_sh_wr_iters; i++) {
  354. #pragma unroll
  355. for (int j = 0; j < b_thread_vecs; j++) {
  356. cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j);
  357. }
  358. B_ptr[i] += b_gl_rd_delta_o;
  359. }
  360. int4* sh_meta_stage = sh_m + m_sh_stage * pipe;
  361. #pragma unroll
  362. for (int i = 0; i < m_sh_iters; i++) {
  363. if (m_sh_wr_pred)
  364. cp_async4(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr], meta_ptr[i]);
  365. meta_ptr[i] += m_gl_rd_delta_o;
  366. }
  367. // Only fetch scales if this tile starts a new group
  368. if constexpr (group_blocks != -1) {
  369. // This assumes group_blocks >= thread_k_blocks
  370. // and would need to be modified to support smaller groups.
  371. static_assert(group_blocks >= thread_k_blocks);
  372. if (pipe % (group_blocks / thread_k_blocks) == 0) {
  373. int4* sh_s_stage = sh_s + s_sh_stage * pipe;
  374. if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
  375. s_gl_rd += s_gl_rd_delta;
  376. }
  377. }
  378. }
  379. // Insert a fence even when we are winding down the pipeline to ensure that
  380. // waiting is also correct at this point.
  381. cp_async_fence();
  382. };
  383. // Wait until the next thread tile has been loaded to shared memory.
  384. auto wait_for_stage = [&]() {
  385. // We only have `stages - 2` active fetches since we are double buffering
  386. // and can only issue the next fetch when it is guaranteed that the previous
  387. // shared memory load is fully complete (as it may otherwise be
  388. // overwritten).
  389. cp_async_wait<stages - 2>();
  390. __syncthreads();
  391. };
  392. // Load the next sub-tile from the current location in the shared memory pipe
  393. // into the current register buffer.
  394. auto fetch_to_registers = [&](int k, int pipe) {
  395. // It may seem inefficient that we reload the groups for every sub-tile;
  396. // however, this does not seem to be a significant bottleneck, while some
  397. // theoretically better attempts have lead to bad instruction ordering by
  398. // the compiler and correspondingly a noticeable drop in performance.
  399. if constexpr (group_blocks != -1) {
  400. // This assumes group_blocks >= thread_k_blocks
  401. // and would need to be modified to support smaller groups.
  402. static_assert(group_blocks >= thread_k_blocks);
  403. int4* sh_s_stage =
  404. sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
  405. (pipe / (group_blocks / thread_k_blocks)));
  406. reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
  407. }
  408. int4* sh_a_stage = sh_a + a_sh_stage * pipe;
  409. #pragma unroll
  410. for (int i = 0; i < thread_m_blocks; i++) {
  411. ldsm4(frag_a[k % 2][i][0],
  412. &sh_a_stage[a_sh_rd_trans[0][k % b_sh_wr_iters][i]]);
  413. ldsm4(frag_a[k % 2][i][1],
  414. &sh_a_stage[a_sh_rd_trans[1][k % b_sh_wr_iters][i]]);
  415. }
  416. int4* sh_b_stage = sh_b + b_sh_stage * pipe;
  417. #pragma unroll
  418. for (int i = 0; i < b_thread_vecs; i++) {
  419. frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>(
  420. &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]);
  421. }
  422. // Load meta with ldsm4
  423. int4* sh_m_stage = sh_m + m_sh_stage * pipe;
  424. ldsm4_m(frag_m[k % 2][0],
  425. &sh_m_stage[m_sh_rd_delta * (k % m_sh_iters) + m_sh_rd]);
  426. };
  427. // Execute the actual tensor core matmul of a sub-tile.
  428. auto matmul = [&](int k) {
  429. // We have the m dimension as the inner loop in order to encourage overlapping
  430. // dequantization and matmul operations.
  431. #pragma unroll
  432. for (int j = 0; j < 4; j++) {
  433. FragB frag_b0;
  434. FragB frag_b1;
  435. if constexpr (num_bits == 4) {
  436. int b_quant = frag_b_quant[k % 2][0][j];
  437. int b_quant_shift = b_quant >> 8;
  438. frag_b0 = dequant_4bit(b_quant);
  439. frag_b1 = dequant_4bit(b_quant_shift);
  440. } else {
  441. int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);
  442. int b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
  443. int b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
  444. frag_b0 = dequant_8bit(b_quant_0);
  445. frag_b1 = dequant_8bit(b_quant_1);
  446. }
  447. // If there are no groups, we can just scale the final output once and can
  448. // avoid doing so for each weight.
  449. if constexpr (group_blocks != -1) {
  450. scale(frag_b0, frag_s[k % 2][j], 0);
  451. }
  452. if constexpr (group_blocks != -1) {
  453. scale(frag_b1, frag_s[k % 2][j], 1);
  454. }
  455. #pragma unroll
  456. for (int i = 0; i < thread_m_blocks; i++) {
  457. mma_sp(frag_b0, frag_b1, frag_a[k % 2][i][0], frag_c[i][j][0],
  458. frag_m[k % 2][j / 2], j % 2);
  459. }
  460. }
  461. };
  462. // Since we slice across the k dimension of a tile in order to increase the
  463. // number of warps while keeping the n dimension of a tile reasonable, we have
  464. // multiple warps that accumulate their partial sums of the same output
  465. // location; which we have to reduce over in the end. We do in shared memory.
  466. auto thread_block_reduce = [&]() {
  467. constexpr int red_off = threads / b_sh_stride_threads / 2;
  468. if (red_off >= 1) {
  469. int red_idx = threadIdx.x / b_sh_stride_threads;
  470. constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
  471. constexpr int red_sh_delta = b_sh_stride_threads;
  472. int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
  473. (threadIdx.x % b_sh_stride_threads);
  474. // Parallel logarithmic shared memory reduction. We make sure to avoid any
  475. // unnecessary read or write iterations, e.g., for two warps we write only
  476. // once by warp 1 and read only once by warp 0.
  477. #pragma unroll
  478. for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
  479. #pragma unroll
  480. for (int i = red_off; i > 0; i /= 2) {
  481. if (i <= red_idx && red_idx < 2 * i) {
  482. #pragma unroll
  483. for (int j = 0; j < 4 * 2; j++) {
  484. int red_sh_wr =
  485. red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
  486. if (i < red_off) {
  487. float* c_rd =
  488. reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
  489. float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
  490. #pragma unroll
  491. for (int k = 0; k < 4; k++)
  492. reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
  493. c_rd[k] + c_wr[k];
  494. }
  495. sh[red_sh_wr] =
  496. reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
  497. }
  498. }
  499. __syncthreads();
  500. }
  501. if (red_idx == 0) {
  502. #pragma unroll
  503. for (int i = 0; i < 4 * 2; i++) {
  504. float* c_rd =
  505. reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
  506. #pragma unroll
  507. for (int j = 0; j < 4; j++)
  508. reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
  509. c_rd[j];
  510. }
  511. }
  512. __syncthreads();
  513. }
  514. }
  515. };
  516. // Since multiple threadblocks may process parts of the same column slice, we
  517. // finally have to globally reduce over the results. As the striped
  518. // partitioning minimizes the number of such reductions and our outputs are
  519. // usually rather small, we perform this reduction serially in L2 cache.
  520. auto global_reduce = [&](bool first = false, bool last = false) {
  521. // We are very careful here to reduce directly in the output buffer to
  522. // maximize L2 cache utilization in this step. To do this, we write out
  523. // results in FP16 (but still reduce with FP32 compute).
  524. constexpr int active_threads = 32 * thread_n_blocks / 4;
  525. if (threadIdx.x < active_threads) {
  526. int c_gl_stride = prob_n / 8;
  527. int c_gl_wr_delta_o = 2 * 4 * c_gl_stride;
  528. int c_gl_wr_delta_i =
  529. c_gl_stride; // 8 threads (e.g., 0,4,8,12,16,20,24,28)
  530. int c_gl_wr = 2 * c_gl_stride * (threadIdx.x % 4) +
  531. 8 * (threadIdx.x / 32) + (threadIdx.x % 32) / 4;
  532. c_gl_wr += (2 * thread_n_blocks) * slice_col;
  533. constexpr int c_sh_wr_delta = active_threads;
  534. int c_sh_wr = threadIdx.x;
  535. int col = 2 * ((threadIdx.x % 32) % 4);
  536. if (!first) {
  537. // Interestingly, doing direct global accesses here really seems to mess up
  538. // the compiler and lead to slowdowns, hence we also use async-copies even
  539. // though these fetches are not actually asynchronous.
  540. #pragma unroll
  541. for (int i = 0; i < thread_m_blocks * 4; i++) {
  542. cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i],
  543. &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
  544. c_gl_wr_delta_i * (i % 2)],
  545. i < (thread_m_blocks - 1) * 4 ||
  546. 8 * (i / 2) + col + (i % 2) < prob_m);
  547. }
  548. cp_async_fence();
  549. cp_async_wait<0>();
  550. }
  551. #pragma unroll
  552. for (int i = 0; i < thread_m_blocks * 4; i++) {
  553. if (i < (thread_m_blocks - 1) * 4 ||
  554. 8 * (i / 2) + col + (i % 2) < prob_m) {
  555. if (!first) {
  556. int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
  557. #pragma unroll
  558. for (int j2 = 0; j2 < 2; j2++) {
  559. #pragma unroll
  560. for (int j1 = 0; j1 < 4; j1++) {
  561. reinterpret_cast<float*>(
  562. &frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 +
  563. 4 * ((i % 4) / 2) + i % 2] +=
  564. __half2float(
  565. reinterpret_cast<__half*>(&c_red)[(j2 * 4 + j1)]);
  566. }
  567. }
  568. }
  569. if (!last) {
  570. int4 c;
  571. #pragma unroll
  572. for (int j2 = 0; j2 < 2; j2++) {
  573. #pragma unroll
  574. for (int j1 = 0; j1 < 4; j1++) {
  575. reinterpret_cast<__half*>(&c)[(j2 * 4 + j1)] =
  576. __float2half(reinterpret_cast<float*>(
  577. &frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 +
  578. 4 * ((i % 4) / 2) + i % 2]);
  579. }
  580. }
  581. C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =
  582. c;
  583. }
  584. }
  585. }
  586. }
  587. };
  588. // Write out the reduce final result in the correct layout. We only actually
  589. // reshuffle matrix fragments in this step, the reduction above is performed
  590. // in fragment layout.
  591. auto write_result = [&]() {
  592. int c_gl_stride = prob_n / 8;
  593. constexpr int c_sh_stride = 2 * thread_n_blocks; // RLC:
  594. constexpr int c_sh_stride_2 = 2 * c_sh_stride + 2; // RLC:
  595. constexpr int c_sh_stride_3 = 2 * (2 * thread_n_blocks) + 2; // RLC:
  596. int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
  597. int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) +
  598. (threadIdx.x % (2 * thread_n_blocks));
  599. c_gl_wr += (2 * thread_n_blocks) * slice_col;
  600. int c_sh_wr = c_sh_stride_2 * ((threadIdx.x % 32) % 4) +
  601. ((threadIdx.x % 32) / 4); // RLC:
  602. c_sh_wr += 8 * (threadIdx.x / 32); // 128/4(half4)
  603. constexpr int c_sh_rd_delta =
  604. c_sh_stride_3 * (threads / (2 * 2 * thread_n_blocks)); // RLC:
  605. int c_sh_rd = c_sh_stride_3 * (threadIdx.x / (2 * 2 * thread_n_blocks)) +
  606. (threadIdx.x % (2 * 2 * thread_n_blocks));
  607. int c_gl_wr_end = c_gl_stride * prob_m;
  608. auto write = [&](int idx, float c0, float c1, float c2, float c3, FragS& s0,
  609. float c4, float c5, float c6, float c7, FragS& s1) {
  610. uint2 res[2];
  611. res[0] = to_half4(c0, c1, c2, c3);
  612. res[1] = to_half4(c4, c5, c6, c7);
  613. half2* tmp = (half2*)&res;
  614. // for per-column quantization we finally apply the scale here
  615. if constexpr (group_blocks == -1 && num_bits == 4) {
  616. tmp[0] = __hmul2(tmp[0], s0[0]);
  617. tmp[1] = __hmul2(tmp[1], s0[1]);
  618. tmp[2] = __hmul2(tmp[2], s1[0]);
  619. tmp[3] = __hmul2(tmp[3], s1[1]);
  620. }
  621. ((int4*)sh)[idx] = *((int4*)&res[0]);
  622. };
  623. // RLC: only warp 0 and 1 baseline example
  624. if (threadIdx.x / 32 < thread_n_blocks / 4) {
  625. #pragma unroll
  626. for (int i = 0; i < thread_m_blocks; i++) {
  627. int wr = c_sh_wr;
  628. write(wr, frag_c[i][0][0][0], frag_c[i][1][0][0], frag_c[i][2][0][0],
  629. frag_c[i][3][0][0], frag_s[0][0], frag_c[i][0][0][2],
  630. frag_c[i][1][0][2], frag_c[i][2][0][2], frag_c[i][3][0][2],
  631. frag_s[0][2]);
  632. write(wr + c_sh_stride, frag_c[i][0][0][1], frag_c[i][1][0][1],
  633. frag_c[i][2][0][1], frag_c[i][3][0][1], frag_s[0][0],
  634. frag_c[i][0][0][3], frag_c[i][1][0][3], frag_c[i][2][0][3],
  635. frag_c[i][3][0][3], frag_s[0][2]);
  636. write(wr + 4 * c_sh_stride_2, frag_c[i][0][1][0], frag_c[i][1][1][0],
  637. frag_c[i][2][1][0], frag_c[i][3][1][0], frag_s[0][0],
  638. frag_c[i][0][1][2], frag_c[i][1][1][2], frag_c[i][2][1][2],
  639. frag_c[i][3][1][2], frag_s[0][2]);
  640. write(wr + 4 * c_sh_stride_2 + c_sh_stride, frag_c[i][0][1][1],
  641. frag_c[i][1][1][1], frag_c[i][2][1][1], frag_c[i][3][1][1],
  642. frag_s[0][0], frag_c[i][0][1][3], frag_c[i][1][1][3],
  643. frag_c[i][2][1][3], frag_c[i][3][1][3], frag_s[0][2]);
  644. c_sh_wr += 8 * c_sh_stride_2;
  645. }
  646. }
  647. __syncthreads();
  648. #pragma unroll
  649. for (int i = 0;
  650. i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
  651. i++) {
  652. if (c_gl_wr < c_gl_wr_end) {
  653. C[c_gl_wr] = sh[c_sh_rd];
  654. c_gl_wr += c_gl_wr_delta;
  655. c_sh_rd += c_sh_rd_delta;
  656. }
  657. }
  658. };
  659. // Start global fetch and register load pipelines.
  660. auto start_pipes = [&]() {
  661. #pragma unroll
  662. for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters);
  663. zero_accums();
  664. wait_for_stage();
  665. fetch_to_registers(0, 0);
  666. a_gl_rd += a_gl_rd_delta_o * (stages - 1);
  667. };
  668. start_pipes();
  669. // Main loop.
  670. while (slice_iters) {
  671. // We unroll over both the global fetch and the register load pipeline to
  672. // ensure all shared memory accesses are static. Note that both pipelines have
  673. // even length meaning that the next iteration will always start at index 0.
  674. #pragma unroll
  675. for (int pipe = 0; pipe < stages;) {
  676. fetch_to_shared((pipe + stages - 1) % stages, pipe,
  677. slice_iters >= stages);
  678. matmul(pipe);
  679. wait_for_stage();
  680. fetch_to_registers(pipe + 1, (pipe + 1) % stages);
  681. pipe++;
  682. slice_iters--;
  683. if (slice_iters == 0) break;
  684. }
  685. a_gl_rd += a_gl_rd_delta_o * stages;
  686. // Process results and, if necessary, proceed to the next column slice.
  687. // While this pattern may not be the most readable, other ways of writing
  688. // the loop seemed to noticeably worse performance after compilation.
  689. if (slice_iters == 0) {
  690. cp_async_wait<0>();
  691. bool last = slice_idx == slice_count - 1;
  692. // For per-column scales, we only fetch them here in the final step before
  693. // write-out
  694. if constexpr (group_blocks == -1) {
  695. if constexpr (num_bits == 8) {
  696. if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
  697. cp_async_fence();
  698. } else {
  699. if (last) {
  700. if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
  701. cp_async_fence();
  702. }
  703. }
  704. }
  705. thread_block_reduce();
  706. if constexpr (group_blocks == -1) {
  707. if constexpr (num_bits == 8) {
  708. cp_async_wait<0>();
  709. __syncthreads();
  710. if (threadIdx.x / 32 < thread_n_blocks / 4) {
  711. *(float4*)(frag_s) = *(float4*)(&sh_s[s_sh_rd]);
  712. }
  713. } else {
  714. if (last) {
  715. cp_async_wait<0>();
  716. __syncthreads();
  717. if (threadIdx.x / 32 < thread_n_blocks / 4) {
  718. *(float4*)(frag_s) = *(float4*)(&sh_s[s_sh_rd]);
  719. }
  720. }
  721. }
  722. }
  723. // For 8-bit channelwise, we apply the scale before the global reduction
  724. // that converts the fp32 results to fp16 (so that we avoid possible
  725. // overflow in fp16)
  726. if constexpr (group_blocks == -1 && num_bits == 8) {
  727. if (threadIdx.x / 32 < thread_n_blocks / 4) {
  728. #pragma unroll
  729. for (int i = 0; i < thread_m_blocks; i++) {
  730. scale_floats(&frag_c[i][0][0][0], &frag_c[i][1][0][0],
  731. &frag_c[i][2][0][0], &frag_c[i][3][0][0], frag_s[0][0],
  732. &frag_c[i][0][0][2], &frag_c[i][1][0][2],
  733. &frag_c[i][2][0][2], &frag_c[i][3][0][2],
  734. frag_s[0][2]);
  735. scale_floats(&frag_c[i][0][0][1], &frag_c[i][1][0][1],
  736. &frag_c[i][2][0][1], &frag_c[i][3][0][1], frag_s[0][0],
  737. &frag_c[i][0][0][3], &frag_c[i][1][0][3],
  738. &frag_c[i][2][0][3], &frag_c[i][3][0][3],
  739. frag_s[0][2]);
  740. scale_floats(&frag_c[i][0][1][0], &frag_c[i][1][1][0],
  741. &frag_c[i][2][1][0], &frag_c[i][3][1][0], frag_s[0][0],
  742. &frag_c[i][0][1][2], &frag_c[i][1][1][2],
  743. &frag_c[i][2][1][2], &frag_c[i][3][1][2],
  744. frag_s[0][2]);
  745. scale_floats(&frag_c[i][0][1][1], &frag_c[i][1][1][1],
  746. &frag_c[i][2][1][1], &frag_c[i][3][1][1], frag_s[0][0],
  747. &frag_c[i][0][1][3], &frag_c[i][1][1][3],
  748. &frag_c[i][2][1][3], &frag_c[i][3][1][3],
  749. frag_s[0][2]);
  750. }
  751. }
  752. }
  753. if (slice_count > 1) { // only globally reduce if there is more than one
  754. // block in a slice
  755. barrier_acquire(&locks[slice_col], slice_idx);
  756. global_reduce(slice_idx == 0, last);
  757. barrier_release(&locks[slice_col], last);
  758. }
  759. if (last) // only the last block in a slice actually writes the result
  760. write_result();
  761. slice_row = 0;
  762. slice_col_par++;
  763. slice_col++;
  764. init_slice();
  765. if (slice_iters) {
  766. a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
  767. (threadIdx.x % a_gl_rd_delta_o);
  768. #pragma unroll
  769. for (int i = 0; i < b_sh_wr_iters; i++)
  770. B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
  771. #pragma unroll
  772. for (int i = 0; i < m_sh_iters; i++)
  773. meta_ptr[i] += (m_sh_stride)-m_gl_rd_delta_o * k_tiles;
  774. if (slice_col == 0) {
  775. #pragma unroll
  776. for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
  777. #pragma unroll
  778. for (int i = 0; i < m_sh_iters; i++) meta_ptr[i] -= m_gl_stride;
  779. }
  780. s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
  781. start_pipes();
  782. }
  783. }
  784. }
  785. }
  786. #endif
  787. #define CALL_IF_2_4(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
  788. THREAD_K_BLOCKS, GROUP_BLOCKS) \
  789. else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
  790. thread_n_blocks == THREAD_N_BLOCKS && \
  791. thread_k_blocks == THREAD_K_BLOCKS && \
  792. group_blocks == GROUP_BLOCKS) { \
  793. cudaFuncSetAttribute( \
  794. Marlin_24<NUM_BITS, THREADS, THREAD_N_BLOCKS, THREAD_M_BLOCKS, \
  795. THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>, \
  796. cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
  797. Marlin_24<NUM_BITS, THREADS, THREAD_N_BLOCKS, THREAD_M_BLOCKS, \
  798. THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS> \
  799. <<<blocks, THREADS, max_shared_mem, stream>>>(A_ptr, B_ptr, meta_ptr, \
  800. C_ptr, s_ptr, prob_n, \
  801. prob_m, prob_k, locks); \
  802. }
  803. void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
  804. void* s, int prob_m, int prob_n, int prob_k,
  805. void* workspace, int num_bits, int groupsize = -1,
  806. int dev = 0, cudaStream_t stream = 0, int thread_k = -1,
  807. int thread_m = -1, int sms = -1, int max_par = 16) {
  808. int tot_n = prob_n;
  809. int tot_n_blocks = ceildiv(tot_n, 16);
  810. int pad = 16 * tot_n_blocks - tot_n;
  811. if (sms == -1) {
  812. cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
  813. }
  814. TORCH_CHECK(sms > 0);
  815. int max_shared_mem = 0;
  816. cudaDeviceGetAttribute(&max_shared_mem,
  817. cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
  818. TORCH_CHECK(max_shared_mem > 0);
  819. if (thread_k == -1 || thread_m == -1) {
  820. if (prob_n <= 16) {
  821. // For small batchizes, better partitioningif is slightly more important
  822. // than better compute utilization
  823. thread_k = 128;
  824. thread_m = 128;
  825. } else if (prob_n <= 256) {
  826. thread_k = 64;
  827. thread_m = 256;
  828. } else {
  829. thread_k = 32;
  830. thread_m = 512;
  831. }
  832. }
  833. int thread_k_blocks = thread_k / 32; // 2:4 version with m16n8k32 instruction
  834. int thread_m_blocks = thread_m / 16;
  835. int group_blocks = (groupsize == -1) ? -1 : groupsize / 16;
  836. int blocks = sms;
  837. TORCH_CHECK(prob_m % thread_m == 0, "prob_m = ", prob_m,
  838. " is not divisible by thread_m = ", thread_m);
  839. TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
  840. " is not divisible by thread_k = ", thread_k);
  841. if (group_blocks != -1) {
  842. TORCH_CHECK((prob_k / 2) % group_blocks == 0, "prob_k/2 = ", prob_k / 2,
  843. " is not divisible by group_blocks = ", group_blocks);
  844. }
  845. TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
  846. ", ", prob_n, ", ", prob_k, "]");
  847. const int4* A_ptr = (const int4*)A;
  848. const int4* B_ptr = (const int4*)B;
  849. const int4* meta_ptr = (const int4*)meta;
  850. int4* C_ptr = (int4*)C;
  851. const int4* s_ptr = (const int4*)s;
  852. constexpr int max_m_blocks = 4;
  853. int* locks = (int*)workspace;
  854. for (int i = 0; i < tot_n_blocks; i += max_m_blocks) {
  855. int thread_n_blocks = tot_n_blocks - i;
  856. prob_n = tot_n - 16 * i;
  857. int par = 1;
  858. if (thread_n_blocks > max_m_blocks) {
  859. // Note that parallel > 1 currently only works for inputs without any
  860. // padding
  861. par = (16 * thread_n_blocks - pad) / (max_m_blocks * 16);
  862. if (par > max_par) par = max_par;
  863. prob_n = (max_m_blocks * 16) * par;
  864. i += max_m_blocks * (par - 1);
  865. thread_n_blocks = max_m_blocks;
  866. }
  867. // For compilation speed, we only define the kernel configurations that have
  868. // seemed useful (in terms of performance) in our testing, however many more
  869. // are, in principle, possible.
  870. // the false is start of the CALL_IF macros
  871. if (false) {
  872. } // BMxBNxBK, group
  873. // 4-bit
  874. CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128
  875. CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64
  876. CALL_IF_2_4(4, 16, 1, 2, -1) // e.g., 16x256x64
  877. CALL_IF_2_4(4, 16, 1, 2, 4) // e.g., 16x256x64, 64
  878. CALL_IF_2_4(4, 16, 2, 2, -1) // e.g.. 32x256x64
  879. CALL_IF_2_4(4, 16, 2, 2, 4)
  880. CALL_IF_2_4(4, 16, 3, 2, -1)
  881. CALL_IF_2_4(4, 16, 3, 2, 4)
  882. CALL_IF_2_4(4, 16, 4, 2, -1)
  883. CALL_IF_2_4(4, 16, 4, 2, 4)
  884. CALL_IF_2_4(4, 32, 1, 1, -1) // e.g., 16x256x64
  885. CALL_IF_2_4(4, 32, 1, 1, 4) // e.g., 16x256x64, 64
  886. CALL_IF_2_4(4, 32, 2, 1, -1) // e.g.. 32x256x64
  887. CALL_IF_2_4(4, 32, 2, 1, 4)
  888. CALL_IF_2_4(4, 32, 3, 1, -1)
  889. CALL_IF_2_4(4, 32, 3, 1, 4)
  890. CALL_IF_2_4(4, 32, 4, 1, -1)
  891. CALL_IF_2_4(4, 32, 4, 1, 4)
  892. // 8-bit
  893. CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128
  894. CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64
  895. CALL_IF_2_4(8, 16, 1, 2, -1) // e.g., 16x256x64
  896. CALL_IF_2_4(8, 16, 1, 2, 4) // e.g., 16x256x64, 64
  897. CALL_IF_2_4(8, 16, 2, 2, -1) // e.g.. 32x256x64
  898. CALL_IF_2_4(8, 16, 2, 2, 4)
  899. CALL_IF_2_4(8, 16, 3, 2, -1)
  900. CALL_IF_2_4(8, 16, 3, 2, 4)
  901. CALL_IF_2_4(8, 16, 4, 2, -1)
  902. CALL_IF_2_4(8, 16, 4, 2, 4)
  903. CALL_IF_2_4(8, 32, 1, 1, -1) // e.g., 16x256x64
  904. CALL_IF_2_4(8, 32, 1, 1, 4) // e.g., 16x256x64, 64
  905. CALL_IF_2_4(8, 32, 2, 1, -1) // e.g.. 32x256x64
  906. CALL_IF_2_4(8, 32, 2, 1, 4)
  907. CALL_IF_2_4(8, 32, 3, 1, -1)
  908. CALL_IF_2_4(8, 32, 3, 1, 4)
  909. CALL_IF_2_4(8, 32, 4, 1, -1)
  910. CALL_IF_2_4(8, 32, 4, 1, 4)
  911. else {
  912. throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) +
  913. ", " + str(prob_k) + ", " + str(prob_n) + "]" +
  914. ", groupsize = " + str(groupsize) +
  915. ", thread_m_blocks = " + str(thread_m_blocks) +
  916. ", thread_n_blocks = " + str(thread_n_blocks) +
  917. ", thread_k_blocks = " + str(thread_k_blocks));
  918. }
  919. A_ptr += 16 * thread_n_blocks * (prob_k / 8) * par;
  920. C_ptr += 16 * thread_n_blocks * (prob_m / 8) * par;
  921. }
  922. }
  923. } // namespace marlin_24
  924. torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
  925. torch::Tensor& b_meta,
  926. torch::Tensor& b_scales,
  927. torch::Tensor& workspace,
  928. aphrodite::ScalarTypeTorchPtr const& b_q_type,
  929. int64_t size_m, int64_t size_n,
  930. int64_t size_k) {
  931. // Verify num_bits
  932. TORCH_CHECK(*b_q_type == aphrodite::kU4B8 || *b_q_type == aphrodite::kU8B128,
  933. "num_bits must be uint4b8 or uint8b128. Got = ", b_q_type->str());
  934. int pack_factor = 32 / b_q_type->size_bits();
  935. // Verify M
  936. TORCH_CHECK(size_m == a.size(0),
  937. "Shape mismatch: a.size(0) = " + str(a.size(0)) +
  938. ", size_m = " + str(size_m));
  939. // Verify K
  940. TORCH_CHECK(size_k == a.size(1),
  941. "Shape mismatch: a.size(1) = " + str(a.size(1)) +
  942. ", size_k = " + str(size_k));
  943. TORCH_CHECK(size_k % marlin_24::tile_size == 0,
  944. "size_k = " + str(size_k) + " is not divisible by tile_size = " +
  945. str(marlin_24::tile_size));
  946. TORCH_CHECK((size_k / marlin_24::tile_size / 2) == b_q_weight.size(0),
  947. "Shape mismatch: b_q_weight.size(0) = " +
  948. str(b_q_weight.size(0)) + ", size_k = " + str(size_k) +
  949. ", tile_size = " + str(marlin_24::tile_size));
  950. // Verify N
  951. TORCH_CHECK(b_scales.size(1) == size_n,
  952. "b_scales.size(1) = " + str(b_scales.size(1)) +
  953. ", size_n = " + str(size_n));
  954. TORCH_CHECK(
  955. b_q_weight.size(1) % marlin_24::tile_size == 0,
  956. "b_q_weight.size(1) = " + str(b_q_weight.size(1)) +
  957. " is not divisible by tile_size = " + str(marlin_24::tile_size));
  958. int actual_size_n = (b_q_weight.size(1) / marlin_24::tile_size) * pack_factor;
  959. TORCH_CHECK(
  960. size_n == actual_size_n,
  961. "size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n));
  962. // Verify meta
  963. TORCH_CHECK(b_meta.size(0) == size_k / 8 / 2 / 2,
  964. "b_meta.size(0) = ", b_meta.size(0),
  965. " is not size_k / 8 / 2 / 2 = ", size_k / 8 / 2 / 2);
  966. TORCH_CHECK(b_meta.size(1) == size_n * 2, "b_meta.size(1) = ", b_meta.size(1),
  967. " is not size_n * 2 = ", size_n * 2);
  968. // Verify A device and strides
  969. TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
  970. TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
  971. // Verify B device and strides
  972. TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
  973. TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
  974. // Verify b_meta device and strides
  975. TORCH_CHECK(b_meta.device().is_cuda(), "b_meta is not on GPU");
  976. TORCH_CHECK(b_meta.is_contiguous(), "b_meta is not contiguous");
  977. // Verify scales device and strides
  978. TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
  979. TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
  980. // Alloc C matrix
  981. const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
  982. auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
  983. torch::Tensor c = torch::empty({size_m, size_n}, options);
  984. int thread_k = -1;
  985. int thread_m = -1;
  986. int sms = -1;
  987. int max_par = marlin_24::max_par;
  988. int groupsize = -1;
  989. if (b_scales.size(0) > 1) {
  990. TORCH_CHECK(size_k % b_scales.size(0) == 0,
  991. "size_k = " + str(size_k) +
  992. ", is not divisible by b_scales.size(0) = " +
  993. str(b_scales.size(0)));
  994. groupsize = size_k / b_scales.size(0);
  995. groupsize /= 2; // Because of 24
  996. }
  997. // Verify groupsize
  998. TORCH_CHECK(groupsize == -1 || groupsize == 64,
  999. "Unexpected groupsize = " + str(groupsize));
  1000. // Verify workspace size
  1001. TORCH_CHECK(size_n % marlin_24::min_thread_n == 0,
  1002. "size_n = " + str(size_n) +
  1003. ", is not divisible by min_thread_n = " +
  1004. str(marlin_24::min_thread_n));
  1005. int min_workspace_size =
  1006. (size_n / marlin_24::min_thread_n) * marlin_24::max_par;
  1007. TORCH_CHECK(workspace.numel() >= min_workspace_size,
  1008. "workspace.numel = " + str(workspace.numel()) +
  1009. " is below min_workspace_size = " + str(min_workspace_size));
  1010. int dev = a.get_device();
  1011. marlin_24::marlin_cuda_2_4(
  1012. a.data_ptr(), b_q_weight.data_ptr(), b_meta.data_ptr(), c.data_ptr(),
  1013. b_scales.data_ptr(), size_n, size_m, size_k, workspace.data_ptr(),
  1014. b_q_type->size_bits(), groupsize, dev,
  1015. at::cuda::getCurrentCUDAStream(dev), thread_k, thread_m, sms, max_par);
  1016. return c;
  1017. }