123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355 |
- /*
- Adapted from https://github.com/turboderp/exllamav2 and
- https://github.com/qwopqwop200/GPTQ-for-LLaMa
- */
- #include <cstdint>
- #include <cstdio>
- #include <torch/all.h>
- #include <c10/cuda/CUDAGuard.h>
- #include <ATen/cuda/CUDAContext.h>
- #include <cuda_runtime.h>
- #include <cuda_fp16.h>
- #include "compat.cuh"
- #include "matrix_view.cuh"
- #include "qdq_2.cuh"
- #include "qdq_3.cuh"
- #include "qdq_4.cuh"
- #include "qdq_8.cuh"
- namespace aphrodite {
- namespace gptq {
- #define BLOCK_KN_SIZE 128
- #define BLOCK_M_SIZE_MAX 8
- #define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32)
- #define MAX_Q_GEMM_ROWS 50
- #define MAX_Q_GEMM_ROWS_8BIT 24
- #define MAX_ALT_GEMM_ROWS 8
- #define THREADS_X 32
- #define THREADS_Y 32
- #define DIVIDE(x, size) (((x) + (size) - 1) / (size))
- #if defined(USE_ROCM)
- #include <hipblas/hipblas.h>
- __host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(
- hipblasHandle_t handle, hipblasOperation_t transA,
- hipblasOperation_t transB, int m, int n, int k, const half* alpha,
- const half* AP, int lda, const half* BP, int ldb, const half* beta,
- half* CP, int ldc) {
- return hipblasHgemm(handle, transA, transB, m, n, k,
- reinterpret_cast<const hipblasHalf*>(alpha),
- reinterpret_cast<const hipblasHalf*>(AP), lda,
- reinterpret_cast<const hipblasHalf*>(BP), ldb,
- reinterpret_cast<const hipblasHalf*>(beta),
- reinterpret_cast<hipblasHalf*>(CP), ldc);
- }
- #define hipblasHgemm __compat_hipblasHgemm
- // Previous version of PyTorch were converting to rocBLAS instead of hipBLAS.
- #define rocblas_operation_none HIPBLAS_OP_N
- #define rocblas_hgemm __compat_hipblasHgemm
- #endif
- __forceinline__ __device__ half2 dot22_8(half2 (&dq)[4], const half* a_ptr,
- const half2 g_result) {
- half2 result = {};
- const half2* a2_ptr = (const half2*)a_ptr;
- #pragma unroll
- for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
- return __hadd2(result, g_result);
- }
- __forceinline__ __device__ float dot22_8_f(half2 (&dq)[4], const half* a_ptr) {
- half2 result = {};
- const half2* a2_ptr = (const half2*)a_ptr;
- #pragma unroll
- for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
- return __half2float(__low2half(result)) + __half2float(__high2half(result));
- }
- __forceinline__ __device__ half2 dot22_8(half2 (&dq)[4], const half* a_ptr,
- const half2 g_result,
- const half qs_h) {
- half2 result = {};
- const half2* a2_ptr = (const half2*)a_ptr;
- #pragma unroll
- for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
- return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
- }
- __forceinline__ __device__ half2 dot22_16(half2 (&dq)[8], const half* a_ptr,
- const half2 g_result,
- const half qs_h) {
- half2 result = {};
- const half2* a2_ptr = (const half2*)a_ptr;
- #pragma unroll
- for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
- return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
- }
- __forceinline__ __device__ half2 dot22_32(half2 (&dq)[16], const half* a_ptr,
- const half2 g_result,
- const half qs_h) {
- half2 result = {};
- const half2* a2_ptr = (const half2*)a_ptr;
- #pragma unroll
- for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
- return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
- }
- __forceinline__ __device__ float dot22_8_f(half2 (&dq)[4], const half* a_ptr,
- const float g_result,
- const float qs_f) {
- half2 result = {};
- const half2* a2_ptr = (const half2*)a_ptr;
- #pragma unroll
- for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
- float result_f =
- __half2float(__low2half(result)) + __half2float(__high2half(result));
- return fma(result_f, qs_f, g_result);
- }
- __forceinline__ __device__ float dot22_16_f(half2 (&dq)[8], const half* a_ptr,
- const float g_result,
- const float qs_f) {
- half2 result = {};
- const half2* a2_ptr = (const half2*)a_ptr;
- #pragma unroll
- for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
- float result_f =
- __half2float(__low2half(result)) + __half2float(__high2half(result));
- return fma(result_f, qs_f, g_result);
- }
- __forceinline__ __device__ float dot22_32_f(half2 (&dq)[16], const half* a_ptr,
- const float g_result,
- const float qs_f) {
- half2 result = {};
- const half2* a2_ptr = (const half2*)a_ptr;
- #pragma unroll
- for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
- float result_f =
- __half2float(__low2half(result)) + __half2float(__high2half(result));
- return fma(result_f, qs_f, g_result);
- }
- __forceinline__ __device__ half dot22_8_h(half2 (&dq)[4], const half* a_ptr,
- const half g_result,
- const half qs_h) {
- // Use FP32 accumulator to avoid potential overflow since unscaled weights are
- // in the range -128..127
- float result = {};
- #pragma unroll
- for (int i = 0; i < 4; i++) {
- half2 w01 = dq[i];
- float w0 = __low2float(w01);
- float w1 = __high2float(w01);
- float x0 = __half2float(*a_ptr++);
- float x1 = __half2float(*a_ptr++);
- result = fma(w0, x0, result);
- result = fma(w1, x1, result);
- }
- float qs = __half2float(qs_h);
- result *= qs;
- half result_h = __float2half_rn(result);
- return __hadd(result_h, g_result);
- }
- __forceinline__ __device__ half dot22_16_h(half2 (&dq)[8], const half* a_ptr,
- const half g_result,
- const half qs_h) {
- half2 result = {};
- const half2* a2_ptr = (const half2*)a_ptr;
- #pragma unroll
- for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
- half result_h = __hadd(__low2half(result), __high2half(result));
- return __hfma(result_h, qs_h, g_result);
- }
- __forceinline__ __device__ half dot22_32_h(half2 (&dq)[16], const half* a_ptr,
- const half g_result,
- const half qs_h) {
- half2 result = {};
- const half2* a2_ptr = (const half2*)a_ptr;
- #pragma unroll
- for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
- half result_h = __hadd(__low2half(result), __high2half(result));
- return __hfma(result_h, qs_h, g_result);
- }
- typedef void (*fp_gemm_half_q_half_gptq_kernel)(const half*, const uint32_t*,
- const uint32_t*, const half*,
- half*, const int, const int,
- const int, const int,
- const int*);
- template <bool first_block, int m_count>
- __global__ void gemm_half_q_half_gptq_4bit_kernel(
- const half* __restrict__ a, const uint32_t* __restrict__ b_q_weight,
- const uint32_t* __restrict__ b_gptq_qzeros,
- const half* __restrict__ b_gptq_scales, half* __restrict__ c,
- const int size_m, const int size_n, const int size_k, const int groups,
- const int* __restrict__ b_q_perm) {
- MatrixView_half a_(a, size_m, size_k);
- MatrixView_half_rw c_(c, size_m, size_n);
- MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
- MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
- int t = threadIdx.x;
- // Block
- int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
- int offset_m = blockIdx.y * m_count;
- int offset_k = blockIdx.z * BLOCK_KN_SIZE;
- int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
- int end_m = min(offset_m + m_count, size_m);
- int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
- int n = offset_n + t * 4;
- // Preload block_a
- __shared__ half block_a[m_count][BLOCK_KN_SIZE];
- if (offset_k + t < end_k) {
- for (int m = 0; m < m_count; ++m) {
- const half* a_ptr = a_.item_ptr(offset_m + m, 0);
- half* block_a_ptr = block_a[m];
- half a0;
- if (b_q_perm)
- a0 = a_ptr[b_q_perm[offset_k + t]];
- else
- a0 = a_ptr[offset_k + t];
- block_a_ptr[t] = a0;
- }
- }
- // Zero output
- if (n >= size_n) return;
- if (blockIdx.z == 0) {
- for (int m = 0; m < m_count; m++)
- *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
- }
- __syncthreads();
- // Find initial group
- int groupsize = size_k / groups;
- int group = offset_k / groupsize;
- int nextgroup = offset_k + groupsize;
- // a, b offset
- int qk = offset_k / (32 / 4);
- const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
- const half* a_ptr = &block_a[0][0];
- int a_stride = BLOCK_KN_SIZE;
- // Initial group
- int zeros[4];
- float scales[4];
- half2 z1z16[4][2];
- half2 y1y16[4][2];
- b_gptq_qzeros_.item4(zeros, group, n);
- b_gptq_scales_.item4_f(scales, group, n);
- dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
- dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
- dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
- dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
- // Column result
- float block_c[m_count][4] = {};
- // Dequantize and multiply
- int k = offset_k;
- while (k < end_k) {
- if (k == nextgroup) {
- group++;
- nextgroup += groupsize;
- b_gptq_qzeros_.item4(zeros, group, n);
- b_gptq_scales_.item4_f(scales, group, n);
- dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
- dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
- dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
- dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
- }
- #pragma unroll
- for (int j = 0; j < 4; j++) {
- const int4* b_ptr4 = (int4*)b_ptr;
- int4 load_int4 = *b_ptr4;
- half2 dq[4][4];
- dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n,
- false);
- dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n,
- false);
- dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n,
- false);
- dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n,
- false);
- #pragma unroll
- for (int m = 0; m < m_count; m++) {
- block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0],
- block_c[m][0]);
- block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1],
- block_c[m][1]);
- block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2],
- block_c[m][2]);
- block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3],
- block_c[m][3]);
- }
- b_ptr += size_n;
- a_ptr += 8;
- }
- k += 32;
- }
- for (int m = 0; m < m_count; m++) {
- half2* out = (half2*)c_.item_ptr(offset_m + m, n);
- half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]),
- __float2half_rn(block_c[m][1]));
- half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]),
- __float2half_rn(block_c[m][3]));
- atomicAdd(out, result01);
- atomicAdd(out + 1, result23);
- }
- }
- template <bool first_block, int m_count>
- __global__ void gemm_half_q_half_gptq_2bit_kernel(
- const half* __restrict__ a, const uint32_t* __restrict__ b_q_weight,
- const uint32_t* __restrict__ b_gptq_qzeros,
- const half* __restrict__ b_gptq_scales, half* __restrict__ c,
- const int size_m, const int size_n, const int size_k, const int groups,
- const int* __restrict__ b_q_perm) {
- MatrixView_half a_(a, size_m, size_k);
- MatrixView_half_rw c_(c, size_m, size_n);
- MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
- MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
- int t = threadIdx.x;
- // Block
- int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
- int offset_m = blockIdx.y * m_count;
- int offset_k = blockIdx.z * BLOCK_KN_SIZE;
- int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
- int end_m = min(offset_m + m_count, size_m);
- int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
- int n = offset_n + t * 4;
- // Preload block_a
- __shared__ half block_a[m_count][BLOCK_KN_SIZE];
- if (offset_k + t < end_k) {
- for (int m = 0; m < m_count; ++m) {
- const half* a_ptr = a_.item_ptr(offset_m + m, 0);
- half* block_a_ptr = block_a[m];
- half a0;
- if (b_q_perm)
- a0 = a_ptr[b_q_perm[offset_k + t]];
- else
- a0 = a_ptr[offset_k + t];
- block_a_ptr[t] = a0;
- }
- }
- // Zero output
- if (n >= size_n) return;
- if (blockIdx.z == 0) {
- for (int m = 0; m < m_count; m++)
- *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
- }
- __syncthreads();
- // Find initial group
- int groupsize = size_k / groups;
- int group = offset_k / groupsize;
- int nextgroup = offset_k + groupsize;
- // a, b offset
- int qk = offset_k / (32 / 2);
- const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
- const half* a_ptr = &block_a[0][0];
- int a_stride = BLOCK_KN_SIZE;
- // Initial group
- int zeros[4];
- half scales[4];
- b_gptq_qzeros_.item4(zeros, group, n);
- b_gptq_scales_.item4(scales, group, n);
- // Column result
- half block_c[m_count][4] = {};
- // Dequantize and multiply
- int k = offset_k;
- while (k < end_k) {
- if (k == nextgroup) {
- group++;
- nextgroup += groupsize;
- b_gptq_qzeros_.item4(zeros, group, n);
- b_gptq_scales_.item4(scales, group, n);
- }
- #pragma unroll
- for (int j = 0; j < 1; j++) {
- const int4* b_ptr4 = (int4*)b_ptr;
- int4 load_int4 = *b_ptr4;
- half2 dq[4][8];
- dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1);
- dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1);
- dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1);
- dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1);
- #pragma unroll
- for (int m = 0; m < m_count; m++) {
- block_c[m][0] =
- dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]);
- block_c[m][1] =
- dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]);
- block_c[m][2] =
- dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]);
- block_c[m][3] =
- dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]);
- }
- b_ptr += size_n;
- a_ptr += 16;
- }
- k += 16;
- }
- for (int m = 0; m < m_count; m++) {
- half2* out = (half2*)c_.item_ptr(offset_m + m, n);
- half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]);
- half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]);
- atomicAdd(out, result01);
- atomicAdd(out + 1, result23);
- }
- }
- template <bool first_block, int m_count>
- __global__ void gemm_half_q_half_gptq_3bit_kernel(
- const half* __restrict__ a, const uint32_t* __restrict__ b_q_weight,
- const uint32_t* __restrict__ b_gptq_qzeros,
- const half* __restrict__ b_gptq_scales, half* __restrict__ c,
- const int size_m, const int size_n, const int size_k, const int groups,
- const int* __restrict__ b_q_perm) {
- MatrixView_half a_(a, size_m, size_k);
- MatrixView_half_rw c_(c, size_m, size_n);
- MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
- MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
- int t = threadIdx.x;
- // Block
- int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
- int offset_m = blockIdx.y * m_count;
- int offset_k = blockIdx.z * BLOCK_KN_SIZE;
- int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
- int end_m = min(offset_m + m_count, size_m);
- int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
- int n = offset_n + t * 4;
- // Preload block_a
- __shared__ half block_a[m_count][BLOCK_KN_SIZE];
- if (offset_k + t < end_k) {
- for (int m = 0; m < m_count; ++m) {
- const half* a_ptr = a_.item_ptr(offset_m + m, 0);
- half* block_a_ptr = block_a[m];
- half a0;
- if (b_q_perm)
- a0 = a_ptr[b_q_perm[offset_k + t]];
- else
- a0 = a_ptr[offset_k + t];
- block_a_ptr[t] = a0;
- }
- }
- // Zero output
- if (n >= size_n) return;
- if (blockIdx.z == 0) {
- for (int m = 0; m < m_count; m++)
- *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
- }
- __syncthreads();
- // Find initial group
- int groupsize = size_k / groups;
- int group = offset_k / groupsize;
- int nextgroup = offset_k + groupsize;
- // a, b offset
- int qk = offset_k / 32 * 3;
- const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
- const half* a_ptr = &block_a[0][0];
- int a_stride = BLOCK_KN_SIZE;
- // Initial group
- int zeros[4];
- half scales[4];
- b_gptq_qzeros_.item4(zeros, group, n);
- b_gptq_scales_.item4(scales, group, n);
- // Column result
- half block_c[m_count][4] = {};
- // Dequantize and multiply
- int k = offset_k;
- while (k < end_k) {
- if (k == nextgroup) {
- group++;
- nextgroup += groupsize;
- b_gptq_qzeros_.item4(zeros, group, n);
- b_gptq_scales_.item4(scales, group, n);
- }
- #pragma unroll
- for (int j = 0; j < 1; j++) {
- int4 load_int4[3];
- load_int4[0] = *((int4*)b_ptr);
- b_ptr += size_n;
- load_int4[1] = *((int4*)b_ptr);
- b_ptr += size_n;
- load_int4[2] = *((int4*)b_ptr);
- b_ptr += size_n;
- half2 dq[4][16];
- dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0],
- size_n, zeros[0] + 1);
- dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1],
- size_n, zeros[1] + 1);
- dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2],
- size_n, zeros[2] + 1);
- dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3],
- size_n, zeros[3] + 1);
- #pragma unroll
- for (int m = 0; m < m_count; m++) {
- block_c[m][0] =
- dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]);
- block_c[m][1] =
- dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]);
- block_c[m][2] =
- dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]);
- block_c[m][3] =
- dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]);
- }
- a_ptr += 32;
- }
- k += 32;
- }
- for (int m = 0; m < m_count; m++) {
- half2* out = (half2*)c_.item_ptr(offset_m + m, n);
- half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]);
- half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]);
- atomicAdd(out, result01);
- atomicAdd(out + 1, result23);
- }
- }
- template <bool first_block, int m_count>
- __global__ void gemm_half_q_half_gptq_8bit_kernel(
- const half* __restrict__ a, const uint32_t* __restrict__ b_q_weight,
- const uint32_t* __restrict__ b_gptq_qzeros,
- const half* __restrict__ b_gptq_scales, half* __restrict__ c,
- const int size_m, const int size_n, const int size_k, const int groups,
- const int* __restrict__ b_q_perm) {
- MatrixView_half a_(a, size_m, size_k);
- MatrixView_half_rw c_(c, size_m, size_n);
- MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
- MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
- int t = threadIdx.x;
- // Block
- int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
- int offset_m = blockIdx.y * m_count;
- int offset_k = blockIdx.z * BLOCK_KN_SIZE;
- int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
- int end_m = min(offset_m + m_count, size_m);
- int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
- int n = offset_n + t * 4;
- // Preload block_a
- __shared__ half block_a[m_count][BLOCK_KN_SIZE];
- if (offset_k + t < end_k) {
- for (int m = 0; m < m_count; ++m) {
- const half* a_ptr = a_.item_ptr(offset_m + m, 0);
- half* block_a_ptr = block_a[m];
- half a0;
- if (b_q_perm)
- a0 = a_ptr[b_q_perm[offset_k + t]];
- else
- a0 = a_ptr[offset_k + t];
- block_a_ptr[t] = a0;
- }
- }
- // Zero output
- if (n >= size_n) return;
- if (blockIdx.z == 0) {
- for (int m = 0; m < m_count; m++)
- *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
- }
- __syncthreads();
- // Find initial group
- int groupsize = size_k / groups;
- int group = offset_k / groupsize;
- int nextgroup = offset_k + groupsize;
- // a, b offset
- int qk = offset_k / (32 / 8);
- const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
- const half* a_ptr = &block_a[0][0];
- int a_stride = BLOCK_KN_SIZE;
- // Initial group
- int zeros[4];
- half scales[4];
- b_gptq_qzeros_.item4(zeros, group, n);
- b_gptq_scales_.item4(scales, group, n);
- // Column result
- half block_c[m_count][4] = {};
- // Dequantize and multiply
- int k = offset_k;
- while (k < end_k) {
- if (k == nextgroup) {
- group++;
- nextgroup += groupsize;
- b_gptq_qzeros_.item4(zeros, group, n);
- b_gptq_scales_.item4(scales, group, n);
- }
- #pragma unroll
- for (int j = 0; j < 4; j++) {
- int4 load_int4[2];
- load_int4[0] = *((int4*)b_ptr);
- b_ptr += size_n;
- load_int4[1] = *((int4*)b_ptr);
- b_ptr += size_n;
- half2 dq[4][4];
- dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n,
- zeros[0] + 1);
- dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n,
- zeros[1] + 1);
- dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n,
- zeros[2] + 1);
- dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n,
- zeros[3] + 1);
- for (int m = 0; m < m_count; m++) {
- block_c[m][0] =
- dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]);
- block_c[m][1] =
- dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]);
- block_c[m][2] =
- dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]);
- block_c[m][3] =
- dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]);
- }
- a_ptr += 8;
- }
- k += 32;
- }
- for (int m = 0; m < m_count; m++) {
- half2* out = (half2*)c_.item_ptr(offset_m + m, n);
- half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]);
- half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]);
- atomicAdd(out, result01);
- atomicAdd(out + 1, result23);
- }
- }
- fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(
- bool first_block, const int m_count, const int bit) {
- #define SELECT_KERNEL(M_COUNT) \
- if (m_count == M_COUNT) { \
- if (bit == 2) return gemm_half_q_half_gptq_2bit_kernel<true, M_COUNT>; \
- if (bit == 3) return gemm_half_q_half_gptq_3bit_kernel<true, M_COUNT>; \
- if (bit == 4) return gemm_half_q_half_gptq_4bit_kernel<true, M_COUNT>; \
- if (bit == 8) return gemm_half_q_half_gptq_8bit_kernel<true, M_COUNT>; \
- }
- #if BLOCK_M_SIZE_MAX >= 1
- SELECT_KERNEL(1);
- #endif
- #if BLOCK_M_SIZE_MAX >= 2
- SELECT_KERNEL(2);
- #endif
- #if BLOCK_M_SIZE_MAX >= 3
- SELECT_KERNEL(3);
- #endif
- #if BLOCK_M_SIZE_MAX >= 4
- SELECT_KERNEL(4);
- #endif
- #if BLOCK_M_SIZE_MAX >= 5
- SELECT_KERNEL(5);
- #endif
- #if BLOCK_M_SIZE_MAX >= 6
- SELECT_KERNEL(6);
- #endif
- #if BLOCK_M_SIZE_MAX >= 7
- SELECT_KERNEL(7);
- #endif
- #if BLOCK_M_SIZE_MAX >= 8
- SELECT_KERNEL(8);
- #endif
- return NULL;
- }
- void gemm_half_q_half_cuda_part(const half* a, const uint32_t* b_q_weight,
- const uint32_t* b_gptq_qzeros,
- const half* b_gptq_scales, const int* b_q_perm,
- half* c, int size_m, int size_n, int size_k,
- int m_count, int groups, int bit) {
- dim3 blockDim, gridDim;
- blockDim.x = BLOCK_KN_SIZE;
- blockDim.y = 1;
- blockDim.z = 1;
- gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
- gridDim.y = DIVIDE(size_m, m_count);
- gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
- fp_gemm_half_q_half_gptq_kernel kernel =
- pick_gemm_half_q_half_gptq_kernel(true, m_count, bit);
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
- kernel<<<gridDim, blockDim, 0, stream>>>(a, b_q_weight, b_gptq_qzeros,
- b_gptq_scales, c, size_m, size_n,
- size_k, groups, b_q_perm);
- }
- __global__ void reconstruct_exllama_8bit_kernel(
- const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm,
- const uint32_t* __restrict__ b_gptq_qzeros,
- const half* __restrict__ b_gptq_scales, const int size_k, const int size_n,
- const int groups, half* __restrict__ b) {
- MatrixView_half_rw b_(b, size_k, size_n);
- MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
- MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
- int offset_k = BLOCK_KN_SIZE * blockIdx.y;
- int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
- int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
- // Preload remapping table
- __shared__ int perm[BLOCK_KN_SIZE];
- int t = threadIdx.x;
- if (b_q_perm) {
- if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t];
- }
- // Column
- int n = offset_n + t * 4;
- if (n >= size_n) return;
- // Find initial group
- int groupsize = size_k / groups;
- int group = offset_k / groupsize;
- int nextgroup = offset_k + groupsize;
- // b offset
- int qk = offset_k / (32 / 8);
- const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
- // Initial zeros/scale
- int zeros[4];
- half2 scales[4];
- b_gptq_qzeros_.item4(zeros, group, n);
- b_gptq_scales_.item4_h2(scales, group, n);
- __syncthreads();
- int k = offset_k;
- int lk = 0;
- while (k < end_k) {
- if (k == nextgroup) {
- group++;
- nextgroup += groupsize;
- b_gptq_qzeros_.item4(zeros, group, n);
- b_gptq_scales_.item4_h2(scales, group, n);
- }
- for (int p = 0; p < 4; p++) {
- int4 load_int4[2];
- load_int4[0] = *((int4*)b_ptr);
- b_ptr += size_n;
- load_int4[1] = *((int4*)b_ptr);
- b_ptr += size_n;
- half2 dq[4][4];
- dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n,
- zeros[0] + 1);
- dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n,
- zeros[1] + 1);
- dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n,
- zeros[2] + 1);
- dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n,
- zeros[3] + 1);
- // half* dqh = (half*)dq;
- if (b_q_perm) {
- for (int j = 0; j < 4; j++) {
- for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
- b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]),
- __low2half(dq[2][j]), __low2half(dq[3][j]));
- b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]),
- __high2half(dq[2][j]), __high2half(dq[3][j]));
- }
- } else {
- for (int j = 0; j < 4; j++) {
- for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
- b_.set4(offset_k + lk++, n, __low2half(dq[0][j]),
- __low2half(dq[1][j]), __low2half(dq[2][j]),
- __low2half(dq[3][j]));
- b_.set4(offset_k + lk++, n, __high2half(dq[0][j]),
- __high2half(dq[1][j]), __high2half(dq[2][j]),
- __high2half(dq[3][j]));
- }
- }
- }
- k += 32;
- }
- }
- __global__ void reconstruct_exllama_4bit_kernel(
- const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm,
- const uint32_t* __restrict__ b_gptq_qzeros,
- const half* __restrict__ b_gptq_scales, const int size_k, const int size_n,
- const int groups, half* __restrict__ b) {
- if (blockIdx.z > 0) {
- b_q_weight = b_q_weight + blockIdx.z * size_k * size_n / 8;
- b_gptq_scales = b_gptq_scales + blockIdx.z * groups * size_n;
- b_gptq_qzeros = b_gptq_qzeros + blockIdx.z * groups * size_n / 8;
- if (b_q_perm) b_q_perm = b_q_perm + blockIdx.z * size_k;
- b = b + blockIdx.z * size_k * size_n;
- }
- MatrixView_half_rw b_(b, size_k, size_n);
- MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
- MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
- int offset_k = BLOCK_KN_SIZE * blockIdx.y;
- int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
- int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
- // Preload remapping table
- __shared__ int perm[BLOCK_KN_SIZE];
- int t = threadIdx.x;
- if (b_q_perm) {
- if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t];
- }
- // Column
- int n = offset_n + t * 4;
- if (n >= size_n) return;
- // Find initial group
- int groupsize = size_k / groups;
- int group = offset_k / groupsize;
- int nextgroup = offset_k + groupsize;
- // b offset
- int qk = offset_k / (32 / 4);
- const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
- // Initial zeros/scale
- int zeros[4];
- half2 scales[4];
- half2 z1z16[4][2];
- half2 y1y16[4][2];
- b_gptq_qzeros_.item4(zeros, group, n);
- b_gptq_scales_.item4_h2(scales, group, n);
- dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
- dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
- dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
- dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
- __syncthreads();
- int k = offset_k;
- int lk = 0;
- while (k < end_k) {
- if (k == nextgroup) {
- group++;
- nextgroup += groupsize;
- b_gptq_qzeros_.item4(zeros, group, n);
- b_gptq_scales_.item4_h2(scales, group, n);
- dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
- dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
- dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
- dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
- }
- for (int p = 0; p < 4; p++) {
- half2 dq[4][4];
- const int4* b_ptr4 = (int4*)b_ptr;
- int4 load_int4 = *b_ptr4;
- dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n,
- false);
- dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n,
- false);
- dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n,
- false);
- dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n,
- false);
- b_ptr += size_n;
- // half* dqh = (half*)dq;
- if (b_q_perm) {
- for (int j = 0; j < 4; j++) {
- for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
- b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]),
- __low2half(dq[2][j]), __low2half(dq[3][j]));
- b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]),
- __high2half(dq[2][j]), __high2half(dq[3][j]));
- }
- } else {
- for (int j = 0; j < 4; j++) {
- for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
- b_.set4(offset_k + lk++, n, __low2half(dq[0][j]),
- __low2half(dq[1][j]), __low2half(dq[2][j]),
- __low2half(dq[3][j]));
- b_.set4(offset_k + lk++, n, __high2half(dq[0][j]),
- __high2half(dq[1][j]), __high2half(dq[2][j]),
- __high2half(dq[3][j]));
- }
- }
- }
- k += 32;
- }
- }
- __global__ void reconstruct_exllama_3bit_kernel(
- const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm,
- const uint32_t* __restrict__ b_gptq_qzeros,
- const half* __restrict__ b_gptq_scales, const int size_k, const int size_n,
- const int groups, half* __restrict__ b) {
- MatrixView_half_rw b_(b, size_k, size_n);
- MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
- MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
- int offset_k = BLOCK_KN_SIZE * blockIdx.y;
- int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
- int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
- // Preload remapping table
- __shared__ int perm[BLOCK_KN_SIZE];
- int t = threadIdx.x;
- if (b_q_perm) {
- if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t];
- }
- // Column
- int n = offset_n + t * 4;
- if (n >= size_n) return;
- // Find initial group
- int groupsize = size_k / groups;
- int group = offset_k / groupsize;
- int nextgroup = offset_k + groupsize;
- // b offset
- int qk = offset_k / 32 * 3;
- const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
- // Initial zeros/scale
- int zeros[4];
- half2 scales[4];
- b_gptq_qzeros_.item4(zeros, group, n);
- b_gptq_scales_.item4_h2(scales, group, n);
- __syncthreads();
- int k = offset_k;
- int lk = 0;
- while (k < end_k) {
- if (k == nextgroup) {
- group++;
- nextgroup += groupsize;
- b_gptq_qzeros_.item4(zeros, group, n);
- b_gptq_scales_.item4_h2(scales, group, n);
- }
- for (int p = 0; p < 1; p++) {
- int4 load_int4[3];
- load_int4[0] = *((int4*)b_ptr);
- b_ptr += size_n;
- load_int4[1] = *((int4*)b_ptr);
- b_ptr += size_n;
- load_int4[2] = *((int4*)b_ptr);
- b_ptr += size_n;
- half2 dq[4][16];
- dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0],
- size_n, zeros[0] + 1);
- dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1],
- size_n, zeros[1] + 1);
- dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2],
- size_n, zeros[2] + 1);
- dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3],
- size_n, zeros[3] + 1);
- if (b_q_perm) {
- for (int j = 0; j < 16; j++) {
- for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
- b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]),
- __low2half(dq[2][j]), __low2half(dq[3][j]));
- b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]),
- __high2half(dq[2][j]), __high2half(dq[3][j]));
- }
- } else {
- for (int j = 0; j < 16; j++) {
- for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
- b_.set4(offset_k + lk++, n, __low2half(dq[0][j]),
- __low2half(dq[1][j]), __low2half(dq[2][j]),
- __low2half(dq[3][j]));
- b_.set4(offset_k + lk++, n, __high2half(dq[0][j]),
- __high2half(dq[1][j]), __high2half(dq[2][j]),
- __high2half(dq[3][j]));
- }
- }
- }
- k += 32;
- }
- }
- __global__ void reconstruct_exllama_2bit_kernel(
- const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm,
- const uint32_t* __restrict__ b_gptq_qzeros,
- const half* __restrict__ b_gptq_scales, const int size_k, const int size_n,
- const int groups, half* __restrict__ b) {
- MatrixView_half_rw b_(b, size_k, size_n);
- MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
- MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
- int offset_k = BLOCK_KN_SIZE * blockIdx.y;
- int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
- int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
- // Preload remapping table
- __shared__ int perm[BLOCK_KN_SIZE];
- int t = threadIdx.x;
- if (b_q_perm) {
- if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t];
- }
- // Column
- int n = offset_n + t * 4;
- if (n >= size_n) return;
- // Find initial group
- int groupsize = size_k / groups;
- int group = offset_k / groupsize;
- int nextgroup = offset_k + groupsize;
- // b offset
- int qk = offset_k / (32 / 2);
- const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
- // Initial zeros/scale
- int zeros[4];
- half2 scales[4];
- b_gptq_qzeros_.item4(zeros, group, n);
- b_gptq_scales_.item4_h2(scales, group, n);
- __syncthreads();
- int k = offset_k;
- int lk = 0;
- while (k < end_k) {
- if (k == nextgroup) {
- group++;
- nextgroup += groupsize;
- b_gptq_qzeros_.item4(zeros, group, n);
- b_gptq_scales_.item4_h2(scales, group, n);
- }
- for (int p = 0; p < 2; p++) {
- const int4* b_ptr4 = (int4*)b_ptr;
- int4 load_int4 = *b_ptr4;
- half2 dq[4][8];
- dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1);
- dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1);
- dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1);
- dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1);
- b_ptr += size_n;
- // half* dqh = (half*)dq;
- if (b_q_perm) {
- for (int j = 0; j < 8; j++) {
- for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
- b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]),
- __low2half(dq[2][j]), __low2half(dq[3][j]));
- b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]),
- __high2half(dq[2][j]), __high2half(dq[3][j]));
- }
- } else {
- for (int j = 0; j < 8; j++) {
- for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
- b_.set4(offset_k + lk++, n, __low2half(dq[0][j]),
- __low2half(dq[1][j]), __low2half(dq[2][j]),
- __low2half(dq[3][j]));
- b_.set4(offset_k + lk++, n, __high2half(dq[0][j]),
- __high2half(dq[1][j]), __high2half(dq[2][j]),
- __high2half(dq[3][j]));
- }
- }
- }
- k += 32;
- }
- }
- void reconstruct_exllama(const uint32_t* b_q_weight,
- const uint32_t* b_gptq_qzeros,
- const half* b_gptq_scales, const int* b_q_perm,
- half* out, int height, int width, int groups,
- int num_experts, int bit) {
- dim3 blockDim, gridDim;
- blockDim.x = BLOCK_KN_SIZE;
- blockDim.y = 1;
- gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
- gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
- gridDim.z = num_experts;
- auto reconstruct_exllama_kernel = reconstruct_exllama_4bit_kernel;
- if (bit == 2) {
- reconstruct_exllama_kernel = reconstruct_exllama_2bit_kernel;
- } else if (bit == 3) {
- reconstruct_exllama_kernel = reconstruct_exllama_3bit_kernel;
- } else if (bit == 8) {
- reconstruct_exllama_kernel = reconstruct_exllama_8bit_kernel;
- }
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
- reconstruct_exllama_kernel<<<gridDim, blockDim, 0, stream>>>(
- b_q_weight, b_q_perm, b_gptq_qzeros, b_gptq_scales, height, width, groups,
- out);
- }
- __global__ void gemm_half_q_half_alt_4bit_kernel(
- const half2* __restrict__ vec, const uint32_t* __restrict__ mat,
- half* __restrict__ mul, const half* __restrict__ scales,
- const uint32_t* __restrict__ zeros, const int* __restrict__ g_idx,
- int batch, int height, int width) {
- int zero_width = width / 8;
- int vec_height = height * 4;
- const int blockwidth2 = BLOCK_KN_SIZE / 2;
- int b = blockIdx.y * BLOCK_M_SIZE_MAX;
- int b_end = min(BLOCK_M_SIZE_MAX, batch - b);
- int h = BLOCK_KN_SIZE * blockIdx.z / 8;
- int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4;
- int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
- __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2];
- if (threadIdx.x < h_end) {
- for (int m = 0; m < b_end; ++m) {
- blockvec[m][threadIdx.x] =
- vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 +
- threadIdx.x];
- }
- }
- __shared__ half2 deq2[256][8];
- int val = threadIdx.x / 8;
- int off = threadIdx.x % 8;
- for (; val < 256; val += BLOCK_KN_SIZE / 8) {
- deq2[val][off] =
- __halves2half2(__int2half_rn(val & 0xF), __int2half_rn(val >> 4));
- }
- if (blockIdx.z == 0) {
- for (int m = 0; m < b_end; m++) mul[(b + m) * width + w] = __int2half_rn(0);
- }
- __syncthreads();
- int i = width * h + w;
- int g_h = h * 8;
- int k = 0;
- int z_w = w / 8;
- int z_mod = (w % 8) * 4;
- half2 res2;
- half res[BLOCK_M_SIZE_MAX] = {};
- unsigned int tmp;
- while (k < h_end) {
- tmp = mat[i];
- half2 scales_tmp[4];
- half2 zeros_tmp[4];
- for (int tmp_k = 0; tmp_k < 4; tmp_k++) {
- int g = g_idx[g_h + (k + tmp_k) * 2];
- int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1];
- half scale_f = scales[g * width + w];
- half scale_f2 = scales[g2 * width + w];
- half2 scale = __halves2half2(scale_f, scale_f2);
- half2 zero = __halves2half2(
- __hmul(scale_f,
- __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xF) -
- 1)),
- __hmul(scale_f2,
- __int2half_rn(
- -((zeros[g2 * zero_width + z_w] >> z_mod) & 0xF) - 1)));
- scales_tmp[tmp_k] = scale;
- zeros_tmp[tmp_k] = zero;
- }
- for (int m = 0; m < b_end; m++) {
- #ifndef USE_ROCM
- res2 = {};
- #else
- res2.x = __half_as_ushort(__float2half(0));
- res2.y = __half_as_ushort(__float2half(0));
- #endif
- res2 = __hfma2(
- __hfma2(deq2[(tmp >> 0) & 0xff][off], scales_tmp[0], zeros_tmp[0]),
- blockvec[m][k + 0], res2);
- res2 = __hfma2(
- __hfma2(deq2[(tmp >> 8) & 0xff][off], scales_tmp[1], zeros_tmp[1]),
- blockvec[m][k + 1], res2);
- res2 = __hfma2(
- __hfma2(deq2[(tmp >> 16) & 0xff][off], scales_tmp[2], zeros_tmp[2]),
- blockvec[m][k + 2], res2);
- res2 = __hfma2(
- __hfma2(deq2[(tmp >> 24) & 0xff][off], scales_tmp[3], zeros_tmp[3]),
- blockvec[m][k + 3], res2);
- #ifndef USE_ROCM
- res[m] = __hadd(res[m], __hadd(res2.x, res2.y));
- #else
- res[m] = __hadd(
- res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)));
- #endif
- }
- i += width;
- k += 4;
- }
- for (int m = 0; m < b_end; m++) {
- atomicAdd(&mul[(b + m) * width + w], res[m]);
- }
- }
- __global__ void gemm_half_q_half_alt_8bit_kernel(
- const half2* __restrict__ vec, const uint32_t* __restrict__ mat,
- half* __restrict__ mul, const half* __restrict__ scales,
- const uint32_t* __restrict__ zeros, const int* __restrict__ g_idx,
- int batch, int height, int width) {
- int zero_width = width / 4;
- int vec_height = height * 2;
- const int blockwidth2 = BLOCK_KN_SIZE / 2;
- int b = blockIdx.y * BLOCK_M_SIZE_MAX;
- int b_end = min(BLOCK_M_SIZE_MAX, batch - b);
- int h = BLOCK_KN_SIZE * blockIdx.z / 4;
- int h_end = min(BLOCK_KN_SIZE / 4, height - h) * 2;
- int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
- __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2];
- if (threadIdx.x < h_end) {
- for (int m = 0; m < b_end; ++m) {
- blockvec[m][threadIdx.x] =
- vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 +
- threadIdx.x];
- }
- }
- if (blockIdx.z == 0) {
- for (int m = 0; m < b_end; m++) mul[(b + m) * width + w] = __int2half_rn(0);
- }
- __syncthreads();
- int i = width * h + w;
- int g_h = h * 4;
- int k = 0;
- int z_w = w / 4;
- int z_mod = (w % 4) * 8;
- half2 res2;
- half res[BLOCK_M_SIZE_MAX] = {};
- unsigned int tmp;
- while (k < h_end) {
- tmp = mat[i];
- half2 scales_tmp[2];
- half2 zeros_tmp[2];
- for (int tmp_k = 0; tmp_k < 2; tmp_k++) {
- int g = g_idx[g_h + (k + tmp_k) * 2];
- int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1];
- half scale_f = scales[g * width + w];
- half scale_f2 = scales[g2 * width + w];
- half2 scale = __halves2half2(scale_f, scale_f2);
- half2 zero = __halves2half2(
- __hmul(scale_f,
- __int2half_rn(
- -((zeros[g * zero_width + z_w] >> z_mod) & 0xff) - 1)),
- __hmul(scale_f2,
- __int2half_rn(
- -((zeros[g2 * zero_width + z_w] >> z_mod) & 0xff) - 1)));
- scales_tmp[tmp_k] = scale;
- zeros_tmp[tmp_k] = zero;
- }
- for (int m = 0; m < b_end; m++) {
- #ifndef USE_ROCM
- res2 = {};
- #else
- res2.x = __half_as_ushort(__float2half(0));
- res2.y = __half_as_ushort(__float2half(0));
- #endif
- half2 v12 = __halves2half2(__int2half_rn(tmp & 0xFF),
- __int2half_rn((tmp >> 8) & 0xFF));
- res2 = __hfma2(__hfma2(v12, scales_tmp[0], zeros_tmp[0]),
- blockvec[m][k + 0], res2);
- half2 v34 = __halves2half2(__int2half_rn((tmp >> 16) & 0xFF),
- __int2half_rn((tmp >> 24) & 0xFF));
- res2 = __hfma2(__hfma2(v34, scales_tmp[1], zeros_tmp[1]),
- blockvec[m][k + 1], res2);
- #ifndef USE_ROCM
- res[m] = __hadd(res[m], __hadd(res2.x, res2.y));
- #else
- res[m] = __hadd(
- res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)));
- #endif
- }
- i += width;
- k += 2;
- }
- for (int m = 0; m < b_end; m++) {
- atomicAdd(&mul[(b + m) * width + w], res[m]);
- }
- }
- void gemm_half_q_half_alt(const half* a, const uint32_t* b_q_weight,
- const uint32_t* b_gptq_qzeros,
- const half* b_gptq_scales, const int* b_g_idx,
- half* c, int size_m, int size_n, int size_k,
- int bit) {
- dim3 blockDim, gridDim;
- blockDim.x = BLOCK_KN_SIZE;
- blockDim.y = 1;
- blockDim.z = 1;
- gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE);
- gridDim.y = DIVIDE(size_m, BLOCK_M_SIZE_MAX);
- gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
- auto kernel = gemm_half_q_half_alt_4bit_kernel;
- if (bit == 8) {
- kernel = gemm_half_q_half_alt_8bit_kernel;
- }
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
- kernel<<<gridDim, blockDim, 0, stream>>>(
- (const half2*)a, b_q_weight, c, b_gptq_scales, b_gptq_qzeros, b_g_idx,
- size_m, size_k / 32 * bit, size_n);
- }
- template <class T, int bit>
- __global__ void reconstruct_gptq_kernel(const uint32_t* __restrict__ w,
- const half* __restrict__ w_scales,
- const uint32_t* __restrict__ w_zeros,
- const int* __restrict__ g_idx,
- const int height, const int width,
- const int group,
- half* __restrict__ out) {
- if (blockIdx.z > 0) {
- w = w + blockIdx.z * height * width / 8;
- w_scales = w_scales + blockIdx.z * group * width;
- w_zeros = w_zeros + blockIdx.z * group * width / 8;
- g_idx = g_idx + blockIdx.z * height;
- out = out + blockIdx.z * height * width;
- }
- // Start of block
- int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
- int row = blockIdx.y * 32 / bit;
- if (column >= width) return;
- // Views
- MatrixView_half_rw out_(out, height, width);
- MatrixView_half w_scales_(w_scales, group, width);
- T w_zeros_(w_zeros, group, width);
- uint32_t w_read = w[blockIdx.y * width + column];
- half* out_ptr = out_.item_ptr(row, column);
- #pragma unroll
- for (int s = 0; s < 32; s += bit) {
- int group = g_idx[row + s / bit];
- half w_scale = w_scales_.item(group, column);
- uint32_t w_zero = w_zeros_.item(group, column) + 1;
- half w_item =
- __hmul(__int2half_rn((int)((w_read >> s) & ((1 << bit) - 1)) - w_zero),
- w_scale);
- *out_ptr = w_item;
- out_ptr += out_.width;
- }
- }
- __global__ void reconstruct_gptq_3bit_kernel(
- const uint32_t* __restrict__ w, const half* __restrict__ w_scales,
- const uint32_t* __restrict__ w_zeros, const int* __restrict__ g_idx,
- const int height, const int width, const int group,
- half* __restrict__ out) {
- // Start of block
- int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
- int row = blockIdx.y * 32;
- if (column >= width) return;
- // Views
- MatrixView_half_rw out_(out, height, width);
- MatrixView_half w_scales_(w_scales, group, width);
- MatrixView_q3_row w_zeros_(w_zeros, group, width);
- uint32_t w1 = w[(blockIdx.y * 3) * width + column];
- uint32_t w2 = w[(blockIdx.y * 3 + 1) * width + column];
- uint32_t w3 = w[(blockIdx.y * 3 + 2) * width + column];
- half* out_ptr = out_.item_ptr(row, column);
- #pragma unroll
- for (int i = 0; i < 32; i += 1) {
- int group = g_idx[row + i];
- half w_scale = w_scales_.item(group, column);
- uint32_t w_zero = w_zeros_.item(group, column) + 1;
- int w_item;
- if (i == 10) {
- w_item = (w1 >> 30) | ((w2 << 2) & 0x4);
- } else if (i == 21) {
- w_item = (w2 >> 31) | ((w3 << 1) & 0x6);
- } else if (i < 10) {
- w_item = ((w1 >> (i * 3)) & 0x7);
- } else if (i < 21) {
- w_item = ((w2 >> (i * 3 - 32)) & 0x7);
- } else {
- w_item = ((w3 >> (i * 3 - 64)) & 0x7);
- }
- *out_ptr = __hmul(__int2half_rn(w_item - w_zero), w_scale);
- out_ptr += out_.width;
- }
- }
- void reconstruct_gptq(const uint32_t* b_q_weight, const uint32_t* b_gptq_qzeros,
- const half* b_gptq_scales, const int* b_g_idx, half* out,
- int height, int width, int groups, int num_experts,
- int bit) {
- dim3 blockDim, gridDim;
- blockDim.x = BLOCK_KN_SIZE;
- blockDim.y = 1;
- gridDim.y = DIVIDE(height, 32 / bit);
- gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
- gridDim.z = num_experts;
- auto kernel = reconstruct_gptq_kernel<MatrixView_q4_row, 4>;
- if (bit == 2) {
- kernel = reconstruct_gptq_kernel<MatrixView_q2_row, 2>;
- } else if (bit == 8) {
- kernel = reconstruct_gptq_kernel<MatrixView_q8_row, 8>;
- } else if (bit == 3) {
- kernel = reconstruct_gptq_3bit_kernel;
- gridDim.y = DIVIDE(height, 32);
- }
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
- kernel<<<gridDim, blockDim, 0, stream>>>(b_q_weight, b_gptq_scales,
- b_gptq_qzeros, b_g_idx, height,
- width, groups, out);
- }
- void dequant_gptq_cuda(const uint32_t* b_q_weight,
- const uint32_t* b_gptq_qzeros, const half* b_gptq_scales,
- const int* b_g_idx, half* temp_dq, int size_k,
- int size_n, int groups, int num_experts, int bits,
- bool use_exllama) {
- if (use_exllama) {
- reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
- temp_dq, size_k, size_n, groups, num_experts, bits);
- } else {
- reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, temp_dq,
- size_k, size_n, groups, num_experts, bits);
- }
- }
- void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a,
- const uint32_t* b_q_weight,
- const uint32_t* b_gptq_qzeros,
- const half* b_gptq_scales, const int* b_g_idx,
- half* c, half* temp_dq, int size_m, int size_n,
- int size_k, int groups, bool use_exllama, int bit) {
- bool use_reconstruct;
- if (use_exllama) {
- use_reconstruct = ((bit == 8 && size_m > MAX_Q_GEMM_ROWS_8BIT) ||
- (bit != 8 && size_m > MAX_Q_GEMM_ROWS));
- } else {
- // The 2/3-bit kernels are somehow slower than dequant + gemm baseline, so
- // we disabled them for now.
- use_reconstruct = (bit < 4 || size_m > MAX_ALT_GEMM_ROWS);
- }
- if (use_reconstruct) {
- // Reconstruct FP16 matrix, then cuBLAS
- dequant_gptq_cuda(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
- temp_dq, size_k, size_n, groups, 1, bit, use_exllama);
- const half alpha = __float2half(1.0f);
- const half beta = __float2half(0.0f);
- cublasHgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, size_n, size_m, size_k,
- &alpha, temp_dq, size_n, a, size_k, &beta, c, size_n);
- } else if (use_exllama) {
- // Quantized matmul
- int max_chunks = size_m / BLOCK_M_SIZE_MAX;
- int last_chunk = max_chunks * BLOCK_M_SIZE_MAX;
- int last_chunk_size = size_m - last_chunk;
- if (max_chunks) {
- gemm_half_q_half_cuda_part(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
- b_g_idx, c, last_chunk, size_n, size_k,
- BLOCK_M_SIZE_MAX, groups, bit);
- }
- if (last_chunk_size) {
- gemm_half_q_half_cuda_part(a + last_chunk * size_k, b_q_weight,
- b_gptq_qzeros, b_gptq_scales, b_g_idx,
- c + last_chunk * size_n, last_chunk_size,
- size_n, size_k, last_chunk_size, groups, bit);
- }
- } else {
- gemm_half_q_half_alt(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
- c, size_m, size_n, size_k, bit);
- }
- }
- __global__ void shuffle_4bit_kernel(uint32_t* __restrict__ b_q_weight,
- const int size_k, const int size_n) {
- int n = blockIdx.x * THREADS_X + threadIdx.x;
- if (n >= size_n) return;
- int k = 0;
- uint32_t* b_ptr = b_q_weight + n;
- while (k < size_k) {
- shuffle_4bit_8(b_ptr, size_n);
- b_ptr += 1 * size_n;
- k += 8;
- }
- }
- __global__ void shuffle_8bit_kernel(uint32_t* __restrict__ b_q_weight,
- const int size_k, const int size_n) {
- int n = blockIdx.x * THREADS_X + threadIdx.x;
- if (n >= size_n) return;
- int k = 0;
- uint32_t* b_ptr = b_q_weight + n;
- while (k < size_k) {
- shuffle_8bit_4(b_ptr, size_n);
- b_ptr += 1 * size_n;
- k += 4;
- }
- }
- __global__ void shuffle_2bit_kernel(uint32_t* __restrict__ b_q_weight,
- const int size_k, const int size_n) {
- int n = blockIdx.x * THREADS_X + threadIdx.x;
- if (n >= size_n) return;
- int k = 0;
- uint32_t* b_ptr = b_q_weight + n;
- while (k < size_k) {
- shuffle_2bit_16(b_ptr, size_n);
- b_ptr += 1 * size_n;
- k += 16;
- }
- }
- __global__ void shuffle_3bit_kernel(uint32_t* __restrict__ b_q_weight,
- const int size_k, const int size_n) {
- int n = blockIdx.x * THREADS_X + threadIdx.x;
- if (n >= size_n) return;
- int k = 0;
- uint32_t* b_ptr = b_q_weight + n;
- while (k < size_k) {
- shuffle_3bit_32(b_ptr, size_n);
- b_ptr += 3 * size_n;
- k += 32;
- }
- }
- __global__ void make_sequential_4bit_kernel(const uint32_t* __restrict__ w,
- uint32_t* __restrict__ w_new,
- const int* __restrict__ q_perm,
- const int w_height,
- const int w_width) {
- if (blockIdx.z > 0) {
- w = w + blockIdx.z * w_height * w_width;
- w_new = w_new + blockIdx.z * w_height * w_width;
- q_perm = q_perm + blockIdx.z * w_height * 8;
- }
- const uint64_t* w2 = (uint64_t*)w;
- uint64_t* w_new2 = (uint64_t*)w_new;
- int w2_stride = w_width >> 1;
- int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
- if (w2_column >= w2_stride) return;
- int w_new2_row = blockIdx.y;
- int q_perm_idx = w_new2_row << 3;
- uint64_t dst = 0;
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- int source_row = q_perm[q_perm_idx++];
- int w2_row = source_row >> 3;
- int w2_subrow = source_row & 0x07;
- int w2_row_shift = w2_subrow << 2;
- int wnew2_row_shift = i << 2;
- uint64_t src = w2[w2_row * w2_stride + w2_column];
- src >>= w2_row_shift;
- src &= 0x0000000f0000000f;
- src <<= wnew2_row_shift;
- dst |= src;
- }
- w_new2[w_new2_row * w2_stride + w2_column] = dst;
- }
- __global__ void make_sequential_2bit_kernel(const uint32_t* __restrict__ w,
- uint32_t* __restrict__ w_new,
- const int* __restrict__ q_perm,
- const int w_height,
- const int w_width) {
- if (blockIdx.z > 0) {
- w = w + blockIdx.z * w_height * w_width;
- w_new = w_new + blockIdx.z * w_height * w_width;
- q_perm = q_perm + blockIdx.z * w_height * 16;
- }
- const uint64_t* w2 = (uint64_t*)w;
- uint64_t* w_new2 = (uint64_t*)w_new;
- int w2_stride = w_width >> 1;
- int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
- if (w2_column >= w2_stride) return;
- int w_new2_row = blockIdx.y;
- int q_perm_idx = w_new2_row << 4;
- uint64_t dst = 0;
- #pragma unroll
- for (int i = 0; i < 16; i++) {
- int source_row = q_perm[q_perm_idx++];
- int w2_row = source_row >> 4;
- int w2_subrow = source_row & 0x0f;
- int w2_row_shift = w2_subrow << 1;
- int wnew2_row_shift = i << 1;
- uint64_t src = w2[w2_row * w2_stride + w2_column];
- src >>= w2_row_shift;
- src &= 0x0000000300000003;
- src <<= wnew2_row_shift;
- dst |= src;
- }
- w_new2[w_new2_row * w2_stride + w2_column] = dst;
- }
- __global__ void make_sequential_3bit_kernel(const uint32_t* __restrict__ w,
- uint32_t* __restrict__ w_new,
- const int* __restrict__ q_perm,
- const int w_height,
- const int w_width) {
- if (blockIdx.z > 0) {
- w = w + blockIdx.z * w_height * w_width;
- w_new = w_new + blockIdx.z * w_height * w_width;
- q_perm = q_perm + blockIdx.z * w_height * 32 / 3;
- }
- int w_column = THREADS_X * blockIdx.x + threadIdx.x;
- if (w_column >= w_width) return;
- int w_new_row = blockIdx.y * 3;
- int q_perm_idx = blockIdx.y << 5;
- uint32_t dst[3] = {0, 0, 0};
- #pragma unroll
- for (int i = 0; i < 32; i++) {
- int source_row = q_perm[q_perm_idx++];
- int z_w = (source_row / 32) * 3;
- int z_mod = source_row % 32;
- int z_bit;
- if (z_mod != 10) {
- if (z_mod != 21) {
- z_bit = z_mod;
- if (z_bit > 21) {
- z_bit *= 3;
- z_bit -= 64;
- z_w += 2;
- } else if (z_bit > 10) {
- z_bit *= 3;
- z_bit -= 32;
- z_w += 1;
- } else {
- z_bit *= 3;
- }
- } else {
- z_w += 1;
- }
- }
- uint64_t src;
- if (z_mod == 10) {
- src = (w[z_w * w_width + w_column] >> 30) |
- ((w[(z_w + 1) * w_width + w_column] << 2) & 0x4);
- } else if (z_mod == 21) {
- src = (w[z_w * w_width + w_column] >> 31) |
- ((w[(z_w + 1) * w_width + w_column] << 1) & 0x6);
- } else {
- src = w[z_w * w_width + w_column];
- src >>= z_bit;
- src &= 0x07;
- }
- z_w = 0;
- if (i != 10) {
- if (i != 21) {
- z_bit = i;
- if (z_bit > 21) {
- z_bit *= 3;
- z_bit -= 64;
- z_w += 2;
- } else if (z_bit > 10) {
- z_bit *= 3;
- z_bit -= 32;
- z_w += 1;
- } else {
- z_bit *= 3;
- }
- } else {
- z_w += 1;
- }
- }
- if (i == 10) {
- dst[z_w] |= (src & 0x03) << 30;
- dst[z_w + 1] |= ((src & 0x4) >> 2);
- } else if (i == 21) {
- dst[z_w] |= (src & 0x01) << 31;
- dst[z_w + 1] |= ((src & 0x6) >> 1);
- } else {
- dst[z_w] |= (src << z_bit);
- }
- }
- w_new[w_new_row * w_width + w_column] = dst[0];
- w_new[(w_new_row + 1) * w_width + w_column] = dst[1];
- w_new[(w_new_row + 2) * w_width + w_column] = dst[2];
- }
- __global__ void make_sequential_8bit_kernel(const uint32_t* __restrict__ w,
- uint32_t* __restrict__ w_new,
- const int* __restrict__ q_perm,
- const int w_height,
- const int w_width) {
- if (blockIdx.z > 0) {
- w = w + blockIdx.z * w_height * w_width;
- w_new = w_new + blockIdx.z * w_height * w_width;
- q_perm = q_perm + blockIdx.z * w_height * 4;
- }
- const uint64_t* w2 = (uint64_t*)w;
- uint64_t* w_new2 = (uint64_t*)w_new;
- int w2_stride = w_width >> 1;
- int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
- if (w2_column >= w2_stride) return;
- int w_new2_row = blockIdx.y;
- int q_perm_idx = w_new2_row << 2;
- uint64_t dst = 0;
- #pragma unroll
- for (int i = 0; i < 4; i++) {
- int source_row = q_perm[q_perm_idx++];
- int w2_row = source_row >> 2;
- int w2_subrow = source_row & 0x03;
- int w2_row_shift = w2_subrow << 3;
- int wnew2_row_shift = i << 3;
- uint64_t src = w2[w2_row * w2_stride + w2_column];
- src >>= w2_row_shift;
- src &= 0x000000ff000000ff;
- src <<= wnew2_row_shift;
- dst |= src;
- }
- w_new2[w_new2_row * w2_stride + w2_column] = dst;
- }
- void shuffle_exllama_weight(uint32_t* q_weight, int* q_perm, int height,
- int width, int num_experts, int bit) {
- if (q_perm) {
- uint32_t* new_qweight = NULL;
- cudaMalloc(&new_qweight,
- num_experts * height / 32 * bit * width * sizeof(uint32_t));
- dim3 blockDim, gridDim;
- blockDim.x = THREADS_X;
- blockDim.y = 1;
- gridDim.x = DIVIDE(width, THREADS_X);
- gridDim.y = height / 32 * bit;
- gridDim.z = num_experts;
- auto kernel = make_sequential_4bit_kernel;
- if (bit == 2) {
- kernel = make_sequential_2bit_kernel;
- } else if (bit == 3) {
- kernel = make_sequential_3bit_kernel;
- gridDim.y = height / 32;
- } else if (bit == 8) {
- kernel = make_sequential_8bit_kernel;
- }
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
- kernel<<<gridDim, blockDim, 0, stream>>>(q_weight, new_qweight, q_perm,
- height / 32 * bit, width);
- // Replace qweights
- cudaMemcpyAsync(q_weight, new_qweight,
- num_experts * height / 32 * bit * width * sizeof(uint32_t),
- cudaMemcpyDeviceToDevice);
- // Cleanup
- cudaDeviceSynchronize();
- cudaFree(new_qweight);
- }
- dim3 blockDim, gridDim;
- blockDim.x = THREADS_X;
- blockDim.y = 1;
- gridDim.x = DIVIDE(width, THREADS_X);
- gridDim.y = 1;
- auto shuffle_kernel = shuffle_4bit_kernel;
- if (bit == 2) {
- shuffle_kernel = shuffle_2bit_kernel;
- } else if (bit == 3) {
- shuffle_kernel = shuffle_3bit_kernel;
- } else if (bit == 8) {
- shuffle_kernel = shuffle_8bit_kernel;
- }
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
- shuffle_kernel<<<gridDim, blockDim, 0, stream>>>(q_weight,
- height * num_experts, width);
- }
- template <int m_count>
- __global__ void group_gemm_half_q_half_gptq_kernel(
- const half* __restrict__ a, const uint32_t* __restrict__ b_q_weight,
- const uint32_t* __restrict__ b_gptq_qzeros,
- const half* __restrict__ b_gptq_scales, half* __restrict__ c,
- const int size_m, const int size_n, const int size_k, const int groups,
- const int* __restrict__ b_q_perm, const float* __restrict__ topk_weights,
- const int* __restrict__ sorted_token_ids_ptr,
- const int* __restrict__ expert_ids_ptr,
- const int* __restrict__ num_tokens_post_padded, const int num_valid_tokens,
- const int top_k) {
- int num_tokens = *num_tokens_post_padded;
- int offset_m = blockIdx.y * m_count;
- if (offset_m >= num_tokens) return;
- int expert_id = expert_ids_ptr[blockIdx.y];
- b_q_weight = b_q_weight + size_k * size_n / 8 * expert_id;
- b_gptq_qzeros = b_gptq_qzeros + groups * size_n / 8 * expert_id;
- b_gptq_scales = b_gptq_scales + groups * size_n * expert_id;
- if (b_q_perm) b_q_perm = b_q_perm + size_k * expert_id;
- MatrixView_half a_(a, size_m, size_k);
- MatrixView_half_rw c_(c, size_m, size_n);
- MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
- MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
- int t = threadIdx.x;
- // Block
- int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
- int offset_k = blockIdx.z * BLOCK_KN_SIZE;
- int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
- int end_m = min(offset_m + m_count, size_m);
- int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
- int n = offset_n + t * 4;
- // Preload block_a
- __shared__ half block_a[m_count][BLOCK_KN_SIZE];
- int token_a[m_count];
- int valid_count = m_count;
- for (int m = 0; m < m_count; ++m) {
- int token_id = sorted_token_ids_ptr[offset_m + m];
- if (token_id >= num_valid_tokens) {
- valid_count = m;
- break;
- }
- token_a[m] = token_id;
- }
- if (offset_k + t < end_k) {
- for (int m = 0; m < valid_count; ++m) {
- const half* a_ptr = a_.item_ptr(token_a[m] / top_k, 0);
- half* block_a_ptr = block_a[m];
- half a0;
- if (b_q_perm)
- a0 = a_ptr[b_q_perm[offset_k + t]];
- else
- a0 = a_ptr[offset_k + t];
- block_a_ptr[t] = a0;
- }
- }
- // Zero output
- if (n >= size_n) return;
- __syncthreads();
- // Find initial group
- int groupsize = size_k / groups;
- int group = offset_k / groupsize;
- int nextgroup = offset_k + groupsize;
- // a, b offset
- int qk = offset_k / (32 / 4);
- const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
- const half* a_ptr = &block_a[0][0];
- int a_stride = BLOCK_KN_SIZE;
- // Initial group
- int zeros[4];
- float scales[4];
- half2 z1z16[4][2];
- half2 y1y16[4][2];
- b_gptq_qzeros_.item4(zeros, group, n);
- b_gptq_scales_.item4_f(scales, group, n);
- dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
- dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
- dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
- dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
- // Column result
- float block_c[m_count][4] = {};
- // Dequantize and multiply
- int k = offset_k;
- while (k < end_k) {
- if (k == nextgroup) {
- group++;
- nextgroup += groupsize;
- b_gptq_qzeros_.item4(zeros, group, n);
- b_gptq_scales_.item4_f(scales, group, n);
- dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
- dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
- dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
- dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
- }
- #pragma unroll
- for (int j = 0; j < 4; j++) {
- const int4* b_ptr4 = (int4*)b_ptr;
- int4 load_int4 = *b_ptr4;
- half2 dq[4][4];
- dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n,
- false);
- dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n,
- false);
- dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n,
- false);
- dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n,
- false);
- for (int m = 0; m < valid_count; m++) {
- block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0],
- block_c[m][0]);
- block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1],
- block_c[m][1]);
- block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2],
- block_c[m][2]);
- block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3],
- block_c[m][3]);
- }
- b_ptr += size_n;
- a_ptr += 8;
- }
- k += 32;
- }
- for (int m = 0; m < valid_count; m++) {
- if (topk_weights) {
- #pragma unroll
- for (int j = 0; j < 4; ++j) {
- block_c[m][j] = block_c[m][j] * topk_weights[token_a[m]];
- }
- }
- half2* out = (half2*)c_.item_ptr(token_a[m], n);
- half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]),
- __float2half_rn(block_c[m][1]));
- half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]),
- __float2half_rn(block_c[m][3]));
- atomicAdd(out, result01);
- atomicAdd(out + 1, result23);
- }
- }
- void group_gemm_half_q_half(const half* a, const uint32_t* b_q_weight,
- const uint32_t* b_gptq_qzeros,
- const half* b_gptq_scales, const int* b_q_perm,
- half* c, const float* __restrict__ topk_weights,
- const int* __restrict__ sorted_token_ids_ptr,
- const int* __restrict__ expert_ids_ptr,
- const int* __restrict__ num_tokens_post_padded,
- const int num_valid_tokens, const int top_k,
- int size_m, int size_n, int size_k, int pad_size_m,
- int groups) {
- dim3 blockDim, gridDim;
- blockDim.x = BLOCK_KN_SIZE;
- blockDim.y = 1;
- blockDim.z = 1;
- gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
- gridDim.y = DIVIDE(pad_size_m, BLOCK_M_SIZE_MAX);
- gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
- group_gemm_half_q_half_gptq_kernel<BLOCK_M_SIZE_MAX>
- <<<gridDim, blockDim, 0, stream>>>(
- a, b_q_weight, b_gptq_qzeros, b_gptq_scales, c, size_m, size_n,
- size_k, groups, b_q_perm, topk_weights, sorted_token_ids_ptr,
- expert_ids_ptr, num_tokens_post_padded, num_valid_tokens, top_k);
- }
- __global__ void group_gemm_half_q_half_alt_kernel(
- const half2* __restrict__ vec, const uint32_t* __restrict__ mat,
- half* __restrict__ mul, const half* __restrict__ scales,
- const uint32_t* __restrict__ zeros, const int* __restrict__ g_idx,
- int batch, int height, int width, int groups,
- const float* __restrict__ topk_weights,
- const int* __restrict__ sorted_token_ids_ptr,
- const int* __restrict__ expert_ids_ptr,
- const int* __restrict__ num_tokens_post_padded, const int num_valid_tokens,
- const int top_k) {
- int num_tokens = *num_tokens_post_padded;
- int b = blockIdx.y * BLOCK_M_SIZE_MAX;
- if (b >= num_tokens) return;
- int expert_id = expert_ids_ptr[blockIdx.y];
- mat = mat + height * width * expert_id;
- scales = scales + groups * width * expert_id;
- zeros = zeros + groups * width / 8 * expert_id;
- g_idx = g_idx + height * 8 * expert_id;
- int zero_width = width / 8;
- int vec_height = height * 4;
- const int blockwidth2 = BLOCK_KN_SIZE / 2;
- int b_end = BLOCK_M_SIZE_MAX;
- int h = BLOCK_KN_SIZE * blockIdx.z / 8;
- int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4;
- int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
- int token_a[BLOCK_M_SIZE_MAX];
- for (int m = 0; m < b_end; ++m) {
- int token_id = sorted_token_ids_ptr[b + m];
- if (token_id >= num_valid_tokens) {
- b_end = m;
- break;
- }
- token_a[m] = token_id;
- }
- __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2];
- if (threadIdx.x < h_end) {
- for (int m = 0; m < b_end; ++m) {
- blockvec[m][threadIdx.x] =
- vec[token_a[m] / top_k * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 +
- threadIdx.x];
- }
- }
- __shared__ half2 deq2[256][8];
- int val = threadIdx.x / 8;
- int off = threadIdx.x % 8;
- for (; val < 256; val += BLOCK_KN_SIZE / 8) {
- deq2[val][off] =
- __halves2half2(__int2half_rn(val & 0xF), __int2half_rn(val >> 4));
- }
- __syncthreads();
- int i = width * h + w;
- int g_h = h * 8;
- int k = 0;
- int z_w = w / 8;
- int z_mod = (w % 8) * 4;
- half2 res2;
- half res[BLOCK_M_SIZE_MAX] = {};
- unsigned int tmp;
- while (k < h_end) {
- tmp = mat[i];
- half2 scales_tmp[4];
- half2 zeros_tmp[4];
- for (int tmp_k = 0; tmp_k < 4; tmp_k++) {
- int g = g_idx[g_h + (k + tmp_k) * 2];
- int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1];
- half scale_f = scales[g * width + w];
- half scale_f2 = scales[g2 * width + w];
- half2 scale = __halves2half2(scale_f, scale_f2);
- half2 zero = __halves2half2(
- __hmul(scale_f,
- __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xF) -
- 1)),
- __hmul(scale_f2,
- __int2half_rn(
- -((zeros[g2 * zero_width + z_w] >> z_mod) & 0xF) - 1)));
- scales_tmp[tmp_k] = scale;
- zeros_tmp[tmp_k] = zero;
- }
- for (int m = 0; m < b_end; m++) {
- #ifndef USE_ROCM
- res2 = {};
- #else
- res2.x = __half_as_ushort(__float2half(0));
- res2.y = __half_as_ushort(__float2half(0));
- #endif
- res2 = __hfma2(
- __hfma2(deq2[(tmp >> 0) & 0xff][off], scales_tmp[0], zeros_tmp[0]),
- blockvec[m][k + 0], res2);
- res2 = __hfma2(
- __hfma2(deq2[(tmp >> 8) & 0xff][off], scales_tmp[1], zeros_tmp[1]),
- blockvec[m][k + 1], res2);
- res2 = __hfma2(
- __hfma2(deq2[(tmp >> 16) & 0xff][off], scales_tmp[2], zeros_tmp[2]),
- blockvec[m][k + 2], res2);
- res2 = __hfma2(
- __hfma2(deq2[(tmp >> 24) & 0xff][off], scales_tmp[3], zeros_tmp[3]),
- blockvec[m][k + 3], res2);
- #ifndef USE_ROCM
- res[m] = __hadd(res[m], __hadd(res2.x, res2.y));
- #else
- res[m] = __hadd(
- res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)));
- #endif
- }
- i += width;
- k += 4;
- }
- for (int m = 0; m < b_end; m++) {
- if (topk_weights) {
- res[m] = __float2half(__half2float(res[m]) * topk_weights[token_a[m]]);
- }
- atomicAdd(&mul[token_a[m] * width + w], res[m]);
- }
- }
- void group_gemm_half_q_half_alt(const half* a, const uint32_t* b_q_weight,
- const uint32_t* b_gptq_qzeros,
- const half* b_gptq_scales, const int* b_g_idx,
- half* c, const float* __restrict__ topk_weights,
- const int* __restrict__ sorted_token_ids_ptr,
- const int* __restrict__ expert_ids_ptr,
- const int* __restrict__ num_tokens_post_padded,
- const int num_valid_tokens, const int top_k,
- int size_m, int size_n, int size_k,
- int pad_size_m, int groups) {
- dim3 blockDim, gridDim;
- blockDim.x = BLOCK_KN_SIZE;
- blockDim.y = 1;
- blockDim.z = 1;
- gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE);
- gridDim.y = DIVIDE(pad_size_m, BLOCK_M_SIZE_MAX);
- gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
- group_gemm_half_q_half_alt_kernel<<<gridDim, blockDim, 0, stream>>>(
- (const half2*)a, b_q_weight, c, b_gptq_scales, b_gptq_qzeros, b_g_idx,
- size_m, size_k / 8, size_n, groups, topk_weights, sorted_token_ids_ptr,
- expert_ids_ptr, num_tokens_post_padded, num_valid_tokens, top_k);
- }
- // Only support 4-bit so far
- void group_gemm_half_q_half_cuda(const half* a, const uint32_t* b_q_weight,
- const uint32_t* b_gptq_qzeros,
- const half* b_gptq_scales, const int* b_g_idx,
- half* c,
- const float* __restrict__ topk_weights,
- const int* __restrict__ sorted_token_ids_ptr,
- const int* __restrict__ expert_ids_ptr,
- const int* __restrict__ num_tokens_post_padded,
- const int num_valid_tokens, const int top_k,
- int size_m, int size_n, int size_k,
- int pad_size_m, int groups, bool use_exllama) {
- if (use_exllama) {
- group_gemm_half_q_half(
- a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, c, topk_weights,
- sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded,
- num_valid_tokens, top_k, size_m, size_n, size_k, pad_size_m, groups);
- } else {
- group_gemm_half_q_half_alt(
- a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, c, topk_weights,
- sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded,
- num_valid_tokens, top_k, size_m, size_n, size_k, pad_size_m, groups);
- }
- }
- } // namespace gptq
- } // namespace aphrodite
- torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
- torch::Tensor b_gptq_qzeros,
- torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
- bool use_exllama, int64_t bit) {
- const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
- auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
- at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options);
- at::Tensor temp_dq = torch::empty(
- {b_q_weight.size(0) * 32 / bit, b_q_weight.size(1)}, options);
- aphrodite::gptq::gemm_half_q_half_cuda(
- at::cuda::getCurrentCUDABlasHandle(), (const half*)a.data_ptr(),
- (const uint32_t*)b_q_weight.data_ptr(),
- (const uint32_t*)b_gptq_qzeros.data_ptr(),
- (const half*)b_gptq_scales.data_ptr(),
- b_g_idx.device().is_meta() ? NULL : (const int*)b_g_idx.data_ptr(),
- (half*)c.data_ptr(), (half*)temp_dq.data_ptr(),
- c.size(0), // m
- c.size(1), // n
- a.size(1), // k
- b_gptq_qzeros.size(0), // group number
- use_exllama, bit);
- return c;
- }
- void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit) {
- const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
- int num_experts = q_weight.dim() == 3 ? q_weight.size(0) : 1;
- int size_k = q_weight.dim() == 3 ? q_weight.size(1) * 32 / bit
- : q_weight.size(0) * 32 / bit;
- int size_n = q_weight.dim() == 3 ? q_weight.size(2) : q_weight.size(1);
- aphrodite::gptq::shuffle_exllama_weight(
- (uint32_t*)q_weight.data_ptr(),
- q_perm.device().is_meta() || q_perm.numel() == 0
- ? NULL
- : (int*)q_perm.data_ptr(),
- size_k, size_n, num_experts, bit);
- }
- // Only support 4-bit
- // todo: extend support to other bits
- torch::Tensor group_gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
- torch::Tensor b_gptq_qzeros,
- torch::Tensor b_gptq_scales,
- torch::Tensor b_g_idx, torch::Tensor topk_weights,
- torch::Tensor sorted_token_ids_ptr,
- torch::Tensor expert_ids_ptr,
- torch::Tensor num_tokens_post_padded,
- bool mul_weights, bool use_exllama) {
- const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
- auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
- at::Tensor c = torch::zeros(
- {a.size(0), topk_weights.size(1), b_q_weight.size(2)}, options);
- aphrodite::gptq::group_gemm_half_q_half_cuda(
- (const half*)a.data_ptr(), (const uint32_t*)b_q_weight.data_ptr(),
- (const uint32_t*)b_gptq_qzeros.data_ptr(),
- (const half*)b_gptq_scales.data_ptr(),
- b_g_idx.device().is_meta() ? NULL : (const int*)b_g_idx.data_ptr(),
- (half*)c.data_ptr(),
- mul_weights ? (const float*)topk_weights.data_ptr() : NULL,
- (const int*)sorted_token_ids_ptr.data_ptr(),
- (const int*)expert_ids_ptr.data_ptr(),
- (const int*)num_tokens_post_padded.data_ptr(),
- topk_weights.numel(), // num tokens
- topk_weights.size(1) / a.size(1), // top_k
- a.size(0) * a.size(1), // m
- c.size(2), // n
- a.size(2), // k
- sorted_token_ids_ptr.size(0),
- b_gptq_qzeros.size(1), // group number
- use_exllama);
- return c;
- }
- torch::Tensor dequant_gptq(torch::Tensor b_q_weight,
- torch::Tensor b_gptq_qzeros,
- torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
- int bits, bool use_exllama) {
- const at::cuda::OptionalCUDAGuard device_guard(device_of(b_gptq_scales));
- auto options = torch::TensorOptions()
- .dtype(b_gptq_scales.dtype())
- .device(b_gptq_scales.device());
- at::Tensor temp_dq;
- int num_experts;
- int size_k;
- int size_n;
- int groups;
- // moe
- if (b_q_weight.dim() == 3) {
- temp_dq = torch::empty({b_q_weight.size(0), b_q_weight.size(1) * 32 / bits,
- b_q_weight.size(2)},
- options);
- num_experts = b_q_weight.size(0);
- size_k = b_q_weight.size(1) * 32 / bits;
- size_n = b_q_weight.size(2);
- groups = b_gptq_scales.size(1);
- } else {
- temp_dq = torch::empty({b_q_weight.size(0) * 32 / bits, b_q_weight.size(1)},
- options);
- num_experts = 1;
- size_k = b_q_weight.size(0) * 32 / bits;
- size_n = b_q_weight.size(1);
- groups = b_gptq_scales.size(0);
- }
- aphrodite::gptq::dequant_gptq_cuda(
- (const uint32_t*)b_q_weight.data_ptr(),
- (const uint32_t*)b_gptq_qzeros.data_ptr(),
- (const half*)b_gptq_scales.data_ptr(),
- b_g_idx.device().is_meta() ? NULL : (const int*)b_g_idx.data_ptr(),
- (half*)temp_dq.data_ptr(), size_k, size_n, groups, num_experts, bits,
- use_exllama);
- return temp_dq;
- }
|