1
0

gemm_kernels.cu 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771
  1. /*
  2. Adapted from https://github.com/mit-han-lab/llm-awq
  3. @article{lin2023awq,
  4. title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
  5. author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
  6. journal={arXiv},
  7. year={2023}
  8. }
  9. */
  10. #include <torch/extension.h>
  11. #include <c10/cuda/CUDAGuard.h>
  12. #include "dequantize.cuh"
  13. #include <cuda_fp16.h>
  14. namespace aphrodite {
  15. namespace awq {
  16. // Pack two half values.
  17. static inline __device__ __host__ unsigned
  18. __pack_half2(const half x, const half y) {
  19. unsigned v0 = *((unsigned short *)&x);
  20. unsigned v1 = *((unsigned short *)&y);
  21. return (v1 << 16) | v0;
  22. }
  23. template<int N>
  24. __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
  25. int G,
  26. int split_k_iters,
  27. half* __restrict__ A,
  28. int* __restrict__ B,
  29. half* __restrict__ scaling_factors,
  30. int* __restrict__ zeros,
  31. int M,
  32. int IC,
  33. int OC,
  34. half* __restrict__ C)
  35. {
  36. // Only support matrix n = 64 or 128
  37. assert(N == 64 || N == 128);
  38. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
  39. assert(false);
  40. #else
  41. static constexpr uint32_t ZERO = 0x0;
  42. float C_warp[32];
  43. __shared__ half A_shared[16 * (32 + 8)];
  44. __shared__ half B_shared[32 * (N + 8)];
  45. __shared__ half scaling_factors_shared[N];
  46. __shared__ half zeros_shared[N];
  47. int j_factors1 = ((OC + N - 1) / N);
  48. int blockIdx_x = 0;
  49. int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
  50. int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
  51. half A_shared_warp[8];
  52. half B_shared_warp[N / 4];
  53. for (int j_0_4_init = 0; j_0_4_init < N / 32; ++j_0_4_init) {
  54. for (int i = 0; i < 8; ++i) {
  55. C_warp[(j_0_4_init * 8) + i] = 0.0;
  56. }
  57. }
  58. static constexpr int row_stride_warp = 32 * 8 / 32;
  59. static constexpr int row_stride = 2 * 32 * 8 / N;
  60. bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N;
  61. // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
  62. bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
  63. // bool wb_C_flag = (threadIdx.x / 4) < M;
  64. half* A_ptr = A
  65. + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
  66. + (((int)threadIdx.x) % (32 / 8)) * 8;
  67. int* B_ptr = B
  68. + ((int)threadIdx.y) * (OC / 8) * (256 / N)
  69. + (((int)threadIdx.x) / (N / 8)) * (OC / 8)
  70. + (((int)blockIdx_y) % j_factors1) * (N / 8)
  71. + (((int)threadIdx.x) % (N / 8)) * 1;
  72. // Why * 1 in the above line?
  73. half* A_shared_ptr = A_shared
  74. + ((int)threadIdx.y) * row_stride_warp * (32 + 8)
  75. + (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
  76. + (((int)threadIdx.x) % (32 / 8) ) * 8;
  77. half* B_shared_ptr = B_shared
  78. + ((int)threadIdx.y) * (row_stride / 2) * (N + 8)
  79. + (((int)threadIdx.x) / (N / 8)) * (N + 8)
  80. + (((int)threadIdx.x) % (N / 8)) * 8;
  81. int* zeros_ptr = zeros
  82. + (((int)blockIdx_y) % j_factors1) * (N / 8)
  83. + ((int)threadIdx.x) % (N / 8);
  84. half* scaling_factors_ptr = scaling_factors
  85. + (((int)blockIdx_y) % j_factors1) * N
  86. + (((int)threadIdx.x) % (N / 8)) * 8;
  87. half* C_ptr = C
  88. + static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
  89. + (((int)blockIdx_y) % j_factors1) * N
  90. + ((int)threadIdx.y) * (N / 2)
  91. + (((int)threadIdx.x) % 4) * 2;
  92. // preload s.f. and zeros
  93. int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
  94. if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
  95. for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
  96. int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
  97. __syncthreads();
  98. // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
  99. if (ld_A_flag)
  100. {
  101. *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
  102. }
  103. else
  104. {
  105. *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
  106. }
  107. // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
  108. uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
  109. uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
  110. uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
  111. /*
  112. if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
  113. printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
  114. }
  115. */
  116. // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
  117. int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
  118. for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) {
  119. // B: 32 x 136 (128+8) float16
  120. // each warp: 32 x 4
  121. // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
  122. // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
  123. // row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
  124. uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
  125. uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
  126. //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
  127. // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
  128. // - zero and * scale
  129. // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
  130. asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
  131. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
  132. asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
  133. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
  134. asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
  135. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
  136. asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
  137. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
  138. /*
  139. if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
  140. printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
  141. }
  142. */
  143. // write back
  144. *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = B_loaded_fp16;
  145. }
  146. __syncthreads();
  147. for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) {
  148. {
  149. unsigned int addr;
  150. __asm__ __volatile__(
  151. "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
  152. : "=r"(addr)
  153. : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
  154. );
  155. __asm__ __volatile__(
  156. "ldmatrix.sync.aligned.m8n8.x4.shared.b16"
  157. "{%0, %1, %2, %3}, [%4];\n"
  158. : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
  159. : "r"(addr)
  160. );
  161. }
  162. for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) {
  163. {
  164. unsigned int addr;
  165. __asm__ __volatile__(
  166. "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
  167. : "=r"(addr)
  168. : "l"((void *)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + (((int)threadIdx.y) * (N / 2))) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * (N + 8)) + ((((int)threadIdx.x) >> 4) * 8))))
  169. );
  170. __asm__ __volatile__(
  171. "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
  172. "{%0, %1, %2, %3}, [%4];\n"
  173. : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
  174. : "r"(addr)
  175. );
  176. }
  177. }
  178. for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) {
  179. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
  180. {
  181. __asm__ __volatile__(
  182. "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
  183. "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
  184. : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
  185. : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
  186. }
  187. {
  188. __asm__ __volatile__(
  189. "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
  190. "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
  191. : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
  192. : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
  193. }
  194. {
  195. __asm__ __volatile__(
  196. "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
  197. "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
  198. : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
  199. : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
  200. }
  201. {
  202. __asm__ __volatile__(
  203. "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
  204. "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
  205. : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
  206. : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
  207. }
  208. #else
  209. {
  210. __asm__ __volatile__(
  211. "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
  212. "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
  213. : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
  214. : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
  215. }
  216. {
  217. __asm__ __volatile__(
  218. "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
  219. "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
  220. : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
  221. : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
  222. }
  223. #endif
  224. }
  225. }
  226. }
  227. // TODO: Shang: Hoist loop invariance.
  228. for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {
  229. for (int local_id = 0; local_id < 8; ++local_id) {
  230. int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
  231. if (row_offset < M)
  232. {
  233. *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
  234. }
  235. }
  236. }
  237. #endif
  238. }
  239. __global__ void __launch_bounds__(64) dequantize_weights(
  240. int* __restrict__ B,
  241. half* __restrict__ scaling_factors,
  242. int* __restrict__ zeros,
  243. half* __restrict__ C,
  244. int G,
  245. int in_c,
  246. int out_c
  247. )
  248. {
  249. if (blockIdx.z > 0) {
  250. B = B + blockIdx.z * in_c * out_c / 8;
  251. scaling_factors = scaling_factors + blockIdx.z * in_c * out_c / G;
  252. zeros = zeros + blockIdx.z * in_c * out_c / G / 8;
  253. C = C + blockIdx.z * in_c * out_c;
  254. }
  255. int j_factors1 = 4;
  256. int row_stride2 = 4;
  257. int split_k_iters = 1;
  258. static constexpr uint32_t ZERO = 0x0;
  259. half B_shared[32 * (128 + 8)];
  260. half* B_shared_ptr2 = B_shared;
  261. half B_shared_warp[32];
  262. int OC = 512;
  263. int N = blockDim.x * gridDim.x; // 2
  264. int col = (blockIdx.x * blockDim.x + threadIdx.x);
  265. int row = blockIdx.y * blockDim.y + threadIdx.y;
  266. int index1 = 8 * col + 8 * row * N;
  267. half* C_ptr2 = C + index1;
  268. int index2 = col + row * N;
  269. int* B_ptr2 = B + index2;
  270. int index3 = col + (int)(row / G) * N;
  271. int* zeros_ptr2 = zeros + index3;
  272. int index4 = 8 * col + (int)(row / G) * N * 8;
  273. half* scaling_factors_ptr2 = scaling_factors + index4;
  274. uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr2);
  275. uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
  276. uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr2);
  277. uint32_t B_loaded = *(uint32_t*)B_ptr2;
  278. uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
  279. asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
  280. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
  281. asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
  282. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
  283. asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
  284. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
  285. asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
  286. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
  287. *(uint4*)B_shared_ptr2 = B_loaded_fp16;
  288. for (int i = 0; i < 8; ++i) {
  289. *(C_ptr2 + i) = B_shared[i];
  290. }
  291. }
  292. template<int N>
  293. __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
  294. int G,
  295. int split_k_iters,
  296. half* __restrict__ A,
  297. int* __restrict__ B,
  298. half* __restrict__ scaling_factors,
  299. int* __restrict__ zeros,
  300. const float* __restrict__ topk_weights,
  301. const int* __restrict__ sorted_token_ids_ptr,
  302. const int* __restrict__ expert_ids_ptr,
  303. const int* __restrict__ num_tokens_post_padded,
  304. const int num_valid_tokens,
  305. const int top_k,
  306. const int expert_num,
  307. int pad_M,
  308. int M,
  309. int IC,
  310. int OC,
  311. half* __restrict__ C)
  312. {
  313. // Only support matrix n = 64 or 128
  314. assert(N == 64 || N == 128);
  315. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
  316. assert(false);
  317. #else
  318. int num_tokens = *num_tokens_post_padded;
  319. int j_factors1 = ((OC + N - 1) / N);
  320. int blockIdx_x = 0;
  321. int blockIdx_y = blockIdx.x % ((pad_M + 16 - 1) / 16 * j_factors1);
  322. int blockIdx_z = blockIdx.x / ((pad_M + 16 - 1) / 16 * j_factors1);
  323. int block = blockIdx_y / j_factors1;
  324. if (block * 16 >= num_tokens) return;
  325. static constexpr uint32_t ZERO = 0x0;
  326. float C_warp[32];
  327. __shared__ half A_shared[16 * (32 + 8)];
  328. __shared__ half B_shared[32 * (N + 8)];
  329. __shared__ half scaling_factors_shared[N];
  330. __shared__ half zeros_shared[N];
  331. half A_shared_warp[8];
  332. half B_shared_warp[N / 4];
  333. for (int j_0_4_init = 0; j_0_4_init < N / 32; ++j_0_4_init) {
  334. for (int i = 0; i < 8; ++i) {
  335. C_warp[(j_0_4_init * 8) + i] = 0.0;
  336. }
  337. }
  338. static constexpr int row_stride_warp = 32 * 8 / 32;
  339. static constexpr int row_stride = 2 * 32 * 8 / N;
  340. bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N;
  341. // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
  342. int row = (block * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32);
  343. int token_id = sorted_token_ids_ptr[row];
  344. bool ld_A_flag = (token_id < num_valid_tokens);
  345. half* A_ptr = A + token_id / top_k * IC + (((int)threadIdx.x) % (32 / 8)) * 8;
  346. int expert_id = expert_ids_ptr[block];
  347. B = B + OC * IC / 8 * expert_id;
  348. scaling_factors = scaling_factors + OC * IC / G * expert_id;
  349. zeros = zeros + OC * IC / G / 8 * expert_id;
  350. int* B_ptr = B
  351. + ((int)threadIdx.y) * (OC / 8) * (256 / N)
  352. + (((int)threadIdx.x) / (N / 8)) * (OC / 8)
  353. + (((int)blockIdx_y) % j_factors1) * (N / 8)
  354. + (((int)threadIdx.x) % (N / 8)) * 1;
  355. // Why * 1 in the above line?
  356. half* A_shared_ptr = A_shared
  357. + ((int)threadIdx.y) * row_stride_warp * (32 + 8)
  358. + (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
  359. + (((int)threadIdx.x) % (32 / 8) ) * 8;
  360. half* B_shared_ptr = B_shared
  361. + ((int)threadIdx.y) * (row_stride / 2) * (N + 8)
  362. + (((int)threadIdx.x) / (N / 8)) * (N + 8)
  363. + (((int)threadIdx.x) % (N / 8)) * 8;
  364. int* zeros_ptr = zeros
  365. + (((int)blockIdx_y) % j_factors1) * (N / 8)
  366. + ((int)threadIdx.x) % (N / 8);
  367. half* scaling_factors_ptr = scaling_factors
  368. + (((int)blockIdx_y) % j_factors1) * N
  369. + (((int)threadIdx.x) % (N / 8)) * 8;
  370. half* C_ptr = C
  371. + static_cast<long long>(blockIdx_z) * M * OC * expert_num // blockIdz.x -> split_k dim
  372. + (((int)blockIdx_y) % j_factors1) * N
  373. + ((int)threadIdx.y) * (N / 2)
  374. + (((int)threadIdx.x) % 4) * 2;
  375. // preload s.f. and zeros
  376. int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
  377. if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
  378. for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
  379. int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
  380. __syncthreads();
  381. // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
  382. if (ld_A_flag)
  383. {
  384. *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
  385. }
  386. else
  387. {
  388. *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
  389. }
  390. uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
  391. uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
  392. uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
  393. int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
  394. for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) {
  395. uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
  396. uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
  397. // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
  398. asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
  399. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
  400. asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
  401. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
  402. asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
  403. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
  404. asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
  405. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
  406. // write back
  407. *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = B_loaded_fp16;
  408. }
  409. __syncthreads();
  410. for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) {
  411. {
  412. unsigned int addr;
  413. __asm__ __volatile__(
  414. "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
  415. : "=r"(addr)
  416. : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
  417. );
  418. __asm__ __volatile__(
  419. "ldmatrix.sync.aligned.m8n8.x4.shared.b16"
  420. "{%0, %1, %2, %3}, [%4];\n"
  421. : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
  422. : "r"(addr)
  423. );
  424. }
  425. for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) {
  426. {
  427. unsigned int addr;
  428. __asm__ __volatile__(
  429. "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
  430. : "=r"(addr)
  431. : "l"((void *)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + (((int)threadIdx.y) * (N / 2))) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * (N + 8)) + ((((int)threadIdx.x) >> 4) * 8))))
  432. );
  433. __asm__ __volatile__(
  434. "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
  435. "{%0, %1, %2, %3}, [%4];\n"
  436. : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
  437. : "r"(addr)
  438. );
  439. }
  440. }
  441. for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) {
  442. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
  443. {
  444. __asm__ __volatile__(
  445. "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
  446. "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
  447. : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
  448. : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
  449. }
  450. {
  451. __asm__ __volatile__(
  452. "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
  453. "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
  454. : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
  455. : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
  456. }
  457. {
  458. __asm__ __volatile__(
  459. "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
  460. "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
  461. : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
  462. : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
  463. }
  464. {
  465. __asm__ __volatile__(
  466. "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
  467. "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
  468. : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
  469. : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
  470. }
  471. #else
  472. {
  473. __asm__ __volatile__(
  474. "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
  475. "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
  476. : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
  477. : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
  478. }
  479. {
  480. __asm__ __volatile__(
  481. "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
  482. "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
  483. : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
  484. : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
  485. }
  486. #endif
  487. }
  488. }
  489. }
  490. // TODO: Shang: Hoist loop invariance.
  491. for (int ax1_0_1 = 0; ax1_0_1 < N / 32; ++ax1_0_1) {
  492. for (int local_id = 0; local_id < 8; ++local_id) {
  493. int row_offset = block * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
  494. int token_id = sorted_token_ids_ptr[row_offset];
  495. if (token_id < num_valid_tokens)
  496. {
  497. float value = C_warp[(ax1_0_1 * 8) + local_id];
  498. if (topk_weights) {
  499. value = value * topk_weights[token_id];
  500. }
  501. *(C_ptr + ax1_0_1 * 16 + token_id * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(value);
  502. }
  503. }
  504. }
  505. #endif
  506. }
  507. } // namespace awq
  508. } // namespace aphrodite
  509. torch::Tensor awq_dequantize(
  510. torch::Tensor _kernel,
  511. torch::Tensor _scaling_factors,
  512. torch::Tensor _zeros,
  513. int split_k_iters,
  514. int thx,
  515. int thy)
  516. {
  517. int in_c = _kernel.dim() == 2 ? _kernel.size(0) : _kernel.size(1);
  518. int qout_c = _kernel.dim() == 2 ? _kernel.size(1) : _kernel.size(2);
  519. int num_experts = _kernel.dim() == 2 ? 1 : _kernel.size(0);
  520. int out_c = qout_c * 8;
  521. int G = in_c / (_kernel.dim() == 2 ? _scaling_factors.size(0) : _scaling_factors.size(1));
  522. int x_thread = thx;
  523. int y_thread = thy;
  524. int x_blocks = 1;
  525. int y_blocks = 1;
  526. if (thx==0) {
  527. x_thread = qout_c;
  528. }
  529. if (thy==0) {
  530. y_thread = in_c;
  531. }
  532. if (thx==0 && thy==0) {
  533. x_thread = 8;
  534. y_thread = 8;
  535. x_blocks = (int)(qout_c / 8);
  536. y_blocks = (int)(in_c / 8);
  537. }
  538. const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors));
  539. auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device());
  540. at::Tensor _de_kernel;
  541. if (num_experts == 1) {
  542. _de_kernel = torch::empty({in_c, out_c}, options);
  543. } else {
  544. _de_kernel = torch::empty({num_experts, in_c, out_c}, options);
  545. }
  546. auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
  547. auto de_kernel = reinterpret_cast<half*>(_de_kernel.data_ptr<at::Half>());
  548. auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
  549. auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
  550. dim3 num_blocks(x_blocks, y_blocks, num_experts);
  551. dim3 threads_per_block(x_thread, y_thread);
  552. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  553. aphrodite::awq::dequantize_weights<<<num_blocks, threads_per_block, 0, stream>>>(
  554. kernel, scaling_factors, zeros, de_kernel, G, in_c, out_c);
  555. return _de_kernel;
  556. }
  557. // in_feats: M, IC [float16]
  558. // kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
  559. // scaling_factors: IC // G, OC [float16]
  560. // zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
  561. // assume that batch_size < 16 for now
  562. torch::Tensor awq_gemm(
  563. torch::Tensor _in_feats,
  564. torch::Tensor _kernel,
  565. torch::Tensor _scaling_factors,
  566. torch::Tensor _zeros,
  567. int split_k_iters)
  568. {
  569. int num_in_feats = _in_feats.size(0);
  570. int num_in_channels = _in_feats.size(1);
  571. const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
  572. auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
  573. at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options);
  574. int num_out_feats = _out_feats.size(-2);
  575. int num_out_channels = _out_feats.size(-1);
  576. auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
  577. auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
  578. auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
  579. auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
  580. auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
  581. int group_size = num_in_channels / _scaling_factors.size(0);
  582. if (num_out_channels % 64 != 0)
  583. throw std::invalid_argument("OC is not multiple of cta_N = 64");
  584. if (num_out_channels % 8 != 0)
  585. throw std::invalid_argument("OC is not multiple of pack_num = 8");
  586. if (group_size % 32 != 0)
  587. throw std::invalid_argument("Group size should be a multiple of 32");
  588. if (num_out_channels % group_size != 0)
  589. throw std::invalid_argument("OC is not multiple of Group size");
  590. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  591. if (num_out_channels % 128 == 0)
  592. {
  593. int j_factors1 = num_out_channels / 128 / 1;
  594. dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
  595. // threadIdx.x: 32
  596. // threadIdx.y: i_factors[2] * j_factors[2]
  597. dim3 threads_per_block(32, 2);
  598. aphrodite::awq::gemm_forward_4bit_cuda_m16nXk32<128><<<num_blocks, threads_per_block, 0, stream>>>(
  599. group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels,
  600. num_out_channels, out_feats);
  601. }
  602. else if (num_out_channels % 64 == 0)
  603. {
  604. int j_factors1 = num_out_channels / 64 / 1;
  605. dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
  606. // threadIdx.x: 32
  607. // threadIdx.y: i_factors[2] * j_factors[2]
  608. dim3 threads_per_block(32, 2);
  609. aphrodite::awq::gemm_forward_4bit_cuda_m16nXk32<64><<<num_blocks, threads_per_block, 0, stream>>>(
  610. group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels,
  611. num_out_channels, out_feats);
  612. }
  613. return _out_feats.sum(0);
  614. }
  615. torch::Tensor awq_group_gemm(
  616. torch::Tensor _in_feats,
  617. torch::Tensor _kernel,
  618. torch::Tensor _scaling_factors,
  619. torch::Tensor _zeros,
  620. torch::Tensor _topk_weights,
  621. torch::Tensor _sorted_token_ids_ptr,
  622. torch::Tensor _expert_ids_ptr,
  623. torch::Tensor _num_tokens_post_padded,
  624. bool mul_weights,
  625. int split_k_iters)
  626. {
  627. int num_in_feats = _in_feats.size(0);
  628. int pad_num_in_feats = _sorted_token_ids_ptr.size(0);
  629. int num_in_channels = _in_feats.size(2);
  630. const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
  631. auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
  632. int num_experts = _topk_weights.size(1);
  633. int top_k = num_experts / _in_feats.size(1);
  634. int group_size = num_in_channels / _scaling_factors.size(1);
  635. at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _topk_weights.size(1), _kernel.size(2) * 8}, options);
  636. int num_out_channels = _out_feats.size(-1);
  637. auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
  638. auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
  639. auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
  640. auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
  641. auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
  642. auto topk_weights = mul_weights ? reinterpret_cast<float*>(_topk_weights.data_ptr()) : nullptr;
  643. auto sorted_token_ids_ptr = reinterpret_cast<int*>(_sorted_token_ids_ptr.data_ptr());
  644. auto expert_ids_ptr = reinterpret_cast<int*>(_expert_ids_ptr.data_ptr());
  645. auto num_tokens_post_padded = reinterpret_cast<int*>(_num_tokens_post_padded.data_ptr());
  646. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  647. if (num_out_channels % 128 == 0)
  648. {
  649. int j_factors1 = num_out_channels / 128 / 1;
  650. dim3 num_blocks((pad_num_in_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
  651. // threadIdx.x: 32
  652. // threadIdx.y: i_factors[2] * j_factors[2]
  653. dim3 threads_per_block(32, 2);
  654. aphrodite::awq::group_gemm_forward_4bit_cuda_m16nXk32<128><<<num_blocks, threads_per_block, 0, stream>>>(
  655. group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros,
  656. topk_weights, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded,
  657. _topk_weights.numel(), top_k, num_experts, pad_num_in_feats,
  658. num_in_feats, num_in_channels, num_out_channels, out_feats);
  659. }
  660. else if (num_out_channels % 64 == 0)
  661. {
  662. int j_factors1 = num_out_channels / 64 / 1;
  663. dim3 num_blocks((pad_num_in_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
  664. // threadIdx.x: 32
  665. // threadIdx.y: i_factors[2] * j_factors[2]
  666. dim3 threads_per_block(32, 2);
  667. aphrodite::awq::group_gemm_forward_4bit_cuda_m16nXk32<64><<<num_blocks, threads_per_block, 0, stream>>>(
  668. group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros,
  669. topk_weights, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded,
  670. _topk_weights.numel(), top_k, num_experts, pad_num_in_feats,
  671. num_in_feats, num_in_channels, num_out_channels, out_feats);
  672. }
  673. return _out_feats.sum(0);
  674. }