gptq_marlin.cu 93 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433
  1. /*
  2. * Modified by Neural Magic
  3. * Copyright (C) Marlin.2024 Elias Frantar
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. /*
  18. * Adapted from https://github.com/IST-DASLab/marlin
  19. */
  20. #include "marlin.cuh"
  21. #include "marlin_dtypes.cuh"
  22. #include "core/scalar_type.hpp"
  23. #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
  24. static_assert(std::is_same<scalar_t, half>::value || \
  25. std::is_same<scalar_t, nv_bfloat16>::value, \
  26. "only float16 and bfloat16 is supported");
  27. template <typename T>
  28. inline std::string str(T x) {
  29. return std::to_string(x);
  30. }
  31. namespace marlin {
  32. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  33. __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
  34. int const* __restrict__ perm_int_ptr,
  35. int4* __restrict__ out_int4_ptr, int size_m,
  36. int size_k, int block_rows) {}
  37. template <typename scalar_t, // compute dtype, half or nv_float16
  38. const aphrodite::ScalarTypeId w_type_id, // weight ScalarType id
  39. const int threads, // number of threads in a threadblock
  40. const int thread_m_blocks, // number of 16x16 blocks in the m
  41. // dimension (batchsize) of the
  42. // threadblock
  43. const int thread_n_blocks, // same for n dimension (output)
  44. const int thread_k_blocks, // same for k dimension (reduction)
  45. const int stages, // number of stages for the async global->shared
  46. // fetch pipeline
  47. const bool has_act_order, // whether act_order is enabled
  48. const int group_blocks = -1, // number of consecutive 16x16 blocks
  49. // with a separate quantization scale
  50. const bool is_zp_float // is zero point of float16 type?
  51. >
  52. __global__ void Marlin(
  53. const int4* __restrict__ A, // fp16 input matrix of shape mxk
  54. const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
  55. int4* __restrict__ C, // fp16 output buffer of shape mxn
  56. int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
  57. const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
  58. // (k/groupsize)xn
  59. const int* __restrict__ g_idx, // int32 group indices of shape k
  60. int num_groups, // number of scale groups per output channel
  61. int prob_m, // batch dimension m
  62. int prob_n, // output dimension n
  63. int prob_k, // reduction dimension k
  64. int* locks, // extra global storage for barrier synchronization
  65. bool use_fp32_reduce // whether to use fp32 global reduce
  66. ) {}
  67. } // namespace marlin
  68. torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
  69. torch::Tensor& b_scales, torch::Tensor& b_zeros,
  70. torch::Tensor& g_idx, torch::Tensor& perm,
  71. torch::Tensor& workspace,
  72. aphrodite::ScalarTypeTorchPtr const& b_q_type,
  73. int64_t size_m, int64_t size_n, int64_t size_k,
  74. bool is_k_full, bool has_zp, bool is_zp_float) {
  75. TORCH_CHECK_NOT_IMPLEMENTED(false,
  76. "marlin_gemm(..) requires CUDA_ARCH >= 8.0");
  77. return torch::empty({1, 1});
  78. }
  79. #else
  80. // m16n8k16 tensor core mma instruction with fp16 inputs and fp32
  81. // output/accumulation.
  82. template <typename scalar_t>
  83. __device__ inline void mma(const typename ScalarType<scalar_t>::FragA& a_frag,
  84. const typename ScalarType<scalar_t>::FragB& frag_b,
  85. typename ScalarType<scalar_t>::FragC& frag_c) {
  86. const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
  87. const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
  88. float* c = reinterpret_cast<float*>(&frag_c);
  89. if constexpr (std::is_same<scalar_t, half>::value) {
  90. asm volatile(
  91. "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
  92. "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
  93. : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
  94. : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
  95. "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
  96. } else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
  97. asm volatile(
  98. "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
  99. "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
  100. : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
  101. : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
  102. "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
  103. } else {
  104. STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
  105. }
  106. }
  107. // Instruction for loading a full 16x16 matrix fragment of operand A from shared
  108. // memory, directly in tensor core layout.
  109. template <typename scalar_t>
  110. __device__ inline void ldsm4(typename ScalarType<scalar_t>::FragA& frag_a,
  111. const void* smem_ptr) {
  112. uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
  113. uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
  114. asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
  115. : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
  116. : "r"(smem));
  117. }
  118. // Lookup-table based 3-input logical operation; explicitly used for
  119. // dequantization as the compiler does not seem to automatically recognize it in
  120. // all cases.
  121. template <int lut>
  122. __device__ inline int lop3(int a, int b, int c) {
  123. int res;
  124. asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
  125. : "=r"(res)
  126. : "r"(a), "r"(b), "r"(c), "n"(lut));
  127. return res;
  128. }
  129. // Constructs destination register by taking bytes from 2 sources (based on
  130. // mask)
  131. template <int start_byte, int mask>
  132. __device__ inline uint32_t prmt(uint32_t a) {
  133. uint32_t res;
  134. asm volatile("prmt.b32 %0, %1, %2, %3;\n"
  135. : "=r"(res)
  136. : "r"(a), "n"(start_byte), "n"(mask));
  137. return res;
  138. }
  139. template <typename scalar_t, aphrodite::ScalarTypeId w_type_id>
  140. __device__ inline typename ScalarType<scalar_t>::FragB dequant(int q);
  141. //
  142. // Efficiently dequantize 4bit values packed in an int32 value into a full
  143. // B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
  144. // with some small changes:
  145. // - FP16:
  146. // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
  147. // - BF16:
  148. // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
  149. //
  150. template <>
  151. __device__ inline typename ScalarType<half>::FragB
  152. dequant<half, aphrodite::kU4B8.id()>(int q) {
  153. const int LO = 0x000f000f;
  154. const int HI = 0x00f000f0;
  155. const int EX = 0x64006400;
  156. // Guarantee that the `(a & b) | c` operations are LOP3s.
  157. int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
  158. int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
  159. // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
  160. // directly into `SUB` and `ADD`.
  161. const int SUB = 0x64086408;
  162. const int MUL = 0x2c002c00;
  163. const int ADD = 0xd480d480;
  164. typename ScalarType<half>::FragB frag_b;
  165. frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
  166. *reinterpret_cast<const half2*>(&SUB));
  167. frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
  168. *reinterpret_cast<const half2*>(&MUL),
  169. *reinterpret_cast<const half2*>(&ADD));
  170. return frag_b;
  171. }
  172. template <>
  173. __device__ inline typename ScalarType<nv_bfloat16>::FragB
  174. dequant<nv_bfloat16, aphrodite::kU4B8.id()>(int q) {
  175. static constexpr uint32_t MASK = 0x000f000f;
  176. static constexpr uint32_t EX = 0x43004300;
  177. // Guarantee that the `(a & b) | c` operations are LOP3s.
  178. int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
  179. q >>= 4;
  180. int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
  181. typename ScalarType<nv_bfloat16>::FragB frag_b;
  182. static constexpr uint32_t MUL = 0x3F803F80;
  183. static constexpr uint32_t ADD = 0xC308C308;
  184. frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
  185. *reinterpret_cast<const nv_bfloat162*>(&MUL),
  186. *reinterpret_cast<const nv_bfloat162*>(&ADD));
  187. frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
  188. *reinterpret_cast<const nv_bfloat162*>(&MUL),
  189. *reinterpret_cast<const nv_bfloat162*>(&ADD));
  190. return frag_b;
  191. }
  192. template <>
  193. __device__ inline typename ScalarType<half>::FragB
  194. dequant<half, aphrodite::kU4.id()>(int q) {
  195. const int LO = 0x000f000f;
  196. const int HI = 0x00f000f0;
  197. const int EX = 0x64006400;
  198. // Guarantee that the `(a & b) | c` operations are LOP3s.
  199. int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
  200. int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
  201. const int SUB = 0x64006400;
  202. const int MUL = 0x2c002c00;
  203. const int ADD = 0xd400d400;
  204. typename ScalarType<half>::FragB frag_b;
  205. frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
  206. *reinterpret_cast<const half2*>(&SUB));
  207. frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
  208. *reinterpret_cast<const half2*>(&MUL),
  209. *reinterpret_cast<const half2*>(&ADD));
  210. return frag_b;
  211. }
  212. template <>
  213. __device__ inline typename ScalarType<nv_bfloat16>::FragB
  214. dequant<nv_bfloat16, aphrodite::kU4.id()>(int q) {
  215. static constexpr uint32_t MASK = 0x000f000f;
  216. static constexpr uint32_t EX = 0x43004300;
  217. // Guarantee that the `(a & b) | c` operations are LOP3s.
  218. int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
  219. q >>= 4;
  220. int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
  221. typename ScalarType<nv_bfloat16>::FragB frag_b;
  222. static constexpr uint32_t MUL = 0x3F803F80;
  223. static constexpr uint32_t ADD = 0xC300C300;
  224. frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
  225. *reinterpret_cast<const nv_bfloat162*>(&MUL),
  226. *reinterpret_cast<const nv_bfloat162*>(&ADD));
  227. frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
  228. *reinterpret_cast<const nv_bfloat162*>(&MUL),
  229. *reinterpret_cast<const nv_bfloat162*>(&ADD));
  230. return frag_b;
  231. }
  232. //
  233. // Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
  234. // bf16 Reference:
  235. // - FP16:
  236. // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
  237. // - BF16:
  238. // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
  239. //
  240. template <>
  241. __device__ inline typename ScalarType<half>::FragB
  242. dequant<half, aphrodite::kU8B128.id()>(int q) {
  243. static constexpr uint32_t mask_for_elt_01 = 0x5250;
  244. static constexpr uint32_t mask_for_elt_23 = 0x5351;
  245. static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
  246. uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
  247. uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
  248. static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
  249. typename ScalarType<half>::FragB frag_b;
  250. frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
  251. *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
  252. frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
  253. *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
  254. return frag_b;
  255. }
  256. template <>
  257. __device__ inline typename ScalarType<nv_bfloat16>::FragB
  258. dequant<nv_bfloat16, aphrodite::kU8B128.id()>(int q) {
  259. typename ScalarType<nv_bfloat16>::FragB frag_b;
  260. float fp32_intermediates[4];
  261. uint32_t* fp32_intermediates_casted =
  262. reinterpret_cast<uint32_t*>(fp32_intermediates);
  263. static constexpr uint32_t fp32_base = 0x4B000000;
  264. fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
  265. fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
  266. fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
  267. fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
  268. fp32_intermediates[0] -= 8388736.f;
  269. fp32_intermediates[1] -= 8388736.f;
  270. fp32_intermediates[2] -= 8388736.f;
  271. fp32_intermediates[3] -= 8388736.f;
  272. uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b);
  273. bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
  274. fp32_intermediates_casted[1], 0x7632);
  275. bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
  276. fp32_intermediates_casted[3], 0x7632);
  277. return frag_b;
  278. }
  279. template <>
  280. __device__ inline typename ScalarType<half>::FragB
  281. dequant<half, aphrodite::kU8.id()>(int q) {
  282. static constexpr uint32_t mask_for_elt_01 = 0x5250;
  283. static constexpr uint32_t mask_for_elt_23 = 0x5351;
  284. static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
  285. uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
  286. uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
  287. static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
  288. typename ScalarType<half>::FragB frag_b;
  289. frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
  290. *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
  291. frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
  292. *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
  293. return frag_b;
  294. }
  295. template <>
  296. __device__ inline typename ScalarType<nv_bfloat16>::FragB
  297. dequant<nv_bfloat16, aphrodite::kU8.id()>(int q) {
  298. typename ScalarType<nv_bfloat16>::FragB frag_b;
  299. float fp32_intermediates[4];
  300. uint32_t* fp32_intermediates_casted =
  301. reinterpret_cast<uint32_t*>(fp32_intermediates);
  302. static constexpr uint32_t fp32_base = 0x4B000000;
  303. fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
  304. fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
  305. fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
  306. fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
  307. fp32_intermediates[0] -= 8388608.f;
  308. fp32_intermediates[1] -= 8388608.f;
  309. fp32_intermediates[2] -= 8388608.f;
  310. fp32_intermediates[3] -= 8388608.f;
  311. uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b);
  312. bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
  313. fp32_intermediates_casted[1], 0x7632);
  314. bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
  315. fp32_intermediates_casted[3], 0x7632);
  316. return frag_b;
  317. }
  318. // Multiply dequantized values by the corresponding quantization scale; used
  319. // only for grouped quantization.
  320. template <typename scalar_t>
  321. __device__ inline void scale(typename ScalarType<scalar_t>::FragB& frag_b,
  322. typename ScalarType<scalar_t>::FragS& frag_s,
  323. int i) {
  324. using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
  325. scalar_t2 s =
  326. ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_s)[i]);
  327. frag_b[0] = __hmul2(frag_b[0], s);
  328. frag_b[1] = __hmul2(frag_b[1], s);
  329. }
  330. template <typename scalar_t>
  331. __device__ inline void sub_zp(typename ScalarType<scalar_t>::FragB& frag_b,
  332. typename ScalarType<scalar_t>::scalar_t2& frag_zp,
  333. int i) {
  334. using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
  335. scalar_t2 zp =
  336. ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_zp)[i]);
  337. frag_b[0] = __hsub2(frag_b[0], zp);
  338. frag_b[1] = __hsub2(frag_b[1], zp);
  339. }
  340. template <typename scalar_t>
  341. __device__ inline void sub_zpf(typename ScalarType<scalar_t>::FragB& frag_b,
  342. typename ScalarType<scalar_t>::FragZPF& frag_zpf,
  343. int i) {
  344. using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
  345. scalar_t2 zp =
  346. ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_zpf)[i]);
  347. frag_b[0] = __hsub2(frag_b[0], zp);
  348. frag_b[1] = __hsub2(frag_b[1], zp);
  349. }
  350. // Same as above, but for act_order (each K is multiplied individually)
  351. template <typename scalar_t>
  352. __device__ inline void scale4(typename ScalarType<scalar_t>::FragB& frag_b,
  353. typename ScalarType<scalar_t>::FragS& frag_s_1,
  354. typename ScalarType<scalar_t>::FragS& frag_s_2,
  355. typename ScalarType<scalar_t>::FragS& frag_s_3,
  356. typename ScalarType<scalar_t>::FragS& frag_s_4,
  357. int i) {
  358. using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
  359. scalar_t2 s_val_1_2;
  360. s_val_1_2.x = reinterpret_cast<scalar_t*>(&frag_s_1)[i];
  361. s_val_1_2.y = reinterpret_cast<scalar_t*>(&frag_s_2)[i];
  362. scalar_t2 s_val_3_4;
  363. s_val_3_4.x = reinterpret_cast<scalar_t*>(&frag_s_3)[i];
  364. s_val_3_4.y = reinterpret_cast<scalar_t*>(&frag_s_4)[i];
  365. frag_b[0] = __hmul2(frag_b[0], s_val_1_2);
  366. frag_b[1] = __hmul2(frag_b[1], s_val_3_4);
  367. }
  368. // Given 2 floats multiply by 2 scales (halves)
  369. template <typename scalar_t>
  370. __device__ inline void scale_float(float* c,
  371. typename ScalarType<scalar_t>::FragS& s) {
  372. scalar_t* s_ptr = reinterpret_cast<scalar_t*>(&s);
  373. c[0] = __fmul_rn(c[0], ScalarType<scalar_t>::num2float(s_ptr[0]));
  374. c[1] = __fmul_rn(c[1], ScalarType<scalar_t>::num2float(s_ptr[1]));
  375. }
  376. // Given 2 floats subtract by 2 zero points (halves)
  377. template <typename scalar_t>
  378. __device__ inline void sub_zpf_float(
  379. float* c, typename ScalarType<scalar_t>::FragZPF& zp) {
  380. scalar_t* zp_ptr = reinterpret_cast<scalar_t*>(&zp);
  381. c[0] = __fsub_rn(c[0], ScalarType<scalar_t>::num2float(zp_ptr[0]));
  382. c[1] = __fsub_rn(c[1], ScalarType<scalar_t>::num2float(zp_ptr[1]));
  383. }
  384. // Wait until barrier reaches `count`, then lock for current threadblock.
  385. __device__ inline void barrier_acquire(int* lock, int count) {
  386. if (threadIdx.x == 0) {
  387. int state = -1;
  388. do
  389. // Guarantee that subsequent writes by this threadblock will be visible
  390. // globally.
  391. asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
  392. : "=r"(state)
  393. : "l"(lock));
  394. while (state != count);
  395. }
  396. __syncthreads();
  397. }
  398. // Release barrier and increment visitation count.
  399. __device__ inline void barrier_release(int* lock, bool reset = false) {
  400. __syncthreads();
  401. if (threadIdx.x == 0) {
  402. if (reset) {
  403. lock[0] = 0;
  404. return;
  405. }
  406. int val = 1;
  407. // Make sure that all writes since acquiring this barrier are visible
  408. // globally, while releasing the barrier.
  409. asm volatile("fence.acq_rel.gpu;\n");
  410. asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
  411. :
  412. : "l"(lock), "r"(val));
  413. }
  414. }
  415. // For a given "a" of size [M,K] performs a permutation of the K columns based
  416. // on the given "perm" indices.
  417. __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
  418. int const* __restrict__ perm_int_ptr,
  419. int4* __restrict__ out_int4_ptr, int size_m,
  420. int size_k, int block_rows) {
  421. int start_row = block_rows * blockIdx.x;
  422. int finish_row = start_row + block_rows;
  423. if (finish_row > size_m) {
  424. finish_row = size_m;
  425. }
  426. int cur_block_rows = finish_row - start_row;
  427. int row_stride = size_k * sizeof(half) / 16;
  428. auto permute_row = [&](int row) {
  429. int iters = size_k / default_threads;
  430. int rest = size_k % default_threads;
  431. int offset = row * row_stride;
  432. half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + offset);
  433. half* out_half = reinterpret_cast<half*>(out_int4_ptr + offset);
  434. int base_k = 0;
  435. for (int i = 0; i < iters; i++) {
  436. int cur_k = base_k + threadIdx.x;
  437. int src_pos = perm_int_ptr[cur_k];
  438. out_half[cur_k] = a_row_half[src_pos];
  439. base_k += default_threads;
  440. }
  441. if (rest) {
  442. if (threadIdx.x < rest) {
  443. int cur_k = base_k + threadIdx.x;
  444. int src_pos = perm_int_ptr[cur_k];
  445. out_half[cur_k] = a_row_half[src_pos];
  446. }
  447. }
  448. };
  449. for (int i = 0; i < cur_block_rows; i++) {
  450. int cur_row = start_row + i;
  451. if (cur_row < size_m) {
  452. permute_row(cur_row);
  453. }
  454. }
  455. }
  456. template <typename scalar_t, // compute dtype, half or nv_float16
  457. const aphrodite::ScalarTypeId w_type_id, // weight ScalarType id
  458. const int threads, // number of threads in a threadblock
  459. const int thread_m_blocks, // number of 16x16 blocks in the m
  460. // dimension (batchsize) of the
  461. // threadblock
  462. const int thread_n_blocks, // same for n dimension (output)
  463. const int thread_k_blocks, // same for k dimension (reduction)
  464. const int stages, // number of stages for the async global->shared
  465. // fetch pipeline
  466. const bool has_act_order, // whether act_order is enabled
  467. const bool has_zp, // whether zero-points are enabled
  468. const int group_blocks = -1, // number of consecutive 16x16 blocks
  469. // with a separate quantization scale
  470. const bool is_zp_float // is zero point of float16 type?
  471. >
  472. __global__ void Marlin(
  473. const int4* __restrict__ A, // fp16 input matrix of shape mxk
  474. const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
  475. int4* __restrict__ C, // fp16 output buffer of shape mxn
  476. int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
  477. const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
  478. // (k/groupsize)xn
  479. const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
  480. // (k/groupsize)x(n/pack_factor)
  481. const int* __restrict__ g_idx, // int32 group indices of shape k
  482. int num_groups, // number of scale groups per output channel
  483. int prob_m, // batch dimension m
  484. int prob_n, // output dimension n
  485. int prob_k, // reduction dimension k
  486. int* locks, // extra global storage for barrier synchronization
  487. bool use_fp32_reduce // whether to use fp32 global reduce
  488. ) {
  489. // Each threadblock processes one "stripe" of the B matrix with (roughly) the
  490. // same size, which might involve multiple column "slices" (of width 16 *
  491. // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
  492. // example:
  493. // 0 1 3
  494. // 0 2 3
  495. // 1 2 4
  496. // While this kind of partitioning makes things somewhat more complicated, it
  497. // ensures good utilization of all SMs for many kinds of shape and GPU
  498. // configurations, while requiring as few slow global cross-threadblock
  499. // reductions as possible.
  500. using Dtype = ScalarType<scalar_t>;
  501. using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
  502. using FragA = typename ScalarType<scalar_t>::FragA;
  503. using FragB = typename ScalarType<scalar_t>::FragB;
  504. using FragC = typename ScalarType<scalar_t>::FragC;
  505. using FragS = typename ScalarType<scalar_t>::FragS;
  506. using FragZP = typename ScalarType<scalar_t>::FragZP;
  507. using FragZPF = typename ScalarType<scalar_t>::FragZPF;
  508. static constexpr auto w_type = aphrodite::ScalarType::from_id(w_type_id);
  509. constexpr int pack_factor = 32 / w_type.size_bits();
  510. // For larger GEMMs we run multiple batchsize 64 versions in parallel for a
  511. // better partitioning with less reductions
  512. int parallel = 1;
  513. if (prob_m > 16 * thread_m_blocks) {
  514. parallel = prob_m / (16 * thread_m_blocks);
  515. prob_m = 16 * thread_m_blocks;
  516. }
  517. int k_tiles = prob_k / 16 / thread_k_blocks;
  518. int n_tiles = prob_n / 16 / thread_n_blocks;
  519. int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x);
  520. if constexpr (!has_act_order && group_blocks != -1) {
  521. if (group_blocks >= thread_k_blocks) {
  522. // Ensure that the number of tiles in each stripe is a multiple of the
  523. // groupsize; this avoids an annoying special case where a stripe starts
  524. // in the middle of group.
  525. iters = (group_blocks / thread_k_blocks) *
  526. div_ceil(iters, (group_blocks / thread_k_blocks));
  527. }
  528. }
  529. int slice_row = (iters * blockIdx.x) % k_tiles;
  530. int slice_col_par = (iters * blockIdx.x) / k_tiles;
  531. int slice_col = slice_col_par;
  532. int slice_iters; // number of threadblock tiles in the current slice
  533. int slice_count =
  534. 0; // total number of active threadblocks in the current slice
  535. int slice_idx; // index of threadblock in current slice; numbered bottom to
  536. // top
  537. int par_id = 0;
  538. // We can easily implement parallel problem execution by just remapping
  539. // indices and advancing global pointers
  540. if (slice_col_par >= n_tiles) {
  541. A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8;
  542. C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
  543. locks += (slice_col_par / n_tiles) * n_tiles;
  544. slice_col = slice_col_par % n_tiles;
  545. par_id = slice_col_par / n_tiles;
  546. }
  547. // Compute all information about the current slice which is required for
  548. // synchronization.
  549. auto init_slice = [&]() {
  550. slice_iters =
  551. iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
  552. if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
  553. if (slice_iters == 0) return;
  554. if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
  555. slice_count = 1;
  556. slice_idx = 0;
  557. int col_first = iters * div_ceil(k_tiles * slice_col_par, iters);
  558. if (col_first <= k_tiles * (slice_col_par + 1)) {
  559. int col_off = col_first - k_tiles * slice_col_par;
  560. slice_count = div_ceil(k_tiles - col_off, iters);
  561. if (col_off > 0) slice_count++;
  562. int delta_first = iters * blockIdx.x - col_first;
  563. if (delta_first < 0 || (col_off == 0 && delta_first == 0))
  564. slice_idx = slice_count - 1;
  565. else {
  566. slice_idx = slice_count - 1 - delta_first / iters;
  567. if (col_off > 0) slice_idx--;
  568. }
  569. }
  570. if (slice_col == n_tiles) {
  571. A += 16 * thread_m_blocks * prob_k / 8;
  572. C += 16 * thread_m_blocks * prob_n / 8;
  573. locks += n_tiles;
  574. slice_col = 0;
  575. par_id++;
  576. }
  577. };
  578. init_slice();
  579. // A sizes/strides
  580. // stride of the A matrix in global memory
  581. int a_gl_stride = prob_k / 8;
  582. // stride of an A matrix tile in shared memory
  583. constexpr int a_sh_stride = 16 * thread_k_blocks / 8;
  584. // delta between subsequent A tiles in global memory
  585. constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8;
  586. // between subsequent accesses within a tile
  587. int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o);
  588. // between shared memory writes
  589. constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o);
  590. // between shared memory tile reads
  591. constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4));
  592. // within a shared memory tile
  593. constexpr int a_sh_rd_delta_i = a_sh_stride * 16;
  594. // overall size of a tile
  595. constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks);
  596. // number of shared write iterations for a tile
  597. constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta);
  598. // B sizes/strides
  599. int b_gl_stride = 16 * prob_n / (pack_factor * 4);
  600. constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4;
  601. constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2;
  602. constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;
  603. int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
  604. int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads);
  605. constexpr int b_sh_wr_delta = threads * b_thread_vecs;
  606. constexpr int b_sh_rd_delta = threads * b_thread_vecs;
  607. constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
  608. constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
  609. // Scale sizes/strides without act_order
  610. int s_gl_stride = prob_n / 8;
  611. constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
  612. constexpr int s_tb_groups =
  613. !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
  614. ? thread_k_blocks / group_blocks
  615. : 1;
  616. constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
  617. int s_gl_rd_delta = s_gl_stride;
  618. // Scale size/strides with act_order
  619. constexpr int tb_k = 16 * thread_k_blocks;
  620. constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0;
  621. // constexpr int act_s_row_stride = 1;
  622. // int act_s_col_stride = act_s_row_stride * num_groups;
  623. int act_s_col_stride = 1;
  624. int act_s_col_warp_stride = act_s_col_stride * 8;
  625. int tb_n_warps = thread_n_blocks / 4;
  626. int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;
  627. // Zero-points sizes/strides
  628. int zp_gl_stride = is_zp_float ? prob_n / 8 : (prob_n / pack_factor) / 4;
  629. constexpr int zp_sh_stride = is_zp_float
  630. ? 16 * thread_n_blocks / 8
  631. : ((16 * thread_n_blocks) / pack_factor) / 4;
  632. constexpr int zp_tb_groups = s_tb_groups;
  633. constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0;
  634. int zp_gl_rd_delta = zp_gl_stride;
  635. // Global A read index of current thread.
  636. int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
  637. (threadIdx.x % a_gl_rd_delta_o);
  638. a_gl_rd += a_gl_rd_delta_o * slice_row;
  639. // Shared write index of current thread.
  640. int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
  641. (threadIdx.x % a_gl_rd_delta_o);
  642. // Shared read index.
  643. int a_sh_rd =
  644. a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;
  645. a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
  646. int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) +
  647. (threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
  648. b_gl_rd += b_sh_stride * slice_col;
  649. b_gl_rd += b_gl_rd_delta_o * slice_row;
  650. int b_sh_wr = threadIdx.x * b_thread_vecs;
  651. int b_sh_rd = threadIdx.x * b_thread_vecs;
  652. // For act_order
  653. constexpr int k_iter_size = tb_k / b_sh_wr_iters;
  654. int slice_k_start = tb_k * slice_row;
  655. int slice_k_finish = slice_k_start + tb_k * slice_iters;
  656. int slice_k_start_shared_fetch = slice_k_start;
  657. int slice_n_offset = act_s_col_tb_stride * slice_col;
  658. // No act_order
  659. int s_gl_rd;
  660. if constexpr (!has_act_order) {
  661. if constexpr (group_blocks == -1) {
  662. s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
  663. } else {
  664. s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
  665. s_sh_stride * slice_col + threadIdx.x;
  666. }
  667. }
  668. int s_sh_wr = threadIdx.x;
  669. bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
  670. // Zero-points
  671. int zp_gl_rd;
  672. if constexpr (has_zp) {
  673. if constexpr (group_blocks == -1) {
  674. zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
  675. } else {
  676. zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
  677. zp_sh_stride * slice_col + threadIdx.x;
  678. }
  679. }
  680. int zp_sh_wr = threadIdx.x;
  681. bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride;
  682. // We use a different scale layout for grouped and column-wise quantization as
  683. // we scale a `half2` tile in column-major layout in the former and in
  684. // row-major in the latter case.
  685. int s_sh_rd;
  686. if constexpr (group_blocks != -1)
  687. s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
  688. (threadIdx.x % 32) / 4;
  689. else
  690. s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
  691. (threadIdx.x % 32) % 4;
  692. // Zero-points have the same read layout as the scales
  693. // (without column-wise case)
  694. constexpr int num_col_threads = 8;
  695. constexpr int num_row_threads = 4;
  696. constexpr int num_ints_per_thread = 8 / pack_factor;
  697. int zp_sh_rd;
  698. if constexpr (has_zp) {
  699. if constexpr (is_zp_float) {
  700. if constexpr (group_blocks != -1)
  701. zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
  702. (threadIdx.x % 32) / 4;
  703. } else {
  704. zp_sh_rd = num_ints_per_thread * num_col_threads *
  705. ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
  706. num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads);
  707. }
  708. }
  709. // Precompute which thread should not read memory in which iterations; this is
  710. // needed if there are more threads than required for a certain tilesize or
  711. // when the batchsize is not a multiple of 16.
  712. bool a_sh_wr_pred[a_sh_wr_iters];
  713. #pragma unroll
  714. for (int i = 0; i < a_sh_wr_iters; i++)
  715. a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
  716. // To ensure that writing and reading A tiles to/from shared memory, the
  717. // latter in fragment format, is fully bank conflict free, we need to use a
  718. // rather fancy XOR-based layout. The key here is that neither reads nor
  719. // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
  720. // same shared memory banks. Further, it seems (based on NSight-Compute) that
  721. // each warp must also write a consecutive memory segment?
  722. auto transform_a = [&](int i) {
  723. int row = i / a_gl_rd_delta_o;
  724. return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
  725. };
  726. // Since the computation of this remapping is non-trivial and, due to our main
  727. // loop unrolls, all shared memory accesses are static, we simply precompute
  728. // both transformed reads and writes.
  729. int a_sh_wr_trans[a_sh_wr_iters];
  730. #pragma unroll
  731. for (int i = 0; i < a_sh_wr_iters; i++)
  732. a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
  733. int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
  734. #pragma unroll
  735. for (int i = 0; i < b_sh_wr_iters; i++) {
  736. #pragma unroll
  737. for (int j = 0; j < thread_m_blocks; j++)
  738. a_sh_rd_trans[i][j] =
  739. transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
  740. }
  741. // Since B-accesses have non-constant stride they have to be computed at
  742. // runtime; we break dependencies between subsequent accesses with a tile by
  743. // maintining multiple pointers (we have enough registers), a tiny
  744. // optimization.
  745. const int4* B_ptr[b_sh_wr_iters];
  746. #pragma unroll
  747. for (int i = 0; i < b_sh_wr_iters; i++)
  748. B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
  749. extern __shared__ int4 sh[];
  750. // Shared memory storage for global fetch pipelines.
  751. int4* sh_a = sh;
  752. int4* sh_b = sh_a + (stages * a_sh_stage);
  753. int4* sh_g_idx = sh_b + (stages * b_sh_stage);
  754. int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
  755. int4* sh_s = sh_zp + (stages * zp_sh_stage);
  756. // Register storage for double buffer of shared memory reads.
  757. FragA frag_a[2][thread_m_blocks];
  758. I4 frag_b_quant[2][b_thread_vecs];
  759. FragC frag_c[thread_m_blocks][4][2];
  760. FragS frag_s[2][4]; // No act-order
  761. FragS act_frag_s[2][4][4]; // For act-order
  762. int frag_qzp[2][num_ints_per_thread]; // Zero-points
  763. FragZP frag_zp; // Zero-points in fp16
  764. FragZPF frag_zpf[2][4]; // float16 zero-points
  765. // Zero accumulators.
  766. auto zero_accums = [&]() {
  767. #pragma unroll
  768. for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
  769. reinterpret_cast<float*>(frag_c)[i] = 0;
  770. };
  771. int sh_first_group_id = -1;
  772. int sh_num_groups = -1;
  773. constexpr int sh_max_num_groups = 32;
  774. auto fetch_scales_to_shared = [&](bool is_async, int first_group_id,
  775. int last_group_id) {
  776. sh_first_group_id = first_group_id;
  777. sh_num_groups = last_group_id - first_group_id + 1;
  778. if (sh_num_groups < sh_max_num_groups) {
  779. sh_num_groups = sh_max_num_groups;
  780. }
  781. if (sh_first_group_id + sh_num_groups > num_groups) {
  782. sh_num_groups = num_groups - sh_first_group_id;
  783. }
  784. int row_offset = first_group_id * s_gl_stride;
  785. if (is_async) {
  786. for (int i = 0; i < sh_num_groups; i++) {
  787. if (threadIdx.x < s_sh_stride) {
  788. cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x],
  789. &scales_ptr[row_offset + (i * s_gl_stride) +
  790. slice_n_offset + threadIdx.x]);
  791. }
  792. }
  793. } else {
  794. for (int i = 0; i < sh_num_groups; i++) {
  795. if (threadIdx.x < s_sh_stride) {
  796. sh_s[(i * s_sh_stride) + threadIdx.x] =
  797. scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset +
  798. threadIdx.x];
  799. }
  800. }
  801. }
  802. };
  803. // Asynchronously fetch the next A, B and s tile from global to the next
  804. // shared memory pipeline location.
  805. auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
  806. if (pred) {
  807. int4* sh_a_stage = sh_a + a_sh_stage * pipe;
  808. #pragma unroll
  809. for (int i = 0; i < a_sh_wr_iters; i++) {
  810. cp_async4_pred(
  811. &sh_a_stage[a_sh_wr_trans[i]],
  812. &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
  813. a_sh_wr_pred[i]);
  814. }
  815. int4* sh_b_stage = sh_b + b_sh_stage * pipe;
  816. #pragma unroll
  817. for (int i = 0; i < b_sh_wr_iters; i++) {
  818. #pragma unroll
  819. for (int j = 0; j < b_thread_vecs; j++) {
  820. cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j);
  821. }
  822. B_ptr[i] += b_gl_rd_delta_o;
  823. }
  824. if constexpr (has_act_order) {
  825. // Fetch g_idx thread-block portion
  826. int full_pipe = a_off;
  827. int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe;
  828. if (cur_k < prob_k && cur_k < slice_k_finish) {
  829. int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
  830. int4 const* cur_g_idx_stage_ptr =
  831. reinterpret_cast<int4 const*>(&g_idx[cur_k]);
  832. if (threadIdx.x < g_idx_stage) {
  833. cp_async4_pred(&sh_g_idx_stage[threadIdx.x],
  834. &cur_g_idx_stage_ptr[threadIdx.x]);
  835. }
  836. }
  837. } else {
  838. if constexpr (group_blocks != -1) {
  839. int4* sh_s_stage = sh_s + s_sh_stage * pipe;
  840. if constexpr (group_blocks >= thread_k_blocks) {
  841. // Only fetch scales if this tile starts a new group
  842. if (pipe % (group_blocks / thread_k_blocks) == 0) {
  843. if (s_sh_wr_pred) {
  844. cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
  845. }
  846. s_gl_rd += s_gl_rd_delta;
  847. }
  848. } else {
  849. for (int i = 0; i < s_tb_groups; i++) {
  850. if (s_sh_wr_pred) {
  851. cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr],
  852. &scales_ptr[s_gl_rd]);
  853. }
  854. s_gl_rd += s_gl_rd_delta;
  855. }
  856. }
  857. }
  858. if constexpr (has_zp && group_blocks != -1) {
  859. int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
  860. if constexpr (group_blocks >= thread_k_blocks) {
  861. // Only fetch zero-points if this tile starts a new group
  862. if (pipe % (group_blocks / thread_k_blocks) == 0) {
  863. if (zp_sh_wr_pred) {
  864. cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]);
  865. }
  866. zp_gl_rd += zp_gl_rd_delta;
  867. }
  868. } else {
  869. for (int i = 0; i < zp_tb_groups; i++) {
  870. if (zp_sh_wr_pred) {
  871. cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr],
  872. &zp_ptr[zp_gl_rd]);
  873. }
  874. zp_gl_rd += zp_gl_rd_delta;
  875. }
  876. }
  877. }
  878. }
  879. }
  880. // Insert a fence even when we are winding down the pipeline to ensure that
  881. // waiting is also correct at this point.
  882. cp_async_fence();
  883. };
  884. auto fetch_zp_to_shared = [&]() {
  885. if (zp_sh_wr_pred) {
  886. cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]);
  887. }
  888. };
  889. // Wait until the next thread tile has been loaded to shared memory.
  890. auto wait_for_stage = [&]() {
  891. // We only have `stages - 2` active fetches since we are double buffering
  892. // and can only issue the next fetch when it is guaranteed that the previous
  893. // shared memory load is fully complete (as it may otherwise be
  894. // overwritten).
  895. cp_async_wait<stages - 2>();
  896. __syncthreads();
  897. };
  898. // Load the next sub-tile from the current location in the shared memory pipe
  899. // into the current register buffer.
  900. auto fetch_to_registers = [&](int k, int pipe) {
  901. int4* sh_a_stage = sh_a + a_sh_stage * pipe;
  902. #pragma unroll
  903. for (int i = 0; i < thread_m_blocks; i++)
  904. ldsm4<scalar_t>(frag_a[k % 2][i],
  905. &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
  906. int4* sh_b_stage = sh_b + b_sh_stage * pipe;
  907. #pragma unroll
  908. for (int i = 0; i < b_thread_vecs; i++) {
  909. frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>(
  910. &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]);
  911. }
  912. };
  913. bool is_same_group[stages];
  914. int same_group_id[stages];
  915. auto init_same_group = [&](int pipe) {
  916. if constexpr (!has_act_order) {
  917. is_same_group[pipe] = false;
  918. same_group_id[pipe] = 0;
  919. return;
  920. }
  921. int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
  922. int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);
  923. int group_id_1 = sh_g_idx_int_ptr[0];
  924. int group_id_2 = sh_g_idx_int_ptr[tb_k - 1];
  925. is_same_group[pipe] = group_id_1 == group_id_2;
  926. same_group_id[pipe] = group_id_1;
  927. };
  928. auto fetch_scales_to_registers = [&](int k, int full_pipe) {
  929. int pipe = full_pipe % stages;
  930. if constexpr (!has_act_order) {
  931. // No act-order case
  932. if constexpr (group_blocks != -1) {
  933. if constexpr (group_blocks >= thread_k_blocks) {
  934. int4* sh_s_stage =
  935. sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
  936. (pipe / (group_blocks / thread_k_blocks)));
  937. reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
  938. } else {
  939. int warp_id = threadIdx.x / 32;
  940. int n_warps = thread_n_blocks / 4;
  941. int warp_row = warp_id / n_warps;
  942. int cur_k = warp_row * 16;
  943. cur_k += k_iter_size * (k % b_sh_wr_iters);
  944. int k_blocks = cur_k / 16;
  945. int cur_group_id = k_blocks / group_blocks;
  946. int4* sh_s_stage = sh_s + s_sh_stage * pipe;
  947. reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
  948. sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
  949. }
  950. }
  951. return;
  952. }
  953. // Act-order case
  954. // Determine K of the "current" thread-block
  955. int cur_k = slice_k_start + tb_k * full_pipe;
  956. if (cur_k >= prob_k || cur_k >= slice_k_finish) {
  957. return;
  958. }
  959. // Reset (to current thread-block) since we read g_idx portion from the
  960. // shared memory
  961. cur_k = 0;
  962. // Progress to current iteration
  963. cur_k += k_iter_size * (k % b_sh_wr_iters);
  964. // Determine "position" inside the thread-block (based on warp and
  965. // thread-id)
  966. int warp_id = threadIdx.x / 32;
  967. int n_warps =
  968. thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N
  969. int warp_row = warp_id / n_warps;
  970. int warp_col = warp_id % n_warps;
  971. cur_k += warp_row * 16;
  972. int th_id = threadIdx.x % 32;
  973. cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix
  974. int s_col_shift =
  975. /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) +
  976. (th_id / 4) * act_s_col_stride;
  977. if (is_same_group[pipe]) {
  978. if (k % 2 == 0) {
  979. *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) =
  980. sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride +
  981. s_col_shift];
  982. } else {
  983. *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) =
  984. *(reinterpret_cast<int4*>(&(act_frag_s[(k - 1) % 2][0][0])));
  985. }
  986. for (int i = 1; i < 4; i++) {
  987. *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) =
  988. *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0])));
  989. }
  990. return;
  991. }
  992. int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
  993. int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);
  994. constexpr int k_frag_offsets[4] = {0, 1, 8,
  995. 9}; // Tensor core offsets per thread
  996. #pragma unroll
  997. for (int i = 0; i < 4; i++) {
  998. int actual_k = cur_k + k_frag_offsets[i];
  999. int group_id = sh_g_idx_int_ptr[actual_k];
  1000. int rel_group_id = group_id - sh_first_group_id;
  1001. *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) =
  1002. sh_s[rel_group_id * s_sh_stride + s_col_shift];
  1003. }
  1004. };
  1005. auto fetch_zp_to_registers = [&](int k, int full_pipe) {
  1006. // This code does not handle group_blocks == 0,
  1007. // which signifies act_order.
  1008. // has_zp implies AWQ, which doesn't have act_order,
  1009. static_assert(!has_zp || group_blocks != 0);
  1010. if constexpr (has_zp && !is_zp_float) {
  1011. int pipe = full_pipe % stages;
  1012. if constexpr (group_blocks == -1) {
  1013. for (int i = 0; i < num_ints_per_thread; i++) {
  1014. frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp))[zp_sh_rd + i];
  1015. }
  1016. } else if constexpr (group_blocks >= thread_k_blocks) {
  1017. int4* sh_zp_stage =
  1018. sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
  1019. (pipe / (group_blocks / thread_k_blocks)));
  1020. for (int i = 0; i < num_ints_per_thread; i++) {
  1021. frag_qzp[k % 2][i] =
  1022. (reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
  1023. }
  1024. } else {
  1025. int warp_id = threadIdx.x / 32;
  1026. int n_warps = thread_n_blocks / 4;
  1027. int warp_row = warp_id / n_warps;
  1028. int cur_k = warp_row * 16;
  1029. cur_k += k_iter_size * (k % b_sh_wr_iters);
  1030. int k_blocks = cur_k / 16;
  1031. int cur_group_id = 0;
  1032. // Suppress bogus and persistent divide-by-zero warning
  1033. #pragma nv_diagnostic push
  1034. #pragma nv_diag_suppress divide_by_zero
  1035. cur_group_id = k_blocks / group_blocks;
  1036. #pragma nv_diagnostic pop
  1037. int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
  1038. sh_zp_stage += cur_group_id * zp_sh_stride;
  1039. for (int i = 0; i < num_ints_per_thread; i++) {
  1040. frag_qzp[k % 2][i] =
  1041. (reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
  1042. }
  1043. }
  1044. }
  1045. if constexpr (has_zp && is_zp_float) {
  1046. int pipe = full_pipe % stages;
  1047. if constexpr (group_blocks != -1) {
  1048. if constexpr (group_blocks >= thread_k_blocks) {
  1049. int4* sh_zp_stage =
  1050. sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
  1051. (pipe / (group_blocks / thread_k_blocks)));
  1052. reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd];
  1053. } else {
  1054. int warp_id = threadIdx.x / 32;
  1055. int n_warps = thread_n_blocks / 4;
  1056. int warp_row = warp_id / n_warps;
  1057. int cur_k = warp_row * 16;
  1058. cur_k += k_iter_size * (k % b_sh_wr_iters);
  1059. int k_blocks = cur_k / 16;
  1060. int cur_group_id = k_blocks / group_blocks;
  1061. int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
  1062. reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] =
  1063. sh_zp_stage[zp_sh_rd + cur_group_id * zp_sh_stride];
  1064. }
  1065. }
  1066. }
  1067. };
  1068. // Execute the actual tensor core matmul of a sub-tile.
  1069. auto matmul = [&](int k) {
  1070. if constexpr (has_zp && !is_zp_float) {
  1071. FragB frag_zp_0;
  1072. FragB frag_zp_1;
  1073. int zp_quant_0, zp_quant_1;
  1074. if constexpr (w_type.size_bits() == 4) {
  1075. zp_quant_0 = frag_qzp[k % 2][0];
  1076. zp_quant_1 = zp_quant_0 >> 8;
  1077. } else {
  1078. static_assert(w_type.size_bits() == 8);
  1079. zp_quant_0 = frag_qzp[k % 2][0];
  1080. zp_quant_1 = frag_qzp[k % 2][1];
  1081. }
  1082. frag_zp_0 = dequant<scalar_t, w_type_id>(zp_quant_0);
  1083. frag_zp_1 = dequant<scalar_t, w_type_id>(zp_quant_1);
  1084. frag_zp[0] = frag_zp_0[0];
  1085. frag_zp[1] = frag_zp_0[1];
  1086. frag_zp[2] = frag_zp_1[0];
  1087. frag_zp[3] = frag_zp_1[1];
  1088. }
  1089. // We have the m dimension as the inner loop in order to encourage overlapping
  1090. // dequantization and matmul operations.
  1091. #pragma unroll
  1092. for (int j = 0; j < 4; j++) {
  1093. FragB frag_b0;
  1094. FragB frag_b1;
  1095. int b_quant_0, b_quant_1;
  1096. if constexpr (w_type.size_bits() == 4) {
  1097. b_quant_0 = frag_b_quant[k % 2][0][j];
  1098. b_quant_1 = b_quant_0 >> 8;
  1099. } else {
  1100. static_assert(w_type.size_bits() == 8);
  1101. int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);
  1102. b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
  1103. b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
  1104. }
  1105. frag_b0 = dequant<scalar_t, w_type_id>(b_quant_0);
  1106. frag_b1 = dequant<scalar_t, w_type_id>(b_quant_1);
  1107. // Apply zero-point to frag_b0
  1108. if constexpr (has_zp && !is_zp_float) {
  1109. sub_zp<scalar_t>(frag_b0, frag_zp[j], 0);
  1110. }
  1111. if constexpr (has_zp && is_zp_float && group_blocks != -1) {
  1112. sub_zpf<scalar_t>(frag_b0, frag_zpf[k % 2][j], 0);
  1113. }
  1114. // Apply scale to frag_b0
  1115. if constexpr (has_act_order) {
  1116. scale4<scalar_t>(frag_b0, act_frag_s[k % 2][0][j],
  1117. act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j],
  1118. act_frag_s[k % 2][3][j], 0);
  1119. } else {
  1120. if constexpr (group_blocks != -1) {
  1121. scale<scalar_t>(frag_b0, frag_s[k % 2][j], 0);
  1122. }
  1123. }
  1124. // Apply zero-point to frag_b1
  1125. if constexpr (has_zp && !is_zp_float) {
  1126. sub_zp<scalar_t>(frag_b1, frag_zp[j], 1);
  1127. }
  1128. if constexpr (has_zp && is_zp_float && group_blocks != -1) {
  1129. sub_zpf<scalar_t>(frag_b1, frag_zpf[k % 2][j], 1);
  1130. }
  1131. // Apply scale to frag_b1
  1132. if constexpr (has_act_order) {
  1133. scale4<scalar_t>(frag_b1, act_frag_s[k % 2][0][j],
  1134. act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j],
  1135. act_frag_s[k % 2][3][j], 1);
  1136. } else {
  1137. if constexpr (group_blocks != -1) {
  1138. scale<scalar_t>(frag_b1, frag_s[k % 2][j], 1);
  1139. }
  1140. }
  1141. #pragma unroll
  1142. for (int i = 0; i < thread_m_blocks; i++) {
  1143. mma<scalar_t>(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
  1144. mma<scalar_t>(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
  1145. }
  1146. }
  1147. };
  1148. // Since we slice across the k dimension of a tile in order to increase the
  1149. // number of warps while keeping the n dimension of a tile reasonable, we have
  1150. // multiple warps that accumulate their partial sums of the same output
  1151. // location; which we have to reduce over in the end. We do in shared memory.
  1152. auto thread_block_reduce = [&]() {
  1153. constexpr int red_off = threads / b_sh_stride_threads / 2;
  1154. if (red_off >= 1) {
  1155. int red_idx = threadIdx.x / b_sh_stride_threads;
  1156. constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
  1157. constexpr int red_sh_delta = b_sh_stride_threads;
  1158. int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
  1159. (threadIdx.x % b_sh_stride_threads);
  1160. // Parallel logarithmic shared memory reduction. We make sure to avoid any
  1161. // unnecessary read or write iterations, e.g., for two warps we write only
  1162. // once by warp 1 and read only once by warp 0.
  1163. #pragma unroll
  1164. for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
  1165. #pragma unroll
  1166. for (int i = red_off; i > 0; i /= 2) {
  1167. if (i <= red_idx && red_idx < 2 * i) {
  1168. #pragma unroll
  1169. for (int j = 0; j < 4 * 2; j++) {
  1170. int red_sh_wr =
  1171. red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
  1172. if (i < red_off) {
  1173. float* c_rd =
  1174. reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
  1175. float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
  1176. #pragma unroll
  1177. for (int k = 0; k < 4; k++)
  1178. reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
  1179. c_rd[k] + c_wr[k];
  1180. }
  1181. sh[red_sh_wr] =
  1182. reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
  1183. }
  1184. }
  1185. __syncthreads();
  1186. }
  1187. if (red_idx == 0) {
  1188. #pragma unroll
  1189. for (int i = 0; i < 4 * 2; i++) {
  1190. float* c_rd =
  1191. reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
  1192. #pragma unroll
  1193. for (int j = 0; j < 4; j++)
  1194. reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
  1195. c_rd[j];
  1196. }
  1197. }
  1198. __syncthreads();
  1199. }
  1200. }
  1201. };
  1202. // Since multiple threadblocks may process parts of the same column slice, we
  1203. // finally have to globally reduce over the results. As the striped
  1204. // partitioning minimizes the number of such reductions and our outputs are
  1205. // usually rather small, we perform this reduction serially in L2 cache.
  1206. auto global_reduce_fp16 = [&](bool first = false, bool last = false) {
  1207. // We are very careful here to reduce directly in the output buffer to
  1208. // maximize L2 cache utilization in this step. To do this, we write out
  1209. // results in FP16 (but still reduce with FP32 compute).
  1210. constexpr int active_threads = 32 * thread_n_blocks / 4;
  1211. if (threadIdx.x < active_threads) {
  1212. int c_gl_stride = prob_n / 8;
  1213. int c_gl_wr_delta_o = 8 * c_gl_stride;
  1214. int c_gl_wr_delta_i = 4 * (active_threads / 32);
  1215. int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) +
  1216. 4 * (threadIdx.x / 32) + threadIdx.x % 4;
  1217. c_gl_wr += (2 * thread_n_blocks) * slice_col;
  1218. constexpr int c_sh_wr_delta = active_threads;
  1219. int c_sh_wr = threadIdx.x;
  1220. int row = (threadIdx.x % 32) / 4;
  1221. if (!first) {
  1222. // Interestingly, doing direct global accesses here really seems to mess up
  1223. // the compiler and lead to slowdowns, hence we also use async-copies even
  1224. // though these fetches are not actually asynchronous.
  1225. #pragma unroll
  1226. for (int i = 0; i < thread_m_blocks * 4; i++) {
  1227. cp_async4_pred(
  1228. &sh[c_sh_wr + c_sh_wr_delta * i],
  1229. &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
  1230. c_gl_wr_delta_i * (i % 2)],
  1231. i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
  1232. }
  1233. cp_async_fence();
  1234. cp_async_wait<0>();
  1235. }
  1236. #pragma unroll
  1237. for (int i = 0; i < thread_m_blocks * 4; i++) {
  1238. if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
  1239. if (!first) {
  1240. int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
  1241. #pragma unroll
  1242. for (int j = 0; j < 2 * 4; j++) {
  1243. reinterpret_cast<float*>(
  1244. &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=
  1245. Dtype::num2float(reinterpret_cast<scalar_t*>(&c_red)[j]);
  1246. }
  1247. }
  1248. if (!last) {
  1249. int4 c;
  1250. #pragma unroll
  1251. for (int j = 0; j < 2 * 4; j++) {
  1252. reinterpret_cast<scalar_t*>(&c)[j] =
  1253. Dtype::float2num(reinterpret_cast<float*>(
  1254. &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]);
  1255. }
  1256. C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =
  1257. c;
  1258. }
  1259. }
  1260. }
  1261. }
  1262. };
  1263. // Globally reduce over threadblocks that compute the same column block.
  1264. // We use a tmp C buffer to reduce in full fp32 precision.
  1265. auto global_reduce_fp32 = [&](bool first = false, bool last = false) {
  1266. constexpr int tb_m = thread_m_blocks * 16;
  1267. constexpr int tb_n = thread_n_blocks * 16;
  1268. constexpr int c_size = tb_m * tb_n * sizeof(float) / 16;
  1269. constexpr int active_threads = 32 * thread_n_blocks / 4;
  1270. bool is_th_active = threadIdx.x < active_threads;
  1271. int par_offset = c_size * n_tiles * par_id;
  1272. int slice_offset = c_size * slice_col;
  1273. constexpr int num_floats = thread_m_blocks * 4 * 2 * 4;
  1274. constexpr int th_size = num_floats * sizeof(float) / 16;
  1275. int c_cur_offset = par_offset + slice_offset;
  1276. if (!is_th_active) {
  1277. return;
  1278. }
  1279. if (!first) {
  1280. float* frag_c_ptr = reinterpret_cast<float*>(&frag_c);
  1281. #pragma unroll
  1282. for (int k = 0; k < th_size; k++) {
  1283. sh[threadIdx.x] =
  1284. C_tmp[c_cur_offset + active_threads * k + threadIdx.x];
  1285. float* sh_c_ptr = reinterpret_cast<float*>(&sh[threadIdx.x]);
  1286. #pragma unroll
  1287. for (int f = 0; f < 4; f++) {
  1288. frag_c_ptr[k * 4 + f] += sh_c_ptr[f];
  1289. }
  1290. }
  1291. }
  1292. if (!last) {
  1293. int4* frag_c_ptr = reinterpret_cast<int4*>(&frag_c);
  1294. #pragma unroll
  1295. for (int k = 0; k < th_size; k++) {
  1296. C_tmp[c_cur_offset + active_threads * k + threadIdx.x] = frag_c_ptr[k];
  1297. }
  1298. }
  1299. };
  1300. // Write out the reduce final result in the correct layout. We only actually
  1301. // reshuffle matrix fragments in this step, the reduction above is performed
  1302. // in fragment layout.
  1303. auto write_result = [&]() {
  1304. int c_gl_stride = prob_n / 8;
  1305. constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
  1306. int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
  1307. constexpr int c_sh_rd_delta =
  1308. c_sh_stride * (threads / (2 * thread_n_blocks));
  1309. int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) +
  1310. (threadIdx.x % (2 * thread_n_blocks));
  1311. c_gl_wr += (2 * thread_n_blocks) * slice_col;
  1312. int c_sh_wr =
  1313. (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;
  1314. c_sh_wr += 32 * (threadIdx.x / 32);
  1315. int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) +
  1316. (threadIdx.x % (2 * thread_n_blocks));
  1317. int c_gl_wr_end = c_gl_stride * prob_m;
  1318. // We first reorder in shared memory to guarantee the most efficient final
  1319. // global write patterns
  1320. auto write = [&](int idx, float c0, float c1, FragS& s) {
  1321. scalar_t2 res =
  1322. Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));
  1323. // For per-column quantization we finally apply the scale here (only for
  1324. // 4-bit)
  1325. if constexpr (!has_act_order && group_blocks == -1 &&
  1326. w_type.size_bits() == 4) {
  1327. res = __hmul2(res, s[0]);
  1328. }
  1329. ((scalar_t2*)sh)[idx] = res;
  1330. };
  1331. if (threadIdx.x / 32 < thread_n_blocks / 4) {
  1332. #pragma unroll
  1333. for (int i = 0; i < thread_m_blocks; i++) {
  1334. #pragma unroll
  1335. for (int j = 0; j < 4; j++) {
  1336. int wr = c_sh_wr + 8 * j;
  1337. write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
  1338. frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
  1339. write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],
  1340. frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
  1341. write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],
  1342. frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
  1343. write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],
  1344. frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
  1345. }
  1346. c_sh_wr += 16 * (4 * c_sh_stride);
  1347. }
  1348. }
  1349. __syncthreads();
  1350. #pragma unroll
  1351. for (int i = 0;
  1352. i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
  1353. i++) {
  1354. if (c_gl_wr < c_gl_wr_end) {
  1355. C[c_gl_wr] = sh[c_sh_rd];
  1356. c_gl_wr += c_gl_wr_delta;
  1357. c_sh_rd += c_sh_rd_delta;
  1358. }
  1359. }
  1360. };
  1361. // Start global fetch and register load pipelines.
  1362. auto start_pipes = [&]() {
  1363. #pragma unroll
  1364. for (int i = 0; i < stages - 1; i++) {
  1365. if (has_act_order && i == 0) {
  1366. int last_g_idx = slice_k_start + stages * tb_k * 2;
  1367. if (last_g_idx >= prob_k) {
  1368. last_g_idx = prob_k - 1;
  1369. }
  1370. fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]);
  1371. }
  1372. if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
  1373. if (i == 0) {
  1374. fetch_zp_to_shared();
  1375. }
  1376. }
  1377. fetch_to_shared(i, i, i < slice_iters);
  1378. }
  1379. zero_accums();
  1380. wait_for_stage();
  1381. init_same_group(0);
  1382. fetch_to_registers(0, 0);
  1383. fetch_scales_to_registers(0, 0);
  1384. fetch_zp_to_registers(0, 0);
  1385. a_gl_rd += a_gl_rd_delta_o * (stages - 1);
  1386. slice_k_start_shared_fetch += tb_k * (stages - 1);
  1387. };
  1388. if (slice_iters) {
  1389. start_pipes();
  1390. }
  1391. // Main loop.
  1392. while (slice_iters) {
  1393. // We unroll over both the global fetch and the register load pipeline to
  1394. // ensure all shared memory accesses are static. Note that both pipelines
  1395. // have even length meaning that the next iteration will always start at
  1396. // index 0.
  1397. #pragma unroll
  1398. for (int pipe = 0; pipe < stages;) {
  1399. #pragma unroll
  1400. for (int k = 0; k < b_sh_wr_iters; k++) {
  1401. fetch_to_registers(k + 1, pipe % stages);
  1402. fetch_scales_to_registers(k + 1, pipe);
  1403. fetch_zp_to_registers(k + 1, pipe);
  1404. if (k == b_sh_wr_iters - 2) {
  1405. fetch_to_shared((pipe + stages - 1) % stages, pipe,
  1406. slice_iters >= stages);
  1407. pipe++;
  1408. wait_for_stage();
  1409. init_same_group(pipe % stages);
  1410. }
  1411. matmul(k);
  1412. }
  1413. slice_iters--;
  1414. if (slice_iters == 0) {
  1415. break;
  1416. }
  1417. }
  1418. a_gl_rd += a_gl_rd_delta_o * stages;
  1419. slice_k_start += tb_k * stages;
  1420. slice_k_start_shared_fetch += tb_k * stages;
  1421. if constexpr (has_act_order) {
  1422. int first_group_id = g_idx[slice_k_start];
  1423. int last_g_idx = slice_k_start + stages * tb_k * 2;
  1424. if (last_g_idx >= prob_k) {
  1425. last_g_idx = prob_k - 1;
  1426. }
  1427. int last_group_id = g_idx[last_g_idx];
  1428. if (last_group_id >= sh_first_group_id + sh_num_groups) {
  1429. fetch_scales_to_shared(false, first_group_id, last_group_id);
  1430. __syncthreads();
  1431. }
  1432. }
  1433. // Process results and, if necessary, proceed to the next column slice.
  1434. // While this pattern may not be the most readable, other ways of writing
  1435. // the loop seemed to noticeably worse performance after compilation.
  1436. if (slice_iters == 0) {
  1437. cp_async_wait<0>();
  1438. bool last = slice_idx == slice_count - 1;
  1439. // For per-column scales, we only fetch them here in the final step before
  1440. // write-out
  1441. if constexpr (!has_act_order && group_blocks == -1) {
  1442. if constexpr (w_type.size_bits() == 8) {
  1443. if (s_sh_wr_pred) {
  1444. cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
  1445. }
  1446. cp_async_fence();
  1447. } else {
  1448. if (last) {
  1449. if (s_sh_wr_pred) {
  1450. cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
  1451. }
  1452. cp_async_fence();
  1453. }
  1454. }
  1455. }
  1456. thread_block_reduce();
  1457. if constexpr (!has_act_order && group_blocks == -1) {
  1458. if constexpr (w_type.size_bits() == 8) {
  1459. cp_async_wait<0>();
  1460. __syncthreads();
  1461. if (threadIdx.x / 32 < thread_n_blocks / 4) {
  1462. reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
  1463. reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
  1464. }
  1465. } else {
  1466. if (last) {
  1467. cp_async_wait<0>();
  1468. __syncthreads();
  1469. if (threadIdx.x / 32 < thread_n_blocks / 4) {
  1470. reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
  1471. reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
  1472. }
  1473. }
  1474. }
  1475. }
  1476. // For 8-bit channelwise, we apply the scale before the global reduction
  1477. // that converts the fp32 results to fp16 (so that we avoid possible
  1478. // overflow in fp16)
  1479. if constexpr (!has_act_order && group_blocks == -1 &&
  1480. w_type.size_bits() == 8) {
  1481. if (threadIdx.x / 32 < thread_n_blocks / 4) {
  1482. #pragma unroll
  1483. for (int i = 0; i < thread_m_blocks; i++) {
  1484. #pragma unroll
  1485. for (int j = 0; j < 4; j++) {
  1486. scale_float<scalar_t>(
  1487. reinterpret_cast<float*>(&frag_c[i][j][0][0]),
  1488. frag_s[j / 2][2 * (j % 2) + 0]);
  1489. scale_float<scalar_t>(
  1490. reinterpret_cast<float*>(&frag_c[i][j][0][2]),
  1491. frag_s[j / 2][2 * (j % 2) + 0]);
  1492. scale_float<scalar_t>(
  1493. reinterpret_cast<float*>(&frag_c[i][j][1][0]),
  1494. frag_s[j / 2][2 * (j % 2) + 1]);
  1495. scale_float<scalar_t>(
  1496. reinterpret_cast<float*>(&frag_c[i][j][1][2]),
  1497. frag_s[j / 2][2 * (j % 2) + 1]);
  1498. }
  1499. }
  1500. }
  1501. }
  1502. if (slice_count > 1) { // only globally reduce if there is more than one
  1503. // block in a slice
  1504. barrier_acquire(&locks[slice_col], slice_idx);
  1505. if (use_fp32_reduce) {
  1506. global_reduce_fp32(slice_idx == 0, last);
  1507. } else {
  1508. global_reduce_fp16(slice_idx == 0, last);
  1509. }
  1510. barrier_release(&locks[slice_col], last);
  1511. }
  1512. if (last) // only the last block in a slice actually writes the result
  1513. write_result();
  1514. slice_row = 0;
  1515. slice_col_par++;
  1516. slice_col++;
  1517. init_slice();
  1518. if (slice_iters) {
  1519. a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
  1520. (threadIdx.x % a_gl_rd_delta_o);
  1521. #pragma unroll
  1522. for (int i = 0; i < b_sh_wr_iters; i++)
  1523. B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
  1524. if (slice_col == 0) {
  1525. #pragma unroll
  1526. for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
  1527. }
  1528. // Update slice k/n for scales loading
  1529. if constexpr (has_act_order) {
  1530. slice_k_start = tb_k * slice_row;
  1531. slice_k_finish = slice_k_start + tb_k * slice_iters;
  1532. slice_k_start_shared_fetch = slice_k_start;
  1533. slice_n_offset = act_s_col_tb_stride * slice_col;
  1534. } else {
  1535. s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
  1536. zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
  1537. }
  1538. start_pipes();
  1539. }
  1540. }
  1541. }
  1542. }
  1543. #define __CALL_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
  1544. HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS, \
  1545. IS_ZP_FLOAT) \
  1546. if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
  1547. thread_n_blocks == THREAD_N_BLOCKS && \
  1548. thread_k_blocks == THREAD_K_BLOCKS && \
  1549. has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
  1550. group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
  1551. is_zp_float == IS_ZP_FLOAT) { \
  1552. if constexpr (!IS_ZP_FLOAT || std::is_same<scalar_t, half>::value) { \
  1553. cudaFuncSetAttribute( \
  1554. Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
  1555. THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, \
  1556. HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT>, \
  1557. cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
  1558. Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
  1559. THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
  1560. HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT> \
  1561. <<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
  1562. A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \
  1563. num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \
  1564. } \
  1565. }
  1566. typedef struct {
  1567. int thread_k;
  1568. int thread_n;
  1569. int num_threads;
  1570. } thread_config_t;
  1571. typedef struct {
  1572. int max_m_blocks;
  1573. thread_config_t tb_cfg;
  1574. } exec_config_t;
  1575. thread_config_t small_batch_thread_configs[] = {
  1576. // Ordered by priority
  1577. // thread_k, thread_n, num_threads
  1578. {128, 128, 256},
  1579. {64, 128, 128},
  1580. {128, 64, 128},
  1581. };
  1582. thread_config_t large_batch_thread_configs[] = {
  1583. // Ordered by priority
  1584. // thread_k, thread_n, num_threads
  1585. {64, 256, 256},
  1586. {64, 128, 128},
  1587. {128, 64, 128},
  1588. };
  1589. int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
  1590. int prob_n, int prob_k, int num_bits, int group_size,
  1591. bool has_act_order, bool is_k_full) {
  1592. bool cache_scales_chunk = has_act_order && !is_k_full;
  1593. int tb_n = th_config.thread_n;
  1594. int tb_k = th_config.thread_k;
  1595. // Get max scale groups per thread-block
  1596. int tb_groups;
  1597. if (group_size == -1) {
  1598. tb_groups = 1;
  1599. } else if (group_size == 0) {
  1600. tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size
  1601. } else {
  1602. tb_groups = div_ceil(tb_k, group_size);
  1603. }
  1604. if (cache_scales_chunk) {
  1605. int load_groups =
  1606. tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
  1607. load_groups = max(load_groups, 32); // We load at least 32 scale groups
  1608. return load_groups * tb_n * 2;
  1609. } else {
  1610. int tb_scales = tb_groups * tb_n * 2;
  1611. return tb_scales * pipe_stages;
  1612. }
  1613. }
  1614. bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,
  1615. int prob_m, int prob_n, int prob_k, int num_bits,
  1616. int scales_cache_size, int max_shared_mem) {
  1617. int pack_factor = 32 / num_bits;
  1618. // Get B size
  1619. int tb_k = th_config.thread_k;
  1620. int tb_n = th_config.thread_n;
  1621. int b_size = (tb_k * tb_n / pack_factor) * 4;
  1622. // Get A size
  1623. int m_blocks = div_ceil(prob_m, 16);
  1624. int tb_max_m = 16;
  1625. while (true) {
  1626. if (m_blocks >= max_m_blocks) {
  1627. tb_max_m *= max_m_blocks;
  1628. break;
  1629. }
  1630. max_m_blocks--;
  1631. if (max_m_blocks == 0) {
  1632. TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks);
  1633. }
  1634. }
  1635. int a_size = (tb_max_m * tb_k) * 2;
  1636. float pipe_size = (a_size + b_size) * pipe_stages;
  1637. TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity
  1638. return pipe_size < 0.95f * (max_shared_mem - scales_cache_size);
  1639. }
  1640. bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
  1641. int prob_m, int prob_n, int prob_k, int num_bits,
  1642. int group_size, bool has_act_order, bool is_k_full,
  1643. int max_shared_mem) {
  1644. // Sanity
  1645. if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
  1646. th_config.num_threads == -1) {
  1647. return false;
  1648. }
  1649. // Verify K/N are divisible by thread K/N
  1650. if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
  1651. return false;
  1652. }
  1653. // Verify min for thread K/N
  1654. if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
  1655. return false;
  1656. }
  1657. // num_threads must be at least 128 (= 4 warps)
  1658. if (th_config.num_threads < 128) {
  1659. return false;
  1660. }
  1661. // Determine cache for scales
  1662. int scales_cache_size =
  1663. get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
  1664. group_size, has_act_order, is_k_full);
  1665. // Check that pipeline fits into cache
  1666. if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k,
  1667. num_bits, scales_cache_size, max_shared_mem)) {
  1668. return false;
  1669. }
  1670. return true;
  1671. }
  1672. int determine_reduce_max_m(int prob_m, int max_par) {
  1673. constexpr int tile_m_size = 16;
  1674. if (prob_m <= tile_m_size) {
  1675. return tile_m_size;
  1676. } else if (prob_m <= tile_m_size * 2) {
  1677. return tile_m_size * 2;
  1678. } else if (prob_m <= tile_m_size * 3) {
  1679. return tile_m_size * 3;
  1680. } else if (prob_m <= tile_m_size * 4) {
  1681. return tile_m_size * 4;
  1682. } else {
  1683. int cur_par = min(div_ceil(prob_m, tile_m_size * 4), max_par);
  1684. return tile_m_size * 4 * cur_par;
  1685. }
  1686. }
  1687. exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
  1688. int num_bits, int group_size,
  1689. bool has_act_order, bool is_k_full,
  1690. int max_shared_mem) {
  1691. int max_m_blocks = 4;
  1692. while (max_m_blocks > 0) {
  1693. if (prob_m <= 16) {
  1694. for (auto th_config : small_batch_thread_configs) {
  1695. if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
  1696. num_bits, group_size, has_act_order, is_k_full,
  1697. max_shared_mem)) {
  1698. return exec_config_t{max_m_blocks, th_config};
  1699. }
  1700. }
  1701. } else {
  1702. for (auto th_config : large_batch_thread_configs) {
  1703. if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
  1704. num_bits, group_size, has_act_order, is_k_full,
  1705. max_shared_mem)) {
  1706. return exec_config_t{max_m_blocks, th_config};
  1707. }
  1708. }
  1709. }
  1710. max_m_blocks--; // Process less M blocks per invocation to reduce cache
  1711. // usage
  1712. }
  1713. return exec_config_t{0, {-1, -1, -1}};
  1714. }
  1715. #define GPTQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
  1716. __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \
  1717. false) \
  1718. __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \
  1719. false) \
  1720. __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \
  1721. false) \
  1722. __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \
  1723. false) \
  1724. \
  1725. __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \
  1726. false) \
  1727. __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \
  1728. false) \
  1729. __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \
  1730. false) \
  1731. __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \
  1732. false) \
  1733. \
  1734. __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \
  1735. false) \
  1736. __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \
  1737. false) \
  1738. __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \
  1739. false) \
  1740. __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \
  1741. false) \
  1742. \
  1743. __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \
  1744. false) \
  1745. __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \
  1746. false) \
  1747. __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \
  1748. false) \
  1749. __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \
  1750. false) \
  1751. \
  1752. __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \
  1753. false) \
  1754. __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \
  1755. false) \
  1756. __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \
  1757. false) \
  1758. __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \
  1759. false)
  1760. #define AWQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
  1761. __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \
  1762. false) \
  1763. __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \
  1764. false) \
  1765. __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
  1766. false) \
  1767. __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \
  1768. false) \
  1769. \
  1770. __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \
  1771. false) \
  1772. __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \
  1773. false) \
  1774. __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
  1775. false) \
  1776. __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \
  1777. false) \
  1778. \
  1779. __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \
  1780. false) \
  1781. __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \
  1782. false) \
  1783. __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
  1784. false) \
  1785. __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \
  1786. false) \
  1787. \
  1788. __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \
  1789. false) \
  1790. __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \
  1791. false) \
  1792. __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
  1793. false) \
  1794. __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, false)
  1795. // We currently have 4-bit models only with group_blocks == 4
  1796. #define HQQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
  1797. __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
  1798. true) \
  1799. __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
  1800. true) \
  1801. __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
  1802. true) \
  1803. __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, true)
  1804. template <typename scalar_t>
  1805. void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
  1806. void* zp, void* g_idx, void* perm, void* a_tmp, int prob_m,
  1807. int prob_n, int prob_k, void* workspace,
  1808. aphrodite::ScalarType const& q_type, bool has_act_order,
  1809. bool is_k_full, bool has_zp, int num_groups, int group_size,
  1810. int dev, cudaStream_t stream, int thread_k, int thread_n,
  1811. int sms, int max_par, bool use_fp32_reduce, bool is_zp_float) {
  1812. if (has_zp) {
  1813. TORCH_CHECK(
  1814. q_type == aphrodite::kU4 || q_type == aphrodite::kU8,
  1815. "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
  1816. } else {
  1817. TORCH_CHECK(
  1818. q_type == aphrodite::kU4B8 || q_type == aphrodite::kU8B128,
  1819. "q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
  1820. q_type.str());
  1821. }
  1822. TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
  1823. ", ", prob_n, ", ", prob_k, "]");
  1824. // TODO: remove alias when we start supporting other 8bit types
  1825. int num_bits = q_type.size_bits();
  1826. int tot_m = prob_m;
  1827. int tot_m_blocks = div_ceil(tot_m, 16);
  1828. int pad = 16 * tot_m_blocks - tot_m;
  1829. if (sms == -1) {
  1830. cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
  1831. }
  1832. int max_shared_mem = 0;
  1833. cudaDeviceGetAttribute(&max_shared_mem,
  1834. cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
  1835. TORCH_CHECK(max_shared_mem > 0);
  1836. // Set thread config
  1837. exec_config_t exec_cfg;
  1838. if (thread_k != -1 && thread_n != -1) {
  1839. // User-defined config
  1840. exec_cfg =
  1841. exec_config_t{4, thread_config_t{thread_k, thread_n, default_threads}};
  1842. } else {
  1843. // Auto config
  1844. exec_cfg =
  1845. determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size,
  1846. has_act_order, is_k_full, max_shared_mem);
  1847. }
  1848. TORCH_CHECK(exec_cfg.max_m_blocks > 0 &&
  1849. is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks,
  1850. prob_m, prob_n, prob_k, num_bits, group_size,
  1851. has_act_order, is_k_full, max_shared_mem),
  1852. "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks,
  1853. ", thread_k = ", exec_cfg.tb_cfg.thread_k,
  1854. ", thread_n = ", exec_cfg.tb_cfg.thread_n,
  1855. ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [",
  1856. prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
  1857. ", group_size = ", group_size,
  1858. ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
  1859. ", max_shared_mem = ", max_shared_mem);
  1860. int num_threads = exec_cfg.tb_cfg.num_threads;
  1861. thread_k = exec_cfg.tb_cfg.thread_k;
  1862. thread_n = exec_cfg.tb_cfg.thread_n;
  1863. int thread_k_blocks = thread_k / 16;
  1864. int thread_n_blocks = thread_n / 16;
  1865. int blocks = sms;
  1866. TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
  1867. " is not divisible by thread_n = ", thread_n);
  1868. TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
  1869. " is not divisible by thread_k = ", thread_k);
  1870. int group_blocks = 0;
  1871. if (has_act_order) {
  1872. if (is_k_full) {
  1873. TORCH_CHECK(group_size != -1);
  1874. group_blocks = group_size / 16;
  1875. TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
  1876. " is not divisible by group_blocks = ", group_blocks);
  1877. } else {
  1878. TORCH_CHECK(group_size == 0);
  1879. group_blocks = 0;
  1880. }
  1881. } else {
  1882. if (group_size == -1) {
  1883. group_blocks = -1;
  1884. } else {
  1885. group_blocks = group_size / 16;
  1886. TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
  1887. " is not divisible by group_blocks = ", group_blocks);
  1888. }
  1889. }
  1890. const int4* A_ptr = (const int4*)A;
  1891. const int4* B_ptr = (const int4*)B;
  1892. int4* C_ptr = (int4*)C;
  1893. int4* C_tmp_ptr = (int4*)C_tmp;
  1894. const int4* s_ptr = (const int4*)s;
  1895. const int4* zp_ptr = (const int4*)zp;
  1896. const int* g_idx_ptr = (const int*)g_idx;
  1897. const int* perm_ptr = (const int*)perm;
  1898. int4* a_tmp_ptr = (int4*)a_tmp;
  1899. int* locks = (int*)workspace;
  1900. if (has_act_order) {
  1901. // Permute A columns
  1902. int block_rows = div_ceil(prob_m, blocks);
  1903. permute_cols_kernel<<<blocks, default_threads, 0, stream>>>(
  1904. A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, block_rows);
  1905. A_ptr = a_tmp_ptr;
  1906. }
  1907. // If we have a full K, then we can run the non-act-order version of Marlin
  1908. // (since the weight rows are reordered by increasing group ids, and by having
  1909. // a full K, we have full original groups)
  1910. if (is_k_full) {
  1911. has_act_order = false;
  1912. }
  1913. // Main loop
  1914. for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) {
  1915. int thread_m_blocks = tot_m_blocks - i;
  1916. prob_m = tot_m - 16 * i;
  1917. int par = 1;
  1918. if (thread_m_blocks > exec_cfg.max_m_blocks) {
  1919. // Note that parallel > 1 currently only works for inputs without any
  1920. // padding
  1921. par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks);
  1922. if (par > max_par) par = max_par;
  1923. prob_m = (16 * exec_cfg.max_m_blocks) * par;
  1924. i += exec_cfg.max_m_blocks * (par - 1);
  1925. thread_m_blocks = exec_cfg.max_m_blocks;
  1926. }
  1927. if (false) {
  1928. }
  1929. GPTQ_CALL_IF(aphrodite::kU4B8, 16, 4, 256)
  1930. GPTQ_CALL_IF(aphrodite::kU4B8, 8, 8, 256)
  1931. GPTQ_CALL_IF(aphrodite::kU4B8, 8, 4, 128)
  1932. GPTQ_CALL_IF(aphrodite::kU4B8, 4, 8, 128)
  1933. GPTQ_CALL_IF(aphrodite::kU8B128, 16, 4, 256)
  1934. GPTQ_CALL_IF(aphrodite::kU8B128, 8, 8, 256)
  1935. GPTQ_CALL_IF(aphrodite::kU8B128, 8, 4, 128)
  1936. GPTQ_CALL_IF(aphrodite::kU8B128, 4, 8, 128)
  1937. AWQ_CALL_IF(aphrodite::kU4, 16, 4, 256)
  1938. AWQ_CALL_IF(aphrodite::kU4, 8, 8, 256)
  1939. AWQ_CALL_IF(aphrodite::kU4, 8, 4, 128)
  1940. AWQ_CALL_IF(aphrodite::kU4, 4, 8, 128)
  1941. AWQ_CALL_IF(aphrodite::kU8, 16, 4, 256)
  1942. AWQ_CALL_IF(aphrodite::kU8, 8, 8, 256)
  1943. AWQ_CALL_IF(aphrodite::kU8, 8, 4, 128)
  1944. AWQ_CALL_IF(aphrodite::kU8, 4, 8, 128)
  1945. HQQ_CALL_IF(aphrodite::kU4, 16, 4, 256)
  1946. HQQ_CALL_IF(aphrodite::kU4, 8, 8, 256)
  1947. HQQ_CALL_IF(aphrodite::kU4, 8, 4, 128)
  1948. HQQ_CALL_IF(aphrodite::kU4, 4, 8, 128)
  1949. HQQ_CALL_IF(aphrodite::kU8, 16, 4, 256)
  1950. HQQ_CALL_IF(aphrodite::kU8, 8, 8, 256)
  1951. HQQ_CALL_IF(aphrodite::kU8, 8, 4, 128)
  1952. HQQ_CALL_IF(aphrodite::kU8, 4, 8, 128)
  1953. A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;
  1954. C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par;
  1955. }
  1956. }
  1957. } // namespace marlin
  1958. torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
  1959. torch::Tensor& b_scales, torch::Tensor& b_zeros,
  1960. torch::Tensor& g_idx, torch::Tensor& perm,
  1961. torch::Tensor& workspace,
  1962. aphrodite::ScalarTypeTorchPtr const& b_q_type,
  1963. int64_t size_m, int64_t size_n, int64_t size_k,
  1964. bool is_k_full, bool has_zp,
  1965. bool use_fp32_reduce, bool is_zp_float) {
  1966. if (has_zp) {
  1967. TORCH_CHECK(*b_q_type == aphrodite::kU4 || *b_q_type == aphrodite::kU8,
  1968. "b_q_type must be u4 or u8 when has_zp = True. Got = ",
  1969. b_q_type->str());
  1970. } else {
  1971. TORCH_CHECK(
  1972. *b_q_type == aphrodite::kU4B8 || *b_q_type == aphrodite::kU8B128,
  1973. "b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
  1974. b_q_type->str());
  1975. }
  1976. if (has_zp && is_zp_float) {
  1977. TORCH_CHECK(a.scalar_type() == at::ScalarType::Half,
  1978. "Computation type must be float16 (half) when using float zero "
  1979. "points.");
  1980. }
  1981. int pack_factor = 32 / b_q_type->size_bits();
  1982. // Verify A
  1983. TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
  1984. ", size_m = ", size_m);
  1985. TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1),
  1986. ", size_k = ", size_k);
  1987. // Verify B
  1988. TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k,
  1989. " is not divisible by tile_size = ", marlin::tile_size);
  1990. TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0),
  1991. "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
  1992. ", size_k = ", size_k, ", tile_size = ", marlin::tile_size);
  1993. TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0,
  1994. "b_q_weight.size(1) = ", b_q_weight.size(1),
  1995. " is not divisible by tile_size = ", marlin::tile_size);
  1996. int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor;
  1997. TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
  1998. ", actual_size_n = ", actual_size_n);
  1999. // Verify device and strides
  2000. TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
  2001. TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
  2002. TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
  2003. TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
  2004. TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
  2005. TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
  2006. TORCH_CHECK(b_zeros.device().is_cuda(), "b_zeros is not on GPU");
  2007. TORCH_CHECK(b_zeros.is_contiguous(), "b_zeros is not contiguous");
  2008. TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU");
  2009. TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous");
  2010. TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU");
  2011. TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous");
  2012. // Alloc buffers
  2013. const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
  2014. auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
  2015. torch::Tensor c = torch::empty({size_m, size_n}, options);
  2016. torch::Tensor a_tmp = torch::empty({size_m, size_k}, options);
  2017. // Alloc C tmp buffer that is going to be used for the global reduce
  2018. int reduce_max_m = marlin::determine_reduce_max_m(size_m, marlin::max_par);
  2019. int reduce_n = size_n;
  2020. auto options_fp32 =
  2021. torch::TensorOptions().dtype(at::kFloat).device(a.device());
  2022. if (!use_fp32_reduce) {
  2023. reduce_max_m = 0;
  2024. reduce_n = 0;
  2025. }
  2026. torch::Tensor c_tmp = torch::empty({reduce_max_m, reduce_n}, options_fp32);
  2027. // thread_k: `k` size of a thread_tile in `weights` (can usually be left as
  2028. // auto -1)
  2029. int thread_k = -1;
  2030. // thread_n: `n` size of a thread_tile in `weights` (can usually be left as
  2031. // auto -1)
  2032. int thread_n = -1;
  2033. // sms: number of SMs to use for the kernel (can usually be left as auto -1)
  2034. int sms = -1;
  2035. // Verify g_idx and perm
  2036. TORCH_CHECK((g_idx.size(0) == 0 && perm.size(0) == 0) ||
  2037. (g_idx.size(0) == size_k && perm.size(0) == size_k),
  2038. "Unexpected g_idx.size(0) = ", g_idx.size(0),
  2039. " and perm.size(0) = ", perm.size(0),
  2040. ", where size_k = ", size_k);
  2041. // Detect groupsize and act_order
  2042. int num_groups = -1;
  2043. int group_size = -1;
  2044. bool has_act_order = g_idx.size(0) != 0;
  2045. int rank = b_scales.sizes().size();
  2046. TORCH_CHECK(rank == 2, "b_scales rank = ", rank, " is not 2");
  2047. TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1),
  2048. " is not size_n = ", size_n);
  2049. num_groups = b_scales.size(0);
  2050. if (has_act_order) {
  2051. if (is_k_full) {
  2052. TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1");
  2053. TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k,
  2054. ", is not divisible by num_groups = ", num_groups);
  2055. group_size = size_k / num_groups;
  2056. } else {
  2057. group_size = 0;
  2058. }
  2059. } else {
  2060. if (num_groups > 1) {
  2061. TORCH_CHECK(
  2062. size_k % num_groups == 0, "size_k = ", size_k,
  2063. ", is not divisible by b_scales.size(0) = ", b_scales.size(0));
  2064. group_size = size_k / num_groups;
  2065. } else {
  2066. group_size = -1;
  2067. }
  2068. }
  2069. // Verify b_zeros
  2070. if (has_zp) {
  2071. int rank = b_zeros.sizes().size();
  2072. TORCH_CHECK(rank == 2, "b_zeros rank = ", rank, " is not 2");
  2073. if (is_zp_float) {
  2074. TORCH_CHECK(b_zeros.size(1) == size_n,
  2075. "b_zeros dim 1 = ", b_zeros.size(1),
  2076. " is not size_n = ", size_n);
  2077. TORCH_CHECK(num_groups == b_zeros.size(0),
  2078. "b_zeros dim 0 = ", b_zeros.size(0),
  2079. " is not num_groups = ", num_groups);
  2080. TORCH_CHECK(num_groups != -1, "num_groups must be != -1");
  2081. } else {
  2082. TORCH_CHECK(b_zeros.size(0) == num_groups,
  2083. "b_zeros dim 0 = ", b_zeros.size(0),
  2084. " is not num_groups = ", num_groups);
  2085. TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor,
  2086. "b_zeros dim 1 = ", b_zeros.size(1),
  2087. " is not size_n / pack_factor = ", size_n / pack_factor);
  2088. }
  2089. }
  2090. // Verify workspace size
  2091. TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n,
  2092. ", is not divisible by min_thread_n = ", marlin::min_thread_n);
  2093. int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par;
  2094. TORCH_CHECK(workspace.numel() >= min_workspace_size,
  2095. "workspace.numel = ", workspace.numel(),
  2096. " is below min_workspace_size = ", min_workspace_size);
  2097. int dev = a.get_device();
  2098. if (a.scalar_type() == at::ScalarType::Half) {
  2099. marlin::marlin_mm<half>(
  2100. a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
  2101. c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(),
  2102. b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
  2103. a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
  2104. workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp,
  2105. num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
  2106. thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_zp_float);
  2107. } else if (a.scalar_type() == at::ScalarType::BFloat16) {
  2108. marlin::marlin_mm<nv_bfloat16>(
  2109. a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
  2110. c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
  2111. b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
  2112. perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
  2113. workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp,
  2114. num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
  2115. thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_zp_float);
  2116. } else {
  2117. TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
  2118. }
  2119. return c;
  2120. }
  2121. #endif