q_gemm.cu 82 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668
  1. /*
  2. Adapted from https://github.com/turboderp/exllamav2 and https://github.com/qwopqwop200/GPTQ-for-LLaMa
  3. */
  4. #include <cstdint>
  5. #include <cstdio>
  6. #include <torch/extension.h>
  7. #include <c10/cuda/CUDAGuard.h>
  8. #include <ATen/cuda/CUDAContext.h>
  9. #include <cuda_runtime.h>
  10. #include <cuda_fp16.h>
  11. #include "compat.cuh"
  12. #include "matrix_view.cuh"
  13. #include "qdq_2.cuh"
  14. #include "qdq_3.cuh"
  15. #include "qdq_4.cuh"
  16. #include "qdq_8.cuh"
  17. namespace aphrodite {
  18. namespace gptq {
  19. #define BLOCK_KN_SIZE 128
  20. #define BLOCK_M_SIZE_MAX 8
  21. #define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32)
  22. #define MAX_Q_GEMM_ROWS 50
  23. #define MAX_Q_GEMM_ROWS_8BIT 24
  24. #define MAX_ALT_GEMM_ROWS 8
  25. #define THREADS_X 32
  26. #define THREADS_Y 32
  27. #define DIVIDE(x, size) (((x) + (size) - 1) / (size))
  28. #if defined(USE_ROCM)
  29. #include <hipblas/hipblas.h>
  30. __host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
  31. hipblasOperation_t transA,
  32. hipblasOperation_t transB,
  33. int m,
  34. int n,
  35. int k,
  36. const half* alpha,
  37. const half* AP,
  38. int lda,
  39. const half* BP,
  40. int ldb,
  41. const half* beta,
  42. half* CP,
  43. int ldc) {
  44. return hipblasHgemm(handle, transA, transB, m, n, k,
  45. reinterpret_cast<const hipblasHalf *>(alpha),
  46. reinterpret_cast<const hipblasHalf *>(AP), lda,
  47. reinterpret_cast<const hipblasHalf *>(BP), ldb,
  48. reinterpret_cast<const hipblasHalf *>(beta),
  49. reinterpret_cast<hipblasHalf *>(CP), ldc);
  50. }
  51. #define hipblasHgemm __compat_hipblasHgemm
  52. // Previous version of PyTorch were converting to rocBLAS instead of hipBLAS.
  53. #define rocblas_operation_none HIPBLAS_OP_N
  54. #define rocblas_hgemm __compat_hipblasHgemm
  55. #endif
  56. __forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result)
  57. {
  58. half2 result = {};
  59. const half2* a2_ptr = (const half2*)a_ptr;
  60. #pragma unroll
  61. for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
  62. return __hadd2(result, g_result);
  63. }
  64. __forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr)
  65. {
  66. half2 result = {};
  67. const half2* a2_ptr = (const half2*)a_ptr;
  68. #pragma unroll
  69. for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
  70. return __half2float(__low2half(result)) + __half2float(__high2half(result));
  71. }
  72. __forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h)
  73. {
  74. half2 result = {};
  75. const half2* a2_ptr = (const half2*)a_ptr;
  76. #pragma unroll
  77. for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
  78. return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
  79. }
  80. __forceinline__ __device__ half2 dot22_16(half2(&dq)[8], const half* a_ptr, const half2 g_result, const half qs_h)
  81. {
  82. half2 result = {};
  83. const half2* a2_ptr = (const half2*)a_ptr;
  84. #pragma unroll
  85. for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
  86. return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
  87. }
  88. __forceinline__ __device__ half2 dot22_32(half2(&dq)[16], const half* a_ptr, const half2 g_result, const half qs_h)
  89. {
  90. half2 result = {};
  91. const half2* a2_ptr = (const half2*)a_ptr;
  92. #pragma unroll
  93. for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
  94. return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
  95. }
  96. __forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr, const float g_result, const float qs_f)
  97. {
  98. half2 result = {};
  99. const half2* a2_ptr = (const half2*)a_ptr;
  100. #pragma unroll
  101. for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
  102. float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
  103. return fma(result_f, qs_f, g_result);
  104. }
  105. __forceinline__ __device__ float dot22_16_f(half2(&dq)[8], const half* a_ptr, const float g_result, const float qs_f)
  106. {
  107. half2 result = {};
  108. const half2* a2_ptr = (const half2*)a_ptr;
  109. #pragma unroll
  110. for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
  111. float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
  112. return fma(result_f, qs_f, g_result);
  113. }
  114. __forceinline__ __device__ float dot22_32_f(half2(&dq)[16], const half* a_ptr, const float g_result, const float qs_f)
  115. {
  116. half2 result = {};
  117. const half2* a2_ptr = (const half2*)a_ptr;
  118. #pragma unroll
  119. for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
  120. float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
  121. return fma(result_f, qs_f, g_result);
  122. }
  123. __forceinline__ __device__ half dot22_8_h(half2(&dq)[4], const half* a_ptr, const half g_result, const half qs_h)
  124. {
  125. // Use FP32 accumulator to avoid potential overflow since unscaled weights are in the range -128..127
  126. float result = {};
  127. #pragma unroll
  128. for (int i = 0; i < 4; i++)
  129. {
  130. half2 w01 = dq[i];
  131. float w0 = __low2float(w01);
  132. float w1 = __high2float(w01);
  133. float x0 = __half2float(*a_ptr++);
  134. float x1 = __half2float(*a_ptr++);
  135. result = fma(w0, x0, result);
  136. result = fma(w1, x1, result);
  137. }
  138. float qs = __half2float(qs_h);
  139. result *= qs;
  140. half result_h = __float2half_rn(result);
  141. return __hadd(result_h, g_result);
  142. }
  143. __forceinline__ __device__ half dot22_16_h(half2(&dq)[8], const half* a_ptr, const half g_result, const half qs_h)
  144. {
  145. half2 result = {};
  146. const half2* a2_ptr = (const half2*)a_ptr;
  147. #pragma unroll
  148. for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
  149. half result_h = __hadd(__low2half(result), __high2half(result));
  150. return __hfma(result_h, qs_h, g_result);
  151. }
  152. __forceinline__ __device__ half dot22_32_h(half2(&dq)[16], const half* a_ptr, const half g_result, const half qs_h)
  153. {
  154. half2 result = {};
  155. const half2* a2_ptr = (const half2*)a_ptr;
  156. #pragma unroll
  157. for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
  158. half result_h = __hadd(__low2half(result), __high2half(result));
  159. return __hfma(result_h, qs_h, g_result);
  160. }
  161. typedef void (*fp_gemm_half_q_half_gptq_kernel)
  162. (
  163. const half*,
  164. const uint32_t*,
  165. const uint32_t*,
  166. const half*,
  167. half*,
  168. const int,
  169. const int,
  170. const int,
  171. const int,
  172. const int*
  173. );
  174. template <bool first_block, int m_count>
  175. __global__ void gemm_half_q_half_gptq_4bit_kernel
  176. (
  177. const half* __restrict__ a,
  178. const uint32_t* __restrict__ b_q_weight,
  179. const uint32_t* __restrict__ b_gptq_qzeros,
  180. const half* __restrict__ b_gptq_scales,
  181. half* __restrict__ c,
  182. const int size_m,
  183. const int size_n,
  184. const int size_k,
  185. const int groups,
  186. const int* __restrict__ b_q_perm
  187. )
  188. {
  189. MatrixView_half a_(a, size_m, size_k);
  190. MatrixView_half_rw c_(c, size_m, size_n);
  191. MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
  192. MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
  193. int t = threadIdx.x;
  194. // Block
  195. int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
  196. int offset_m = blockIdx.y * m_count;
  197. int offset_k = blockIdx.z * BLOCK_KN_SIZE;
  198. int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
  199. int end_m = min(offset_m + m_count, size_m);
  200. int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
  201. int n = offset_n + t * 4;
  202. // Preload block_a
  203. __shared__ half block_a[m_count][BLOCK_KN_SIZE];
  204. if (offset_k + t < end_k)
  205. {
  206. for (int m = 0; m < m_count; ++m)
  207. {
  208. const half* a_ptr = a_.item_ptr(offset_m + m, 0);
  209. half* block_a_ptr = block_a[m];
  210. half a0;
  211. if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]];
  212. else a0 = a_ptr[offset_k + t];
  213. block_a_ptr[t] = a0;
  214. }
  215. }
  216. // Zero output
  217. if (n >= size_n) return;
  218. if (blockIdx.z == 0)
  219. {
  220. for (int m = 0; m < m_count; m++)
  221. *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
  222. }
  223. __syncthreads();
  224. // Find initial group
  225. int groupsize = size_k / groups;
  226. int group = offset_k / groupsize;
  227. int nextgroup = offset_k + groupsize;
  228. // a, b offset
  229. int qk = offset_k / (32 / 4);
  230. const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
  231. const half* a_ptr = &block_a[0][0];
  232. int a_stride = BLOCK_KN_SIZE;
  233. // Initial group
  234. int zeros[4];
  235. float scales[4];
  236. half2 z1z16[4][2];
  237. half2 y1y16[4][2];
  238. b_gptq_qzeros_.item4(zeros, group, n);
  239. b_gptq_scales_.item4_f(scales, group, n);
  240. dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
  241. dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
  242. dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
  243. dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
  244. // Column result
  245. float block_c[m_count][4] = {};
  246. // Dequantize and multiply
  247. int k = offset_k;
  248. while (k < end_k)
  249. {
  250. if (k == nextgroup)
  251. {
  252. group++;
  253. nextgroup += groupsize;
  254. b_gptq_qzeros_.item4(zeros, group, n);
  255. b_gptq_scales_.item4_f(scales, group, n);
  256. dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
  257. dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
  258. dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
  259. dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
  260. }
  261. #pragma unroll
  262. for (int j = 0; j < 4; j++)
  263. {
  264. const int4* b_ptr4 = (int4*) b_ptr;
  265. int4 load_int4 = *b_ptr4;
  266. half2 dq[4][4];
  267. dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
  268. dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
  269. dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
  270. dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
  271. #pragma unroll
  272. for (int m = 0; m < m_count; m++)
  273. {
  274. block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]);
  275. block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]);
  276. block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]);
  277. block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]);
  278. }
  279. b_ptr += size_n;
  280. a_ptr += 8;
  281. }
  282. k += 32;
  283. }
  284. for (int m = 0; m < m_count; m++)
  285. {
  286. half2 *out = (half2*) c_.item_ptr(offset_m + m, n);
  287. half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
  288. half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
  289. atomicAdd(out , result01);
  290. atomicAdd(out + 1, result23);
  291. }
  292. }
  293. template <bool first_block, int m_count>
  294. __global__ void gemm_half_q_half_gptq_2bit_kernel
  295. (
  296. const half* __restrict__ a,
  297. const uint32_t* __restrict__ b_q_weight,
  298. const uint32_t* __restrict__ b_gptq_qzeros,
  299. const half* __restrict__ b_gptq_scales,
  300. half* __restrict__ c,
  301. const int size_m,
  302. const int size_n,
  303. const int size_k,
  304. const int groups,
  305. const int* __restrict__ b_q_perm
  306. )
  307. {
  308. MatrixView_half a_(a, size_m, size_k);
  309. MatrixView_half_rw c_(c, size_m, size_n);
  310. MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
  311. MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
  312. int t = threadIdx.x;
  313. // Block
  314. int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
  315. int offset_m = blockIdx.y * m_count;
  316. int offset_k = blockIdx.z * BLOCK_KN_SIZE;
  317. int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
  318. int end_m = min(offset_m + m_count, size_m);
  319. int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
  320. int n = offset_n + t * 4;
  321. // Preload block_a
  322. __shared__ half block_a[m_count][BLOCK_KN_SIZE];
  323. if (offset_k + t < end_k)
  324. {
  325. for (int m = 0; m < m_count; ++m)
  326. {
  327. const half* a_ptr = a_.item_ptr(offset_m + m, 0);
  328. half* block_a_ptr = block_a[m];
  329. half a0;
  330. if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]];
  331. else a0 = a_ptr[offset_k + t];
  332. block_a_ptr[t] = a0;
  333. }
  334. }
  335. // Zero output
  336. if (n >= size_n) return;
  337. if (blockIdx.z == 0)
  338. {
  339. for (int m = 0; m < m_count; m++)
  340. *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
  341. }
  342. __syncthreads();
  343. // Find initial group
  344. int groupsize = size_k / groups;
  345. int group = offset_k / groupsize;
  346. int nextgroup = offset_k + groupsize;
  347. // a, b offset
  348. int qk = offset_k / (32 / 2);
  349. const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
  350. const half* a_ptr = &block_a[0][0];
  351. int a_stride = BLOCK_KN_SIZE;
  352. // Initial group
  353. int zeros[4];
  354. half scales[4];
  355. b_gptq_qzeros_.item4(zeros, group, n);
  356. b_gptq_scales_.item4(scales, group, n);
  357. // Column result
  358. half block_c[m_count][4] = {};
  359. // Dequantize and multiply
  360. int k = offset_k;
  361. while (k < end_k)
  362. {
  363. if (k == nextgroup)
  364. {
  365. group++;
  366. nextgroup += groupsize;
  367. b_gptq_qzeros_.item4(zeros, group, n);
  368. b_gptq_scales_.item4(scales, group, n);
  369. }
  370. #pragma unroll
  371. for (int j = 0; j < 1; j++)
  372. {
  373. const int4* b_ptr4 = (int4*) b_ptr;
  374. int4 load_int4 = *b_ptr4;
  375. half2 dq[4][8];
  376. dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1);
  377. dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1);
  378. dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1);
  379. dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1);
  380. #pragma unroll
  381. for (int m = 0; m < m_count; m++)
  382. {
  383. block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]);
  384. block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]);
  385. block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]);
  386. block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]);
  387. }
  388. b_ptr += size_n;
  389. a_ptr += 16;
  390. }
  391. k += 16;
  392. }
  393. for (int m = 0; m < m_count; m++)
  394. {
  395. half2 *out = (half2*) c_.item_ptr(offset_m + m, n);
  396. half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]);
  397. half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]);
  398. atomicAdd(out , result01);
  399. atomicAdd(out + 1, result23);
  400. }
  401. }
  402. template <bool first_block, int m_count>
  403. __global__ void gemm_half_q_half_gptq_3bit_kernel
  404. (
  405. const half* __restrict__ a,
  406. const uint32_t* __restrict__ b_q_weight,
  407. const uint32_t* __restrict__ b_gptq_qzeros,
  408. const half* __restrict__ b_gptq_scales,
  409. half* __restrict__ c,
  410. const int size_m,
  411. const int size_n,
  412. const int size_k,
  413. const int groups,
  414. const int* __restrict__ b_q_perm
  415. )
  416. {
  417. MatrixView_half a_(a, size_m, size_k);
  418. MatrixView_half_rw c_(c, size_m, size_n);
  419. MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
  420. MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
  421. int t = threadIdx.x;
  422. // Block
  423. int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
  424. int offset_m = blockIdx.y * m_count;
  425. int offset_k = blockIdx.z * BLOCK_KN_SIZE;
  426. int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
  427. int end_m = min(offset_m + m_count, size_m);
  428. int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
  429. int n = offset_n + t * 4;
  430. // Preload block_a
  431. __shared__ half block_a[m_count][BLOCK_KN_SIZE];
  432. if (offset_k + t < end_k)
  433. {
  434. for (int m = 0; m < m_count; ++m)
  435. {
  436. const half* a_ptr = a_.item_ptr(offset_m + m, 0);
  437. half* block_a_ptr = block_a[m];
  438. half a0;
  439. if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]];
  440. else a0 = a_ptr[offset_k + t];
  441. block_a_ptr[t] = a0;
  442. }
  443. }
  444. // Zero output
  445. if (n >= size_n) return;
  446. if (blockIdx.z == 0)
  447. {
  448. for (int m = 0; m < m_count; m++)
  449. *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
  450. }
  451. __syncthreads();
  452. // Find initial group
  453. int groupsize = size_k / groups;
  454. int group = offset_k / groupsize;
  455. int nextgroup = offset_k + groupsize;
  456. // a, b offset
  457. int qk = offset_k / 32 * 3;
  458. const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
  459. const half* a_ptr = &block_a[0][0];
  460. int a_stride = BLOCK_KN_SIZE;
  461. // Initial group
  462. int zeros[4];
  463. half scales[4];
  464. b_gptq_qzeros_.item4(zeros, group, n);
  465. b_gptq_scales_.item4(scales, group, n);
  466. // Column result
  467. half block_c[m_count][4] = {};
  468. // Dequantize and multiply
  469. int k = offset_k;
  470. while (k < end_k)
  471. {
  472. if (k == nextgroup)
  473. {
  474. group++;
  475. nextgroup += groupsize;
  476. b_gptq_qzeros_.item4(zeros, group, n);
  477. b_gptq_scales_.item4(scales, group, n);
  478. }
  479. #pragma unroll
  480. for (int j = 0; j < 1; j++)
  481. {
  482. int4 load_int4[3];
  483. load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
  484. load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
  485. load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
  486. half2 dq[4][16];
  487. dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n, zeros[0] + 1);
  488. dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n, zeros[1] + 1);
  489. dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n, zeros[2] + 1);
  490. dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n, zeros[3] + 1);
  491. #pragma unroll
  492. for (int m = 0; m < m_count; m++)
  493. {
  494. block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]);
  495. block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]);
  496. block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]);
  497. block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]);
  498. }
  499. a_ptr += 32;
  500. }
  501. k += 32;
  502. }
  503. for (int m = 0; m < m_count; m++)
  504. {
  505. half2 *out = (half2*) c_.item_ptr(offset_m + m, n);
  506. half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]);
  507. half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]);
  508. atomicAdd(out , result01);
  509. atomicAdd(out + 1, result23);
  510. }
  511. }
  512. template <bool first_block, int m_count>
  513. __global__ void gemm_half_q_half_gptq_8bit_kernel
  514. (
  515. const half* __restrict__ a,
  516. const uint32_t* __restrict__ b_q_weight,
  517. const uint32_t* __restrict__ b_gptq_qzeros,
  518. const half* __restrict__ b_gptq_scales,
  519. half* __restrict__ c,
  520. const int size_m,
  521. const int size_n,
  522. const int size_k,
  523. const int groups,
  524. const int* __restrict__ b_q_perm
  525. )
  526. {
  527. MatrixView_half a_(a, size_m, size_k);
  528. MatrixView_half_rw c_(c, size_m, size_n);
  529. MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
  530. MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
  531. int t = threadIdx.x;
  532. // Block
  533. int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
  534. int offset_m = blockIdx.y * m_count;
  535. int offset_k = blockIdx.z * BLOCK_KN_SIZE;
  536. int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
  537. int end_m = min(offset_m + m_count, size_m);
  538. int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
  539. int n = offset_n + t * 4;
  540. // Preload block_a
  541. __shared__ half block_a[m_count][BLOCK_KN_SIZE];
  542. if (offset_k + t < end_k)
  543. {
  544. for (int m = 0; m < m_count; ++m)
  545. {
  546. const half* a_ptr = a_.item_ptr(offset_m + m, 0);
  547. half* block_a_ptr = block_a[m];
  548. half a0;
  549. if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]];
  550. else a0 = a_ptr[offset_k + t];
  551. block_a_ptr[t] = a0;
  552. }
  553. }
  554. // Zero output
  555. if (n >= size_n) return;
  556. if (blockIdx.z == 0)
  557. {
  558. for (int m = 0; m < m_count; m++)
  559. *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
  560. }
  561. __syncthreads();
  562. // Find initial group
  563. int groupsize = size_k / groups;
  564. int group = offset_k / groupsize;
  565. int nextgroup = offset_k + groupsize;
  566. // a, b offset
  567. int qk = offset_k / (32 / 8);
  568. const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
  569. const half* a_ptr = &block_a[0][0];
  570. int a_stride = BLOCK_KN_SIZE;
  571. // Initial group
  572. int zeros[4];
  573. half scales[4];
  574. b_gptq_qzeros_.item4(zeros, group, n);
  575. b_gptq_scales_.item4(scales, group, n);
  576. // Column result
  577. half block_c[m_count][4] = {};
  578. // Dequantize and multiply
  579. int k = offset_k;
  580. while (k < end_k)
  581. {
  582. if (k == nextgroup)
  583. {
  584. group++;
  585. nextgroup += groupsize;
  586. b_gptq_qzeros_.item4(zeros, group, n);
  587. b_gptq_scales_.item4(scales, group, n);
  588. }
  589. #pragma unroll
  590. for (int j = 0; j < 4; j++)
  591. {
  592. int4 load_int4[2];
  593. load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
  594. load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
  595. half2 dq[4][4];
  596. dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, zeros[0] + 1);
  597. dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, zeros[1] + 1);
  598. dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, zeros[2] + 1);
  599. dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, zeros[3] + 1);
  600. for (int m = 0; m < m_count; m++)
  601. {
  602. block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]);
  603. block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]);
  604. block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]);
  605. block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]);
  606. }
  607. a_ptr += 8;
  608. }
  609. k += 32;
  610. }
  611. for (int m = 0; m < m_count; m++)
  612. {
  613. half2 *out = (half2*) c_.item_ptr(offset_m + m, n);
  614. half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]);
  615. half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]);
  616. atomicAdd(out , result01);
  617. atomicAdd(out + 1, result23);
  618. }
  619. }
  620. fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(
  621. bool first_block, const int m_count, const int bit)
  622. {
  623. #define SELECT_KERNEL(M_COUNT) \
  624. if (m_count == M_COUNT) { \
  625. if (bit == 2) return gemm_half_q_half_gptq_2bit_kernel<true, M_COUNT>; \
  626. if (bit == 3) return gemm_half_q_half_gptq_3bit_kernel<true, M_COUNT>; \
  627. if (bit == 4) return gemm_half_q_half_gptq_4bit_kernel<true, M_COUNT>; \
  628. if (bit == 8) return gemm_half_q_half_gptq_8bit_kernel<true, M_COUNT>; \
  629. }
  630. #if BLOCK_M_SIZE_MAX >= 1
  631. SELECT_KERNEL(1);
  632. #endif
  633. #if BLOCK_M_SIZE_MAX >= 2
  634. SELECT_KERNEL(2);
  635. #endif
  636. #if BLOCK_M_SIZE_MAX >= 3
  637. SELECT_KERNEL(3);
  638. #endif
  639. #if BLOCK_M_SIZE_MAX >= 4
  640. SELECT_KERNEL(4);
  641. #endif
  642. #if BLOCK_M_SIZE_MAX >= 5
  643. SELECT_KERNEL(5);
  644. #endif
  645. #if BLOCK_M_SIZE_MAX >= 6
  646. SELECT_KERNEL(6);
  647. #endif
  648. #if BLOCK_M_SIZE_MAX >= 7
  649. SELECT_KERNEL(7);
  650. #endif
  651. #if BLOCK_M_SIZE_MAX >= 8
  652. SELECT_KERNEL(8);
  653. #endif
  654. return NULL;
  655. }
  656. void gemm_half_q_half_cuda_part
  657. (
  658. const half* a,
  659. const uint32_t* b_q_weight,
  660. const uint32_t* b_gptq_qzeros,
  661. const half* b_gptq_scales,
  662. const int* b_q_perm,
  663. half* c,
  664. int size_m,
  665. int size_n,
  666. int size_k,
  667. int m_count,
  668. int groups,
  669. int bit
  670. )
  671. {
  672. dim3 blockDim, gridDim;
  673. blockDim.x = BLOCK_KN_SIZE;
  674. blockDim.y = 1;
  675. blockDim.z = 1;
  676. gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
  677. gridDim.y = DIVIDE(size_m, m_count);
  678. gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
  679. fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count, bit);
  680. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  681. kernel<<<gridDim, blockDim, 0, stream>>>
  682. (
  683. a,
  684. b_q_weight,
  685. b_gptq_qzeros,
  686. b_gptq_scales,
  687. c,
  688. size_m,
  689. size_n,
  690. size_k,
  691. groups,
  692. b_q_perm
  693. );
  694. }
  695. __global__ void reconstruct_exllama_8bit_kernel
  696. (
  697. const uint32_t* __restrict__ b_q_weight,
  698. const int* __restrict__ b_q_perm,
  699. const uint32_t* __restrict__ b_gptq_qzeros,
  700. const half* __restrict__ b_gptq_scales,
  701. const int size_k,
  702. const int size_n,
  703. const int groups,
  704. half* __restrict__ b
  705. )
  706. {
  707. MatrixView_half_rw b_(b, size_k, size_n);
  708. MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
  709. MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
  710. int offset_k = BLOCK_KN_SIZE * blockIdx.y;
  711. int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
  712. int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
  713. // Preload remapping table
  714. __shared__ int perm[BLOCK_KN_SIZE];
  715. int t = threadIdx.x;
  716. if (b_q_perm)
  717. {
  718. if (offset_k + t < size_k)
  719. perm[t] = b_q_perm[offset_k + t];
  720. }
  721. // Column
  722. int n = offset_n + t * 4;
  723. if (n >= size_n) return;
  724. // Find initial group
  725. int groupsize = size_k / groups;
  726. int group = offset_k / groupsize;
  727. int nextgroup = offset_k + groupsize;
  728. // b offset
  729. int qk = offset_k / (32 / 8);
  730. const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
  731. // Initial zeros/scale
  732. int zeros[4];
  733. half2 scales[4];
  734. b_gptq_qzeros_.item4(zeros, group, n);
  735. b_gptq_scales_.item4_h2(scales, group, n);
  736. __syncthreads();
  737. int k = offset_k;
  738. int lk = 0;
  739. while (k < end_k)
  740. {
  741. if (k == nextgroup)
  742. {
  743. group++;
  744. nextgroup += groupsize;
  745. b_gptq_qzeros_.item4(zeros, group, n);
  746. b_gptq_scales_.item4_h2(scales, group, n);
  747. }
  748. for (int p = 0; p < 4; p++)
  749. {
  750. int4 load_int4[2];
  751. load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
  752. load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
  753. half2 dq[4][4];
  754. dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, zeros[0] + 1);
  755. dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, zeros[1] + 1);
  756. dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, zeros[2] + 1);
  757. dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, zeros[3] + 1);
  758. //half* dqh = (half*)dq;
  759. if (b_q_perm)
  760. {
  761. for (int j = 0; j < 4; j++)
  762. {
  763. for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
  764. b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
  765. b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
  766. }
  767. }
  768. else
  769. {
  770. for (int j = 0; j < 4; j++)
  771. {
  772. for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
  773. b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
  774. b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
  775. }
  776. }
  777. }
  778. k += 32;
  779. }
  780. }
  781. __global__ void reconstruct_exllama_4bit_kernel
  782. (
  783. const uint32_t* __restrict__ b_q_weight,
  784. const int* __restrict__ b_q_perm,
  785. const uint32_t* __restrict__ b_gptq_qzeros,
  786. const half* __restrict__ b_gptq_scales,
  787. const int size_k,
  788. const int size_n,
  789. const int groups,
  790. half* __restrict__ b
  791. )
  792. {
  793. if (blockIdx.z > 0){
  794. b_q_weight = b_q_weight + blockIdx.z * size_k * size_n / 8;
  795. b_gptq_scales = b_gptq_scales + blockIdx.z * groups * size_n;
  796. b_gptq_qzeros = b_gptq_qzeros + blockIdx.z * groups * size_n / 8;
  797. if (b_q_perm) b_q_perm = b_q_perm + blockIdx.z * size_k;
  798. b = b + blockIdx.z * size_k * size_n;
  799. }
  800. MatrixView_half_rw b_(b, size_k, size_n);
  801. MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
  802. MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
  803. int offset_k = BLOCK_KN_SIZE * blockIdx.y;
  804. int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
  805. int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
  806. // Preload remapping table
  807. __shared__ int perm[BLOCK_KN_SIZE];
  808. int t = threadIdx.x;
  809. if (b_q_perm)
  810. {
  811. if (offset_k + t < size_k)
  812. perm[t] = b_q_perm[offset_k + t];
  813. }
  814. // Column
  815. int n = offset_n + t * 4;
  816. if (n >= size_n) return;
  817. // Find initial group
  818. int groupsize = size_k / groups;
  819. int group = offset_k / groupsize;
  820. int nextgroup = offset_k + groupsize;
  821. // b offset
  822. int qk = offset_k / (32 / 4);
  823. const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
  824. // Initial zeros/scale
  825. int zeros[4];
  826. half2 scales[4];
  827. half2 z1z16[4][2];
  828. half2 y1y16[4][2];
  829. b_gptq_qzeros_.item4(zeros, group, n);
  830. b_gptq_scales_.item4_h2(scales, group, n);
  831. dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
  832. dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
  833. dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
  834. dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
  835. __syncthreads();
  836. int k = offset_k;
  837. int lk = 0;
  838. while (k < end_k)
  839. {
  840. if (k == nextgroup)
  841. {
  842. group++;
  843. nextgroup += groupsize;
  844. b_gptq_qzeros_.item4(zeros, group, n);
  845. b_gptq_scales_.item4_h2(scales, group, n);
  846. dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
  847. dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
  848. dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
  849. dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
  850. }
  851. for (int p = 0; p < 4; p++)
  852. {
  853. half2 dq[4][4];
  854. const int4* b_ptr4 = (int4*) b_ptr;
  855. int4 load_int4 = *b_ptr4;
  856. dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
  857. dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
  858. dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
  859. dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
  860. b_ptr += size_n;
  861. //half* dqh = (half*)dq;
  862. if (b_q_perm)
  863. {
  864. for (int j = 0; j < 4; j++)
  865. {
  866. for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
  867. b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
  868. b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
  869. }
  870. }
  871. else
  872. {
  873. for (int j = 0; j < 4; j++)
  874. {
  875. for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
  876. b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
  877. b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
  878. }
  879. }
  880. }
  881. k += 32;
  882. }
  883. }
  884. __global__ void reconstruct_exllama_3bit_kernel
  885. (
  886. const uint32_t* __restrict__ b_q_weight,
  887. const int* __restrict__ b_q_perm,
  888. const uint32_t* __restrict__ b_gptq_qzeros,
  889. const half* __restrict__ b_gptq_scales,
  890. const int size_k,
  891. const int size_n,
  892. const int groups,
  893. half* __restrict__ b
  894. )
  895. {
  896. MatrixView_half_rw b_(b, size_k, size_n);
  897. MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
  898. MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
  899. int offset_k = BLOCK_KN_SIZE * blockIdx.y;
  900. int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
  901. int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
  902. // Preload remapping table
  903. __shared__ int perm[BLOCK_KN_SIZE];
  904. int t = threadIdx.x;
  905. if (b_q_perm)
  906. {
  907. if (offset_k + t < size_k)
  908. perm[t] = b_q_perm[offset_k + t];
  909. }
  910. // Column
  911. int n = offset_n + t * 4;
  912. if (n >= size_n) return;
  913. // Find initial group
  914. int groupsize = size_k / groups;
  915. int group = offset_k / groupsize;
  916. int nextgroup = offset_k + groupsize;
  917. // b offset
  918. int qk = offset_k / 32* 3;
  919. const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
  920. // Initial zeros/scale
  921. int zeros[4];
  922. half2 scales[4];
  923. b_gptq_qzeros_.item4(zeros, group, n);
  924. b_gptq_scales_.item4_h2(scales, group, n);
  925. __syncthreads();
  926. int k = offset_k;
  927. int lk = 0;
  928. while (k < end_k)
  929. {
  930. if (k == nextgroup)
  931. {
  932. group++;
  933. nextgroup += groupsize;
  934. b_gptq_qzeros_.item4(zeros, group, n);
  935. b_gptq_scales_.item4_h2(scales, group, n);
  936. }
  937. for (int p = 0; p < 1; p++)
  938. {
  939. int4 load_int4[3];
  940. load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
  941. load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
  942. load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
  943. half2 dq[4][16];
  944. dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n, zeros[0] + 1);
  945. dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n, zeros[1] + 1);
  946. dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n, zeros[2] + 1);
  947. dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n, zeros[3] + 1);
  948. if (b_q_perm)
  949. {
  950. for (int j = 0; j < 16; j++)
  951. {
  952. for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
  953. b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
  954. b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
  955. }
  956. }
  957. else
  958. {
  959. for (int j = 0; j < 16; j++)
  960. {
  961. for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
  962. b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
  963. b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
  964. }
  965. }
  966. }
  967. k += 32;
  968. }
  969. }
  970. __global__ void reconstruct_exllama_2bit_kernel
  971. (
  972. const uint32_t* __restrict__ b_q_weight,
  973. const int* __restrict__ b_q_perm,
  974. const uint32_t* __restrict__ b_gptq_qzeros,
  975. const half* __restrict__ b_gptq_scales,
  976. const int size_k,
  977. const int size_n,
  978. const int groups,
  979. half* __restrict__ b
  980. )
  981. {
  982. MatrixView_half_rw b_(b, size_k, size_n);
  983. MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
  984. MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
  985. int offset_k = BLOCK_KN_SIZE * blockIdx.y;
  986. int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
  987. int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
  988. // Preload remapping table
  989. __shared__ int perm[BLOCK_KN_SIZE];
  990. int t = threadIdx.x;
  991. if (b_q_perm)
  992. {
  993. if (offset_k + t < size_k)
  994. perm[t] = b_q_perm[offset_k + t];
  995. }
  996. // Column
  997. int n = offset_n + t * 4;
  998. if (n >= size_n) return;
  999. // Find initial group
  1000. int groupsize = size_k / groups;
  1001. int group = offset_k / groupsize;
  1002. int nextgroup = offset_k + groupsize;
  1003. // b offset
  1004. int qk = offset_k / (32 / 2);
  1005. const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
  1006. // Initial zeros/scale
  1007. int zeros[4];
  1008. half2 scales[4];
  1009. b_gptq_qzeros_.item4(zeros, group, n);
  1010. b_gptq_scales_.item4_h2(scales, group, n);
  1011. __syncthreads();
  1012. int k = offset_k;
  1013. int lk = 0;
  1014. while (k < end_k)
  1015. {
  1016. if (k == nextgroup)
  1017. {
  1018. group++;
  1019. nextgroup += groupsize;
  1020. b_gptq_qzeros_.item4(zeros, group, n);
  1021. b_gptq_scales_.item4_h2(scales, group, n);
  1022. }
  1023. for (int p = 0; p < 2; p++)
  1024. {
  1025. const int4* b_ptr4 = (int4*) b_ptr;
  1026. int4 load_int4 = *b_ptr4;
  1027. half2 dq[4][8];
  1028. dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1);
  1029. dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1);
  1030. dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1);
  1031. dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1);
  1032. b_ptr += size_n;
  1033. //half* dqh = (half*)dq;
  1034. if (b_q_perm)
  1035. {
  1036. for (int j = 0; j < 8; j++)
  1037. {
  1038. for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
  1039. b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
  1040. b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
  1041. }
  1042. }
  1043. else
  1044. {
  1045. for (int j = 0; j < 8; j++)
  1046. {
  1047. for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
  1048. b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
  1049. b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
  1050. }
  1051. }
  1052. }
  1053. k += 32;
  1054. }
  1055. }
  1056. void reconstruct_exllama
  1057. (
  1058. const uint32_t* b_q_weight,
  1059. const uint32_t* b_gptq_qzeros,
  1060. const half* b_gptq_scales,
  1061. const int* b_q_perm,
  1062. half* out,
  1063. int height,
  1064. int width,
  1065. int groups,
  1066. int num_experts,
  1067. int bit
  1068. )
  1069. {
  1070. dim3 blockDim, gridDim;
  1071. blockDim.x = BLOCK_KN_SIZE;
  1072. blockDim.y = 1;
  1073. gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
  1074. gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
  1075. gridDim.z = num_experts;
  1076. auto reconstruct_exllama_kernel = reconstruct_exllama_4bit_kernel;
  1077. if (bit == 2) {
  1078. reconstruct_exllama_kernel = reconstruct_exllama_2bit_kernel;
  1079. } else if (bit == 3) {
  1080. reconstruct_exllama_kernel = reconstruct_exllama_3bit_kernel;
  1081. } else if (bit == 8) {
  1082. reconstruct_exllama_kernel = reconstruct_exllama_8bit_kernel;
  1083. }
  1084. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  1085. reconstruct_exllama_kernel<<<gridDim, blockDim, 0, stream>>>
  1086. (
  1087. b_q_weight,
  1088. b_q_perm,
  1089. b_gptq_qzeros,
  1090. b_gptq_scales,
  1091. height,
  1092. width,
  1093. groups,
  1094. out
  1095. );
  1096. }
  1097. __global__ void gemm_half_q_half_alt_4bit_kernel(
  1098. const half2* __restrict__ vec,
  1099. const uint32_t* __restrict__ mat,
  1100. half* __restrict__ mul,
  1101. const half* __restrict__ scales,
  1102. const uint32_t* __restrict__ zeros,
  1103. const int* __restrict__ g_idx,
  1104. int batch,
  1105. int height,
  1106. int width
  1107. )
  1108. {
  1109. int zero_width = width / 8;
  1110. int vec_height = height * 4;
  1111. const int blockwidth2 = BLOCK_KN_SIZE / 2;
  1112. int b = blockIdx.y * BLOCK_M_SIZE_MAX;
  1113. int b_end = min(BLOCK_M_SIZE_MAX, batch - b);
  1114. int h = BLOCK_KN_SIZE * blockIdx.z / 8;
  1115. int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4;
  1116. int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
  1117. __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2];
  1118. if (threadIdx.x < h_end) {
  1119. for (int m = 0; m < b_end; ++m) {
  1120. blockvec[m][threadIdx.x] =
  1121. vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 +
  1122. threadIdx.x];
  1123. }
  1124. }
  1125. __shared__ half2 deq2[256][8];
  1126. int val = threadIdx.x / 8;
  1127. int off = threadIdx.x % 8;
  1128. for (; val < 256; val += BLOCK_KN_SIZE / 8) {
  1129. deq2[val][off] = __halves2half2(
  1130. __int2half_rn(val & 0xF), __int2half_rn(val >> 4)
  1131. );
  1132. }
  1133. if (blockIdx.z == 0)
  1134. {
  1135. for (int m = 0; m < b_end; m++)
  1136. mul[(b + m) * width + w] = __int2half_rn(0);
  1137. }
  1138. __syncthreads();
  1139. int i = width * h + w;
  1140. int g_h = h * 8;
  1141. int k = 0;
  1142. int z_w = w / 8;
  1143. int z_mod = (w % 8) * 4;
  1144. half2 res2;
  1145. half res[BLOCK_M_SIZE_MAX] = {};
  1146. unsigned int tmp;
  1147. while (k < h_end) {
  1148. tmp = mat[i];
  1149. half2 scales_tmp[4];
  1150. half2 zeros_tmp[4];
  1151. for (int tmp_k = 0; tmp_k < 4; tmp_k++) {
  1152. int g = g_idx[g_h + (k + tmp_k) * 2];
  1153. int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1];
  1154. half scale_f = scales[g * width + w];
  1155. half scale_f2 = scales[g2 * width + w];
  1156. half2 scale = __halves2half2(scale_f, scale_f2);
  1157. half2 zero = __halves2half2(
  1158. __hmul(scale_f, __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xF) - 1)),
  1159. __hmul(scale_f2, __int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xF) - 1))
  1160. );
  1161. scales_tmp[tmp_k] = scale;
  1162. zeros_tmp[tmp_k] = zero;
  1163. }
  1164. for (int m = 0; m < b_end; m++) {
  1165. #ifndef USE_ROCM
  1166. res2 = {};
  1167. #else
  1168. res2.x = __half_as_ushort(__float2half(0));
  1169. res2.y = __half_as_ushort(__float2half(0));
  1170. #endif
  1171. res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xff][off], scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2);
  1172. res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xff][off], scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2);
  1173. res2 = __hfma2(__hfma2(deq2[(tmp >> 16) & 0xff][off], scales_tmp[2], zeros_tmp[2]), blockvec[m][k + 2], res2);
  1174. res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scales_tmp[3], zeros_tmp[3]), blockvec[m][k + 3], res2);
  1175. #ifndef USE_ROCM
  1176. res[m] = __hadd(res[m], __hadd(res2.x, res2.y));
  1177. #else
  1178. res[m] = __hadd(res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)));
  1179. #endif
  1180. }
  1181. i += width;
  1182. k += 4;
  1183. }
  1184. for (int m = 0; m < b_end; m++) {
  1185. atomicAdd(&mul[(b + m) * width + w], res[m]);
  1186. }
  1187. }
  1188. __global__ void gemm_half_q_half_alt_8bit_kernel(
  1189. const half2* __restrict__ vec,
  1190. const uint32_t* __restrict__ mat,
  1191. half* __restrict__ mul,
  1192. const half* __restrict__ scales,
  1193. const uint32_t* __restrict__ zeros,
  1194. const int* __restrict__ g_idx,
  1195. int batch,
  1196. int height,
  1197. int width
  1198. )
  1199. {
  1200. int zero_width = width / 4;
  1201. int vec_height = height * 2;
  1202. const int blockwidth2 = BLOCK_KN_SIZE / 2;
  1203. int b = blockIdx.y * BLOCK_M_SIZE_MAX;
  1204. int b_end = min(BLOCK_M_SIZE_MAX, batch - b);
  1205. int h = BLOCK_KN_SIZE * blockIdx.z / 4;
  1206. int h_end = min(BLOCK_KN_SIZE / 4, height - h) * 2;
  1207. int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
  1208. __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2];
  1209. if (threadIdx.x < h_end) {
  1210. for (int m = 0; m < b_end; ++m) {
  1211. blockvec[m][threadIdx.x] =
  1212. vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 +
  1213. threadIdx.x];
  1214. }
  1215. }
  1216. if (blockIdx.z == 0)
  1217. {
  1218. for (int m = 0; m < b_end; m++)
  1219. mul[(b + m) * width + w] = __int2half_rn(0);
  1220. }
  1221. __syncthreads();
  1222. int i = width * h + w;
  1223. int g_h = h * 4;
  1224. int k = 0;
  1225. int z_w = w / 4;
  1226. int z_mod = (w % 4) * 8;
  1227. half2 res2;
  1228. half res[BLOCK_M_SIZE_MAX] = {};
  1229. unsigned int tmp;
  1230. while (k < h_end) {
  1231. tmp = mat[i];
  1232. half2 scales_tmp[2];
  1233. half2 zeros_tmp[2];
  1234. for (int tmp_k = 0; tmp_k < 2; tmp_k++) {
  1235. int g = g_idx[g_h + (k + tmp_k) * 2];
  1236. int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1];
  1237. half scale_f = scales[g * width + w];
  1238. half scale_f2 = scales[g2 * width + w];
  1239. half2 scale = __halves2half2(scale_f, scale_f2);
  1240. half2 zero = __halves2half2(
  1241. __hmul(scale_f, __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xff) - 1)),
  1242. __hmul(scale_f2, __int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xff) - 1))
  1243. );
  1244. scales_tmp[tmp_k] = scale;
  1245. zeros_tmp[tmp_k] = zero;
  1246. }
  1247. for (int m = 0; m < b_end; m++) {
  1248. #ifndef USE_ROCM
  1249. res2 = {};
  1250. #else
  1251. res2.x = __half_as_ushort(__float2half(0));
  1252. res2.y = __half_as_ushort(__float2half(0));
  1253. #endif
  1254. half2 v12 = __halves2half2(__int2half_rn(tmp & 0xFF), __int2half_rn((tmp >> 8) & 0xFF));
  1255. res2 = __hfma2(__hfma2(v12, scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2);
  1256. half2 v34 = __halves2half2(__int2half_rn((tmp >> 16) & 0xFF), __int2half_rn((tmp >> 24) & 0xFF));
  1257. res2 = __hfma2(__hfma2(v34, scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2);
  1258. #ifndef USE_ROCM
  1259. res[m] = __hadd(res[m], __hadd(res2.x, res2.y));
  1260. #else
  1261. res[m] = __hadd(res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)));
  1262. #endif
  1263. }
  1264. i += width;
  1265. k += 2;
  1266. }
  1267. for (int m = 0; m < b_end; m++) {
  1268. atomicAdd(&mul[(b + m) * width + w], res[m]);
  1269. }
  1270. }
  1271. void gemm_half_q_half_alt
  1272. (
  1273. const half* a,
  1274. const uint32_t* b_q_weight,
  1275. const uint32_t* b_gptq_qzeros,
  1276. const half* b_gptq_scales,
  1277. const int* b_g_idx,
  1278. half* c,
  1279. int size_m,
  1280. int size_n,
  1281. int size_k,
  1282. int bit
  1283. )
  1284. {
  1285. dim3 blockDim, gridDim;
  1286. blockDim.x = BLOCK_KN_SIZE;
  1287. blockDim.y = 1;
  1288. blockDim.z = 1;
  1289. gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE);
  1290. gridDim.y = DIVIDE(size_m, BLOCK_M_SIZE_MAX);
  1291. gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
  1292. auto kernel = gemm_half_q_half_alt_4bit_kernel;
  1293. if (bit == 8) {
  1294. kernel = gemm_half_q_half_alt_8bit_kernel;
  1295. }
  1296. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  1297. kernel<<<gridDim, blockDim, 0, stream>>>
  1298. (
  1299. (const half2*) a,
  1300. b_q_weight,
  1301. c,
  1302. b_gptq_scales,
  1303. b_gptq_qzeros,
  1304. b_g_idx,
  1305. size_m,
  1306. size_k / 32 * bit,
  1307. size_n
  1308. );
  1309. }
  1310. template<class T, int bit>
  1311. __global__ void reconstruct_gptq_kernel
  1312. (
  1313. const uint32_t* __restrict__ w,
  1314. const half* __restrict__ w_scales,
  1315. const uint32_t* __restrict__ w_zeros,
  1316. const int* __restrict__ g_idx,
  1317. const int height,
  1318. const int width,
  1319. const int group,
  1320. half* __restrict__ out
  1321. )
  1322. {
  1323. if (blockIdx.z > 0){
  1324. w = w + blockIdx.z * height * width / 8;
  1325. w_scales = w_scales + blockIdx.z * group * width;
  1326. w_zeros = w_zeros + blockIdx.z * group * width / 8;
  1327. g_idx = g_idx + blockIdx.z * height;
  1328. out = out + blockIdx.z * height * width;
  1329. }
  1330. // Start of block
  1331. int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
  1332. int row = blockIdx.y * 32 / bit;
  1333. if (column >= width) return;
  1334. // Views
  1335. MatrixView_half_rw out_(out, height, width);
  1336. MatrixView_half w_scales_(w_scales, group, width);
  1337. T w_zeros_(w_zeros, group, width);
  1338. uint32_t w_read = w[blockIdx.y * width + column];
  1339. half* out_ptr = out_.item_ptr(row, column);
  1340. #pragma unroll
  1341. for (int s = 0; s < 32; s += bit)
  1342. {
  1343. int group = g_idx[row + s / bit];
  1344. half w_scale = w_scales_.item(group, column);
  1345. uint32_t w_zero = w_zeros_.item(group, column) + 1;
  1346. half w_item = __hmul(__int2half_rn((int)((w_read >> s) & ((1 << bit) - 1)) - w_zero), w_scale);
  1347. *out_ptr = w_item; out_ptr += out_.width;
  1348. }
  1349. }
  1350. __global__ void reconstruct_gptq_3bit_kernel
  1351. (
  1352. const uint32_t* __restrict__ w,
  1353. const half* __restrict__ w_scales,
  1354. const uint32_t* __restrict__ w_zeros,
  1355. const int* __restrict__ g_idx,
  1356. const int height,
  1357. const int width,
  1358. const int group,
  1359. half* __restrict__ out
  1360. )
  1361. {
  1362. // Start of block
  1363. int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
  1364. int row = blockIdx.y * 32;
  1365. if (column >= width) return;
  1366. // Views
  1367. MatrixView_half_rw out_(out, height, width);
  1368. MatrixView_half w_scales_(w_scales, group, width);
  1369. MatrixView_q3_row w_zeros_(w_zeros, group, width);
  1370. uint32_t w1 = w[(blockIdx.y * 3) * width + column];
  1371. uint32_t w2 = w[(blockIdx.y * 3 + 1) * width + column];
  1372. uint32_t w3 = w[(blockIdx.y * 3 + 2) * width + column];
  1373. half* out_ptr = out_.item_ptr(row, column);
  1374. #pragma unroll
  1375. for (int i = 0; i < 32; i += 1)
  1376. {
  1377. int group = g_idx[row + i];
  1378. half w_scale = w_scales_.item(group, column);
  1379. uint32_t w_zero = w_zeros_.item(group, column) + 1;
  1380. int w_item;
  1381. if (i == 10) {
  1382. w_item = (w1 >> 30) | ((w2 << 2) & 0x4);
  1383. } else if (i == 21) {
  1384. w_item = (w2 >> 31) | ((w3 << 1) & 0x6);
  1385. } else if (i < 10) {
  1386. w_item = ((w1 >> (i * 3)) & 0x7);
  1387. } else if (i < 21) {
  1388. w_item = ((w2 >> (i * 3 - 32)) & 0x7);
  1389. } else {
  1390. w_item = ((w3 >> (i * 3 - 64)) & 0x7);
  1391. }
  1392. *out_ptr = __hmul(__int2half_rn(w_item - w_zero), w_scale);
  1393. out_ptr += out_.width;
  1394. }
  1395. }
  1396. void reconstruct_gptq
  1397. (
  1398. const uint32_t* b_q_weight,
  1399. const uint32_t* b_gptq_qzeros,
  1400. const half* b_gptq_scales,
  1401. const int* b_g_idx,
  1402. half* out,
  1403. int height,
  1404. int width,
  1405. int groups,
  1406. int num_experts,
  1407. int bit
  1408. )
  1409. {
  1410. dim3 blockDim, gridDim;
  1411. blockDim.x = BLOCK_KN_SIZE;
  1412. blockDim.y = 1;
  1413. gridDim.y = DIVIDE(height, 32 / bit);
  1414. gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
  1415. gridDim.z = num_experts;
  1416. auto kernel = reconstruct_gptq_kernel<MatrixView_q4_row, 4>;
  1417. if (bit == 2) {
  1418. kernel = reconstruct_gptq_kernel<MatrixView_q2_row, 2>;
  1419. } else if (bit == 8) {
  1420. kernel = reconstruct_gptq_kernel<MatrixView_q8_row, 8>;
  1421. } else if (bit == 3) {
  1422. kernel = reconstruct_gptq_3bit_kernel;
  1423. gridDim.y = DIVIDE(height, 32);
  1424. }
  1425. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  1426. kernel<<<gridDim, blockDim, 0, stream>>>
  1427. (
  1428. b_q_weight,
  1429. b_gptq_scales,
  1430. b_gptq_qzeros,
  1431. b_g_idx,
  1432. height,
  1433. width,
  1434. groups,
  1435. out
  1436. );
  1437. }
  1438. void dequant_gptq_cuda
  1439. (
  1440. const uint32_t* b_q_weight,
  1441. const uint32_t* b_gptq_qzeros,
  1442. const half* b_gptq_scales,
  1443. const int* b_g_idx,
  1444. half* temp_dq,
  1445. int size_k,
  1446. int size_n,
  1447. int groups,
  1448. int num_experts,
  1449. int bits,
  1450. bool use_exllama
  1451. )
  1452. {
  1453. if (use_exllama) {
  1454. reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, temp_dq,
  1455. size_k, size_n, groups, num_experts, bits);
  1456. }
  1457. else
  1458. {
  1459. reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
  1460. temp_dq, size_k, size_n, groups, num_experts, bits);
  1461. }
  1462. }
  1463. void gemm_half_q_half_cuda
  1464. (
  1465. cublasHandle_t cublas_handle,
  1466. const half* a,
  1467. const uint32_t* b_q_weight,
  1468. const uint32_t* b_gptq_qzeros,
  1469. const half* b_gptq_scales,
  1470. const int* b_g_idx,
  1471. half* c,
  1472. half* temp_dq,
  1473. int size_m,
  1474. int size_n,
  1475. int size_k,
  1476. int groups,
  1477. bool use_exllama,
  1478. int bit
  1479. )
  1480. {
  1481. bool use_reconstruct;
  1482. if (use_exllama) {
  1483. use_reconstruct = ((bit == 8 && size_m > MAX_Q_GEMM_ROWS_8BIT) || (bit != 8 && size_m > MAX_Q_GEMM_ROWS));
  1484. } else {
  1485. // The 2/3-bit kernels are somehow slower than dequant + gemm baseline, so we disabled them for now.
  1486. use_reconstruct = (bit < 4 || size_m > MAX_ALT_GEMM_ROWS);
  1487. }
  1488. if (use_reconstruct) {
  1489. // Reconstruct FP16 matrix, then cuBLAS
  1490. dequant_gptq_cuda(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, temp_dq,
  1491. size_k, size_n, groups, 1, bit, use_exllama);
  1492. const half alpha = __float2half(1.0f);
  1493. const half beta = __float2half(0.0f);
  1494. cublasHgemm(cublas_handle,
  1495. CUBLAS_OP_N,
  1496. CUBLAS_OP_N,
  1497. size_n, size_m, size_k,
  1498. &alpha, temp_dq, size_n,
  1499. a, size_k,
  1500. &beta, c, size_n);
  1501. }
  1502. else if (use_exllama)
  1503. {
  1504. // Quantized matmul
  1505. int max_chunks = size_m / BLOCK_M_SIZE_MAX;
  1506. int last_chunk = max_chunks * BLOCK_M_SIZE_MAX;
  1507. int last_chunk_size = size_m - last_chunk;
  1508. if (max_chunks)
  1509. {
  1510. gemm_half_q_half_cuda_part(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
  1511. c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX,
  1512. groups, bit);
  1513. }
  1514. if (last_chunk_size)
  1515. {
  1516. gemm_half_q_half_cuda_part(a + last_chunk * size_k, b_q_weight, b_gptq_qzeros,
  1517. b_gptq_scales, b_g_idx, c + last_chunk * size_n,
  1518. last_chunk_size, size_n, size_k, last_chunk_size,
  1519. groups, bit);
  1520. }
  1521. }
  1522. else
  1523. {
  1524. gemm_half_q_half_alt(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
  1525. c, size_m, size_n, size_k, bit);
  1526. }
  1527. }
  1528. __global__ void shuffle_4bit_kernel
  1529. (
  1530. uint32_t* __restrict__ b_q_weight,
  1531. const int size_k,
  1532. const int size_n
  1533. )
  1534. {
  1535. int n = blockIdx.x * THREADS_X + threadIdx.x;
  1536. if (n >= size_n) return;
  1537. int k = 0;
  1538. uint32_t* b_ptr = b_q_weight + n;
  1539. while (k < size_k) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k += 8; }
  1540. }
  1541. __global__ void shuffle_8bit_kernel
  1542. (
  1543. uint32_t* __restrict__ b_q_weight,
  1544. const int size_k,
  1545. const int size_n
  1546. )
  1547. {
  1548. int n = blockIdx.x * THREADS_X + threadIdx.x;
  1549. if (n >= size_n) return;
  1550. int k = 0;
  1551. uint32_t* b_ptr = b_q_weight + n;
  1552. while (k < size_k) { shuffle_8bit_4 (b_ptr, size_n); b_ptr += 1 * size_n; k += 4; }
  1553. }
  1554. __global__ void shuffle_2bit_kernel
  1555. (
  1556. uint32_t* __restrict__ b_q_weight,
  1557. const int size_k,
  1558. const int size_n
  1559. )
  1560. {
  1561. int n = blockIdx.x * THREADS_X + threadIdx.x;
  1562. if (n >= size_n) return;
  1563. int k = 0;
  1564. uint32_t* b_ptr = b_q_weight + n;
  1565. while (k < size_k) { shuffle_2bit_16(b_ptr, size_n); b_ptr += 1 * size_n; k += 16; }
  1566. }
  1567. __global__ void shuffle_3bit_kernel
  1568. (
  1569. uint32_t* __restrict__ b_q_weight,
  1570. const int size_k,
  1571. const int size_n
  1572. )
  1573. {
  1574. int n = blockIdx.x * THREADS_X + threadIdx.x;
  1575. if (n >= size_n) return;
  1576. int k = 0;
  1577. uint32_t* b_ptr = b_q_weight + n;
  1578. while (k < size_k) { shuffle_3bit_32(b_ptr, size_n); b_ptr += 3 * size_n; k += 32; }
  1579. }
  1580. __global__ void make_sequential_4bit_kernel
  1581. (
  1582. const uint32_t* __restrict__ w,
  1583. uint32_t* __restrict__ w_new,
  1584. const int* __restrict__ q_perm,
  1585. const int w_height,
  1586. const int w_width
  1587. )
  1588. {
  1589. if (blockIdx.z > 0){
  1590. w = w + blockIdx.z * w_height * w_width;
  1591. w_new = w_new + blockIdx.z * w_height * w_width;
  1592. q_perm = q_perm + blockIdx.z * w_height * 8;
  1593. }
  1594. const uint64_t* w2 = (uint64_t*) w;
  1595. uint64_t* w_new2 = (uint64_t*) w_new;
  1596. int w2_stride = w_width >> 1;
  1597. int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
  1598. if (w2_column >= w2_stride) return;
  1599. int w_new2_row = blockIdx.y;
  1600. int q_perm_idx = w_new2_row << 3;
  1601. uint64_t dst = 0;
  1602. #pragma unroll
  1603. for (int i = 0; i < 8; i++)
  1604. {
  1605. int source_row = q_perm[q_perm_idx++];
  1606. int w2_row = source_row >> 3;
  1607. int w2_subrow = source_row & 0x07;
  1608. int w2_row_shift = w2_subrow << 2;
  1609. int wnew2_row_shift = i << 2;
  1610. uint64_t src = w2[w2_row * w2_stride + w2_column];
  1611. src >>= w2_row_shift;
  1612. src &= 0x0000000f0000000f;
  1613. src <<= wnew2_row_shift;
  1614. dst |= src;
  1615. }
  1616. w_new2[w_new2_row * w2_stride + w2_column] = dst;
  1617. }
  1618. __global__ void make_sequential_2bit_kernel
  1619. (
  1620. const uint32_t* __restrict__ w,
  1621. uint32_t* __restrict__ w_new,
  1622. const int* __restrict__ q_perm,
  1623. const int w_height,
  1624. const int w_width
  1625. )
  1626. {
  1627. if (blockIdx.z > 0){
  1628. w = w + blockIdx.z * w_height * w_width;
  1629. w_new = w_new + blockIdx.z * w_height * w_width;
  1630. q_perm = q_perm + blockIdx.z * w_height * 16;
  1631. }
  1632. const uint64_t* w2 = (uint64_t*) w;
  1633. uint64_t* w_new2 = (uint64_t*) w_new;
  1634. int w2_stride = w_width >> 1;
  1635. int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
  1636. if (w2_column >= w2_stride) return;
  1637. int w_new2_row = blockIdx.y;
  1638. int q_perm_idx = w_new2_row << 4;
  1639. uint64_t dst = 0;
  1640. #pragma unroll
  1641. for (int i = 0; i < 16; i++)
  1642. {
  1643. int source_row = q_perm[q_perm_idx++];
  1644. int w2_row = source_row >> 4;
  1645. int w2_subrow = source_row & 0x0f;
  1646. int w2_row_shift = w2_subrow << 1;
  1647. int wnew2_row_shift = i << 1;
  1648. uint64_t src = w2[w2_row * w2_stride + w2_column];
  1649. src >>= w2_row_shift;
  1650. src &= 0x0000000300000003;
  1651. src <<= wnew2_row_shift;
  1652. dst |= src;
  1653. }
  1654. w_new2[w_new2_row * w2_stride + w2_column] = dst;
  1655. }
  1656. __global__ void make_sequential_3bit_kernel
  1657. (
  1658. const uint32_t* __restrict__ w,
  1659. uint32_t* __restrict__ w_new,
  1660. const int* __restrict__ q_perm,
  1661. const int w_height,
  1662. const int w_width
  1663. )
  1664. {
  1665. if (blockIdx.z > 0){
  1666. w = w + blockIdx.z * w_height * w_width;
  1667. w_new = w_new + blockIdx.z * w_height * w_width;
  1668. q_perm = q_perm + blockIdx.z * w_height * 32 / 3;
  1669. }
  1670. int w_column = THREADS_X * blockIdx.x + threadIdx.x;
  1671. if (w_column >= w_width) return;
  1672. int w_new_row = blockIdx.y * 3;
  1673. int q_perm_idx = blockIdx.y << 5;
  1674. uint32_t dst[3] = {0, 0, 0};
  1675. #pragma unroll
  1676. for (int i = 0; i < 32; i++)
  1677. {
  1678. int source_row = q_perm[q_perm_idx++];
  1679. int z_w = (source_row / 32) * 3;
  1680. int z_mod = source_row % 32;
  1681. int z_bit;
  1682. if (z_mod != 10){
  1683. if (z_mod != 21){
  1684. z_bit = z_mod;
  1685. if (z_bit > 21){
  1686. z_bit *= 3;
  1687. z_bit -= 64;
  1688. z_w += 2;
  1689. } else if (z_bit > 10){
  1690. z_bit *= 3;
  1691. z_bit -= 32;
  1692. z_w += 1;
  1693. } else {
  1694. z_bit *= 3;
  1695. }
  1696. } else {
  1697. z_w += 1;
  1698. }
  1699. }
  1700. uint64_t src;
  1701. if (z_mod == 10) {
  1702. src = (w[z_w * w_width + w_column] >> 30) | ((w[(z_w + 1) * w_width + w_column] << 2) & 0x4);
  1703. } else if (z_mod == 21){
  1704. src = (w[z_w * w_width + w_column] >> 31) | ((w[(z_w + 1) * w_width + w_column] << 1) & 0x6);
  1705. } else {
  1706. src = w[z_w * w_width + w_column];
  1707. src >>= z_bit;
  1708. src &= 0x07;
  1709. }
  1710. z_w = 0;
  1711. if (i != 10){
  1712. if (i != 21){
  1713. z_bit = i;
  1714. if (z_bit > 21){
  1715. z_bit *= 3;
  1716. z_bit -= 64;
  1717. z_w += 2;
  1718. } else if (z_bit > 10){
  1719. z_bit *= 3;
  1720. z_bit -= 32;
  1721. z_w += 1;
  1722. } else {
  1723. z_bit *= 3;
  1724. }
  1725. } else {
  1726. z_w += 1;
  1727. }
  1728. }
  1729. if (i == 10) {
  1730. dst[z_w] |= (src & 0x03) << 30;
  1731. dst[z_w + 1] |= ((src & 0x4) >> 2);
  1732. } else if (i == 21) {
  1733. dst[z_w] |= (src & 0x01) << 31;
  1734. dst[z_w + 1] |= ((src & 0x6) >> 1);
  1735. } else {
  1736. dst[z_w] |= (src << z_bit);
  1737. }
  1738. }
  1739. w_new[w_new_row * w_width + w_column] = dst[0];
  1740. w_new[(w_new_row + 1) * w_width + w_column] = dst[1];
  1741. w_new[(w_new_row + 2) * w_width + w_column] = dst[2];
  1742. }
  1743. __global__ void make_sequential_8bit_kernel
  1744. (
  1745. const uint32_t* __restrict__ w,
  1746. uint32_t* __restrict__ w_new,
  1747. const int* __restrict__ q_perm,
  1748. const int w_height,
  1749. const int w_width
  1750. )
  1751. {
  1752. if (blockIdx.z > 0){
  1753. w = w + blockIdx.z * w_height * w_width;
  1754. w_new = w_new + blockIdx.z * w_height * w_width;
  1755. q_perm = q_perm + blockIdx.z * w_height * 4;
  1756. }
  1757. const uint64_t* w2 = (uint64_t*) w;
  1758. uint64_t* w_new2 = (uint64_t*) w_new;
  1759. int w2_stride = w_width >> 1;
  1760. int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
  1761. if (w2_column >= w2_stride) return;
  1762. int w_new2_row = blockIdx.y;
  1763. int q_perm_idx = w_new2_row << 2;
  1764. uint64_t dst = 0;
  1765. #pragma unroll
  1766. for (int i = 0; i < 4; i++)
  1767. {
  1768. int source_row = q_perm[q_perm_idx++];
  1769. int w2_row = source_row >> 2;
  1770. int w2_subrow = source_row & 0x03;
  1771. int w2_row_shift = w2_subrow << 3;
  1772. int wnew2_row_shift = i << 3;
  1773. uint64_t src = w2[w2_row * w2_stride + w2_column];
  1774. src >>= w2_row_shift;
  1775. src &= 0x000000ff000000ff;
  1776. src <<= wnew2_row_shift;
  1777. dst |= src;
  1778. }
  1779. w_new2[w_new2_row * w2_stride + w2_column] = dst;
  1780. }
  1781. void shuffle_exllama_weight
  1782. (
  1783. uint32_t* q_weight,
  1784. int* q_perm,
  1785. int height,
  1786. int width,
  1787. int num_experts,
  1788. int bit
  1789. )
  1790. {
  1791. if (q_perm)
  1792. {
  1793. uint32_t* new_qweight = NULL;
  1794. cudaMalloc(&new_qweight, num_experts * height / 32 * bit * width * sizeof(uint32_t));
  1795. dim3 blockDim, gridDim;
  1796. blockDim.x = THREADS_X;
  1797. blockDim.y = 1;
  1798. gridDim.x = DIVIDE(width, THREADS_X);
  1799. gridDim.y = height / 32 * bit;
  1800. gridDim.z = num_experts;
  1801. auto kernel = make_sequential_4bit_kernel;
  1802. if (bit == 2) {
  1803. kernel = make_sequential_2bit_kernel;
  1804. } else if (bit == 3) {
  1805. kernel = make_sequential_3bit_kernel;
  1806. gridDim.y = height / 32;
  1807. } else if (bit == 8) {
  1808. kernel = make_sequential_8bit_kernel;
  1809. }
  1810. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  1811. kernel<<<gridDim, blockDim, 0, stream>>>
  1812. (
  1813. q_weight,
  1814. new_qweight,
  1815. q_perm,
  1816. height / 32 * bit,
  1817. width
  1818. );
  1819. // Replace qweights
  1820. cudaMemcpyAsync(q_weight, new_qweight, num_experts * height / 32 * bit * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
  1821. // Cleanup
  1822. cudaDeviceSynchronize();
  1823. cudaFree(new_qweight);
  1824. }
  1825. dim3 blockDim, gridDim;
  1826. blockDim.x = THREADS_X;
  1827. blockDim.y = 1;
  1828. gridDim.x = DIVIDE(width, THREADS_X);
  1829. gridDim.y = 1;
  1830. auto shuffle_kernel = shuffle_4bit_kernel;
  1831. if (bit == 2) {
  1832. shuffle_kernel = shuffle_2bit_kernel;
  1833. } else if (bit == 3) {
  1834. shuffle_kernel = shuffle_3bit_kernel;
  1835. } else if (bit == 8) {
  1836. shuffle_kernel = shuffle_8bit_kernel;
  1837. }
  1838. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  1839. shuffle_kernel<<<gridDim, blockDim, 0, stream>>>(q_weight, height * num_experts, width);
  1840. }
  1841. template <int m_count>
  1842. __global__ void group_gemm_half_q_half_gptq_kernel
  1843. (
  1844. const half* __restrict__ a,
  1845. const uint32_t* __restrict__ b_q_weight,
  1846. const uint32_t* __restrict__ b_gptq_qzeros,
  1847. const half* __restrict__ b_gptq_scales,
  1848. half* __restrict__ c,
  1849. const int size_m,
  1850. const int size_n,
  1851. const int size_k,
  1852. const int groups,
  1853. const int* __restrict__ b_q_perm,
  1854. const float* __restrict__ topk_weights,
  1855. const int* __restrict__ sorted_token_ids_ptr,
  1856. const int* __restrict__ expert_ids_ptr,
  1857. const int* __restrict__ num_tokens_post_padded,
  1858. const int num_valid_tokens,
  1859. const int top_k
  1860. )
  1861. {
  1862. int num_tokens = *num_tokens_post_padded;
  1863. int offset_m = blockIdx.y * m_count;
  1864. if (offset_m >= num_tokens) return;
  1865. int expert_id = expert_ids_ptr[blockIdx.y];
  1866. b_q_weight = b_q_weight + size_k * size_n / 8 * expert_id;
  1867. b_gptq_qzeros = b_gptq_qzeros + groups * size_n / 8 * expert_id;
  1868. b_gptq_scales = b_gptq_scales + groups * size_n * expert_id;
  1869. if (b_q_perm) b_q_perm = b_q_perm + size_k * expert_id;
  1870. MatrixView_half a_(a, size_m, size_k);
  1871. MatrixView_half_rw c_(c, size_m, size_n);
  1872. MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
  1873. MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
  1874. int t = threadIdx.x;
  1875. // Block
  1876. int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
  1877. int offset_k = blockIdx.z * BLOCK_KN_SIZE;
  1878. int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
  1879. int end_m = min(offset_m + m_count, size_m);
  1880. int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
  1881. int n = offset_n + t * 4;
  1882. // Preload block_a
  1883. __shared__ half block_a[m_count][BLOCK_KN_SIZE];
  1884. int token_a[m_count];
  1885. int valid_count = m_count;
  1886. for (int m = 0; m < m_count; ++m) {
  1887. int token_id = sorted_token_ids_ptr[offset_m + m];
  1888. if (token_id >= num_valid_tokens) {
  1889. valid_count = m;
  1890. break;
  1891. }
  1892. token_a[m] = token_id;
  1893. }
  1894. if (offset_k + t < end_k)
  1895. {
  1896. for (int m = 0; m < valid_count; ++m)
  1897. {
  1898. const half* a_ptr = a_.item_ptr(token_a[m] / top_k, 0);
  1899. half* block_a_ptr = block_a[m];
  1900. half a0;
  1901. if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]];
  1902. else a0 = a_ptr[offset_k + t];
  1903. block_a_ptr[t] = a0;
  1904. }
  1905. }
  1906. // Zero output
  1907. if (n >= size_n) return;
  1908. __syncthreads();
  1909. // Find initial group
  1910. int groupsize = size_k / groups;
  1911. int group = offset_k / groupsize;
  1912. int nextgroup = offset_k + groupsize;
  1913. // a, b offset
  1914. int qk = offset_k / (32 / 4);
  1915. const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
  1916. const half* a_ptr = &block_a[0][0];
  1917. int a_stride = BLOCK_KN_SIZE;
  1918. // Initial group
  1919. int zeros[4];
  1920. float scales[4];
  1921. half2 z1z16[4][2];
  1922. half2 y1y16[4][2];
  1923. b_gptq_qzeros_.item4(zeros, group, n);
  1924. b_gptq_scales_.item4_f(scales, group, n);
  1925. dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
  1926. dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
  1927. dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
  1928. dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
  1929. // Column result
  1930. float block_c[m_count][4] = {};
  1931. // Dequantize and multiply
  1932. int k = offset_k;
  1933. while (k < end_k)
  1934. {
  1935. if (k == nextgroup)
  1936. {
  1937. group++;
  1938. nextgroup += groupsize;
  1939. b_gptq_qzeros_.item4(zeros, group, n);
  1940. b_gptq_scales_.item4_f(scales, group, n);
  1941. dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
  1942. dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
  1943. dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
  1944. dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
  1945. }
  1946. #pragma unroll
  1947. for (int j = 0; j < 4; j++)
  1948. {
  1949. const int4* b_ptr4 = (int4*) b_ptr;
  1950. int4 load_int4 = *b_ptr4;
  1951. half2 dq[4][4];
  1952. dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
  1953. dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
  1954. dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
  1955. dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
  1956. for (int m = 0; m < valid_count; m++)
  1957. {
  1958. block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]);
  1959. block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]);
  1960. block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]);
  1961. block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]);
  1962. }
  1963. b_ptr += size_n;
  1964. a_ptr += 8;
  1965. }
  1966. k += 32;
  1967. }
  1968. for (int m = 0; m < valid_count; m++)
  1969. {
  1970. if (topk_weights) {
  1971. #pragma unroll
  1972. for (int j = 0; j < 4; ++j) {
  1973. block_c[m][j] = block_c[m][j] * topk_weights[token_a[m]];
  1974. }
  1975. }
  1976. half2 *out = (half2*) c_.item_ptr(token_a[m], n);
  1977. half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
  1978. half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
  1979. atomicAdd(out , result01);
  1980. atomicAdd(out + 1, result23);
  1981. }
  1982. }
  1983. void group_gemm_half_q_half
  1984. (
  1985. const half* a,
  1986. const uint32_t* b_q_weight,
  1987. const uint32_t* b_gptq_qzeros,
  1988. const half* b_gptq_scales,
  1989. const int* b_q_perm,
  1990. half* c,
  1991. const float* __restrict__ topk_weights,
  1992. const int* __restrict__ sorted_token_ids_ptr,
  1993. const int* __restrict__ expert_ids_ptr,
  1994. const int* __restrict__ num_tokens_post_padded,
  1995. const int num_valid_tokens,
  1996. const int top_k,
  1997. int size_m,
  1998. int size_n,
  1999. int size_k,
  2000. int pad_size_m,
  2001. int groups
  2002. )
  2003. {
  2004. dim3 blockDim, gridDim;
  2005. blockDim.x = BLOCK_KN_SIZE;
  2006. blockDim.y = 1;
  2007. blockDim.z = 1;
  2008. gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
  2009. gridDim.y = DIVIDE(pad_size_m, BLOCK_M_SIZE_MAX);
  2010. gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
  2011. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  2012. group_gemm_half_q_half_gptq_kernel<BLOCK_M_SIZE_MAX><<<gridDim, blockDim, 0, stream>>>
  2013. (
  2014. a,
  2015. b_q_weight,
  2016. b_gptq_qzeros,
  2017. b_gptq_scales,
  2018. c,
  2019. size_m,
  2020. size_n,
  2021. size_k,
  2022. groups,
  2023. b_q_perm,
  2024. topk_weights,
  2025. sorted_token_ids_ptr,
  2026. expert_ids_ptr,
  2027. num_tokens_post_padded,
  2028. num_valid_tokens,
  2029. top_k
  2030. );
  2031. }
  2032. __global__ void group_gemm_half_q_half_alt_kernel(
  2033. const half2* __restrict__ vec,
  2034. const uint32_t* __restrict__ mat,
  2035. half* __restrict__ mul,
  2036. const half* __restrict__ scales,
  2037. const uint32_t* __restrict__ zeros,
  2038. const int* __restrict__ g_idx,
  2039. int batch,
  2040. int height,
  2041. int width,
  2042. int groups,
  2043. const float* __restrict__ topk_weights,
  2044. const int* __restrict__ sorted_token_ids_ptr,
  2045. const int* __restrict__ expert_ids_ptr,
  2046. const int* __restrict__ num_tokens_post_padded,
  2047. const int num_valid_tokens,
  2048. const int top_k
  2049. )
  2050. {
  2051. int num_tokens = *num_tokens_post_padded;
  2052. int b = blockIdx.y * BLOCK_M_SIZE_MAX;
  2053. if (b >= num_tokens) return;
  2054. int expert_id = expert_ids_ptr[blockIdx.y];
  2055. mat = mat + height * width * expert_id;
  2056. scales = scales + groups * width * expert_id;
  2057. zeros = zeros + groups * width / 8 * expert_id;
  2058. g_idx = g_idx + height * 8 * expert_id;
  2059. int zero_width = width / 8;
  2060. int vec_height = height * 4;
  2061. const int blockwidth2 = BLOCK_KN_SIZE / 2;
  2062. int b_end = BLOCK_M_SIZE_MAX;
  2063. int h = BLOCK_KN_SIZE * blockIdx.z / 8;
  2064. int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4;
  2065. int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
  2066. int token_a[BLOCK_M_SIZE_MAX];
  2067. for (int m = 0; m < b_end; ++m) {
  2068. int token_id = sorted_token_ids_ptr[b + m];
  2069. if (token_id >= num_valid_tokens) {
  2070. b_end = m;
  2071. break;
  2072. }
  2073. token_a[m] = token_id;
  2074. }
  2075. __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2];
  2076. if (threadIdx.x < h_end) {
  2077. for (int m = 0; m < b_end; ++m) {
  2078. blockvec[m][threadIdx.x] =
  2079. vec[token_a[m] / top_k * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 +
  2080. threadIdx.x];
  2081. }
  2082. }
  2083. __shared__ half2 deq2[256][8];
  2084. int val = threadIdx.x / 8;
  2085. int off = threadIdx.x % 8;
  2086. for (; val < 256; val += BLOCK_KN_SIZE / 8) {
  2087. deq2[val][off] = __halves2half2(
  2088. __int2half_rn(val & 0xF), __int2half_rn(val >> 4)
  2089. );
  2090. }
  2091. __syncthreads();
  2092. int i = width * h + w;
  2093. int g_h = h * 8;
  2094. int k = 0;
  2095. int z_w = w / 8;
  2096. int z_mod = (w % 8) * 4;
  2097. half2 res2;
  2098. half res[BLOCK_M_SIZE_MAX] = {};
  2099. unsigned int tmp;
  2100. while (k < h_end) {
  2101. tmp = mat[i];
  2102. half2 scales_tmp[4];
  2103. half2 zeros_tmp[4];
  2104. for (int tmp_k = 0; tmp_k < 4; tmp_k++) {
  2105. int g = g_idx[g_h + (k + tmp_k) * 2];
  2106. int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1];
  2107. half scale_f = scales[g * width + w];
  2108. half scale_f2 = scales[g2 * width + w];
  2109. half2 scale = __halves2half2(scale_f, scale_f2);
  2110. half2 zero = __halves2half2(
  2111. __hmul(scale_f, __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xF) - 1)),
  2112. __hmul(scale_f2, __int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xF) - 1))
  2113. );
  2114. scales_tmp[tmp_k] = scale;
  2115. zeros_tmp[tmp_k] = zero;
  2116. }
  2117. for (int m = 0; m < b_end; m++) {
  2118. #ifndef USE_ROCM
  2119. res2 = {};
  2120. #else
  2121. res2.x = __half_as_ushort(__float2half(0));
  2122. res2.y = __half_as_ushort(__float2half(0));
  2123. #endif
  2124. res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xff][off], scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2);
  2125. res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xff][off], scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2);
  2126. res2 = __hfma2(__hfma2(deq2[(tmp >> 16) & 0xff][off], scales_tmp[2], zeros_tmp[2]), blockvec[m][k + 2], res2);
  2127. res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scales_tmp[3], zeros_tmp[3]), blockvec[m][k + 3], res2);
  2128. #ifndef USE_ROCM
  2129. res[m] = __hadd(res[m], __hadd(res2.x, res2.y));
  2130. #else
  2131. res[m] = __hadd(res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)));
  2132. #endif
  2133. }
  2134. i += width;
  2135. k += 4;
  2136. }
  2137. for (int m = 0; m < b_end; m++) {
  2138. if (topk_weights) {
  2139. res[m] = __float2half(__half2float(res[m]) * topk_weights[token_a[m]]);
  2140. }
  2141. atomicAdd(&mul[token_a[m] * width + w], res[m]);
  2142. }
  2143. }
  2144. void group_gemm_half_q_half_alt
  2145. (
  2146. const half* a,
  2147. const uint32_t* b_q_weight,
  2148. const uint32_t* b_gptq_qzeros,
  2149. const half* b_gptq_scales,
  2150. const int* b_g_idx,
  2151. half* c,
  2152. const float* __restrict__ topk_weights,
  2153. const int* __restrict__ sorted_token_ids_ptr,
  2154. const int* __restrict__ expert_ids_ptr,
  2155. const int* __restrict__ num_tokens_post_padded,
  2156. const int num_valid_tokens,
  2157. const int top_k,
  2158. int size_m,
  2159. int size_n,
  2160. int size_k,
  2161. int pad_size_m,
  2162. int groups
  2163. )
  2164. {
  2165. dim3 blockDim, gridDim;
  2166. blockDim.x = BLOCK_KN_SIZE;
  2167. blockDim.y = 1;
  2168. blockDim.z = 1;
  2169. gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE);
  2170. gridDim.y = DIVIDE(pad_size_m, BLOCK_M_SIZE_MAX);
  2171. gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
  2172. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  2173. group_gemm_half_q_half_alt_kernel<<<gridDim, blockDim, 0, stream>>>
  2174. (
  2175. (const half2*) a,
  2176. b_q_weight,
  2177. c,
  2178. b_gptq_scales,
  2179. b_gptq_qzeros,
  2180. b_g_idx,
  2181. size_m,
  2182. size_k / 8,
  2183. size_n,
  2184. groups,
  2185. topk_weights,
  2186. sorted_token_ids_ptr,
  2187. expert_ids_ptr,
  2188. num_tokens_post_padded,
  2189. num_valid_tokens,
  2190. top_k
  2191. );
  2192. }
  2193. // Only support 4-bit so far
  2194. void group_gemm_half_q_half_cuda
  2195. (
  2196. const half* a,
  2197. const uint32_t* b_q_weight,
  2198. const uint32_t* b_gptq_qzeros,
  2199. const half* b_gptq_scales,
  2200. const int* b_g_idx,
  2201. half* c,
  2202. const float* __restrict__ topk_weights,
  2203. const int* __restrict__ sorted_token_ids_ptr,
  2204. const int* __restrict__ expert_ids_ptr,
  2205. const int* __restrict__ num_tokens_post_padded,
  2206. const int num_valid_tokens,
  2207. const int top_k,
  2208. int size_m,
  2209. int size_n,
  2210. int size_k,
  2211. int pad_size_m,
  2212. int groups,
  2213. bool use_exllama
  2214. ) {
  2215. if (use_exllama) {
  2216. group_gemm_half_q_half(
  2217. a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, c,
  2218. topk_weights, sorted_token_ids_ptr, expert_ids_ptr,
  2219. num_tokens_post_padded, num_valid_tokens,
  2220. top_k, size_m, size_n, size_k, pad_size_m, groups
  2221. );
  2222. } else {
  2223. group_gemm_half_q_half_alt(
  2224. a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, c,
  2225. topk_weights, sorted_token_ids_ptr, expert_ids_ptr,
  2226. num_tokens_post_padded, num_valid_tokens,
  2227. top_k, size_m, size_n, size_k, pad_size_m, groups
  2228. );
  2229. }
  2230. }
  2231. } // namespace gptq
  2232. } // namespace aphrodite
  2233. torch::Tensor gptq_gemm
  2234. (
  2235. torch::Tensor a,
  2236. torch::Tensor b_q_weight,
  2237. torch::Tensor b_gptq_qzeros,
  2238. torch::Tensor b_gptq_scales,
  2239. torch::Tensor b_g_idx,
  2240. bool use_exllama,
  2241. int bit
  2242. )
  2243. {
  2244. const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
  2245. auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
  2246. at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options);
  2247. at::Tensor temp_dq = torch::empty({b_q_weight.size(0) * 32 / bit, b_q_weight.size(1)}, options);
  2248. aphrodite::gptq::gemm_half_q_half_cuda
  2249. (
  2250. at::cuda::getCurrentCUDABlasHandle(),
  2251. (const half*) a.data_ptr(),
  2252. (const uint32_t*) b_q_weight.data_ptr(),
  2253. (const uint32_t*)b_gptq_qzeros.data_ptr(),
  2254. (const half*) b_gptq_scales.data_ptr(),
  2255. b_g_idx.device().is_meta() ? NULL : (const int*) b_g_idx.data_ptr(),
  2256. (half*) c.data_ptr(),
  2257. (half*) temp_dq.data_ptr(),
  2258. c.size(0), // m
  2259. c.size(1), // n
  2260. a.size(1), // k
  2261. b_gptq_qzeros.size(0), // group number
  2262. use_exllama,
  2263. bit
  2264. );
  2265. return c;
  2266. }
  2267. void gptq_shuffle
  2268. (
  2269. torch::Tensor q_weight,
  2270. torch::Tensor q_perm,
  2271. int bit
  2272. )
  2273. {
  2274. const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
  2275. int num_experts = q_weight.dim() == 3 ? q_weight.size(0) : 1;
  2276. int size_k = q_weight.dim() == 3 ? q_weight.size(1) * 32 / bit : q_weight.size(0) * 32 / bit;
  2277. int size_n = q_weight.dim() == 3 ? q_weight.size(2) : q_weight.size(1);
  2278. aphrodite::gptq::shuffle_exllama_weight(
  2279. (uint32_t*) q_weight.data_ptr(),
  2280. q_perm.device().is_meta() ? NULL : (int*) q_perm.data_ptr(),
  2281. size_k,
  2282. size_n,
  2283. num_experts,
  2284. bit
  2285. );
  2286. }
  2287. // Only support 4-bit
  2288. // todo: extend support to other bits
  2289. torch::Tensor group_gptq_gemm
  2290. (
  2291. torch::Tensor a,
  2292. torch::Tensor b_q_weight,
  2293. torch::Tensor b_gptq_qzeros,
  2294. torch::Tensor b_gptq_scales,
  2295. torch::Tensor b_g_idx,
  2296. torch::Tensor topk_weights,
  2297. torch::Tensor sorted_token_ids_ptr,
  2298. torch::Tensor expert_ids_ptr,
  2299. torch::Tensor num_tokens_post_padded,
  2300. bool mul_weights,
  2301. bool use_exllama
  2302. )
  2303. {
  2304. const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
  2305. auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
  2306. at::Tensor c = torch::zeros({a.size(0), topk_weights.size(1), b_q_weight.size(2)}, options);
  2307. aphrodite::gptq::group_gemm_half_q_half_cuda
  2308. (
  2309. (const half*) a.data_ptr(),
  2310. (const uint32_t*) b_q_weight.data_ptr(),
  2311. (const uint32_t*)b_gptq_qzeros.data_ptr(),
  2312. (const half*) b_gptq_scales.data_ptr(),
  2313. b_g_idx.device().is_meta() ? NULL : (const int*) b_g_idx.data_ptr(),
  2314. (half*) c.data_ptr(),
  2315. mul_weights ? (const float*) topk_weights.data_ptr() : NULL,
  2316. (const int*) sorted_token_ids_ptr.data_ptr(),
  2317. (const int*) expert_ids_ptr.data_ptr(),
  2318. (const int*) num_tokens_post_padded.data_ptr(),
  2319. topk_weights.numel(), // num tokens
  2320. topk_weights.size(1) / a.size(1), // top_k
  2321. a.size(0) * a.size(1), // m
  2322. c.size(2), // n
  2323. a.size(2), // k
  2324. sorted_token_ids_ptr.size(0),
  2325. b_gptq_qzeros.size(1), // group number
  2326. use_exllama
  2327. );
  2328. return c;
  2329. }
  2330. torch::Tensor dequant_gptq
  2331. (
  2332. torch::Tensor b_q_weight,
  2333. torch::Tensor b_gptq_qzeros,
  2334. torch::Tensor b_gptq_scales,
  2335. torch::Tensor b_g_idx,
  2336. int bits,
  2337. bool use_exllama
  2338. ) {
  2339. const at::cuda::OptionalCUDAGuard device_guard(device_of(b_gptq_scales));
  2340. auto options = torch::TensorOptions().dtype(b_gptq_scales.dtype()).device(b_gptq_scales.device());
  2341. at::Tensor temp_dq;
  2342. int num_experts;
  2343. int size_k;
  2344. int size_n;
  2345. int groups;
  2346. // moe
  2347. if (b_q_weight.dim() == 3) {
  2348. temp_dq = torch::empty({b_q_weight.size(0), b_q_weight.size(1) * 32 / bits, b_q_weight.size(2)}, options);
  2349. num_experts = b_q_weight.size(0);
  2350. size_k = b_q_weight.size(1) * 32 / bits;
  2351. size_n = b_q_weight.size(2);
  2352. groups = b_gptq_scales.size(1);
  2353. } else
  2354. {
  2355. temp_dq = torch::empty({b_q_weight.size(0) * 32 / bits, b_q_weight.size(1)}, options);
  2356. num_experts = 1;
  2357. size_k = b_q_weight.size(0) * 32 / bits;
  2358. size_n = b_q_weight.size(1);
  2359. groups = b_gptq_scales.size(0);
  2360. }
  2361. aphrodite::gptq::dequant_gptq_cuda(
  2362. (const uint32_t*) b_q_weight.data_ptr(),
  2363. (const uint32_t*)b_gptq_qzeros.data_ptr(),
  2364. (const half*) b_gptq_scales.data_ptr(),
  2365. b_g_idx.device().is_meta() ? NULL : (const int*) b_g_idx.data_ptr(),
  2366. (half*) temp_dq.data_ptr(),
  2367. size_k, size_n, groups,
  2368. num_experts, bits, use_exllama);
  2369. return temp_dq;
  2370. }