1
0

gemm_kernels.cu 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914
  1. /*
  2. Adapted from https://github.com/mit-han-lab/llm-awq
  3. @article{lin2023awq,
  4. title={AWQ: Activation-aware Weight Quantization for LLM Compression and
  5. Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang,
  6. Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
  7. }
  8. */
  9. #include <torch/all.h>
  10. #include <c10/cuda/CUDAGuard.h>
  11. #include "dequantize.cuh"
  12. #include <cuda_fp16.h>
  13. namespace aphrodite {
  14. namespace awq {
  15. template <int N>
  16. __global__ void __launch_bounds__(64)
  17. gemm_forward_4bit_cuda_m16nXk32(int G, int split_k_iters,
  18. half* __restrict__ A, int* __restrict__ B,
  19. half* __restrict__ scaling_factors,
  20. int* __restrict__ zeros, int M, int IC,
  21. int OC, half* __restrict__ C) {
  22. // Only support matrix n = 64 or 128
  23. assert(N == 64 || N == 128);
  24. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
  25. assert(false);
  26. #else
  27. static constexpr uint32_t ZERO = 0x0;
  28. float C_warp[32];
  29. __shared__ half A_shared[16 * (32 + 8)];
  30. __shared__ half B_shared[32 * (N + 8)];
  31. int j_factors1 = ((OC + N - 1) / N);
  32. int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
  33. int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
  34. half A_shared_warp[8];
  35. half B_shared_warp[N / 4];
  36. for (int j_0_4_init = 0; j_0_4_init < N / 32; ++j_0_4_init) {
  37. for (int i = 0; i < 8; ++i) {
  38. C_warp[(j_0_4_init * 8) + i] = 0.0;
  39. }
  40. }
  41. static constexpr int row_stride_warp = 32 * 8 / 32;
  42. static constexpr int row_stride = 2 * 32 * 8 / N;
  43. // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
  44. bool ld_A_flag =
  45. (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp +
  46. threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
  47. // bool wb_C_flag = (threadIdx.x / 4) < M;
  48. half* A_ptr =
  49. A +
  50. (((int)blockIdx_y) / j_factors1 * 16 +
  51. (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) *
  52. IC +
  53. (((int)threadIdx.x) % (32 / 8)) * 8;
  54. int* B_ptr = B + ((int)threadIdx.y) * (OC / 8) * (256 / N) +
  55. (((int)threadIdx.x) / (N / 8)) * (OC / 8) +
  56. (((int)blockIdx_y) % j_factors1) * (N / 8) +
  57. (((int)threadIdx.x) % (N / 8)) * 1;
  58. // Why * 1 in the above line?
  59. half* A_shared_ptr = A_shared +
  60. ((int)threadIdx.y) * row_stride_warp * (32 + 8) +
  61. (((int)threadIdx.x) / (32 / 8)) * (32 + 8) +
  62. (((int)threadIdx.x) % (32 / 8)) * 8;
  63. half* B_shared_ptr = B_shared +
  64. ((int)threadIdx.y) * (row_stride / 2) * (N + 8) +
  65. (((int)threadIdx.x) / (N / 8)) * (N + 8) +
  66. (((int)threadIdx.x) % (N / 8)) * 8;
  67. int* zeros_ptr = zeros + (((int)blockIdx_y) % j_factors1) * (N / 8) +
  68. ((int)threadIdx.x) % (N / 8);
  69. half* scaling_factors_ptr = scaling_factors +
  70. (((int)blockIdx_y) % j_factors1) * N +
  71. (((int)threadIdx.x) % (N / 8)) * 8;
  72. half* C_ptr =
  73. C +
  74. static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
  75. + (((int)blockIdx_y) % j_factors1) * N + ((int)threadIdx.y) * (N / 2) +
  76. (((int)threadIdx.x) % 4) * 2;
  77. // preload s.f. and zeros
  78. int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
  79. if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
  80. for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
  81. int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
  82. __syncthreads();
  83. // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
  84. if (ld_A_flag) {
  85. *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
  86. } else {
  87. *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
  88. }
  89. // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
  90. uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
  91. uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
  92. uint4 B_loaded_scale =
  93. *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
  94. /*
  95. if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 &&
  96. threadIdx.y == 0){ printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x,
  97. B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x,
  98. B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
  99. }
  100. */
  101. // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
  102. int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
  103. for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) {
  104. // B: 32 x 136 (128+8) float16
  105. // each warp: 32 x 4
  106. // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus
  107. // zero -> WB UINT4
  108. // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) *
  109. // 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15)
  110. // * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 *
  111. // 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) *
  112. // 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) *
  113. // 8))); row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
  114. uint32_t B_loaded =
  115. *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
  116. uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
  117. // - zero and * scale
  118. // TODO (Haotian): can save 4 assembly instructions if sormulate as deq =
  119. // q * scale - zero * scale.
  120. asm volatile("sub.f16x2 %0, %1, %2;\n"
  121. : "=r"(B_loaded_fp16.x)
  122. : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
  123. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
  124. : "=r"(B_loaded_fp16.x)
  125. : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
  126. asm volatile("sub.f16x2 %0, %1, %2;\n"
  127. : "=r"(B_loaded_fp16.y)
  128. : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
  129. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
  130. : "=r"(B_loaded_fp16.y)
  131. : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
  132. asm volatile("sub.f16x2 %0, %1, %2;\n"
  133. : "=r"(B_loaded_fp16.z)
  134. : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
  135. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
  136. : "=r"(B_loaded_fp16.z)
  137. : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
  138. asm volatile("sub.f16x2 %0, %1, %2;\n"
  139. : "=r"(B_loaded_fp16.w)
  140. : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
  141. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
  142. : "=r"(B_loaded_fp16.w)
  143. : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
  144. /*
  145. if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 ==
  146. 0 && threadIdx.x == 17 && threadIdx.y == 0){ printf("[x] %X %X %X %X\n",
  147. B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
  148. }
  149. */
  150. // write back
  151. *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) =
  152. B_loaded_fp16;
  153. }
  154. __syncthreads();
  155. for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) {
  156. {
  157. unsigned int addr;
  158. asm("{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
  159. "addr; }\n"
  160. : "=r"(addr)
  161. : "l"((void*)((&(A_shared[(k_0_1 * 16)])) +
  162. (((((int)threadIdx.x) & 15) * 40) +
  163. ((((int)threadIdx.x) >> 4) * 8)))));
  164. asm("ldmatrix.sync.aligned.m8n8.x4.shared.b16"
  165. "{%0, %1, %2, %3}, [%4];\n"
  166. : "=r"(((unsigned*)(A_shared_warp + 0))[0]),
  167. "=r"(((unsigned*)(A_shared_warp + 0))[1]),
  168. "=r"(((unsigned*)(A_shared_warp + 0))[2]),
  169. "=r"(((unsigned*)(A_shared_warp + 0))[3])
  170. : "r"(addr));
  171. }
  172. for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) {
  173. {
  174. unsigned int addr;
  175. asm("{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
  176. "addr; }\n"
  177. : "=r"(addr)
  178. : "l"((void*)((&(B_shared[(((k_0_1 * (N * 16 + 128)) +
  179. (((int)threadIdx.y) * (N / 2))) +
  180. (ax1_0 * 16))])) +
  181. (((((int)threadIdx.x) & 15) * (N + 8)) +
  182. ((((int)threadIdx.x) >> 4) * 8)))));
  183. asm("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
  184. "{%0, %1, %2, %3}, [%4];\n"
  185. : "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[0]),
  186. "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[1]),
  187. "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[2]),
  188. "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[3])
  189. : "r"(addr));
  190. }
  191. }
  192. for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) {
  193. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
  194. {
  195. asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
  196. "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
  197. : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
  198. "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
  199. "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
  200. "=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
  201. : "r"(((unsigned*)(A_shared_warp + 0))[0]),
  202. "r"(((unsigned*)(A_shared_warp + 0))[1]),
  203. "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]),
  204. "f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
  205. "f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
  206. "f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
  207. "f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
  208. }
  209. {
  210. asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
  211. "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
  212. : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
  213. "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
  214. "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
  215. "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
  216. : "r"(((unsigned*)(A_shared_warp + 0))[0]),
  217. "r"(((unsigned*)(A_shared_warp + 0))[1]),
  218. "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]),
  219. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
  220. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
  221. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
  222. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
  223. }
  224. {
  225. asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
  226. "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
  227. : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
  228. "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
  229. "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
  230. "=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
  231. : "r"(((unsigned*)(A_shared_warp + 0))[2]),
  232. "r"(((unsigned*)(A_shared_warp + 0))[3]),
  233. "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]),
  234. "f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
  235. "f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
  236. "f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
  237. "f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
  238. }
  239. {
  240. asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
  241. "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
  242. : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
  243. "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
  244. "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
  245. "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
  246. : "r"(((unsigned*)(A_shared_warp + 0))[2]),
  247. "r"(((unsigned*)(A_shared_warp + 0))[3]),
  248. "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]),
  249. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
  250. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
  251. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
  252. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
  253. }
  254. #else
  255. {
  256. asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
  257. "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
  258. "%13};\n"
  259. : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
  260. "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
  261. "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
  262. "=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
  263. : "r"(((unsigned*)(A_shared_warp + 0))[0]),
  264. "r"(((unsigned*)(A_shared_warp + 0))[1]),
  265. "r"(((unsigned*)(A_shared_warp + 0))[2]),
  266. "r"(((unsigned*)(A_shared_warp + 0))[3]),
  267. "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]),
  268. "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]),
  269. "f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
  270. "f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
  271. "f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
  272. "f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
  273. }
  274. {
  275. asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
  276. "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
  277. "%13};\n"
  278. : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
  279. "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
  280. "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
  281. "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
  282. : "r"(((unsigned*)(A_shared_warp + 0))[0]),
  283. "r"(((unsigned*)(A_shared_warp + 0))[1]),
  284. "r"(((unsigned*)(A_shared_warp + 0))[2]),
  285. "r"(((unsigned*)(A_shared_warp + 0))[3]),
  286. "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]),
  287. "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]),
  288. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
  289. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
  290. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
  291. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
  292. }
  293. #endif
  294. }
  295. }
  296. }
  297. // TODO: Shang: Hoist loop invariance.
  298. for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {
  299. for (int local_id = 0; local_id < 8; ++local_id) {
  300. int row_offset = (((int)blockIdx_y) / j_factors1) * 16 +
  301. ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
  302. if (row_offset < M) {
  303. *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 +
  304. local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
  305. }
  306. }
  307. }
  308. #endif
  309. }
  310. __global__ void __launch_bounds__(64)
  311. dequantize_weights(int* __restrict__ B, half* __restrict__ scaling_factors,
  312. int* __restrict__ zeros, half* __restrict__ C, int G,
  313. int in_c, int out_c) {
  314. if (blockIdx.z > 0) {
  315. B = B + blockIdx.z * in_c * out_c / 8;
  316. scaling_factors = scaling_factors + blockIdx.z * in_c * out_c / G;
  317. zeros = zeros + blockIdx.z * in_c * out_c / G / 8;
  318. C = C + blockIdx.z * in_c * out_c;
  319. }
  320. static constexpr uint32_t ZERO = 0x0;
  321. half B_shared[32 * (128 + 8)];
  322. half* B_shared_ptr2 = B_shared;
  323. int N = blockDim.x * gridDim.x; // 2
  324. int col = (blockIdx.x * blockDim.x + threadIdx.x);
  325. int row = blockIdx.y * blockDim.y + threadIdx.y;
  326. int index1 = 8 * col + 8 * row * N;
  327. half* C_ptr2 = C + index1;
  328. int index2 = col + row * N;
  329. int* B_ptr2 = B + index2;
  330. int index3 = col + (int)(row / G) * N;
  331. int* zeros_ptr2 = zeros + index3;
  332. int index4 = 8 * col + (int)(row / G) * N * 8;
  333. half* scaling_factors_ptr2 = scaling_factors + index4;
  334. uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr2);
  335. uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
  336. uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr2);
  337. uint32_t B_loaded = *(uint32_t*)B_ptr2;
  338. uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
  339. asm volatile("sub.f16x2 %0, %1, %2;\n"
  340. : "=r"(B_loaded_fp16.x)
  341. : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
  342. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
  343. : "=r"(B_loaded_fp16.x)
  344. : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
  345. asm volatile("sub.f16x2 %0, %1, %2;\n"
  346. : "=r"(B_loaded_fp16.y)
  347. : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
  348. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
  349. : "=r"(B_loaded_fp16.y)
  350. : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
  351. asm volatile("sub.f16x2 %0, %1, %2;\n"
  352. : "=r"(B_loaded_fp16.z)
  353. : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
  354. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
  355. : "=r"(B_loaded_fp16.z)
  356. : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
  357. asm volatile("sub.f16x2 %0, %1, %2;\n"
  358. : "=r"(B_loaded_fp16.w)
  359. : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
  360. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
  361. : "=r"(B_loaded_fp16.w)
  362. : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
  363. *(uint4*)B_shared_ptr2 = B_loaded_fp16;
  364. for (int i = 0; i < 8; ++i) {
  365. *(C_ptr2 + i) = B_shared[i];
  366. }
  367. }
  368. template <int N>
  369. __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
  370. int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B,
  371. half* __restrict__ scaling_factors, int* __restrict__ zeros,
  372. const float* __restrict__ topk_weights,
  373. const int* __restrict__ sorted_token_ids_ptr,
  374. const int* __restrict__ expert_ids_ptr,
  375. const int* __restrict__ num_tokens_post_padded, const int num_valid_tokens,
  376. const int top_k, const int expert_num, int pad_M, int M, int IC, int OC,
  377. half* __restrict__ C) {
  378. // Only support matrix n = 64 or 128
  379. assert(N == 64 || N == 128);
  380. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
  381. assert(false);
  382. #else
  383. int num_tokens = *num_tokens_post_padded;
  384. int j_factors1 = ((OC + N - 1) / N);
  385. [[maybe_unused]] int blockIdx_x = 0;
  386. int blockIdx_y = blockIdx.x % ((pad_M + 16 - 1) / 16 * j_factors1);
  387. int blockIdx_z = blockIdx.x / ((pad_M + 16 - 1) / 16 * j_factors1);
  388. int block = blockIdx_y / j_factors1;
  389. if (block * 16 >= num_tokens) return;
  390. static constexpr uint32_t ZERO = 0x0;
  391. float C_warp[32];
  392. __shared__ half A_shared[16 * (32 + 8)];
  393. __shared__ half B_shared[32 * (N + 8)];
  394. [[maybe_unused]] __shared__ half scaling_factors_shared[N];
  395. [[maybe_unused]] __shared__ half zeros_shared[N];
  396. half A_shared_warp[8];
  397. half B_shared_warp[N / 4];
  398. for (int j_0_4_init = 0; j_0_4_init < N / 32; ++j_0_4_init) {
  399. for (int i = 0; i < 8; ++i) {
  400. C_warp[(j_0_4_init * 8) + i] = 0.0;
  401. }
  402. }
  403. static constexpr int row_stride_warp = 32 * 8 / 32;
  404. static constexpr int row_stride = 2 * 32 * 8 / N;
  405. [[maybe_unused]] bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N;
  406. // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
  407. int row = (block * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32);
  408. int token_id = sorted_token_ids_ptr[row];
  409. bool ld_A_flag = (token_id < num_valid_tokens);
  410. half* A_ptr = A + token_id / top_k * IC + (((int)threadIdx.x) % (32 / 8)) * 8;
  411. int expert_id = expert_ids_ptr[block];
  412. B = B + OC * IC / 8 * expert_id;
  413. scaling_factors = scaling_factors + OC * IC / G * expert_id;
  414. zeros = zeros + OC * IC / G / 8 * expert_id;
  415. int* B_ptr = B + ((int)threadIdx.y) * (OC / 8) * (256 / N) +
  416. (((int)threadIdx.x) / (N / 8)) * (OC / 8) +
  417. (((int)blockIdx_y) % j_factors1) * (N / 8) +
  418. (((int)threadIdx.x) % (N / 8)) * 1;
  419. // Why * 1 in the above line?
  420. half* A_shared_ptr = A_shared +
  421. ((int)threadIdx.y) * row_stride_warp * (32 + 8) +
  422. (((int)threadIdx.x) / (32 / 8)) * (32 + 8) +
  423. (((int)threadIdx.x) % (32 / 8)) * 8;
  424. half* B_shared_ptr = B_shared +
  425. ((int)threadIdx.y) * (row_stride / 2) * (N + 8) +
  426. (((int)threadIdx.x) / (N / 8)) * (N + 8) +
  427. (((int)threadIdx.x) % (N / 8)) * 8;
  428. int* zeros_ptr = zeros + (((int)blockIdx_y) % j_factors1) * (N / 8) +
  429. ((int)threadIdx.x) % (N / 8);
  430. half* scaling_factors_ptr = scaling_factors +
  431. (((int)blockIdx_y) % j_factors1) * N +
  432. (((int)threadIdx.x) % (N / 8)) * 8;
  433. half* C_ptr = C +
  434. static_cast<long long>(blockIdx_z) * M * OC *
  435. expert_num // blockIdz.x -> split_k dim
  436. + (((int)blockIdx_y) % j_factors1) * N +
  437. ((int)threadIdx.y) * (N / 2) + (((int)threadIdx.x) % 4) * 2;
  438. // preload s.f. and zeros
  439. int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
  440. if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
  441. for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
  442. int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
  443. __syncthreads();
  444. // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
  445. if (ld_A_flag) {
  446. *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
  447. } else {
  448. *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
  449. }
  450. uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
  451. uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
  452. uint4 B_loaded_scale =
  453. *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
  454. int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
  455. for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) {
  456. uint32_t B_loaded =
  457. *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
  458. uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
  459. // TODO (Haotian): can save 4 assembly instructions if sormulate as deq =
  460. // q * scale - zero * scale.
  461. asm volatile("sub.f16x2 %0, %1, %2;\n"
  462. : "=r"(B_loaded_fp16.x)
  463. : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
  464. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
  465. : "=r"(B_loaded_fp16.x)
  466. : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
  467. asm volatile("sub.f16x2 %0, %1, %2;\n"
  468. : "=r"(B_loaded_fp16.y)
  469. : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
  470. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
  471. : "=r"(B_loaded_fp16.y)
  472. : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
  473. asm volatile("sub.f16x2 %0, %1, %2;\n"
  474. : "=r"(B_loaded_fp16.z)
  475. : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
  476. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
  477. : "=r"(B_loaded_fp16.z)
  478. : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
  479. asm volatile("sub.f16x2 %0, %1, %2;\n"
  480. : "=r"(B_loaded_fp16.w)
  481. : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
  482. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
  483. : "=r"(B_loaded_fp16.w)
  484. : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
  485. // write back
  486. *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) =
  487. B_loaded_fp16;
  488. }
  489. __syncthreads();
  490. for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) {
  491. {
  492. unsigned int addr;
  493. asm("{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
  494. "addr; }\n"
  495. : "=r"(addr)
  496. : "l"((void*)((&(A_shared[(k_0_1 * 16)])) +
  497. (((((int)threadIdx.x) & 15) * 40) +
  498. ((((int)threadIdx.x) >> 4) * 8)))));
  499. asm("ldmatrix.sync.aligned.m8n8.x4.shared.b16"
  500. "{%0, %1, %2, %3}, [%4];\n"
  501. : "=r"(((unsigned*)(A_shared_warp + 0))[0]),
  502. "=r"(((unsigned*)(A_shared_warp + 0))[1]),
  503. "=r"(((unsigned*)(A_shared_warp + 0))[2]),
  504. "=r"(((unsigned*)(A_shared_warp + 0))[3])
  505. : "r"(addr));
  506. }
  507. for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) {
  508. {
  509. unsigned int addr;
  510. asm("{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
  511. "addr; }\n"
  512. : "=r"(addr)
  513. : "l"((void*)((&(B_shared[(((k_0_1 * (N * 16 + 128)) +
  514. (((int)threadIdx.y) * (N / 2))) +
  515. (ax1_0 * 16))])) +
  516. (((((int)threadIdx.x) & 15) * (N + 8)) +
  517. ((((int)threadIdx.x) >> 4) * 8)))));
  518. asm("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
  519. "{%0, %1, %2, %3}, [%4];\n"
  520. : "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[0]),
  521. "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[1]),
  522. "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[2]),
  523. "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[3])
  524. : "r"(addr));
  525. }
  526. }
  527. for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) {
  528. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
  529. {
  530. asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
  531. "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
  532. : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
  533. "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
  534. "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
  535. "=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
  536. : "r"(((unsigned*)(A_shared_warp + 0))[0]),
  537. "r"(((unsigned*)(A_shared_warp + 0))[1]),
  538. "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]),
  539. "f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
  540. "f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
  541. "f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
  542. "f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
  543. }
  544. {
  545. asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
  546. "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
  547. : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
  548. "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
  549. "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
  550. "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
  551. : "r"(((unsigned*)(A_shared_warp + 0))[0]),
  552. "r"(((unsigned*)(A_shared_warp + 0))[1]),
  553. "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]),
  554. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
  555. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
  556. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
  557. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
  558. }
  559. {
  560. asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
  561. "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
  562. : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
  563. "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
  564. "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
  565. "=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
  566. : "r"(((unsigned*)(A_shared_warp + 0))[2]),
  567. "r"(((unsigned*)(A_shared_warp + 0))[3]),
  568. "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]),
  569. "f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
  570. "f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
  571. "f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
  572. "f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
  573. }
  574. {
  575. asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
  576. "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
  577. : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
  578. "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
  579. "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
  580. "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
  581. : "r"(((unsigned*)(A_shared_warp + 0))[2]),
  582. "r"(((unsigned*)(A_shared_warp + 0))[3]),
  583. "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]),
  584. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
  585. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
  586. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
  587. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
  588. }
  589. #else
  590. {
  591. asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
  592. "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
  593. "%13};\n"
  594. : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
  595. "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
  596. "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
  597. "=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
  598. : "r"(((unsigned*)(A_shared_warp + 0))[0]),
  599. "r"(((unsigned*)(A_shared_warp + 0))[1]),
  600. "r"(((unsigned*)(A_shared_warp + 0))[2]),
  601. "r"(((unsigned*)(A_shared_warp + 0))[3]),
  602. "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]),
  603. "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]),
  604. "f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
  605. "f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
  606. "f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
  607. "f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
  608. }
  609. {
  610. asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
  611. "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
  612. "%13};\n"
  613. : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
  614. "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
  615. "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
  616. "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
  617. : "r"(((unsigned*)(A_shared_warp + 0))[0]),
  618. "r"(((unsigned*)(A_shared_warp + 0))[1]),
  619. "r"(((unsigned*)(A_shared_warp + 0))[2]),
  620. "r"(((unsigned*)(A_shared_warp + 0))[3]),
  621. "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]),
  622. "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]),
  623. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
  624. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
  625. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
  626. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
  627. }
  628. #endif
  629. }
  630. }
  631. }
  632. // TODO: Shang: Hoist loop invariance.
  633. for (int ax1_0_1 = 0; ax1_0_1 < N / 32; ++ax1_0_1) {
  634. for (int local_id = 0; local_id < 8; ++local_id) {
  635. int row_offset =
  636. block * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
  637. int token_id = sorted_token_ids_ptr[row_offset];
  638. if (token_id < num_valid_tokens) {
  639. float value = C_warp[(ax1_0_1 * 8) + local_id];
  640. if (topk_weights) {
  641. value = value * topk_weights[token_id];
  642. }
  643. *(C_ptr + ax1_0_1 * 16 + token_id * OC + (local_id / 4) * 8 +
  644. local_id % 2) = __float2half(value);
  645. }
  646. }
  647. }
  648. #endif
  649. }
  650. } // namespace awq
  651. } // namespace aphrodite
  652. torch::Tensor awq_dequantize(torch::Tensor _kernel,
  653. torch::Tensor _scaling_factors,
  654. torch::Tensor _zeros, int64_t split_k_iters,
  655. int64_t thx, int64_t thy) {
  656. int in_c = _kernel.dim() == 2 ? _kernel.size(0) : _kernel.size(1);
  657. int qout_c = _kernel.dim() == 2 ? _kernel.size(1) : _kernel.size(2);
  658. int num_experts = _kernel.dim() == 2 ? 1 : _kernel.size(0);
  659. int out_c = qout_c * 8;
  660. int G = in_c / (_kernel.dim() == 2 ? _scaling_factors.size(0)
  661. : _scaling_factors.size(1));
  662. int x_thread = thx;
  663. int y_thread = thy;
  664. int x_blocks = 1;
  665. int y_blocks = 1;
  666. if (thx == 0) {
  667. x_thread = qout_c;
  668. }
  669. if (thy == 0) {
  670. y_thread = in_c;
  671. }
  672. if (thx == 0 && thy == 0) {
  673. x_thread = 8;
  674. y_thread = 8;
  675. x_blocks = (int)(qout_c / 8);
  676. y_blocks = (int)(in_c / 8);
  677. }
  678. const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors));
  679. auto options = torch::TensorOptions()
  680. .dtype(_scaling_factors.dtype())
  681. .device(_scaling_factors.device());
  682. at::Tensor _de_kernel;
  683. if (num_experts == 1) {
  684. _de_kernel = torch::empty({in_c, out_c}, options);
  685. } else {
  686. _de_kernel = torch::empty({num_experts, in_c, out_c}, options);
  687. }
  688. auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
  689. auto de_kernel = reinterpret_cast<half*>(_de_kernel.data_ptr<at::Half>());
  690. auto scaling_factors =
  691. reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
  692. auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
  693. dim3 num_blocks(x_blocks, y_blocks, num_experts);
  694. dim3 threads_per_block(x_thread, y_thread);
  695. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  696. aphrodite::awq::
  697. dequantize_weights<<<num_blocks, threads_per_block, 0, stream>>>(
  698. kernel, scaling_factors, zeros, de_kernel, G, in_c, out_c);
  699. return _de_kernel;
  700. }
  701. // in_feats: M, IC [float16]
  702. // kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
  703. // scaling_factors: IC // G, OC [float16]
  704. // zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
  705. // assume that batch_size < 16 for now
  706. torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
  707. torch::Tensor _scaling_factors, torch::Tensor _zeros,
  708. int64_t split_k_iters) {
  709. int num_in_feats = _in_feats.size(0);
  710. int num_in_channels = _in_feats.size(1);
  711. const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
  712. auto options = torch::TensorOptions()
  713. .dtype(_in_feats.dtype())
  714. .device(_in_feats.device());
  715. at::Tensor _out_feats =
  716. torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options);
  717. int num_out_feats = _out_feats.size(-2);
  718. int num_out_channels = _out_feats.size(-1);
  719. auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
  720. auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
  721. auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
  722. auto scaling_factors =
  723. reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
  724. auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
  725. int group_size = num_in_channels / _scaling_factors.size(0);
  726. if (num_out_channels % 64 != 0)
  727. throw std::invalid_argument("OC is not multiple of cta_N = 64");
  728. if (num_out_channels % 8 != 0)
  729. throw std::invalid_argument("OC is not multiple of pack_num = 8");
  730. if (group_size % 32 != 0)
  731. throw std::invalid_argument("Group size should be a multiple of 32");
  732. if (num_out_channels % group_size != 0)
  733. throw std::invalid_argument("OC is not multiple of Group size");
  734. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  735. if (num_out_channels % 128 == 0) {
  736. int j_factors1 = num_out_channels / 128 / 1;
  737. dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
  738. // threadIdx.x: 32
  739. // threadIdx.y: i_factors[2] * j_factors[2]
  740. dim3 threads_per_block(32, 2);
  741. aphrodite::awq::gemm_forward_4bit_cuda_m16nXk32<128>
  742. <<<num_blocks, threads_per_block, 0, stream>>>(
  743. group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros,
  744. num_in_feats, num_in_channels, num_out_channels, out_feats);
  745. } else if (num_out_channels % 64 == 0) {
  746. int j_factors1 = num_out_channels / 64 / 1;
  747. dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 *
  748. split_k_iters);
  749. // threadIdx.x: 32
  750. // threadIdx.y: i_factors[2] * j_factors[2]
  751. dim3 threads_per_block(32, 2);
  752. aphrodite::awq::gemm_forward_4bit_cuda_m16nXk32<64>
  753. <<<num_blocks, threads_per_block, 0, stream>>>(
  754. group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros,
  755. num_in_feats, num_in_channels, num_out_channels, out_feats);
  756. }
  757. return _out_feats.sum(0);
  758. }
  759. torch::Tensor awq_group_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
  760. torch::Tensor _scaling_factors,
  761. torch::Tensor _zeros, torch::Tensor _topk_weights,
  762. torch::Tensor _sorted_token_ids_ptr,
  763. torch::Tensor _expert_ids_ptr,
  764. torch::Tensor _num_tokens_post_padded,
  765. bool mul_weights, int split_k_iters) {
  766. int num_in_feats = _in_feats.size(0);
  767. int pad_num_in_feats = _sorted_token_ids_ptr.size(0);
  768. int num_in_channels = _in_feats.size(2);
  769. const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
  770. auto options = torch::TensorOptions()
  771. .dtype(_in_feats.dtype())
  772. .device(_in_feats.device());
  773. int num_experts = _topk_weights.size(1);
  774. int top_k = num_experts / _in_feats.size(1);
  775. int group_size = num_in_channels / _scaling_factors.size(1);
  776. at::Tensor _out_feats = torch::empty(
  777. {split_k_iters, num_in_feats, _topk_weights.size(1), _kernel.size(2) * 8},
  778. options);
  779. int num_out_channels = _out_feats.size(-1);
  780. auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
  781. auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
  782. auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
  783. auto scaling_factors =
  784. reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
  785. auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
  786. auto topk_weights = mul_weights
  787. ? reinterpret_cast<float*>(_topk_weights.data_ptr())
  788. : nullptr;
  789. auto sorted_token_ids_ptr =
  790. reinterpret_cast<int*>(_sorted_token_ids_ptr.data_ptr());
  791. auto expert_ids_ptr = reinterpret_cast<int*>(_expert_ids_ptr.data_ptr());
  792. auto num_tokens_post_padded =
  793. reinterpret_cast<int*>(_num_tokens_post_padded.data_ptr());
  794. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  795. if (num_out_channels % 128 == 0) {
  796. int j_factors1 = num_out_channels / 128 / 1;
  797. dim3 num_blocks((pad_num_in_feats + 16 - 1) / 16 * j_factors1 *
  798. split_k_iters);
  799. // threadIdx.x: 32
  800. // threadIdx.y: i_factors[2] * j_factors[2]
  801. dim3 threads_per_block(32, 2);
  802. aphrodite::awq::group_gemm_forward_4bit_cuda_m16nXk32<128>
  803. <<<num_blocks, threads_per_block, 0, stream>>>(
  804. group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros,
  805. topk_weights, sorted_token_ids_ptr, expert_ids_ptr,
  806. num_tokens_post_padded, _topk_weights.numel(), top_k, num_experts,
  807. pad_num_in_feats, num_in_feats, num_in_channels, num_out_channels,
  808. out_feats);
  809. } else if (num_out_channels % 64 == 0) {
  810. int j_factors1 = num_out_channels / 64 / 1;
  811. dim3 num_blocks((pad_num_in_feats + 16 - 1) / 16 * j_factors1 *
  812. split_k_iters);
  813. // threadIdx.x: 32
  814. // threadIdx.y: i_factors[2] * j_factors[2]
  815. dim3 threads_per_block(32, 2);
  816. aphrodite::awq::group_gemm_forward_4bit_cuda_m16nXk32<64>
  817. <<<num_blocks, threads_per_block, 0, stream>>>(
  818. group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros,
  819. topk_weights, sorted_token_ids_ptr, expert_ids_ptr,
  820. num_tokens_post_padded, _topk_weights.numel(), top_k, num_experts,
  821. pad_num_in_feats, num_in_feats, num_in_channels, num_out_channels,
  822. out_feats);
  823. }
  824. return _out_feats.sum(0);
  825. }