fp8_marlin.cu 50 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305
  1. /*
  2. * Modified by Neural Magic
  3. * Copyright (C) Marlin.2024 Elias Frantar
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. /*
  18. * Adapted from https://github.com/IST-DASLab/marlin
  19. */
  20. #include "../gptq_marlin/marlin.cuh"
  21. #include "../gptq_marlin/marlin_dtypes.cuh"
  22. using namespace marlin;
  23. #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
  24. static_assert(std::is_same<scalar_t, half>::value || \
  25. std::is_same<scalar_t, nv_bfloat16>::value, \
  26. "only float16 and bfloat16 is supported");
  27. template <typename T>
  28. inline std::string str(T x) {
  29. return std::to_string(x);
  30. }
  31. namespace fp8_marlin {
  32. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  33. template <typename scalar_t, // compute dtype, half or nv_float16
  34. const int num_bits, // number of bits used for weights
  35. const int threads, // number of threads in a threadblock
  36. const int thread_m_blocks, // number of 16x16 blocks in the m
  37. // dimension (batchsize) of the
  38. // threadblock
  39. const int thread_n_blocks, // same for n dimension (output)
  40. const int thread_k_blocks, // same for k dimension (reduction)
  41. const int stages, // number of stages for the async global->shared
  42. // fetch pipeline
  43. const int group_blocks = -1 // number of consecutive 16x16 blocks
  44. // with a separate quantization scale
  45. >
  46. __global__ void Marlin(
  47. const int4* __restrict__ A, // fp16 input matrix of shape mxk
  48. const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
  49. int4* __restrict__ C, // fp16 output buffer of shape mxn
  50. const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
  51. // (k/groupsize)xn
  52. int num_groups, // number of scale groups per output channel
  53. int prob_m, // batch dimension m
  54. int prob_n, // output dimension n
  55. int prob_k, // reduction dimension k
  56. int* locks // extra global storage for barrier synchronization
  57. ) {}
  58. } // namespace fp8_marlin
  59. torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
  60. torch::Tensor& b_scales, torch::Tensor& workspace,
  61. int64_t num_bits, int64_t size_m, int64_t size_n,
  62. int64_t size_k) {
  63. TORCH_CHECK_NOT_IMPLEMENTED(false,
  64. "marlin_gemm(..) requires CUDA_ARCH >= 8.0");
  65. return torch::empty({1, 1});
  66. }
  67. #else
  68. // m16n8k16 tensor core mma instruction with fp16 inputs and fp32
  69. // output/accumulation.
  70. template <typename scalar_t>
  71. __device__ inline void mma(const typename ScalarType<scalar_t>::FragA& a_frag,
  72. const typename ScalarType<scalar_t>::FragB& frag_b,
  73. typename ScalarType<scalar_t>::FragC& frag_c) {
  74. const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
  75. const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
  76. float* c = reinterpret_cast<float*>(&frag_c);
  77. if constexpr (std::is_same<scalar_t, half>::value) {
  78. asm volatile(
  79. "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
  80. "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
  81. : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
  82. : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
  83. "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
  84. } else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
  85. asm volatile(
  86. "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
  87. "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
  88. : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
  89. : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
  90. "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
  91. } else {
  92. STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
  93. }
  94. }
  95. // Instruction for loading a full 16x16 matrix fragment of operand A from shared
  96. // memory, directly in tensor core layout.
  97. template <typename scalar_t>
  98. __device__ inline void ldsm4(typename ScalarType<scalar_t>::FragA& frag_a,
  99. const void* smem_ptr) {
  100. uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
  101. uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
  102. asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
  103. : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
  104. : "r"(smem));
  105. }
  106. // Fast FP8ToFp16/FP8ToBf16: Efficiently dequantize 8bit fp8_e4m3 values to fp16
  107. // bf16 Reference:
  108. // - FP16:
  109. // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
  110. // - BF16:
  111. // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
  112. template <typename scalar_t>
  113. __device__ inline typename ScalarType<scalar_t>::FragB dequant_8bit(int q) {
  114. STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
  115. }
  116. template <>
  117. __device__ inline typename ScalarType<half>::FragB dequant_8bit<half>(int q) {
  118. // Constants for FP8 (E4M3) and FP16 formats
  119. constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, FP16_EXPONENT = 5;
  120. constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT;
  121. // Calculate MASK for extracting mantissa and exponent
  122. constexpr int MASK1 = 0x80000000;
  123. constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA);
  124. constexpr int MASK3 = MASK2 & 0x7fffffff;
  125. constexpr int MASK = MASK3 | (MASK3 >> 16);
  126. // Final MASK value: 0x7F007F00
  127. // Extract and shift FP8 values to FP16 format
  128. int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
  129. int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT);
  130. // Construct and apply exponent bias
  131. constexpr int BIAS_OFFSET =
  132. (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));
  133. const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET));
  134. // Convert to half2 and apply bias
  135. typename ScalarType<half>::FragB frag_b;
  136. // Note: reverse indexing is intentional because weights are permuted
  137. frag_b[1] = __hmul2(*reinterpret_cast<const half2*>(&Out1), bias_reg);
  138. frag_b[0] = __hmul2(*reinterpret_cast<const half2*>(&Out2), bias_reg);
  139. return frag_b;
  140. }
  141. template <>
  142. __device__ inline typename ScalarType<nv_bfloat16>::FragB
  143. dequant_8bit<nv_bfloat16>(int q) {
  144. // Constants for FP8 (E4M3) and BF16 formats
  145. constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, BF16_EXPONENT = 8;
  146. constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;
  147. // Calculate MASK for extracting mantissa and exponent
  148. constexpr int MASK1 = 0x80000000;
  149. constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA);
  150. constexpr int MASK3 = MASK2 & 0x7fffffff;
  151. constexpr int MASK = MASK3 | (MASK3 >> 16);
  152. // Final MASK value: 0x7F007F00
  153. // Extract and shift FP8 values to BF16 format
  154. int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
  155. int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT);
  156. // Construct and apply exponent bias
  157. constexpr int BIAS_OFFSET =
  158. (1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));
  159. // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent
  160. // position
  161. constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23;
  162. const nv_bfloat162 bias_reg =
  163. __float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS));
  164. // Convert to bfloat162 and apply bias
  165. typename ScalarType<nv_bfloat16>::FragB frag_b;
  166. // Note: reverse indexing is intentional because weights are permuted
  167. frag_b[1] = __hmul2(*reinterpret_cast<const nv_bfloat162*>(&Out1), bias_reg);
  168. frag_b[0] = __hmul2(*reinterpret_cast<const nv_bfloat162*>(&Out2), bias_reg);
  169. return frag_b;
  170. }
  171. // Multiply dequantized values by the corresponding quantization scale; used
  172. // only for grouped quantization.
  173. template <typename scalar_t>
  174. __device__ inline void scale(typename ScalarType<scalar_t>::FragB& frag_b,
  175. typename ScalarType<scalar_t>::FragS& frag_s,
  176. int i) {
  177. using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
  178. scalar_t2 s =
  179. ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_s)[i]);
  180. frag_b[0] = __hmul2(frag_b[0], s);
  181. frag_b[1] = __hmul2(frag_b[1], s);
  182. }
  183. // Given 2 floats multiply by 2 scales (halves)
  184. template <typename scalar_t>
  185. __device__ inline void scale_float(float* c,
  186. typename ScalarType<scalar_t>::FragS& s) {
  187. scalar_t* s_ptr = reinterpret_cast<scalar_t*>(&s);
  188. c[0] = __fmul_rn(c[0], ScalarType<scalar_t>::num2float(s_ptr[0]));
  189. c[1] = __fmul_rn(c[1], ScalarType<scalar_t>::num2float(s_ptr[1]));
  190. }
  191. // Wait until barrier reaches `count`, then lock for current threadblock.
  192. __device__ inline void barrier_acquire(int* lock, int count) {
  193. if (threadIdx.x == 0) {
  194. int state = -1;
  195. do
  196. // Guarantee that subsequent writes by this threadblock will be visible
  197. // globally.
  198. asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
  199. : "=r"(state)
  200. : "l"(lock));
  201. while (state != count);
  202. }
  203. __syncthreads();
  204. }
  205. // Release barrier and increment visitation count.
  206. __device__ inline void barrier_release(int* lock, bool reset = false) {
  207. __syncthreads();
  208. if (threadIdx.x == 0) {
  209. if (reset) {
  210. lock[0] = 0;
  211. return;
  212. }
  213. int val = 1;
  214. // Make sure that all writes since acquiring this barrier are visible
  215. // globally, while releasing the barrier.
  216. asm volatile("fence.acq_rel.gpu;\n");
  217. asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
  218. :
  219. : "l"(lock), "r"(val));
  220. }
  221. }
  222. template <typename scalar_t, // compute dtype, half or nv_float16
  223. const int num_bits, // number of bits used for weights
  224. const int threads, // number of threads in a threadblock
  225. const int thread_m_blocks, // number of 16x16 blocks in the m
  226. // dimension (batchsize) of the
  227. // threadblock
  228. const int thread_n_blocks, // same for n dimension (output)
  229. const int thread_k_blocks, // same for k dimension (reduction)
  230. const int stages, // number of stages for the async global->shared
  231. // fetch pipeline
  232. const int group_blocks = -1 // number of consecutive 16x16 blocks
  233. // with a separate quantization scale
  234. >
  235. __global__ void Marlin(
  236. const int4* __restrict__ A, // fp16 input matrix of shape mxk
  237. const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
  238. int4* __restrict__ C, // fp16 output buffer of shape mxn
  239. const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
  240. // (k/groupsize)xn
  241. int num_groups, // number of scale groups per output channel
  242. int prob_m, // batch dimension m
  243. int prob_n, // output dimension n
  244. int prob_k, // reduction dimension k
  245. int* locks // extra global storage for barrier synchronization
  246. ) {
  247. // Each threadblock processes one "stripe" of the B matrix with (roughly) the
  248. // same size, which might involve multiple column "slices" (of width 16 *
  249. // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
  250. // example:
  251. // 0 1 3
  252. // 0 2 3
  253. // 1 2 4
  254. // While this kind of partitioning makes things somewhat more complicated, it
  255. // ensures good utilization of all SMs for many kinds of shape and GPU
  256. // configurations, while requiring as few slow global cross-threadblock
  257. // reductions as possible.
  258. using Dtype = ScalarType<scalar_t>;
  259. using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
  260. using FragA = typename ScalarType<scalar_t>::FragA;
  261. using FragB = typename ScalarType<scalar_t>::FragB;
  262. using FragC = typename ScalarType<scalar_t>::FragC;
  263. using FragS = typename ScalarType<scalar_t>::FragS;
  264. constexpr int pack_factor = 32 / num_bits;
  265. // For larger GEMMs we run multiple batchsize 64 versions in parallel for a
  266. // better partitioning with less reductions
  267. int parallel = 1;
  268. if (prob_m > 16 * thread_m_blocks) {
  269. parallel = prob_m / (16 * thread_m_blocks);
  270. prob_m = 16 * thread_m_blocks;
  271. }
  272. int k_tiles = prob_k / 16 / thread_k_blocks;
  273. int n_tiles = prob_n / 16 / thread_n_blocks;
  274. int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x);
  275. int slice_row = (iters * blockIdx.x) % k_tiles;
  276. int slice_col_par = (iters * blockIdx.x) / k_tiles;
  277. int slice_col = slice_col_par;
  278. int slice_iters; // number of threadblock tiles in the current slice
  279. int slice_count =
  280. 0; // total number of active threadblocks in the current slice
  281. int slice_idx; // index of threadblock in current slice; numbered bottom to
  282. // top
  283. // We can easily implement parallel problem execution by just remapping
  284. // indices and advancing global pointers
  285. if (slice_col_par >= n_tiles) {
  286. A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8;
  287. C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
  288. locks += (slice_col_par / n_tiles) * n_tiles;
  289. slice_col = slice_col_par % n_tiles;
  290. }
  291. // Compute all information about the current slice which is required for
  292. // synchronization.
  293. auto init_slice = [&]() {
  294. slice_iters =
  295. iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
  296. if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
  297. if (slice_iters == 0) return;
  298. if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
  299. slice_count = 1;
  300. slice_idx = 0;
  301. int col_first = iters * div_ceil(k_tiles * slice_col_par, iters);
  302. if (col_first <= k_tiles * (slice_col_par + 1)) {
  303. int col_off = col_first - k_tiles * slice_col_par;
  304. slice_count = div_ceil(k_tiles - col_off, iters);
  305. if (col_off > 0) slice_count++;
  306. int delta_first = iters * blockIdx.x - col_first;
  307. if (delta_first < 0 || (col_off == 0 && delta_first == 0))
  308. slice_idx = slice_count - 1;
  309. else {
  310. slice_idx = slice_count - 1 - delta_first / iters;
  311. if (col_off > 0) slice_idx--;
  312. }
  313. }
  314. if (slice_col == n_tiles) {
  315. A += 16 * thread_m_blocks * prob_k / 8;
  316. C += 16 * thread_m_blocks * prob_n / 8;
  317. locks += n_tiles;
  318. slice_col = 0;
  319. }
  320. };
  321. init_slice();
  322. // A sizes/strides
  323. // stride of the A matrix in global memory
  324. int a_gl_stride = prob_k / 8;
  325. // stride of an A matrix tile in shared memory
  326. constexpr int a_sh_stride = 16 * thread_k_blocks / 8;
  327. // delta between subsequent A tiles in global memory
  328. constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8;
  329. // between subsequent accesses within a tile
  330. int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o);
  331. // between shared memory writes
  332. constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o);
  333. // between shared memory tile reads
  334. constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4));
  335. // within a shared memory tile
  336. constexpr int a_sh_rd_delta_i = a_sh_stride * 16;
  337. // overall size of a tile
  338. constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks);
  339. // number of shared write iterations for a tile
  340. constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta);
  341. // B sizes/strides
  342. int b_gl_stride = 16 * prob_n / (pack_factor * 4);
  343. constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4;
  344. constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2;
  345. constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;
  346. int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
  347. int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads);
  348. constexpr int b_sh_wr_delta = threads * b_thread_vecs;
  349. constexpr int b_sh_rd_delta = threads * b_thread_vecs;
  350. constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
  351. constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
  352. // Scale sizes/strides without act_order
  353. int s_gl_stride = prob_n / 8;
  354. constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
  355. // Scale size/strides with act_order
  356. constexpr int tb_k = 16 * thread_k_blocks;
  357. constexpr int g_idx_stage = 0;
  358. // constexpr int act_s_row_stride = 1;
  359. // int act_s_col_stride = act_s_row_stride * num_groups;
  360. int act_s_col_stride = 1;
  361. int act_s_col_warp_stride = act_s_col_stride * 8;
  362. int tb_n_warps = thread_n_blocks / 4;
  363. int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;
  364. // Global A read index of current thread.
  365. int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
  366. (threadIdx.x % a_gl_rd_delta_o);
  367. a_gl_rd += a_gl_rd_delta_o * slice_row;
  368. // Shared write index of current thread.
  369. int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
  370. (threadIdx.x % a_gl_rd_delta_o);
  371. // Shared read index.
  372. int a_sh_rd =
  373. a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;
  374. a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
  375. int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) +
  376. (threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
  377. b_gl_rd += b_sh_stride * slice_col;
  378. b_gl_rd += b_gl_rd_delta_o * slice_row;
  379. int b_sh_wr = threadIdx.x * b_thread_vecs;
  380. int b_sh_rd = threadIdx.x * b_thread_vecs;
  381. // For act_order
  382. int slice_k_start = tb_k * slice_row;
  383. int slice_k_start_shared_fetch = slice_k_start;
  384. int slice_n_offset = act_s_col_tb_stride * slice_col;
  385. // No act_order
  386. int s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
  387. int s_sh_wr = threadIdx.x;
  388. bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
  389. // We scale a `half2` tile in row-major layout for column-wise quantization.
  390. int s_sh_rd =
  391. 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4;
  392. // Precompute which thread should not read memory in which iterations; this is
  393. // needed if there are more threads than required for a certain tilesize or
  394. // when the batchsize is not a multiple of 16.
  395. bool a_sh_wr_pred[a_sh_wr_iters];
  396. #pragma unroll
  397. for (int i = 0; i < a_sh_wr_iters; i++)
  398. a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
  399. // To ensure that writing and reading A tiles to/from shared memory, the
  400. // latter in fragment format, is fully bank conflict free, we need to use a
  401. // rather fancy XOR-based layout. The key here is that neither reads nor
  402. // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
  403. // same shared memory banks. Further, it seems (based on NSight-Compute) that
  404. // each warp must also write a consecutive memory segment?
  405. auto transform_a = [&](int i) {
  406. int row = i / a_gl_rd_delta_o;
  407. return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
  408. };
  409. // Since the computation of this remapping is non-trivial and, due to our main
  410. // loop unrolls, all shared memory accesses are static, we simply precompute
  411. // both transformed reads and writes.
  412. int a_sh_wr_trans[a_sh_wr_iters];
  413. #pragma unroll
  414. for (int i = 0; i < a_sh_wr_iters; i++)
  415. a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
  416. int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
  417. #pragma unroll
  418. for (int i = 0; i < b_sh_wr_iters; i++) {
  419. #pragma unroll
  420. for (int j = 0; j < thread_m_blocks; j++)
  421. a_sh_rd_trans[i][j] =
  422. transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
  423. }
  424. // Since B-accesses have non-constant stride they have to be computed at
  425. // runtime; we break dependencies between subsequent accesses with a tile by
  426. // maintining multiple pointers (we have enough registers), a tiny
  427. // optimization.
  428. const int4* B_ptr[b_sh_wr_iters];
  429. #pragma unroll
  430. for (int i = 0; i < b_sh_wr_iters; i++)
  431. B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
  432. extern __shared__ int4 sh[];
  433. // Shared memory storage for global fetch pipelines.
  434. int4* sh_a = sh;
  435. int4* sh_b = sh_a + (stages * a_sh_stage);
  436. int4* sh_g_idx = sh_b + (stages * b_sh_stage);
  437. int4* sh_s = sh_g_idx + (stages * g_idx_stage);
  438. // Register storage for double buffer of shared memory reads.
  439. FragA frag_a[2][thread_m_blocks];
  440. I4 frag_b_quant[2][b_thread_vecs];
  441. FragC frag_c[thread_m_blocks][4][2];
  442. FragS frag_s[2][4];
  443. // Zero accumulators.
  444. auto zero_accums = [&]() {
  445. #pragma unroll
  446. for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
  447. reinterpret_cast<float*>(frag_c)[i] = 0;
  448. };
  449. int sh_first_group_id = -1;
  450. int sh_num_groups = -1;
  451. constexpr int sh_max_num_groups = 32;
  452. auto fetch_scales_to_shared = [&](bool is_async, int first_group_id,
  453. int last_group_id) {
  454. sh_first_group_id = first_group_id;
  455. sh_num_groups = last_group_id - first_group_id + 1;
  456. if (sh_num_groups < sh_max_num_groups) {
  457. sh_num_groups = sh_max_num_groups;
  458. }
  459. if (sh_first_group_id + sh_num_groups > num_groups) {
  460. sh_num_groups = num_groups - sh_first_group_id;
  461. }
  462. int row_offset = first_group_id * s_gl_stride;
  463. if (is_async) {
  464. for (int i = 0; i < sh_num_groups; i++) {
  465. if (threadIdx.x < s_sh_stride) {
  466. cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x],
  467. &scales_ptr[row_offset + (i * s_gl_stride) +
  468. slice_n_offset + threadIdx.x]);
  469. }
  470. }
  471. } else {
  472. for (int i = 0; i < sh_num_groups; i++) {
  473. if (threadIdx.x < s_sh_stride) {
  474. sh_s[(i * s_sh_stride) + threadIdx.x] =
  475. scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset +
  476. threadIdx.x];
  477. }
  478. }
  479. }
  480. };
  481. // Asynchronously fetch the next A, B and s tile from global to the next
  482. // shared memory pipeline location.
  483. auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
  484. if (pred) {
  485. int4* sh_a_stage = sh_a + a_sh_stage * pipe;
  486. #pragma unroll
  487. for (int i = 0; i < a_sh_wr_iters; i++) {
  488. cp_async4_pred(
  489. &sh_a_stage[a_sh_wr_trans[i]],
  490. &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
  491. a_sh_wr_pred[i]);
  492. }
  493. int4* sh_b_stage = sh_b + b_sh_stage * pipe;
  494. #pragma unroll
  495. for (int i = 0; i < b_sh_wr_iters; i++) {
  496. #pragma unroll
  497. for (int j = 0; j < b_thread_vecs; j++) {
  498. cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j);
  499. }
  500. B_ptr[i] += b_gl_rd_delta_o;
  501. }
  502. }
  503. // Insert a fence even when we are winding down the pipeline to ensure that
  504. // waiting is also correct at this point.
  505. cp_async_fence();
  506. };
  507. // Wait until the next thread tile has been loaded to shared memory.
  508. auto wait_for_stage = [&]() {
  509. // We only have `stages - 2` active fetches since we are double buffering
  510. // and can only issue the next fetch when it is guaranteed that the previous
  511. // shared memory load is fully complete (as it may otherwise be
  512. // overwritten).
  513. cp_async_wait<stages - 2>();
  514. __syncthreads();
  515. };
  516. // Load the next sub-tile from the current location in the shared memory pipe
  517. // into the current register buffer.
  518. auto fetch_to_registers = [&](int k, int pipe) {
  519. int4* sh_a_stage = sh_a + a_sh_stage * pipe;
  520. #pragma unroll
  521. for (int i = 0; i < thread_m_blocks; i++)
  522. ldsm4<scalar_t>(frag_a[k % 2][i],
  523. &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
  524. int4* sh_b_stage = sh_b + b_sh_stage * pipe;
  525. #pragma unroll
  526. for (int i = 0; i < b_thread_vecs; i++) {
  527. frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>(
  528. &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]);
  529. }
  530. };
  531. bool is_same_group[stages];
  532. int same_group_id[stages];
  533. auto init_same_group = [&](int pipe) {
  534. is_same_group[pipe] = false;
  535. same_group_id[pipe] = 0;
  536. return;
  537. };
  538. // Execute the actual tensor core matmul of a sub-tile.
  539. auto matmul = [&](int k) {
  540. // We have the m dimension as the inner loop in order to encourage overlapping
  541. // dequantization and matmul operations.
  542. #pragma unroll
  543. for (int j = 0; j < 4; j++) {
  544. FragB frag_b0;
  545. FragB frag_b1;
  546. int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);
  547. int b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
  548. int b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
  549. frag_b0 = dequant_8bit<scalar_t>(b_quant_0);
  550. frag_b1 = dequant_8bit<scalar_t>(b_quant_1);
  551. #pragma unroll
  552. for (int i = 0; i < thread_m_blocks; i++) {
  553. mma<scalar_t>(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
  554. mma<scalar_t>(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
  555. }
  556. }
  557. };
  558. // Since we slice across the k dimension of a tile in order to increase the
  559. // number of warps while keeping the n dimension of a tile reasonable, we have
  560. // multiple warps that accumulate their partial sums of the same output
  561. // location; which we have to reduce over in the end. We do in shared memory.
  562. auto thread_block_reduce = [&]() {
  563. constexpr int red_off = threads / b_sh_stride_threads / 2;
  564. if (red_off >= 1) {
  565. int red_idx = threadIdx.x / b_sh_stride_threads;
  566. constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
  567. constexpr int red_sh_delta = b_sh_stride_threads;
  568. int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
  569. (threadIdx.x % b_sh_stride_threads);
  570. // Parallel logarithmic shared memory reduction. We make sure to avoid any
  571. // unnecessary read or write iterations, e.g., for two warps we write only
  572. // once by warp 1 and read only once by warp 0.
  573. #pragma unroll
  574. for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
  575. #pragma unroll
  576. for (int i = red_off; i > 0; i /= 2) {
  577. if (i <= red_idx && red_idx < 2 * i) {
  578. #pragma unroll
  579. for (int j = 0; j < 4 * 2; j++) {
  580. int red_sh_wr =
  581. red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
  582. if (i < red_off) {
  583. float* c_rd =
  584. reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
  585. float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
  586. #pragma unroll
  587. for (int k = 0; k < 4; k++)
  588. reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
  589. c_rd[k] + c_wr[k];
  590. }
  591. sh[red_sh_wr] =
  592. reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
  593. }
  594. }
  595. __syncthreads();
  596. }
  597. if (red_idx == 0) {
  598. #pragma unroll
  599. for (int i = 0; i < 4 * 2; i++) {
  600. float* c_rd =
  601. reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
  602. #pragma unroll
  603. for (int j = 0; j < 4; j++)
  604. reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
  605. c_rd[j];
  606. }
  607. }
  608. __syncthreads();
  609. }
  610. }
  611. };
  612. // Since multiple threadblocks may process parts of the same column slice, we
  613. // finally have to globally reduce over the results. As the striped
  614. // partitioning minimizes the number of such reductions and our outputs are
  615. // usually rather small, we perform this reduction serially in L2 cache.
  616. auto global_reduce = [&](bool first = false, bool last = false) {
  617. // We are very careful here to reduce directly in the output buffer to
  618. // maximize L2 cache utilization in this step. To do this, we write out
  619. // results in FP16 (but still reduce with FP32 compute).
  620. constexpr int active_threads = 32 * thread_n_blocks / 4;
  621. if (threadIdx.x < active_threads) {
  622. int c_gl_stride = prob_n / 8;
  623. int c_gl_wr_delta_o = 8 * c_gl_stride;
  624. int c_gl_wr_delta_i = 4 * (active_threads / 32);
  625. int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) +
  626. 4 * (threadIdx.x / 32) + threadIdx.x % 4;
  627. c_gl_wr += (2 * thread_n_blocks) * slice_col;
  628. constexpr int c_sh_wr_delta = active_threads;
  629. int c_sh_wr = threadIdx.x;
  630. int row = (threadIdx.x % 32) / 4;
  631. if (!first) {
  632. // Interestingly, doing direct global accesses here really seems to mess up
  633. // the compiler and lead to slowdowns, hence we also use async-copies even
  634. // though these fetches are not actually asynchronous.
  635. #pragma unroll
  636. for (int i = 0; i < thread_m_blocks * 4; i++) {
  637. cp_async4_pred(
  638. &sh[c_sh_wr + c_sh_wr_delta * i],
  639. &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
  640. c_gl_wr_delta_i * (i % 2)],
  641. i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
  642. }
  643. cp_async_fence();
  644. cp_async_wait<0>();
  645. }
  646. #pragma unroll
  647. for (int i = 0; i < thread_m_blocks * 4; i++) {
  648. if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
  649. if (!first) {
  650. int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
  651. #pragma unroll
  652. for (int j = 0; j < 2 * 4; j++) {
  653. reinterpret_cast<float*>(
  654. &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=
  655. Dtype::num2float(reinterpret_cast<scalar_t*>(&c_red)[j]);
  656. }
  657. }
  658. if (!last) {
  659. int4 c;
  660. #pragma unroll
  661. for (int j = 0; j < 2 * 4; j++) {
  662. reinterpret_cast<scalar_t*>(&c)[j] =
  663. Dtype::float2num(reinterpret_cast<float*>(
  664. &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]);
  665. }
  666. C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =
  667. c;
  668. }
  669. }
  670. }
  671. }
  672. };
  673. // Write out the reduce final result in the correct layout. We only actually
  674. // reshuffle matrix fragments in this step, the reduction above is performed
  675. // in fragment layout.
  676. auto write_result = [&]() {
  677. int c_gl_stride = prob_n / 8;
  678. constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
  679. int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
  680. constexpr int c_sh_rd_delta =
  681. c_sh_stride * (threads / (2 * thread_n_blocks));
  682. int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) +
  683. (threadIdx.x % (2 * thread_n_blocks));
  684. c_gl_wr += (2 * thread_n_blocks) * slice_col;
  685. int c_sh_wr =
  686. (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;
  687. c_sh_wr += 32 * (threadIdx.x / 32);
  688. int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) +
  689. (threadIdx.x % (2 * thread_n_blocks));
  690. int c_gl_wr_end = c_gl_stride * prob_m;
  691. // We first reorder in shared memory to guarantee the most efficient final
  692. // global write patterns
  693. auto write = [&](int idx, float c0, float c1, FragS& s) {
  694. scalar_t2 res =
  695. Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));
  696. ((scalar_t2*)sh)[idx] = res;
  697. };
  698. if (threadIdx.x / 32 < thread_n_blocks / 4) {
  699. #pragma unroll
  700. for (int i = 0; i < thread_m_blocks; i++) {
  701. #pragma unroll
  702. for (int j = 0; j < 4; j++) {
  703. int wr = c_sh_wr + 8 * j;
  704. write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
  705. frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
  706. write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],
  707. frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
  708. write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],
  709. frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
  710. write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],
  711. frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
  712. }
  713. c_sh_wr += 16 * (4 * c_sh_stride);
  714. }
  715. }
  716. __syncthreads();
  717. #pragma unroll
  718. for (int i = 0;
  719. i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
  720. i++) {
  721. if (c_gl_wr < c_gl_wr_end) {
  722. C[c_gl_wr] = sh[c_sh_rd];
  723. c_gl_wr += c_gl_wr_delta;
  724. c_sh_rd += c_sh_rd_delta;
  725. }
  726. }
  727. };
  728. // Start global fetch and register load pipelines.
  729. auto start_pipes = [&]() {
  730. #pragma unroll
  731. for (int i = 0; i < stages - 1; i++) {
  732. fetch_to_shared(i, i, i < slice_iters);
  733. }
  734. zero_accums();
  735. wait_for_stage();
  736. init_same_group(0);
  737. fetch_to_registers(0, 0);
  738. a_gl_rd += a_gl_rd_delta_o * (stages - 1);
  739. slice_k_start_shared_fetch += tb_k * (stages - 1);
  740. };
  741. if (slice_iters) {
  742. start_pipes();
  743. }
  744. // Main loop.
  745. while (slice_iters) {
  746. // We unroll over both the global fetch and the register load pipeline to
  747. // ensure all shared memory accesses are static. Note that both pipelines
  748. // have even length meaning that the next iteration will always start at
  749. // index 0.
  750. #pragma unroll
  751. for (int pipe = 0; pipe < stages;) {
  752. #pragma unroll
  753. for (int k = 0; k < b_sh_wr_iters; k++) {
  754. fetch_to_registers(k + 1, pipe % stages);
  755. if (k == b_sh_wr_iters - 2) {
  756. fetch_to_shared((pipe + stages - 1) % stages, pipe,
  757. slice_iters >= stages);
  758. pipe++;
  759. wait_for_stage();
  760. init_same_group(pipe % stages);
  761. }
  762. matmul(k);
  763. }
  764. slice_iters--;
  765. if (slice_iters == 0) {
  766. break;
  767. }
  768. }
  769. a_gl_rd += a_gl_rd_delta_o * stages;
  770. slice_k_start += tb_k * stages;
  771. slice_k_start_shared_fetch += tb_k * stages;
  772. // Process results and, if necessary, proceed to the next column slice.
  773. // While this pattern may not be the most readable, other ways of writing
  774. // the loop seemed to noticeably worse performance after compilation.
  775. if (slice_iters == 0) {
  776. cp_async_wait<0>();
  777. bool last = slice_idx == slice_count - 1;
  778. // For per-column scales, we only fetch them here in the final step before
  779. // write-out
  780. if (s_sh_wr_pred) {
  781. cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
  782. }
  783. cp_async_fence();
  784. thread_block_reduce();
  785. cp_async_wait<0>();
  786. __syncthreads();
  787. if (threadIdx.x / 32 < thread_n_blocks / 4) {
  788. reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
  789. reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
  790. }
  791. // For 8-bit channelwise, we apply the scale before the global reduction
  792. // that converts the fp32 results to fp16 (so that we avoid possible
  793. // overflow in fp16)
  794. if (threadIdx.x / 32 < thread_n_blocks / 4) {
  795. #pragma unroll
  796. for (int i = 0; i < thread_m_blocks; i++) {
  797. #pragma unroll
  798. for (int j = 0; j < 4; j++) {
  799. scale_float<scalar_t>(reinterpret_cast<float*>(&frag_c[i][j][0][0]),
  800. frag_s[j / 2][2 * (j % 2) + 0]);
  801. scale_float<scalar_t>(reinterpret_cast<float*>(&frag_c[i][j][0][2]),
  802. frag_s[j / 2][2 * (j % 2) + 0]);
  803. scale_float<scalar_t>(reinterpret_cast<float*>(&frag_c[i][j][1][0]),
  804. frag_s[j / 2][2 * (j % 2) + 1]);
  805. scale_float<scalar_t>(reinterpret_cast<float*>(&frag_c[i][j][1][2]),
  806. frag_s[j / 2][2 * (j % 2) + 1]);
  807. }
  808. }
  809. }
  810. if (slice_count > 1) { // only globally reduce if there is more than one
  811. // block in a slice
  812. barrier_acquire(&locks[slice_col], slice_idx);
  813. global_reduce(slice_idx == 0, last);
  814. barrier_release(&locks[slice_col], last);
  815. }
  816. if (last) // only the last block in a slice actually writes the result
  817. write_result();
  818. slice_row = 0;
  819. slice_col_par++;
  820. slice_col++;
  821. init_slice();
  822. if (slice_iters) {
  823. a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
  824. (threadIdx.x % a_gl_rd_delta_o);
  825. #pragma unroll
  826. for (int i = 0; i < b_sh_wr_iters; i++)
  827. B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
  828. if (slice_col == 0) {
  829. #pragma unroll
  830. for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
  831. }
  832. // Update slice k/n for scales loading
  833. s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
  834. start_pipes();
  835. }
  836. }
  837. }
  838. }
  839. #define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
  840. THREAD_K_BLOCKS, GROUP_BLOCKS, NUM_THREADS) \
  841. else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
  842. thread_n_blocks == THREAD_N_BLOCKS && \
  843. thread_k_blocks == THREAD_K_BLOCKS && \
  844. group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
  845. cudaFuncSetAttribute( \
  846. Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
  847. THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, GROUP_BLOCKS>, \
  848. cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
  849. Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
  850. THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, GROUP_BLOCKS> \
  851. <<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
  852. A_ptr, B_ptr, C_ptr, s_ptr, num_groups, prob_m, prob_n, prob_k, \
  853. locks); \
  854. }
  855. typedef struct {
  856. int thread_k;
  857. int thread_n;
  858. int num_threads;
  859. } thread_config_t;
  860. typedef struct {
  861. int max_m_blocks;
  862. thread_config_t tb_cfg;
  863. } exec_config_t;
  864. thread_config_t small_batch_thread_configs[] = {
  865. // Ordered by priority
  866. // thread_k, thread_n, num_threads
  867. {128, 128, 256},
  868. {64, 128, 128},
  869. {128, 64, 128},
  870. };
  871. thread_config_t large_batch_thread_configs[] = {
  872. // Ordered by priority
  873. // thread_k, thread_n, num_threads
  874. {64, 256, 256},
  875. {64, 128, 128},
  876. {128, 64, 128},
  877. };
  878. int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
  879. int prob_n, int prob_k, int num_bits,
  880. int group_size) {
  881. int tb_n = th_config.thread_n;
  882. // Get max scale groups per thread-block
  883. // Fixed for channelwise
  884. int tb_groups = 1;
  885. int tb_scales = tb_groups * tb_n * 2;
  886. return tb_scales * pipe_stages;
  887. }
  888. bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,
  889. int prob_m, int prob_n, int prob_k, int num_bits,
  890. int scales_cache_size, int max_shared_mem) {
  891. int pack_factor = 32 / num_bits;
  892. // Get B size
  893. int tb_k = th_config.thread_k;
  894. int tb_n = th_config.thread_n;
  895. int b_size = (tb_k * tb_n / pack_factor) * 4;
  896. // Get A size
  897. int m_blocks = div_ceil(prob_m, 16);
  898. int tb_max_m = 16;
  899. while (true) {
  900. if (m_blocks >= max_m_blocks) {
  901. tb_max_m *= max_m_blocks;
  902. break;
  903. }
  904. max_m_blocks--;
  905. if (max_m_blocks == 0) {
  906. TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks);
  907. }
  908. }
  909. int a_size = (tb_max_m * tb_k) * 2;
  910. float pipe_size = (a_size + b_size) * pipe_stages;
  911. TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity
  912. return pipe_size < 0.95f * (max_shared_mem - scales_cache_size);
  913. }
  914. bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
  915. int prob_m, int prob_n, int prob_k, int num_bits,
  916. int group_size, int max_shared_mem) {
  917. // Sanity
  918. if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
  919. th_config.num_threads == -1) {
  920. return false;
  921. }
  922. // Verify K/N are divisible by thread K/N
  923. if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
  924. return false;
  925. }
  926. // Verify min for thread K/N
  927. if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
  928. return false;
  929. }
  930. // num_threads must be at least 128 (= 4 warps)
  931. if (th_config.num_threads < 128) {
  932. return false;
  933. }
  934. // Determine cache for scales
  935. int scales_cache_size = get_scales_cache_size(th_config, prob_m, prob_n,
  936. prob_k, num_bits, group_size);
  937. // Check that pipeline fits into cache
  938. if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k,
  939. num_bits, scales_cache_size, max_shared_mem)) {
  940. return false;
  941. }
  942. return true;
  943. }
  944. exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
  945. int num_bits, int group_size,
  946. int max_shared_mem) {
  947. int max_m_blocks = 4;
  948. while (max_m_blocks > 0) {
  949. if (prob_m <= 16) {
  950. for (auto th_config : small_batch_thread_configs) {
  951. if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
  952. num_bits, group_size, max_shared_mem)) {
  953. return exec_config_t{max_m_blocks, th_config};
  954. }
  955. }
  956. } else {
  957. for (auto th_config : large_batch_thread_configs) {
  958. if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
  959. num_bits, group_size, max_shared_mem)) {
  960. return exec_config_t{max_m_blocks, th_config};
  961. }
  962. }
  963. }
  964. max_m_blocks--; // Process less M blocks per invocation to reduce cache
  965. // usage
  966. }
  967. return exec_config_t{0, {-1, -1, -1}};
  968. }
  969. #define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
  970. __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
  971. __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
  972. __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
  973. __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS)
  974. template <typename scalar_t>
  975. void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, int prob_m,
  976. int prob_n, int prob_k, void* workspace, int num_bits,
  977. int num_groups, int group_size, int dev,
  978. cudaStream_t stream, int thread_k, int thread_n, int sms,
  979. int max_par) {
  980. TORCH_CHECK(num_bits == 8, "num_bits must be 8. Got = ", num_bits);
  981. TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
  982. ", ", prob_n, ", ", prob_k, "]");
  983. int tot_m = prob_m;
  984. int tot_m_blocks = div_ceil(tot_m, 16);
  985. int pad = 16 * tot_m_blocks - tot_m;
  986. if (sms == -1) {
  987. cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
  988. }
  989. int max_shared_mem = 0;
  990. cudaDeviceGetAttribute(&max_shared_mem,
  991. cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
  992. TORCH_CHECK(max_shared_mem > 0);
  993. // Set thread config
  994. exec_config_t exec_cfg;
  995. if (thread_k != -1 && thread_n != -1) {
  996. // User-defined config
  997. exec_cfg =
  998. exec_config_t{4, thread_config_t{thread_k, thread_n, default_threads}};
  999. } else {
  1000. // Auto config
  1001. exec_cfg = determine_thread_config(prob_m, prob_n, prob_k, num_bits,
  1002. group_size, max_shared_mem);
  1003. }
  1004. TORCH_CHECK(
  1005. exec_cfg.max_m_blocks > 0 &&
  1006. is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, prob_m,
  1007. prob_n, prob_k, num_bits, group_size, max_shared_mem),
  1008. "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks,
  1009. ", thread_k = ", exec_cfg.tb_cfg.thread_k,
  1010. ", thread_n = ", exec_cfg.tb_cfg.thread_n,
  1011. ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", prob_m,
  1012. ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
  1013. ", group_size = ", group_size, ", max_shared_mem = ", max_shared_mem);
  1014. int num_threads = exec_cfg.tb_cfg.num_threads;
  1015. thread_k = exec_cfg.tb_cfg.thread_k;
  1016. thread_n = exec_cfg.tb_cfg.thread_n;
  1017. int thread_k_blocks = thread_k / 16;
  1018. int thread_n_blocks = thread_n / 16;
  1019. int blocks = sms;
  1020. TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
  1021. " is not divisible by thread_n = ", thread_n);
  1022. TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
  1023. " is not divisible by thread_k = ", thread_k);
  1024. int group_blocks = -1;
  1025. const int4* A_ptr = (const int4*)A;
  1026. const int4* B_ptr = (const int4*)B;
  1027. int4* C_ptr = (int4*)C;
  1028. const int4* s_ptr = (const int4*)s;
  1029. int* locks = (int*)workspace;
  1030. // Main loop
  1031. for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) {
  1032. int thread_m_blocks = tot_m_blocks - i;
  1033. prob_m = tot_m - 16 * i;
  1034. int par = 1;
  1035. if (thread_m_blocks > exec_cfg.max_m_blocks) {
  1036. // Note that parallel > 1 currently only works for inputs without any
  1037. // padding
  1038. par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks);
  1039. if (par > max_par) par = max_par;
  1040. prob_m = (16 * exec_cfg.max_m_blocks) * par;
  1041. i += exec_cfg.max_m_blocks * (par - 1);
  1042. thread_m_blocks = exec_cfg.max_m_blocks;
  1043. }
  1044. // Define kernel configurations
  1045. if (false) {
  1046. }
  1047. CALL_IF(8, 32, 2, 256)
  1048. CALL_IF(8, 16, 4, 256)
  1049. CALL_IF(8, 8, 8, 256)
  1050. CALL_IF(8, 8, 4, 128)
  1051. CALL_IF(8, 4, 8, 128)
  1052. else {
  1053. TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " +
  1054. str(prob_n) + ", " + str(prob_k) + "]" +
  1055. ", num_groups = " + str(num_groups) +
  1056. ", group_size = " + str(group_size) +
  1057. ", thread_m_blocks = " + str(thread_m_blocks) +
  1058. ", thread_n_blocks = " + str(thread_n_blocks) +
  1059. ", thread_k_blocks = " + str(thread_k_blocks));
  1060. }
  1061. A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;
  1062. C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par;
  1063. }
  1064. }
  1065. } // namespace fp8_marlin
  1066. torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
  1067. torch::Tensor& b_scales, torch::Tensor& workspace,
  1068. int64_t num_bits, int64_t size_m, int64_t size_n,
  1069. int64_t size_k) {
  1070. // Verify num_bits
  1071. TORCH_CHECK(num_bits == 8, "num_bits must be 8. Got = ", num_bits);
  1072. int pack_factor = 32 / num_bits;
  1073. // Verify A
  1074. TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
  1075. ", size_m = ", size_m);
  1076. TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1),
  1077. ", size_k = ", size_k);
  1078. // Verify B
  1079. TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k,
  1080. " is not divisible by tile_size = ", marlin::tile_size);
  1081. TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0),
  1082. "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
  1083. ", size_k = ", size_k, ", tile_size = ", marlin::tile_size);
  1084. TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0,
  1085. "b_q_weight.size(1) = ", b_q_weight.size(1),
  1086. " is not divisible by tile_size = ", marlin::tile_size);
  1087. int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor;
  1088. TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
  1089. ", actual_size_n = ", actual_size_n);
  1090. // Verify device and strides
  1091. TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
  1092. TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
  1093. TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
  1094. TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
  1095. TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
  1096. TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
  1097. // Alloc buffers
  1098. const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
  1099. auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
  1100. torch::Tensor c = torch::empty({size_m, size_n}, options);
  1101. // thread_k: `k` size of a thread_tile in `weights` (can usually be left as
  1102. // auto -1)
  1103. int thread_k = -1;
  1104. // thread_n: `n` size of a thread_tile in `weights` (can usually be left as
  1105. // auto -1)
  1106. int thread_n = -1;
  1107. // sms: number of SMs to use for the kernel (can usually be left as auto -1)
  1108. int sms = -1;
  1109. // Detect groupsize and act_order
  1110. int num_groups = -1;
  1111. int group_size = -1;
  1112. int b_rank = b_scales.sizes().size();
  1113. TORCH_CHECK(b_rank == 2, "b_scales rank = ", b_rank, " is not 2");
  1114. TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1),
  1115. " is not size_n = ", size_n);
  1116. // Channelwise only for FP8
  1117. TORCH_CHECK(b_scales.size(0) == 1)
  1118. num_groups = b_scales.size(0);
  1119. // Verify workspace size
  1120. TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n,
  1121. ", is not divisible by min_thread_n = ", marlin::min_thread_n);
  1122. int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par;
  1123. TORCH_CHECK(workspace.numel() >= min_workspace_size,
  1124. "workspace.numel = ", workspace.numel(),
  1125. " is below min_workspace_size = ", min_workspace_size);
  1126. int dev = a.get_device();
  1127. if (a.scalar_type() == at::ScalarType::Half) {
  1128. fp8_marlin::marlin_mm_f16i4<half>(
  1129. a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
  1130. b_scales.data_ptr<at::Half>(), size_m, size_n, size_k,
  1131. workspace.data_ptr(), num_bits, num_groups, group_size, dev,
  1132. at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
  1133. marlin::max_par);
  1134. } else if (a.scalar_type() == at::ScalarType::BFloat16) {
  1135. fp8_marlin::marlin_mm_f16i4<nv_bfloat16>(
  1136. a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
  1137. c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(), size_m,
  1138. size_n, size_k, workspace.data_ptr(), num_bits, num_groups, group_size,
  1139. dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
  1140. marlin::max_par);
  1141. } else {
  1142. TORCH_CHECK(false, "fp8_marlin_gemm only supports bfloat16 and float16");
  1143. }
  1144. return c;
  1145. }
  1146. #endif