gptq_marlin.cu 87 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299
  1. /*
  2. * Modified by Neural Magic
  3. * Copyright (C) Marlin.2024 Elias Frantar
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. /*
  18. * Adapted from https://github.com/IST-DASLab/marlin
  19. */
  20. #include "marlin.cuh"
  21. #include "marlin_dtypes.cuh"
  22. #include "core/scalar_type.hpp"
  23. #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
  24. static_assert(std::is_same<scalar_t, half>::value || \
  25. std::is_same<scalar_t, nv_bfloat16>::value, \
  26. "only float16 and bfloat16 is supported");
  27. template <typename T>
  28. inline std::string str(T x) {
  29. return std::to_string(x);
  30. }
  31. namespace marlin {
  32. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  33. __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
  34. int const* __restrict__ perm_int_ptr,
  35. int4* __restrict__ out_int4_ptr, int size_m,
  36. int size_k, int block_rows) {}
  37. template <typename scalar_t, // compute dtype, half or nv_float16
  38. const aphrodite::ScalarTypeId w_type_id, // weight ScalarType id
  39. const int threads, // number of threads in a threadblock
  40. const int thread_m_blocks, // number of 16x16 blocks in the m
  41. // dimension (batchsize) of the
  42. // threadblock
  43. const int thread_n_blocks, // same for n dimension (output)
  44. const int thread_k_blocks, // same for k dimension (reduction)
  45. const int stages, // number of stages for the async global->shared
  46. // fetch pipeline
  47. const bool has_act_order, // whether act_order is enabled
  48. const int group_blocks = -1 // number of consecutive 16x16 blocks
  49. // with a separate quantization scale
  50. >
  51. __global__ void Marlin(
  52. const int4* __restrict__ A, // fp16 input matrix of shape mxk
  53. const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
  54. int4* __restrict__ C, // fp16 output buffer of shape mxn
  55. int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
  56. const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
  57. // (k/groupsize)xn
  58. const int* __restrict__ g_idx, // int32 group indices of shape k
  59. int num_groups, // number of scale groups per output channel
  60. int prob_m, // batch dimension m
  61. int prob_n, // output dimension n
  62. int prob_k, // reduction dimension k
  63. int* locks, // extra global storage for barrier synchronization
  64. bool use_fp32_reduce // whether to use fp32 global reduce
  65. ) {}
  66. } // namespace marlin
  67. torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
  68. torch::Tensor& b_scales, torch::Tensor& b_zeros,
  69. torch::Tensor& g_idx, torch::Tensor& perm,
  70. torch::Tensor& workspace,
  71. aphrodite::ScalarTypeTorchPtr const& b_q_type,
  72. int64_t size_m, int64_t size_n, int64_t size_k,
  73. bool is_k_full, bool has_zp) {
  74. TORCH_CHECK_NOT_IMPLEMENTED(false,
  75. "marlin_gemm(..) requires CUDA_ARCH >= 8.0");
  76. return torch::empty({1, 1});
  77. }
  78. #else
  79. // m16n8k16 tensor core mma instruction with fp16 inputs and fp32
  80. // output/accumulation.
  81. template <typename scalar_t>
  82. __device__ inline void mma(const typename ScalarType<scalar_t>::FragA& a_frag,
  83. const typename ScalarType<scalar_t>::FragB& frag_b,
  84. typename ScalarType<scalar_t>::FragC& frag_c) {
  85. const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
  86. const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
  87. float* c = reinterpret_cast<float*>(&frag_c);
  88. if constexpr (std::is_same<scalar_t, half>::value) {
  89. asm volatile(
  90. "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
  91. "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
  92. : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
  93. : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
  94. "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
  95. } else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
  96. asm volatile(
  97. "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
  98. "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
  99. : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
  100. : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
  101. "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
  102. } else {
  103. STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
  104. }
  105. }
  106. // Instruction for loading a full 16x16 matrix fragment of operand A from shared
  107. // memory, directly in tensor core layout.
  108. template <typename scalar_t>
  109. __device__ inline void ldsm4(typename ScalarType<scalar_t>::FragA& frag_a,
  110. const void* smem_ptr) {
  111. uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
  112. uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
  113. asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
  114. : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
  115. : "r"(smem));
  116. }
  117. // Lookup-table based 3-input logical operation; explicitly used for
  118. // dequantization as the compiler does not seem to automatically recognize it in
  119. // all cases.
  120. template <int lut>
  121. __device__ inline int lop3(int a, int b, int c) {
  122. int res;
  123. asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
  124. : "=r"(res)
  125. : "r"(a), "r"(b), "r"(c), "n"(lut));
  126. return res;
  127. }
  128. // Constructs destination register by taking bytes from 2 sources (based on
  129. // mask)
  130. template <int start_byte, int mask>
  131. __device__ inline uint32_t prmt(uint32_t a) {
  132. uint32_t res;
  133. asm volatile("prmt.b32 %0, %1, %2, %3;\n"
  134. : "=r"(res)
  135. : "r"(a), "n"(start_byte), "n"(mask));
  136. return res;
  137. }
  138. template <typename scalar_t, aphrodite::ScalarTypeId w_type_id>
  139. __device__ inline typename ScalarType<scalar_t>::FragB dequant(int q);
  140. //
  141. // Efficiently dequantize 4bit values packed in an int32 value into a full
  142. // B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
  143. // with some small changes:
  144. // - FP16:
  145. // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
  146. // - BF16:
  147. // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
  148. //
  149. template <>
  150. __device__ inline typename ScalarType<half>::FragB
  151. dequant<half, aphrodite::kU4B8.id()>(int q) {
  152. const int LO = 0x000f000f;
  153. const int HI = 0x00f000f0;
  154. const int EX = 0x64006400;
  155. // Guarantee that the `(a & b) | c` operations are LOP3s.
  156. int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
  157. int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
  158. // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
  159. // directly into `SUB` and `ADD`.
  160. const int SUB = 0x64086408;
  161. const int MUL = 0x2c002c00;
  162. const int ADD = 0xd480d480;
  163. typename ScalarType<half>::FragB frag_b;
  164. frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
  165. *reinterpret_cast<const half2*>(&SUB));
  166. frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
  167. *reinterpret_cast<const half2*>(&MUL),
  168. *reinterpret_cast<const half2*>(&ADD));
  169. return frag_b;
  170. }
  171. template <>
  172. __device__ inline typename ScalarType<nv_bfloat16>::FragB
  173. dequant<nv_bfloat16, aphrodite::kU4B8.id()>(int q) {
  174. static constexpr uint32_t MASK = 0x000f000f;
  175. static constexpr uint32_t EX = 0x43004300;
  176. // Guarantee that the `(a & b) | c` operations are LOP3s.
  177. int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
  178. q >>= 4;
  179. int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
  180. typename ScalarType<nv_bfloat16>::FragB frag_b;
  181. static constexpr uint32_t MUL = 0x3F803F80;
  182. static constexpr uint32_t ADD = 0xC308C308;
  183. frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
  184. *reinterpret_cast<const nv_bfloat162*>(&MUL),
  185. *reinterpret_cast<const nv_bfloat162*>(&ADD));
  186. frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
  187. *reinterpret_cast<const nv_bfloat162*>(&MUL),
  188. *reinterpret_cast<const nv_bfloat162*>(&ADD));
  189. return frag_b;
  190. }
  191. template <>
  192. __device__ inline typename ScalarType<half>::FragB
  193. dequant<half, aphrodite::kU4.id()>(int q) {
  194. const int LO = 0x000f000f;
  195. const int HI = 0x00f000f0;
  196. const int EX = 0x64006400;
  197. // Guarantee that the `(a & b) | c` operations are LOP3s.
  198. int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
  199. int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
  200. const int SUB = 0x64006400;
  201. const int MUL = 0x2c002c00;
  202. const int ADD = 0xd400d400;
  203. typename ScalarType<half>::FragB frag_b;
  204. frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
  205. *reinterpret_cast<const half2*>(&SUB));
  206. frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
  207. *reinterpret_cast<const half2*>(&MUL),
  208. *reinterpret_cast<const half2*>(&ADD));
  209. return frag_b;
  210. }
  211. template <>
  212. __device__ inline typename ScalarType<nv_bfloat16>::FragB
  213. dequant<nv_bfloat16, aphrodite::kU4.id()>(int q) {
  214. static constexpr uint32_t MASK = 0x000f000f;
  215. static constexpr uint32_t EX = 0x43004300;
  216. // Guarantee that the `(a & b) | c` operations are LOP3s.
  217. int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
  218. q >>= 4;
  219. int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
  220. typename ScalarType<nv_bfloat16>::FragB frag_b;
  221. static constexpr uint32_t MUL = 0x3F803F80;
  222. static constexpr uint32_t ADD = 0xC300C300;
  223. frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
  224. *reinterpret_cast<const nv_bfloat162*>(&MUL),
  225. *reinterpret_cast<const nv_bfloat162*>(&ADD));
  226. frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
  227. *reinterpret_cast<const nv_bfloat162*>(&MUL),
  228. *reinterpret_cast<const nv_bfloat162*>(&ADD));
  229. return frag_b;
  230. }
  231. //
  232. // Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
  233. // bf16 Reference:
  234. // - FP16:
  235. // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
  236. // - BF16:
  237. // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
  238. //
  239. template <>
  240. __device__ inline typename ScalarType<half>::FragB
  241. dequant<half, aphrodite::kU8B128.id()>(int q) {
  242. static constexpr uint32_t mask_for_elt_01 = 0x5250;
  243. static constexpr uint32_t mask_for_elt_23 = 0x5351;
  244. static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
  245. uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
  246. uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
  247. static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
  248. typename ScalarType<half>::FragB frag_b;
  249. frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
  250. *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
  251. frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
  252. *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
  253. return frag_b;
  254. }
  255. template <>
  256. __device__ inline typename ScalarType<nv_bfloat16>::FragB
  257. dequant<nv_bfloat16, aphrodite::kU8B128.id()>(int q) {
  258. typename ScalarType<nv_bfloat16>::FragB frag_b;
  259. float fp32_intermediates[4];
  260. uint32_t* fp32_intermediates_casted =
  261. reinterpret_cast<uint32_t*>(fp32_intermediates);
  262. static constexpr uint32_t fp32_base = 0x4B000000;
  263. fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
  264. fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
  265. fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
  266. fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
  267. fp32_intermediates[0] -= 8388736.f;
  268. fp32_intermediates[1] -= 8388736.f;
  269. fp32_intermediates[2] -= 8388736.f;
  270. fp32_intermediates[3] -= 8388736.f;
  271. uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b);
  272. bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
  273. fp32_intermediates_casted[1], 0x7632);
  274. bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
  275. fp32_intermediates_casted[3], 0x7632);
  276. return frag_b;
  277. }
  278. template <>
  279. __device__ inline typename ScalarType<half>::FragB
  280. dequant<half, aphrodite::kU8.id()>(int q) {
  281. static constexpr uint32_t mask_for_elt_01 = 0x5250;
  282. static constexpr uint32_t mask_for_elt_23 = 0x5351;
  283. static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
  284. uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
  285. uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
  286. static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
  287. typename ScalarType<half>::FragB frag_b;
  288. frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
  289. *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
  290. frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
  291. *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
  292. return frag_b;
  293. }
  294. template <>
  295. __device__ inline typename ScalarType<nv_bfloat16>::FragB
  296. dequant<nv_bfloat16, aphrodite::kU8.id()>(int q) {
  297. typename ScalarType<nv_bfloat16>::FragB frag_b;
  298. float fp32_intermediates[4];
  299. uint32_t* fp32_intermediates_casted =
  300. reinterpret_cast<uint32_t*>(fp32_intermediates);
  301. static constexpr uint32_t fp32_base = 0x4B000000;
  302. fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
  303. fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
  304. fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
  305. fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
  306. fp32_intermediates[0] -= 8388608.f;
  307. fp32_intermediates[1] -= 8388608.f;
  308. fp32_intermediates[2] -= 8388608.f;
  309. fp32_intermediates[3] -= 8388608.f;
  310. uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b);
  311. bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
  312. fp32_intermediates_casted[1], 0x7632);
  313. bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
  314. fp32_intermediates_casted[3], 0x7632);
  315. return frag_b;
  316. }
  317. // Multiply dequantized values by the corresponding quantization scale; used
  318. // only for grouped quantization.
  319. template <typename scalar_t>
  320. __device__ inline void scale(typename ScalarType<scalar_t>::FragB& frag_b,
  321. typename ScalarType<scalar_t>::FragS& frag_s,
  322. int i) {
  323. using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
  324. scalar_t2 s =
  325. ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_s)[i]);
  326. frag_b[0] = __hmul2(frag_b[0], s);
  327. frag_b[1] = __hmul2(frag_b[1], s);
  328. }
  329. template <typename scalar_t>
  330. __device__ inline void sub_zp(typename ScalarType<scalar_t>::FragB& frag_b,
  331. typename ScalarType<scalar_t>::scalar_t2& frag_zp,
  332. int i) {
  333. using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
  334. scalar_t2 zp =
  335. ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_zp)[i]);
  336. frag_b[0] = __hsub2(frag_b[0], zp);
  337. frag_b[1] = __hsub2(frag_b[1], zp);
  338. }
  339. // Same as above, but for act_order (each K is multiplied individually)
  340. template <typename scalar_t>
  341. __device__ inline void scale4(typename ScalarType<scalar_t>::FragB& frag_b,
  342. typename ScalarType<scalar_t>::FragS& frag_s_1,
  343. typename ScalarType<scalar_t>::FragS& frag_s_2,
  344. typename ScalarType<scalar_t>::FragS& frag_s_3,
  345. typename ScalarType<scalar_t>::FragS& frag_s_4,
  346. int i) {
  347. using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
  348. scalar_t2 s_val_1_2;
  349. s_val_1_2.x = reinterpret_cast<scalar_t*>(&frag_s_1)[i];
  350. s_val_1_2.y = reinterpret_cast<scalar_t*>(&frag_s_2)[i];
  351. scalar_t2 s_val_3_4;
  352. s_val_3_4.x = reinterpret_cast<scalar_t*>(&frag_s_3)[i];
  353. s_val_3_4.y = reinterpret_cast<scalar_t*>(&frag_s_4)[i];
  354. frag_b[0] = __hmul2(frag_b[0], s_val_1_2);
  355. frag_b[1] = __hmul2(frag_b[1], s_val_3_4);
  356. }
  357. // Given 2 floats multiply by 2 scales (halves)
  358. template <typename scalar_t>
  359. __device__ inline void scale_float(float* c,
  360. typename ScalarType<scalar_t>::FragS& s) {
  361. scalar_t* s_ptr = reinterpret_cast<scalar_t*>(&s);
  362. c[0] = __fmul_rn(c[0], ScalarType<scalar_t>::num2float(s_ptr[0]));
  363. c[1] = __fmul_rn(c[1], ScalarType<scalar_t>::num2float(s_ptr[1]));
  364. }
  365. // Wait until barrier reaches `count`, then lock for current threadblock.
  366. __device__ inline void barrier_acquire(int* lock, int count) {
  367. if (threadIdx.x == 0) {
  368. int state = -1;
  369. do
  370. // Guarantee that subsequent writes by this threadblock will be visible
  371. // globally.
  372. asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
  373. : "=r"(state)
  374. : "l"(lock));
  375. while (state != count);
  376. }
  377. __syncthreads();
  378. }
  379. // Release barrier and increment visitation count.
  380. __device__ inline void barrier_release(int* lock, bool reset = false) {
  381. __syncthreads();
  382. if (threadIdx.x == 0) {
  383. if (reset) {
  384. lock[0] = 0;
  385. return;
  386. }
  387. int val = 1;
  388. // Make sure that all writes since acquiring this barrier are visible
  389. // globally, while releasing the barrier.
  390. asm volatile("fence.acq_rel.gpu;\n");
  391. asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
  392. :
  393. : "l"(lock), "r"(val));
  394. }
  395. }
  396. // For a given "a" of size [M,K] performs a permutation of the K columns based
  397. // on the given "perm" indices.
  398. __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
  399. int const* __restrict__ perm_int_ptr,
  400. int4* __restrict__ out_int4_ptr, int size_m,
  401. int size_k, int block_rows) {
  402. int start_row = block_rows * blockIdx.x;
  403. int finish_row = start_row + block_rows;
  404. if (finish_row > size_m) {
  405. finish_row = size_m;
  406. }
  407. int cur_block_rows = finish_row - start_row;
  408. int row_stride = size_k * sizeof(half) / 16;
  409. auto permute_row = [&](int row) {
  410. int iters = size_k / default_threads;
  411. int rest = size_k % default_threads;
  412. int offset = row * row_stride;
  413. half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + offset);
  414. half* out_half = reinterpret_cast<half*>(out_int4_ptr + offset);
  415. int base_k = 0;
  416. for (int i = 0; i < iters; i++) {
  417. int cur_k = base_k + threadIdx.x;
  418. int src_pos = perm_int_ptr[cur_k];
  419. out_half[cur_k] = a_row_half[src_pos];
  420. base_k += default_threads;
  421. }
  422. if (rest) {
  423. if (threadIdx.x < rest) {
  424. int cur_k = base_k + threadIdx.x;
  425. int src_pos = perm_int_ptr[cur_k];
  426. out_half[cur_k] = a_row_half[src_pos];
  427. }
  428. }
  429. };
  430. for (int i = 0; i < cur_block_rows; i++) {
  431. int cur_row = start_row + i;
  432. if (cur_row < size_m) {
  433. permute_row(cur_row);
  434. }
  435. }
  436. }
  437. template <typename scalar_t, // compute dtype, half or nv_float16
  438. const aphrodite::ScalarTypeId w_type_id, // weight ScalarType id
  439. const int threads, // number of threads in a threadblock
  440. const int thread_m_blocks, // number of 16x16 blocks in the m
  441. // dimension (batchsize) of the
  442. // threadblock
  443. const int thread_n_blocks, // same for n dimension (output)
  444. const int thread_k_blocks, // same for k dimension (reduction)
  445. const int stages, // number of stages for the async global->shared
  446. // fetch pipeline
  447. const bool has_act_order, // whether act_order is enabled
  448. const bool has_zp, // whether zero-points are enabled
  449. const int group_blocks = -1 // number of consecutive 16x16 blocks
  450. // with a separate quantization scale
  451. >
  452. __global__ void Marlin(
  453. const int4* __restrict__ A, // fp16 input matrix of shape mxk
  454. const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
  455. int4* __restrict__ C, // fp16 output buffer of shape mxn
  456. int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
  457. const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
  458. // (k/groupsize)xn
  459. const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
  460. // (k/groupsize)x(n/pack_factor)
  461. const int* __restrict__ g_idx, // int32 group indices of shape k
  462. int num_groups, // number of scale groups per output channel
  463. int prob_m, // batch dimension m
  464. int prob_n, // output dimension n
  465. int prob_k, // reduction dimension k
  466. int* locks, // extra global storage for barrier synchronization
  467. bool use_fp32_reduce // whether to use fp32 global reduce
  468. ) {
  469. // Each threadblock processes one "stripe" of the B matrix with (roughly) the
  470. // same size, which might involve multiple column "slices" (of width 16 *
  471. // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
  472. // example:
  473. // 0 1 3
  474. // 0 2 3
  475. // 1 2 4
  476. // While this kind of partitioning makes things somewhat more complicated, it
  477. // ensures good utilization of all SMs for many kinds of shape and GPU
  478. // configurations, while requiring as few slow global cross-threadblock
  479. // reductions as possible.
  480. using Dtype = ScalarType<scalar_t>;
  481. using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
  482. using FragA = typename ScalarType<scalar_t>::FragA;
  483. using FragB = typename ScalarType<scalar_t>::FragB;
  484. using FragC = typename ScalarType<scalar_t>::FragC;
  485. using FragS = typename ScalarType<scalar_t>::FragS;
  486. using FragZP = typename ScalarType<scalar_t>::FragZP;
  487. static constexpr auto w_type = aphrodite::ScalarType::from_id(w_type_id);
  488. constexpr int pack_factor = 32 / w_type.size_bits();
  489. // For larger GEMMs we run multiple batchsize 64 versions in parallel for a
  490. // better partitioning with less reductions
  491. int parallel = 1;
  492. if (prob_m > 16 * thread_m_blocks) {
  493. parallel = prob_m / (16 * thread_m_blocks);
  494. prob_m = 16 * thread_m_blocks;
  495. }
  496. int k_tiles = prob_k / 16 / thread_k_blocks;
  497. int n_tiles = prob_n / 16 / thread_n_blocks;
  498. int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x);
  499. if constexpr (!has_act_order && group_blocks != -1) {
  500. if (group_blocks >= thread_k_blocks) {
  501. // Ensure that the number of tiles in each stripe is a multiple of the
  502. // groupsize; this avoids an annoying special case where a stripe starts
  503. // in the middle of group.
  504. iters = (group_blocks / thread_k_blocks) *
  505. div_ceil(iters, (group_blocks / thread_k_blocks));
  506. }
  507. }
  508. int slice_row = (iters * blockIdx.x) % k_tiles;
  509. int slice_col_par = (iters * blockIdx.x) / k_tiles;
  510. int slice_col = slice_col_par;
  511. int slice_iters; // number of threadblock tiles in the current slice
  512. int slice_count =
  513. 0; // total number of active threadblocks in the current slice
  514. int slice_idx; // index of threadblock in current slice; numbered bottom to
  515. // top
  516. int par_id = 0;
  517. // We can easily implement parallel problem execution by just remapping
  518. // indices and advancing global pointers
  519. if (slice_col_par >= n_tiles) {
  520. A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8;
  521. C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
  522. locks += (slice_col_par / n_tiles) * n_tiles;
  523. slice_col = slice_col_par % n_tiles;
  524. par_id = slice_col_par / n_tiles;
  525. }
  526. // Compute all information about the current slice which is required for
  527. // synchronization.
  528. auto init_slice = [&]() {
  529. slice_iters =
  530. iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
  531. if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
  532. if (slice_iters == 0) return;
  533. if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
  534. slice_count = 1;
  535. slice_idx = 0;
  536. int col_first = iters * div_ceil(k_tiles * slice_col_par, iters);
  537. if (col_first <= k_tiles * (slice_col_par + 1)) {
  538. int col_off = col_first - k_tiles * slice_col_par;
  539. slice_count = div_ceil(k_tiles - col_off, iters);
  540. if (col_off > 0) slice_count++;
  541. int delta_first = iters * blockIdx.x - col_first;
  542. if (delta_first < 0 || (col_off == 0 && delta_first == 0))
  543. slice_idx = slice_count - 1;
  544. else {
  545. slice_idx = slice_count - 1 - delta_first / iters;
  546. if (col_off > 0) slice_idx--;
  547. }
  548. }
  549. if (slice_col == n_tiles) {
  550. A += 16 * thread_m_blocks * prob_k / 8;
  551. C += 16 * thread_m_blocks * prob_n / 8;
  552. locks += n_tiles;
  553. slice_col = 0;
  554. par_id++;
  555. }
  556. };
  557. init_slice();
  558. // A sizes/strides
  559. // stride of the A matrix in global memory
  560. int a_gl_stride = prob_k / 8;
  561. // stride of an A matrix tile in shared memory
  562. constexpr int a_sh_stride = 16 * thread_k_blocks / 8;
  563. // delta between subsequent A tiles in global memory
  564. constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8;
  565. // between subsequent accesses within a tile
  566. int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o);
  567. // between shared memory writes
  568. constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o);
  569. // between shared memory tile reads
  570. constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4));
  571. // within a shared memory tile
  572. constexpr int a_sh_rd_delta_i = a_sh_stride * 16;
  573. // overall size of a tile
  574. constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks);
  575. // number of shared write iterations for a tile
  576. constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta);
  577. // B sizes/strides
  578. int b_gl_stride = 16 * prob_n / (pack_factor * 4);
  579. constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4;
  580. constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2;
  581. constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;
  582. int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
  583. int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads);
  584. constexpr int b_sh_wr_delta = threads * b_thread_vecs;
  585. constexpr int b_sh_rd_delta = threads * b_thread_vecs;
  586. constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
  587. constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
  588. // Scale sizes/strides without act_order
  589. int s_gl_stride = prob_n / 8;
  590. constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
  591. constexpr int s_tb_groups =
  592. !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
  593. ? thread_k_blocks / group_blocks
  594. : 1;
  595. constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
  596. int s_gl_rd_delta = s_gl_stride;
  597. // Scale size/strides with act_order
  598. constexpr int tb_k = 16 * thread_k_blocks;
  599. constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0;
  600. // constexpr int act_s_row_stride = 1;
  601. // int act_s_col_stride = act_s_row_stride * num_groups;
  602. int act_s_col_stride = 1;
  603. int act_s_col_warp_stride = act_s_col_stride * 8;
  604. int tb_n_warps = thread_n_blocks / 4;
  605. int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;
  606. // Zero-points sizes/strides
  607. int zp_gl_stride = (prob_n / pack_factor) / 4;
  608. constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4;
  609. constexpr int zp_tb_groups = s_tb_groups;
  610. constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0;
  611. int zp_gl_rd_delta = zp_gl_stride;
  612. // Global A read index of current thread.
  613. int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
  614. (threadIdx.x % a_gl_rd_delta_o);
  615. a_gl_rd += a_gl_rd_delta_o * slice_row;
  616. // Shared write index of current thread.
  617. int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
  618. (threadIdx.x % a_gl_rd_delta_o);
  619. // Shared read index.
  620. int a_sh_rd =
  621. a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;
  622. a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
  623. int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) +
  624. (threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
  625. b_gl_rd += b_sh_stride * slice_col;
  626. b_gl_rd += b_gl_rd_delta_o * slice_row;
  627. int b_sh_wr = threadIdx.x * b_thread_vecs;
  628. int b_sh_rd = threadIdx.x * b_thread_vecs;
  629. // For act_order
  630. constexpr int k_iter_size = tb_k / b_sh_wr_iters;
  631. int slice_k_start = tb_k * slice_row;
  632. int slice_k_finish = slice_k_start + tb_k * slice_iters;
  633. int slice_k_start_shared_fetch = slice_k_start;
  634. int slice_n_offset = act_s_col_tb_stride * slice_col;
  635. // No act_order
  636. int s_gl_rd;
  637. if constexpr (!has_act_order) {
  638. if constexpr (group_blocks == -1) {
  639. s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
  640. } else {
  641. s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
  642. s_sh_stride * slice_col + threadIdx.x;
  643. }
  644. }
  645. int s_sh_wr = threadIdx.x;
  646. bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
  647. // Zero-points
  648. int zp_gl_rd;
  649. if constexpr (has_zp) {
  650. if constexpr (group_blocks == -1) {
  651. zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
  652. } else {
  653. zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
  654. zp_sh_stride * slice_col + threadIdx.x;
  655. }
  656. }
  657. int zp_sh_wr = threadIdx.x;
  658. bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride;
  659. // We use a different scale layout for grouped and column-wise quantization as
  660. // we scale a `half2` tile in column-major layout in the former and in
  661. // row-major in the latter case.
  662. int s_sh_rd;
  663. if constexpr (group_blocks != -1)
  664. s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
  665. (threadIdx.x % 32) / 4;
  666. else
  667. s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
  668. (threadIdx.x % 32) % 4;
  669. // Zero-points have the same read layout as the scales
  670. // (without column-wise case)
  671. constexpr int num_col_threads = 8;
  672. constexpr int num_row_threads = 4;
  673. constexpr int num_ints_per_thread = 8 / pack_factor;
  674. int zp_sh_rd;
  675. if constexpr (has_zp) {
  676. zp_sh_rd = num_ints_per_thread * num_col_threads *
  677. ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
  678. num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads);
  679. }
  680. // Precompute which thread should not read memory in which iterations; this is
  681. // needed if there are more threads than required for a certain tilesize or
  682. // when the batchsize is not a multiple of 16.
  683. bool a_sh_wr_pred[a_sh_wr_iters];
  684. #pragma unroll
  685. for (int i = 0; i < a_sh_wr_iters; i++)
  686. a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
  687. // To ensure that writing and reading A tiles to/from shared memory, the
  688. // latter in fragment format, is fully bank conflict free, we need to use a
  689. // rather fancy XOR-based layout. The key here is that neither reads nor
  690. // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
  691. // same shared memory banks. Further, it seems (based on NSight-Compute) that
  692. // each warp must also write a consecutive memory segment?
  693. auto transform_a = [&](int i) {
  694. int row = i / a_gl_rd_delta_o;
  695. return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
  696. };
  697. // Since the computation of this remapping is non-trivial and, due to our main
  698. // loop unrolls, all shared memory accesses are static, we simply precompute
  699. // both transformed reads and writes.
  700. int a_sh_wr_trans[a_sh_wr_iters];
  701. #pragma unroll
  702. for (int i = 0; i < a_sh_wr_iters; i++)
  703. a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
  704. int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
  705. #pragma unroll
  706. for (int i = 0; i < b_sh_wr_iters; i++) {
  707. #pragma unroll
  708. for (int j = 0; j < thread_m_blocks; j++)
  709. a_sh_rd_trans[i][j] =
  710. transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
  711. }
  712. // Since B-accesses have non-constant stride they have to be computed at
  713. // runtime; we break dependencies between subsequent accesses with a tile by
  714. // maintining multiple pointers (we have enough registers), a tiny
  715. // optimization.
  716. const int4* B_ptr[b_sh_wr_iters];
  717. #pragma unroll
  718. for (int i = 0; i < b_sh_wr_iters; i++)
  719. B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
  720. extern __shared__ int4 sh[];
  721. // Shared memory storage for global fetch pipelines.
  722. int4* sh_a = sh;
  723. int4* sh_b = sh_a + (stages * a_sh_stage);
  724. int4* sh_g_idx = sh_b + (stages * b_sh_stage);
  725. int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
  726. int4* sh_s = sh_zp + (stages * zp_sh_stage);
  727. // Register storage for double buffer of shared memory reads.
  728. FragA frag_a[2][thread_m_blocks];
  729. I4 frag_b_quant[2][b_thread_vecs];
  730. FragC frag_c[thread_m_blocks][4][2];
  731. FragS frag_s[2][4]; // No act-order
  732. FragS act_frag_s[2][4][4]; // For act-order
  733. int frag_qzp[2][num_ints_per_thread]; // Zero-points
  734. FragZP frag_zp; // Zero-points in fp16
  735. // Zero accumulators.
  736. auto zero_accums = [&]() {
  737. #pragma unroll
  738. for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
  739. reinterpret_cast<float*>(frag_c)[i] = 0;
  740. };
  741. int sh_first_group_id = -1;
  742. int sh_num_groups = -1;
  743. constexpr int sh_max_num_groups = 32;
  744. auto fetch_scales_to_shared = [&](bool is_async, int first_group_id,
  745. int last_group_id) {
  746. sh_first_group_id = first_group_id;
  747. sh_num_groups = last_group_id - first_group_id + 1;
  748. if (sh_num_groups < sh_max_num_groups) {
  749. sh_num_groups = sh_max_num_groups;
  750. }
  751. if (sh_first_group_id + sh_num_groups > num_groups) {
  752. sh_num_groups = num_groups - sh_first_group_id;
  753. }
  754. int row_offset = first_group_id * s_gl_stride;
  755. if (is_async) {
  756. for (int i = 0; i < sh_num_groups; i++) {
  757. if (threadIdx.x < s_sh_stride) {
  758. cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x],
  759. &scales_ptr[row_offset + (i * s_gl_stride) +
  760. slice_n_offset + threadIdx.x]);
  761. }
  762. }
  763. } else {
  764. for (int i = 0; i < sh_num_groups; i++) {
  765. if (threadIdx.x < s_sh_stride) {
  766. sh_s[(i * s_sh_stride) + threadIdx.x] =
  767. scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset +
  768. threadIdx.x];
  769. }
  770. }
  771. }
  772. };
  773. // Asynchronously fetch the next A, B and s tile from global to the next
  774. // shared memory pipeline location.
  775. auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
  776. if (pred) {
  777. int4* sh_a_stage = sh_a + a_sh_stage * pipe;
  778. #pragma unroll
  779. for (int i = 0; i < a_sh_wr_iters; i++) {
  780. cp_async4_pred(
  781. &sh_a_stage[a_sh_wr_trans[i]],
  782. &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
  783. a_sh_wr_pred[i]);
  784. }
  785. int4* sh_b_stage = sh_b + b_sh_stage * pipe;
  786. #pragma unroll
  787. for (int i = 0; i < b_sh_wr_iters; i++) {
  788. #pragma unroll
  789. for (int j = 0; j < b_thread_vecs; j++) {
  790. cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j);
  791. }
  792. B_ptr[i] += b_gl_rd_delta_o;
  793. }
  794. if constexpr (has_act_order) {
  795. // Fetch g_idx thread-block portion
  796. int full_pipe = a_off;
  797. int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe;
  798. if (cur_k < prob_k && cur_k < slice_k_finish) {
  799. int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
  800. int4 const* cur_g_idx_stage_ptr =
  801. reinterpret_cast<int4 const*>(&g_idx[cur_k]);
  802. if (threadIdx.x < g_idx_stage) {
  803. cp_async4_pred(&sh_g_idx_stage[threadIdx.x],
  804. &cur_g_idx_stage_ptr[threadIdx.x]);
  805. }
  806. }
  807. } else {
  808. if constexpr (group_blocks != -1) {
  809. int4* sh_s_stage = sh_s + s_sh_stage * pipe;
  810. if constexpr (group_blocks >= thread_k_blocks) {
  811. // Only fetch scales if this tile starts a new group
  812. if (pipe % (group_blocks / thread_k_blocks) == 0) {
  813. if (s_sh_wr_pred) {
  814. cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
  815. }
  816. s_gl_rd += s_gl_rd_delta;
  817. }
  818. } else {
  819. for (int i = 0; i < s_tb_groups; i++) {
  820. if (s_sh_wr_pred) {
  821. cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr],
  822. &scales_ptr[s_gl_rd]);
  823. }
  824. s_gl_rd += s_gl_rd_delta;
  825. }
  826. }
  827. }
  828. if constexpr (has_zp && group_blocks != -1) {
  829. int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
  830. if constexpr (group_blocks >= thread_k_blocks) {
  831. // Only fetch zero-points if this tile starts a new group
  832. if (pipe % (group_blocks / thread_k_blocks) == 0) {
  833. if (zp_sh_wr_pred) {
  834. cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]);
  835. }
  836. zp_gl_rd += zp_gl_rd_delta;
  837. }
  838. } else {
  839. for (int i = 0; i < zp_tb_groups; i++) {
  840. if (zp_sh_wr_pred) {
  841. cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr],
  842. &zp_ptr[zp_gl_rd]);
  843. }
  844. zp_gl_rd += zp_gl_rd_delta;
  845. }
  846. }
  847. }
  848. }
  849. }
  850. // Insert a fence even when we are winding down the pipeline to ensure that
  851. // waiting is also correct at this point.
  852. cp_async_fence();
  853. };
  854. auto fetch_zp_to_shared = [&]() {
  855. if (zp_sh_wr_pred) {
  856. cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]);
  857. }
  858. };
  859. // Wait until the next thread tile has been loaded to shared memory.
  860. auto wait_for_stage = [&]() {
  861. // We only have `stages - 2` active fetches since we are double buffering
  862. // and can only issue the next fetch when it is guaranteed that the previous
  863. // shared memory load is fully complete (as it may otherwise be
  864. // overwritten).
  865. cp_async_wait<stages - 2>();
  866. __syncthreads();
  867. };
  868. // Load the next sub-tile from the current location in the shared memory pipe
  869. // into the current register buffer.
  870. auto fetch_to_registers = [&](int k, int pipe) {
  871. int4* sh_a_stage = sh_a + a_sh_stage * pipe;
  872. #pragma unroll
  873. for (int i = 0; i < thread_m_blocks; i++)
  874. ldsm4<scalar_t>(frag_a[k % 2][i],
  875. &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
  876. int4* sh_b_stage = sh_b + b_sh_stage * pipe;
  877. #pragma unroll
  878. for (int i = 0; i < b_thread_vecs; i++) {
  879. frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>(
  880. &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]);
  881. }
  882. };
  883. bool is_same_group[stages];
  884. int same_group_id[stages];
  885. auto init_same_group = [&](int pipe) {
  886. if constexpr (!has_act_order) {
  887. is_same_group[pipe] = false;
  888. same_group_id[pipe] = 0;
  889. return;
  890. }
  891. int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
  892. int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);
  893. int group_id_1 = sh_g_idx_int_ptr[0];
  894. int group_id_2 = sh_g_idx_int_ptr[tb_k - 1];
  895. is_same_group[pipe] = group_id_1 == group_id_2;
  896. same_group_id[pipe] = group_id_1;
  897. };
  898. auto fetch_scales_to_registers = [&](int k, int full_pipe) {
  899. int pipe = full_pipe % stages;
  900. if constexpr (!has_act_order) {
  901. // No act-order case
  902. if constexpr (group_blocks != -1) {
  903. if constexpr (group_blocks >= thread_k_blocks) {
  904. int4* sh_s_stage =
  905. sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
  906. (pipe / (group_blocks / thread_k_blocks)));
  907. reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
  908. } else {
  909. int warp_id = threadIdx.x / 32;
  910. int n_warps = thread_n_blocks / 4;
  911. int warp_row = warp_id / n_warps;
  912. int cur_k = warp_row * 16;
  913. cur_k += k_iter_size * (k % b_sh_wr_iters);
  914. int k_blocks = cur_k / 16;
  915. int cur_group_id = k_blocks / group_blocks;
  916. int4* sh_s_stage = sh_s + s_sh_stage * pipe;
  917. reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
  918. sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
  919. }
  920. }
  921. return;
  922. }
  923. // Act-order case
  924. // Determine K of the "current" thread-block
  925. int cur_k = slice_k_start + tb_k * full_pipe;
  926. if (cur_k >= prob_k || cur_k >= slice_k_finish) {
  927. return;
  928. }
  929. // Reset (to current thread-block) since we read g_idx portion from the
  930. // shared memory
  931. cur_k = 0;
  932. // Progress to current iteration
  933. cur_k += k_iter_size * (k % b_sh_wr_iters);
  934. // Determine "position" inside the thread-block (based on warp and
  935. // thread-id)
  936. int warp_id = threadIdx.x / 32;
  937. int n_warps =
  938. thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N
  939. int warp_row = warp_id / n_warps;
  940. int warp_col = warp_id % n_warps;
  941. cur_k += warp_row * 16;
  942. int th_id = threadIdx.x % 32;
  943. cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix
  944. int s_col_shift =
  945. /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) +
  946. (th_id / 4) * act_s_col_stride;
  947. if (is_same_group[pipe]) {
  948. if (k % 2 == 0) {
  949. *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) =
  950. sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride +
  951. s_col_shift];
  952. } else {
  953. *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) =
  954. *(reinterpret_cast<int4*>(&(act_frag_s[(k - 1) % 2][0][0])));
  955. }
  956. for (int i = 1; i < 4; i++) {
  957. *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) =
  958. *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0])));
  959. }
  960. return;
  961. }
  962. int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
  963. int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);
  964. constexpr int k_frag_offsets[4] = {0, 1, 8,
  965. 9}; // Tensor core offsets per thread
  966. #pragma unroll
  967. for (int i = 0; i < 4; i++) {
  968. int actual_k = cur_k + k_frag_offsets[i];
  969. int group_id = sh_g_idx_int_ptr[actual_k];
  970. int rel_group_id = group_id - sh_first_group_id;
  971. *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) =
  972. sh_s[rel_group_id * s_sh_stride + s_col_shift];
  973. }
  974. };
  975. auto fetch_zp_to_registers = [&](int k, int full_pipe) {
  976. // This code does not handle group_blocks == 0,
  977. // which signifies act_order.
  978. // has_zp implies AWQ, which doesn't have act_order,
  979. static_assert(!has_zp || group_blocks != 0);
  980. if constexpr (has_zp) {
  981. int pipe = full_pipe % stages;
  982. if constexpr (group_blocks == -1) {
  983. for (int i = 0; i < num_ints_per_thread; i++) {
  984. frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp))[zp_sh_rd + i];
  985. }
  986. } else if constexpr (group_blocks >= thread_k_blocks) {
  987. int4* sh_zp_stage =
  988. sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
  989. (pipe / (group_blocks / thread_k_blocks)));
  990. for (int i = 0; i < num_ints_per_thread; i++) {
  991. frag_qzp[k % 2][i] =
  992. (reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
  993. }
  994. } else {
  995. int warp_id = threadIdx.x / 32;
  996. int n_warps = thread_n_blocks / 4;
  997. int warp_row = warp_id / n_warps;
  998. int cur_k = warp_row * 16;
  999. cur_k += k_iter_size * (k % b_sh_wr_iters);
  1000. int k_blocks = cur_k / 16;
  1001. int cur_group_id = 0;
  1002. // Suppress bogus and persistent divide-by-zero warning
  1003. #pragma nv_diagnostic push
  1004. #pragma nv_diag_suppress divide_by_zero
  1005. cur_group_id = k_blocks / group_blocks;
  1006. #pragma nv_diagnostic pop
  1007. int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
  1008. sh_zp_stage += cur_group_id * zp_sh_stride;
  1009. for (int i = 0; i < num_ints_per_thread; i++) {
  1010. frag_qzp[k % 2][i] =
  1011. (reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
  1012. }
  1013. }
  1014. }
  1015. };
  1016. // Execute the actual tensor core matmul of a sub-tile.
  1017. auto matmul = [&](int k) {
  1018. if constexpr (has_zp) {
  1019. FragB frag_zp_0;
  1020. FragB frag_zp_1;
  1021. int zp_quant_0, zp_quant_1;
  1022. if constexpr (w_type.size_bits() == 4) {
  1023. zp_quant_0 = frag_qzp[k % 2][0];
  1024. zp_quant_1 = zp_quant_0 >> 8;
  1025. } else {
  1026. static_assert(w_type.size_bits() == 8);
  1027. zp_quant_0 = frag_qzp[k % 2][0];
  1028. zp_quant_1 = frag_qzp[k % 2][1];
  1029. }
  1030. frag_zp_0 = dequant<scalar_t, w_type_id>(zp_quant_0);
  1031. frag_zp_1 = dequant<scalar_t, w_type_id>(zp_quant_1);
  1032. frag_zp[0] = frag_zp_0[0];
  1033. frag_zp[1] = frag_zp_0[1];
  1034. frag_zp[2] = frag_zp_1[0];
  1035. frag_zp[3] = frag_zp_1[1];
  1036. }
  1037. // We have the m dimension as the inner loop in order to encourage overlapping
  1038. // dequantization and matmul operations.
  1039. #pragma unroll
  1040. for (int j = 0; j < 4; j++) {
  1041. FragB frag_b0;
  1042. FragB frag_b1;
  1043. int b_quant_0, b_quant_1;
  1044. if constexpr (w_type.size_bits() == 4) {
  1045. b_quant_0 = frag_b_quant[k % 2][0][j];
  1046. b_quant_1 = b_quant_0 >> 8;
  1047. } else {
  1048. static_assert(w_type.size_bits() == 8);
  1049. int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);
  1050. b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
  1051. b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
  1052. }
  1053. frag_b0 = dequant<scalar_t, w_type_id>(b_quant_0);
  1054. frag_b1 = dequant<scalar_t, w_type_id>(b_quant_1);
  1055. // Apply zero-point to frag_b0
  1056. if constexpr (has_zp) {
  1057. sub_zp<scalar_t>(frag_b0, frag_zp[j], 0);
  1058. }
  1059. // Apply scale to frag_b0
  1060. if constexpr (has_act_order) {
  1061. scale4<scalar_t>(frag_b0, act_frag_s[k % 2][0][j],
  1062. act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j],
  1063. act_frag_s[k % 2][3][j], 0);
  1064. } else {
  1065. if constexpr (group_blocks != -1) {
  1066. scale<scalar_t>(frag_b0, frag_s[k % 2][j], 0);
  1067. }
  1068. }
  1069. // Apply zero-point to frag_b1
  1070. if constexpr (has_zp) {
  1071. sub_zp<scalar_t>(frag_b1, frag_zp[j], 1);
  1072. }
  1073. // Apply scale to frag_b1
  1074. if constexpr (has_act_order) {
  1075. scale4<scalar_t>(frag_b1, act_frag_s[k % 2][0][j],
  1076. act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j],
  1077. act_frag_s[k % 2][3][j], 1);
  1078. } else {
  1079. if constexpr (group_blocks != -1) {
  1080. scale<scalar_t>(frag_b1, frag_s[k % 2][j], 1);
  1081. }
  1082. }
  1083. #pragma unroll
  1084. for (int i = 0; i < thread_m_blocks; i++) {
  1085. mma<scalar_t>(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
  1086. mma<scalar_t>(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
  1087. }
  1088. }
  1089. };
  1090. // Since we slice across the k dimension of a tile in order to increase the
  1091. // number of warps while keeping the n dimension of a tile reasonable, we have
  1092. // multiple warps that accumulate their partial sums of the same output
  1093. // location; which we have to reduce over in the end. We do in shared memory.
  1094. auto thread_block_reduce = [&]() {
  1095. constexpr int red_off = threads / b_sh_stride_threads / 2;
  1096. if (red_off >= 1) {
  1097. int red_idx = threadIdx.x / b_sh_stride_threads;
  1098. constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
  1099. constexpr int red_sh_delta = b_sh_stride_threads;
  1100. int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
  1101. (threadIdx.x % b_sh_stride_threads);
  1102. // Parallel logarithmic shared memory reduction. We make sure to avoid any
  1103. // unnecessary read or write iterations, e.g., for two warps we write only
  1104. // once by warp 1 and read only once by warp 0.
  1105. #pragma unroll
  1106. for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
  1107. #pragma unroll
  1108. for (int i = red_off; i > 0; i /= 2) {
  1109. if (i <= red_idx && red_idx < 2 * i) {
  1110. #pragma unroll
  1111. for (int j = 0; j < 4 * 2; j++) {
  1112. int red_sh_wr =
  1113. red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
  1114. if (i < red_off) {
  1115. float* c_rd =
  1116. reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
  1117. float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
  1118. #pragma unroll
  1119. for (int k = 0; k < 4; k++)
  1120. reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
  1121. c_rd[k] + c_wr[k];
  1122. }
  1123. sh[red_sh_wr] =
  1124. reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
  1125. }
  1126. }
  1127. __syncthreads();
  1128. }
  1129. if (red_idx == 0) {
  1130. #pragma unroll
  1131. for (int i = 0; i < 4 * 2; i++) {
  1132. float* c_rd =
  1133. reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
  1134. #pragma unroll
  1135. for (int j = 0; j < 4; j++)
  1136. reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
  1137. c_rd[j];
  1138. }
  1139. }
  1140. __syncthreads();
  1141. }
  1142. }
  1143. };
  1144. // Since multiple threadblocks may process parts of the same column slice, we
  1145. // finally have to globally reduce over the results. As the striped
  1146. // partitioning minimizes the number of such reductions and our outputs are
  1147. // usually rather small, we perform this reduction serially in L2 cache.
  1148. auto global_reduce_fp16 = [&](bool first = false, bool last = false) {
  1149. // We are very careful here to reduce directly in the output buffer to
  1150. // maximize L2 cache utilization in this step. To do this, we write out
  1151. // results in FP16 (but still reduce with FP32 compute).
  1152. constexpr int active_threads = 32 * thread_n_blocks / 4;
  1153. if (threadIdx.x < active_threads) {
  1154. int c_gl_stride = prob_n / 8;
  1155. int c_gl_wr_delta_o = 8 * c_gl_stride;
  1156. int c_gl_wr_delta_i = 4 * (active_threads / 32);
  1157. int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) +
  1158. 4 * (threadIdx.x / 32) + threadIdx.x % 4;
  1159. c_gl_wr += (2 * thread_n_blocks) * slice_col;
  1160. constexpr int c_sh_wr_delta = active_threads;
  1161. int c_sh_wr = threadIdx.x;
  1162. int row = (threadIdx.x % 32) / 4;
  1163. if (!first) {
  1164. // Interestingly, doing direct global accesses here really seems to mess up
  1165. // the compiler and lead to slowdowns, hence we also use async-copies even
  1166. // though these fetches are not actually asynchronous.
  1167. #pragma unroll
  1168. for (int i = 0; i < thread_m_blocks * 4; i++) {
  1169. cp_async4_pred(
  1170. &sh[c_sh_wr + c_sh_wr_delta * i],
  1171. &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
  1172. c_gl_wr_delta_i * (i % 2)],
  1173. i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
  1174. }
  1175. cp_async_fence();
  1176. cp_async_wait<0>();
  1177. }
  1178. #pragma unroll
  1179. for (int i = 0; i < thread_m_blocks * 4; i++) {
  1180. if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
  1181. if (!first) {
  1182. int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
  1183. #pragma unroll
  1184. for (int j = 0; j < 2 * 4; j++) {
  1185. reinterpret_cast<float*>(
  1186. &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=
  1187. Dtype::num2float(reinterpret_cast<scalar_t*>(&c_red)[j]);
  1188. }
  1189. }
  1190. if (!last) {
  1191. int4 c;
  1192. #pragma unroll
  1193. for (int j = 0; j < 2 * 4; j++) {
  1194. reinterpret_cast<scalar_t*>(&c)[j] =
  1195. Dtype::float2num(reinterpret_cast<float*>(
  1196. &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]);
  1197. }
  1198. C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =
  1199. c;
  1200. }
  1201. }
  1202. }
  1203. }
  1204. };
  1205. // Globally reduce over threadblocks that compute the same column block.
  1206. // We use a tmp C buffer to reduce in full fp32 precision.
  1207. auto global_reduce_fp32 = [&](bool first = false, bool last = false) {
  1208. constexpr int tb_m = thread_m_blocks * 16;
  1209. constexpr int tb_n = thread_n_blocks * 16;
  1210. constexpr int c_size = tb_m * tb_n * sizeof(float) / 16;
  1211. constexpr int active_threads = 32 * thread_n_blocks / 4;
  1212. bool is_th_active = threadIdx.x < active_threads;
  1213. int par_offset = c_size * n_tiles * par_id;
  1214. int slice_offset = c_size * slice_col;
  1215. constexpr int num_floats = thread_m_blocks * 4 * 2 * 4;
  1216. constexpr int th_size = num_floats * sizeof(float) / 16;
  1217. int c_cur_offset = par_offset + slice_offset;
  1218. if (!is_th_active) {
  1219. return;
  1220. }
  1221. if (!first) {
  1222. float* frag_c_ptr = reinterpret_cast<float*>(&frag_c);
  1223. #pragma unroll
  1224. for (int k = 0; k < th_size; k++) {
  1225. sh[threadIdx.x] =
  1226. C_tmp[c_cur_offset + active_threads * k + threadIdx.x];
  1227. float* sh_c_ptr = reinterpret_cast<float*>(&sh[threadIdx.x]);
  1228. #pragma unroll
  1229. for (int f = 0; f < 4; f++) {
  1230. frag_c_ptr[k * 4 + f] += sh_c_ptr[f];
  1231. }
  1232. }
  1233. }
  1234. if (!last) {
  1235. int4* frag_c_ptr = reinterpret_cast<int4*>(&frag_c);
  1236. #pragma unroll
  1237. for (int k = 0; k < th_size; k++) {
  1238. C_tmp[c_cur_offset + active_threads * k + threadIdx.x] = frag_c_ptr[k];
  1239. }
  1240. }
  1241. };
  1242. // Write out the reduce final result in the correct layout. We only actually
  1243. // reshuffle matrix fragments in this step, the reduction above is performed
  1244. // in fragment layout.
  1245. auto write_result = [&]() {
  1246. int c_gl_stride = prob_n / 8;
  1247. constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
  1248. int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
  1249. constexpr int c_sh_rd_delta =
  1250. c_sh_stride * (threads / (2 * thread_n_blocks));
  1251. int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) +
  1252. (threadIdx.x % (2 * thread_n_blocks));
  1253. c_gl_wr += (2 * thread_n_blocks) * slice_col;
  1254. int c_sh_wr =
  1255. (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;
  1256. c_sh_wr += 32 * (threadIdx.x / 32);
  1257. int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) +
  1258. (threadIdx.x % (2 * thread_n_blocks));
  1259. int c_gl_wr_end = c_gl_stride * prob_m;
  1260. // We first reorder in shared memory to guarantee the most efficient final
  1261. // global write patterns
  1262. auto write = [&](int idx, float c0, float c1, FragS& s) {
  1263. scalar_t2 res =
  1264. Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));
  1265. // For per-column quantization we finally apply the scale here (only for
  1266. // 4-bit)
  1267. if constexpr (!has_act_order && group_blocks == -1 &&
  1268. w_type.size_bits() == 4) {
  1269. res = __hmul2(res, s[0]);
  1270. }
  1271. ((scalar_t2*)sh)[idx] = res;
  1272. };
  1273. if (threadIdx.x / 32 < thread_n_blocks / 4) {
  1274. #pragma unroll
  1275. for (int i = 0; i < thread_m_blocks; i++) {
  1276. #pragma unroll
  1277. for (int j = 0; j < 4; j++) {
  1278. int wr = c_sh_wr + 8 * j;
  1279. write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
  1280. frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
  1281. write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],
  1282. frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
  1283. write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],
  1284. frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
  1285. write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],
  1286. frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
  1287. }
  1288. c_sh_wr += 16 * (4 * c_sh_stride);
  1289. }
  1290. }
  1291. __syncthreads();
  1292. #pragma unroll
  1293. for (int i = 0;
  1294. i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
  1295. i++) {
  1296. if (c_gl_wr < c_gl_wr_end) {
  1297. C[c_gl_wr] = sh[c_sh_rd];
  1298. c_gl_wr += c_gl_wr_delta;
  1299. c_sh_rd += c_sh_rd_delta;
  1300. }
  1301. }
  1302. };
  1303. // Start global fetch and register load pipelines.
  1304. auto start_pipes = [&]() {
  1305. #pragma unroll
  1306. for (int i = 0; i < stages - 1; i++) {
  1307. if (has_act_order && i == 0) {
  1308. int last_g_idx = slice_k_start + stages * tb_k * 2;
  1309. if (last_g_idx >= prob_k) {
  1310. last_g_idx = prob_k - 1;
  1311. }
  1312. fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]);
  1313. }
  1314. if constexpr (has_zp && group_blocks == -1) {
  1315. if (i == 0) {
  1316. fetch_zp_to_shared();
  1317. }
  1318. }
  1319. fetch_to_shared(i, i, i < slice_iters);
  1320. }
  1321. zero_accums();
  1322. wait_for_stage();
  1323. init_same_group(0);
  1324. fetch_to_registers(0, 0);
  1325. fetch_scales_to_registers(0, 0);
  1326. fetch_zp_to_registers(0, 0);
  1327. a_gl_rd += a_gl_rd_delta_o * (stages - 1);
  1328. slice_k_start_shared_fetch += tb_k * (stages - 1);
  1329. };
  1330. if (slice_iters) {
  1331. start_pipes();
  1332. }
  1333. // Main loop.
  1334. while (slice_iters) {
  1335. // We unroll over both the global fetch and the register load pipeline to
  1336. // ensure all shared memory accesses are static. Note that both pipelines
  1337. // have even length meaning that the next iteration will always start at
  1338. // index 0.
  1339. #pragma unroll
  1340. for (int pipe = 0; pipe < stages;) {
  1341. #pragma unroll
  1342. for (int k = 0; k < b_sh_wr_iters; k++) {
  1343. fetch_to_registers(k + 1, pipe % stages);
  1344. fetch_scales_to_registers(k + 1, pipe);
  1345. fetch_zp_to_registers(k + 1, pipe);
  1346. if (k == b_sh_wr_iters - 2) {
  1347. fetch_to_shared((pipe + stages - 1) % stages, pipe,
  1348. slice_iters >= stages);
  1349. pipe++;
  1350. wait_for_stage();
  1351. init_same_group(pipe % stages);
  1352. }
  1353. matmul(k);
  1354. }
  1355. slice_iters--;
  1356. if (slice_iters == 0) {
  1357. break;
  1358. }
  1359. }
  1360. a_gl_rd += a_gl_rd_delta_o * stages;
  1361. slice_k_start += tb_k * stages;
  1362. slice_k_start_shared_fetch += tb_k * stages;
  1363. if constexpr (has_act_order) {
  1364. int first_group_id = g_idx[slice_k_start];
  1365. int last_g_idx = slice_k_start + stages * tb_k * 2;
  1366. if (last_g_idx >= prob_k) {
  1367. last_g_idx = prob_k - 1;
  1368. }
  1369. int last_group_id = g_idx[last_g_idx];
  1370. if (last_group_id >= sh_first_group_id + sh_num_groups) {
  1371. fetch_scales_to_shared(false, first_group_id, last_group_id);
  1372. __syncthreads();
  1373. }
  1374. }
  1375. // Process results and, if necessary, proceed to the next column slice.
  1376. // While this pattern may not be the most readable, other ways of writing
  1377. // the loop seemed to noticeably worse performance after compilation.
  1378. if (slice_iters == 0) {
  1379. cp_async_wait<0>();
  1380. bool last = slice_idx == slice_count - 1;
  1381. // For per-column scales, we only fetch them here in the final step before
  1382. // write-out
  1383. if constexpr (!has_act_order && group_blocks == -1) {
  1384. if constexpr (w_type.size_bits() == 8) {
  1385. if (s_sh_wr_pred) {
  1386. cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
  1387. }
  1388. cp_async_fence();
  1389. } else {
  1390. if (last) {
  1391. if (s_sh_wr_pred) {
  1392. cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
  1393. }
  1394. cp_async_fence();
  1395. }
  1396. }
  1397. }
  1398. thread_block_reduce();
  1399. if constexpr (!has_act_order && group_blocks == -1) {
  1400. if constexpr (w_type.size_bits() == 8) {
  1401. cp_async_wait<0>();
  1402. __syncthreads();
  1403. if (threadIdx.x / 32 < thread_n_blocks / 4) {
  1404. reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
  1405. reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
  1406. }
  1407. } else {
  1408. if (last) {
  1409. cp_async_wait<0>();
  1410. __syncthreads();
  1411. if (threadIdx.x / 32 < thread_n_blocks / 4) {
  1412. reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
  1413. reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
  1414. }
  1415. }
  1416. }
  1417. }
  1418. // For 8-bit channelwise, we apply the scale before the global reduction
  1419. // that converts the fp32 results to fp16 (so that we avoid possible
  1420. // overflow in fp16)
  1421. if constexpr (!has_act_order && group_blocks == -1 &&
  1422. w_type.size_bits() == 8) {
  1423. if (threadIdx.x / 32 < thread_n_blocks / 4) {
  1424. #pragma unroll
  1425. for (int i = 0; i < thread_m_blocks; i++) {
  1426. #pragma unroll
  1427. for (int j = 0; j < 4; j++) {
  1428. scale_float<scalar_t>(
  1429. reinterpret_cast<float*>(&frag_c[i][j][0][0]),
  1430. frag_s[j / 2][2 * (j % 2) + 0]);
  1431. scale_float<scalar_t>(
  1432. reinterpret_cast<float*>(&frag_c[i][j][0][2]),
  1433. frag_s[j / 2][2 * (j % 2) + 0]);
  1434. scale_float<scalar_t>(
  1435. reinterpret_cast<float*>(&frag_c[i][j][1][0]),
  1436. frag_s[j / 2][2 * (j % 2) + 1]);
  1437. scale_float<scalar_t>(
  1438. reinterpret_cast<float*>(&frag_c[i][j][1][2]),
  1439. frag_s[j / 2][2 * (j % 2) + 1]);
  1440. }
  1441. }
  1442. }
  1443. }
  1444. if (slice_count > 1) { // only globally reduce if there is more than one
  1445. // block in a slice
  1446. barrier_acquire(&locks[slice_col], slice_idx);
  1447. if (use_fp32_reduce) {
  1448. global_reduce_fp32(slice_idx == 0, last);
  1449. } else {
  1450. global_reduce_fp16(slice_idx == 0, last);
  1451. }
  1452. barrier_release(&locks[slice_col], last);
  1453. }
  1454. if (last) // only the last block in a slice actually writes the result
  1455. write_result();
  1456. slice_row = 0;
  1457. slice_col_par++;
  1458. slice_col++;
  1459. init_slice();
  1460. if (slice_iters) {
  1461. a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
  1462. (threadIdx.x % a_gl_rd_delta_o);
  1463. #pragma unroll
  1464. for (int i = 0; i < b_sh_wr_iters; i++)
  1465. B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
  1466. if (slice_col == 0) {
  1467. #pragma unroll
  1468. for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
  1469. }
  1470. // Update slice k/n for scales loading
  1471. if constexpr (has_act_order) {
  1472. slice_k_start = tb_k * slice_row;
  1473. slice_k_finish = slice_k_start + tb_k * slice_iters;
  1474. slice_k_start_shared_fetch = slice_k_start;
  1475. slice_n_offset = act_s_col_tb_stride * slice_col;
  1476. } else {
  1477. s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
  1478. zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
  1479. }
  1480. start_pipes();
  1481. }
  1482. }
  1483. }
  1484. }
  1485. #define __CALL_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
  1486. HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \
  1487. else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
  1488. thread_n_blocks == THREAD_N_BLOCKS && \
  1489. thread_k_blocks == THREAD_K_BLOCKS && \
  1490. has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
  1491. group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
  1492. cudaFuncSetAttribute( \
  1493. Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
  1494. THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
  1495. HAS_ZP, GROUP_BLOCKS>, \
  1496. cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
  1497. Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
  1498. THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
  1499. HAS_ZP, GROUP_BLOCKS> \
  1500. <<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
  1501. A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \
  1502. num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \
  1503. }
  1504. typedef struct {
  1505. int thread_k;
  1506. int thread_n;
  1507. int num_threads;
  1508. } thread_config_t;
  1509. typedef struct {
  1510. int max_m_blocks;
  1511. thread_config_t tb_cfg;
  1512. } exec_config_t;
  1513. thread_config_t small_batch_thread_configs[] = {
  1514. // Ordered by priority
  1515. // thread_k, thread_n, num_threads
  1516. {128, 128, 256},
  1517. {64, 128, 128},
  1518. {128, 64, 128},
  1519. };
  1520. thread_config_t large_batch_thread_configs[] = {
  1521. // Ordered by priority
  1522. // thread_k, thread_n, num_threads
  1523. {64, 256, 256},
  1524. {64, 128, 128},
  1525. {128, 64, 128},
  1526. };
  1527. int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
  1528. int prob_n, int prob_k, int num_bits, int group_size,
  1529. bool has_act_order, bool is_k_full) {
  1530. bool cache_scales_chunk = has_act_order && !is_k_full;
  1531. int tb_n = th_config.thread_n;
  1532. int tb_k = th_config.thread_k;
  1533. // Get max scale groups per thread-block
  1534. int tb_groups;
  1535. if (group_size == -1) {
  1536. tb_groups = 1;
  1537. } else if (group_size == 0) {
  1538. tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size
  1539. } else {
  1540. tb_groups = div_ceil(tb_k, group_size);
  1541. }
  1542. if (cache_scales_chunk) {
  1543. int load_groups =
  1544. tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
  1545. load_groups = max(load_groups, 32); // We load at least 32 scale groups
  1546. return load_groups * tb_n * 2;
  1547. } else {
  1548. int tb_scales = tb_groups * tb_n * 2;
  1549. return tb_scales * pipe_stages;
  1550. }
  1551. }
  1552. bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,
  1553. int prob_m, int prob_n, int prob_k, int num_bits,
  1554. int scales_cache_size, int max_shared_mem) {
  1555. int pack_factor = 32 / num_bits;
  1556. // Get B size
  1557. int tb_k = th_config.thread_k;
  1558. int tb_n = th_config.thread_n;
  1559. int b_size = (tb_k * tb_n / pack_factor) * 4;
  1560. // Get A size
  1561. int m_blocks = div_ceil(prob_m, 16);
  1562. int tb_max_m = 16;
  1563. while (true) {
  1564. if (m_blocks >= max_m_blocks) {
  1565. tb_max_m *= max_m_blocks;
  1566. break;
  1567. }
  1568. max_m_blocks--;
  1569. if (max_m_blocks == 0) {
  1570. TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks);
  1571. }
  1572. }
  1573. int a_size = (tb_max_m * tb_k) * 2;
  1574. float pipe_size = (a_size + b_size) * pipe_stages;
  1575. TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity
  1576. return pipe_size < 0.95f * (max_shared_mem - scales_cache_size);
  1577. }
  1578. bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
  1579. int prob_m, int prob_n, int prob_k, int num_bits,
  1580. int group_size, bool has_act_order, bool is_k_full,
  1581. int max_shared_mem) {
  1582. // Sanity
  1583. if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
  1584. th_config.num_threads == -1) {
  1585. return false;
  1586. }
  1587. // Verify K/N are divisible by thread K/N
  1588. if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
  1589. return false;
  1590. }
  1591. // Verify min for thread K/N
  1592. if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
  1593. return false;
  1594. }
  1595. // num_threads must be at least 128 (= 4 warps)
  1596. if (th_config.num_threads < 128) {
  1597. return false;
  1598. }
  1599. // Determine cache for scales
  1600. int scales_cache_size =
  1601. get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
  1602. group_size, has_act_order, is_k_full);
  1603. // Check that pipeline fits into cache
  1604. if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k,
  1605. num_bits, scales_cache_size, max_shared_mem)) {
  1606. return false;
  1607. }
  1608. return true;
  1609. }
  1610. int determine_reduce_max_m(int prob_m, int max_par) {
  1611. constexpr int tile_m_size = 16;
  1612. if (prob_m <= tile_m_size) {
  1613. return tile_m_size;
  1614. } else if (prob_m <= tile_m_size * 2) {
  1615. return tile_m_size * 2;
  1616. } else if (prob_m <= tile_m_size * 3) {
  1617. return tile_m_size * 3;
  1618. } else if (prob_m <= tile_m_size * 4) {
  1619. return tile_m_size * 4;
  1620. } else {
  1621. int cur_par = min(div_ceil(prob_m, tile_m_size * 4), max_par);
  1622. return tile_m_size * 4 * cur_par;
  1623. }
  1624. }
  1625. exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
  1626. int num_bits, int group_size,
  1627. bool has_act_order, bool is_k_full,
  1628. int max_shared_mem) {
  1629. int max_m_blocks = 4;
  1630. while (max_m_blocks > 0) {
  1631. if (prob_m <= 16) {
  1632. for (auto th_config : small_batch_thread_configs) {
  1633. if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
  1634. num_bits, group_size, has_act_order, is_k_full,
  1635. max_shared_mem)) {
  1636. return exec_config_t{max_m_blocks, th_config};
  1637. }
  1638. }
  1639. } else {
  1640. for (auto th_config : large_batch_thread_configs) {
  1641. if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
  1642. num_bits, group_size, has_act_order, is_k_full,
  1643. max_shared_mem)) {
  1644. return exec_config_t{max_m_blocks, th_config};
  1645. }
  1646. }
  1647. }
  1648. max_m_blocks--; // Process less M blocks per invocation to reduce cache
  1649. // usage
  1650. }
  1651. return exec_config_t{0, {-1, -1, -1}};
  1652. }
  1653. #define GPTQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
  1654. __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
  1655. __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
  1656. __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
  1657. __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
  1658. \
  1659. __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
  1660. __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
  1661. __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
  1662. __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
  1663. \
  1664. __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
  1665. __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
  1666. __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
  1667. __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
  1668. \
  1669. __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
  1670. __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
  1671. __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
  1672. __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
  1673. \
  1674. __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
  1675. __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
  1676. __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
  1677. __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS)
  1678. #define AWQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
  1679. __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
  1680. __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
  1681. __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
  1682. __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
  1683. \
  1684. __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
  1685. __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
  1686. __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
  1687. __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
  1688. \
  1689. __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
  1690. __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
  1691. __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
  1692. __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
  1693. \
  1694. __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
  1695. __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
  1696. __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
  1697. __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
  1698. template <typename scalar_t>
  1699. void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
  1700. void* zp, void* g_idx, void* perm, void* a_tmp, int prob_m,
  1701. int prob_n, int prob_k, void* workspace,
  1702. aphrodite::ScalarType const& q_type, bool has_act_order,
  1703. bool is_k_full, bool has_zp, int num_groups, int group_size,
  1704. int dev, cudaStream_t stream, int thread_k, int thread_n,
  1705. int sms, int max_par, bool use_fp32_reduce) {
  1706. if (has_zp) {
  1707. TORCH_CHECK(
  1708. q_type == aphrodite::kU4 || q_type == aphrodite::kU8,
  1709. "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
  1710. } else {
  1711. TORCH_CHECK(
  1712. q_type == aphrodite::kU4B8 || q_type == aphrodite::kU8B128,
  1713. "q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
  1714. q_type.str());
  1715. }
  1716. TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
  1717. ", ", prob_n, ", ", prob_k, "]");
  1718. // TODO: remove alias when we start supporting other 8bit types
  1719. int num_bits = q_type.size_bits();
  1720. int tot_m = prob_m;
  1721. int tot_m_blocks = div_ceil(tot_m, 16);
  1722. int pad = 16 * tot_m_blocks - tot_m;
  1723. if (sms == -1) {
  1724. cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
  1725. }
  1726. int max_shared_mem = 0;
  1727. cudaDeviceGetAttribute(&max_shared_mem,
  1728. cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
  1729. TORCH_CHECK(max_shared_mem > 0);
  1730. // Set thread config
  1731. exec_config_t exec_cfg;
  1732. if (thread_k != -1 && thread_n != -1) {
  1733. // User-defined config
  1734. exec_cfg =
  1735. exec_config_t{4, thread_config_t{thread_k, thread_n, default_threads}};
  1736. } else {
  1737. // Auto config
  1738. exec_cfg =
  1739. determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size,
  1740. has_act_order, is_k_full, max_shared_mem);
  1741. }
  1742. TORCH_CHECK(exec_cfg.max_m_blocks > 0 &&
  1743. is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks,
  1744. prob_m, prob_n, prob_k, num_bits, group_size,
  1745. has_act_order, is_k_full, max_shared_mem),
  1746. "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks,
  1747. ", thread_k = ", exec_cfg.tb_cfg.thread_k,
  1748. ", thread_n = ", exec_cfg.tb_cfg.thread_n,
  1749. ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [",
  1750. prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
  1751. ", group_size = ", group_size,
  1752. ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
  1753. ", max_shared_mem = ", max_shared_mem);
  1754. int num_threads = exec_cfg.tb_cfg.num_threads;
  1755. thread_k = exec_cfg.tb_cfg.thread_k;
  1756. thread_n = exec_cfg.tb_cfg.thread_n;
  1757. int thread_k_blocks = thread_k / 16;
  1758. int thread_n_blocks = thread_n / 16;
  1759. int blocks = sms;
  1760. TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
  1761. " is not divisible by thread_n = ", thread_n);
  1762. TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
  1763. " is not divisible by thread_k = ", thread_k);
  1764. int group_blocks = 0;
  1765. if (has_act_order) {
  1766. if (is_k_full) {
  1767. TORCH_CHECK(group_size != -1);
  1768. group_blocks = group_size / 16;
  1769. TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
  1770. " is not divisible by group_blocks = ", group_blocks);
  1771. } else {
  1772. TORCH_CHECK(group_size == 0);
  1773. group_blocks = 0;
  1774. }
  1775. } else {
  1776. if (group_size == -1) {
  1777. group_blocks = -1;
  1778. } else {
  1779. group_blocks = group_size / 16;
  1780. TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
  1781. " is not divisible by group_blocks = ", group_blocks);
  1782. }
  1783. }
  1784. const int4* A_ptr = (const int4*)A;
  1785. const int4* B_ptr = (const int4*)B;
  1786. int4* C_ptr = (int4*)C;
  1787. int4* C_tmp_ptr = (int4*)C_tmp;
  1788. const int4* s_ptr = (const int4*)s;
  1789. const int4* zp_ptr = (const int4*)zp;
  1790. const int* g_idx_ptr = (const int*)g_idx;
  1791. const int* perm_ptr = (const int*)perm;
  1792. int4* a_tmp_ptr = (int4*)a_tmp;
  1793. int* locks = (int*)workspace;
  1794. if (has_act_order) {
  1795. // Permute A columns
  1796. int block_rows = div_ceil(prob_m, blocks);
  1797. permute_cols_kernel<<<blocks, default_threads, 0, stream>>>(
  1798. A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, block_rows);
  1799. A_ptr = a_tmp_ptr;
  1800. }
  1801. // If we have a full K, then we can run the non-act-order version of Marlin
  1802. // (since the weight rows are reordered by increasing group ids, and by having
  1803. // a full K, we have full original groups)
  1804. if (is_k_full) {
  1805. has_act_order = false;
  1806. }
  1807. // Main loop
  1808. for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) {
  1809. int thread_m_blocks = tot_m_blocks - i;
  1810. prob_m = tot_m - 16 * i;
  1811. int par = 1;
  1812. if (thread_m_blocks > exec_cfg.max_m_blocks) {
  1813. // Note that parallel > 1 currently only works for inputs without any
  1814. // padding
  1815. par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks);
  1816. if (par > max_par) par = max_par;
  1817. prob_m = (16 * exec_cfg.max_m_blocks) * par;
  1818. i += exec_cfg.max_m_blocks * (par - 1);
  1819. thread_m_blocks = exec_cfg.max_m_blocks;
  1820. }
  1821. if (false) {
  1822. }
  1823. GPTQ_CALL_IF(aphrodite::kU4B8, 16, 4, 256)
  1824. GPTQ_CALL_IF(aphrodite::kU4B8, 8, 8, 256)
  1825. GPTQ_CALL_IF(aphrodite::kU4B8, 8, 4, 128)
  1826. GPTQ_CALL_IF(aphrodite::kU4B8, 4, 8, 128)
  1827. GPTQ_CALL_IF(aphrodite::kU8B128, 16, 4, 256)
  1828. GPTQ_CALL_IF(aphrodite::kU8B128, 8, 8, 256)
  1829. GPTQ_CALL_IF(aphrodite::kU8B128, 8, 4, 128)
  1830. GPTQ_CALL_IF(aphrodite::kU8B128, 4, 8, 128)
  1831. AWQ_CALL_IF(aphrodite::kU4, 16, 4, 256)
  1832. AWQ_CALL_IF(aphrodite::kU4, 8, 8, 256)
  1833. AWQ_CALL_IF(aphrodite::kU4, 8, 4, 128)
  1834. AWQ_CALL_IF(aphrodite::kU4, 4, 8, 128)
  1835. AWQ_CALL_IF(aphrodite::kU8, 16, 4, 256)
  1836. AWQ_CALL_IF(aphrodite::kU8, 8, 8, 256)
  1837. AWQ_CALL_IF(aphrodite::kU8, 8, 4, 128)
  1838. AWQ_CALL_IF(aphrodite::kU8, 4, 8, 128)
  1839. else {
  1840. TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
  1841. ", ", prob_k, "]", ", has_act_order = ", has_act_order,
  1842. ", num_groups = ", num_groups, ", group_size = ", group_size,
  1843. ", thread_m_blocks = ", thread_m_blocks,
  1844. ", thread_n_blocks = ", thread_n_blocks,
  1845. ", thread_k_blocks = ", thread_k_blocks,
  1846. ", num_bits = ", num_bits);
  1847. }
  1848. A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;
  1849. C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par;
  1850. }
  1851. }
  1852. } // namespace marlin
  1853. torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
  1854. torch::Tensor& b_scales, torch::Tensor& b_zeros,
  1855. torch::Tensor& g_idx, torch::Tensor& perm,
  1856. torch::Tensor& workspace,
  1857. aphrodite::ScalarTypeTorchPtr const& b_q_type,
  1858. int64_t size_m, int64_t size_n, int64_t size_k,
  1859. bool is_k_full, bool has_zp,
  1860. bool use_fp32_reduce) {
  1861. if (has_zp) {
  1862. TORCH_CHECK(*b_q_type == aphrodite::kU4 || *b_q_type == aphrodite::kU8,
  1863. "b_q_type must be u4 or u8 when has_zp = True. Got = ",
  1864. b_q_type->str());
  1865. } else {
  1866. TORCH_CHECK(
  1867. *b_q_type == aphrodite::kU4B8 || *b_q_type == aphrodite::kU8B128,
  1868. "b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
  1869. b_q_type->str());
  1870. }
  1871. int pack_factor = 32 / b_q_type->size_bits();
  1872. // Verify A
  1873. TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
  1874. ", size_m = ", size_m);
  1875. TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1),
  1876. ", size_k = ", size_k);
  1877. // Verify B
  1878. TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k,
  1879. " is not divisible by tile_size = ", marlin::tile_size);
  1880. TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0),
  1881. "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
  1882. ", size_k = ", size_k, ", tile_size = ", marlin::tile_size);
  1883. TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0,
  1884. "b_q_weight.size(1) = ", b_q_weight.size(1),
  1885. " is not divisible by tile_size = ", marlin::tile_size);
  1886. int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor;
  1887. TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
  1888. ", actual_size_n = ", actual_size_n);
  1889. // Verify device and strides
  1890. TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
  1891. TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
  1892. TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
  1893. TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
  1894. TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
  1895. TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
  1896. TORCH_CHECK(b_zeros.device().is_cuda(), "b_zeros is not on GPU");
  1897. TORCH_CHECK(b_zeros.is_contiguous(), "b_zeros is not contiguous");
  1898. TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU");
  1899. TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous");
  1900. TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU");
  1901. TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous");
  1902. // Alloc buffers
  1903. const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
  1904. auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
  1905. torch::Tensor c = torch::empty({size_m, size_n}, options);
  1906. torch::Tensor a_tmp = torch::empty({size_m, size_k}, options);
  1907. // Alloc C tmp buffer that is going to be used for the global reduce
  1908. int reduce_max_m = marlin::determine_reduce_max_m(size_m, marlin::max_par);
  1909. int reduce_n = size_n;
  1910. auto options_fp32 =
  1911. torch::TensorOptions().dtype(at::kFloat).device(a.device());
  1912. if (!use_fp32_reduce) {
  1913. reduce_max_m = 0;
  1914. reduce_n = 0;
  1915. }
  1916. torch::Tensor c_tmp = torch::empty({reduce_max_m, reduce_n}, options_fp32);
  1917. // thread_k: `k` size of a thread_tile in `weights` (can usually be left as
  1918. // auto -1)
  1919. int thread_k = -1;
  1920. // thread_n: `n` size of a thread_tile in `weights` (can usually be left as
  1921. // auto -1)
  1922. int thread_n = -1;
  1923. // sms: number of SMs to use for the kernel (can usually be left as auto -1)
  1924. int sms = -1;
  1925. // Verify g_idx and perm
  1926. TORCH_CHECK((g_idx.size(0) == 0 && perm.size(0) == 0) ||
  1927. (g_idx.size(0) == size_k && perm.size(0) == size_k),
  1928. "Unexpected g_idx.size(0) = ", g_idx.size(0),
  1929. " and perm.size(0) = ", perm.size(0),
  1930. ", where size_k = ", size_k);
  1931. // Detect groupsize and act_order
  1932. int num_groups = -1;
  1933. int group_size = -1;
  1934. bool has_act_order = g_idx.size(0) != 0;
  1935. int rank = b_scales.sizes().size();
  1936. TORCH_CHECK(rank == 2, "b_scales rank = ", rank, " is not 2");
  1937. TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1),
  1938. " is not size_n = ", size_n);
  1939. num_groups = b_scales.size(0);
  1940. if (has_act_order) {
  1941. if (is_k_full) {
  1942. TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1");
  1943. TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k,
  1944. ", is not divisible by num_groups = ", num_groups);
  1945. group_size = size_k / num_groups;
  1946. } else {
  1947. group_size = 0;
  1948. }
  1949. } else {
  1950. if (num_groups > 1) {
  1951. TORCH_CHECK(
  1952. size_k % num_groups == 0, "size_k = ", size_k,
  1953. ", is not divisible by b_scales.size(0) = ", b_scales.size(0));
  1954. group_size = size_k / num_groups;
  1955. } else {
  1956. group_size = -1;
  1957. }
  1958. }
  1959. // Verify b_zeros
  1960. if (has_zp) {
  1961. int rank = b_zeros.sizes().size();
  1962. TORCH_CHECK(rank == 2, "b_zeros rank = ", rank, " is not 2");
  1963. TORCH_CHECK(b_zeros.size(0) == num_groups,
  1964. "b_zeros dim 0 = ", b_zeros.size(0),
  1965. " is not num_groups = ", num_groups);
  1966. TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor,
  1967. "b_zeros dim 1 = ", b_scales.size(1),
  1968. " is not size_n / pack_factor = ", size_n / pack_factor);
  1969. }
  1970. // Verify workspace size
  1971. TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n,
  1972. ", is not divisible by min_thread_n = ", marlin::min_thread_n);
  1973. int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par;
  1974. TORCH_CHECK(workspace.numel() >= min_workspace_size,
  1975. "workspace.numel = ", workspace.numel(),
  1976. " is below min_workspace_size = ", min_workspace_size);
  1977. int dev = a.get_device();
  1978. if (a.scalar_type() == at::ScalarType::Half) {
  1979. marlin::marlin_mm<half>(
  1980. a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
  1981. c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(),
  1982. b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
  1983. a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
  1984. workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp,
  1985. num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
  1986. thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
  1987. } else if (a.scalar_type() == at::ScalarType::BFloat16) {
  1988. marlin::marlin_mm<nv_bfloat16>(
  1989. a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
  1990. c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
  1991. b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
  1992. perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
  1993. workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp,
  1994. num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
  1995. thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
  1996. } else {
  1997. TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
  1998. }
  1999. return c;
  2000. }
  2001. #endif