gemm_kernels.cu 40 KB


  1. /*
  2. Adapted from https://github.com/mit-han-lab/llm-awq
  3. @article{lin2023awq,
  4. title={AWQ: Activation-aware Weight Quantization for LLM Compression and
  5. Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang,
  6. Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
  7. }
  8. */
  9. #include <torch/all.h>
  10. #include <c10/cuda/CUDAGuard.h>
  11. #include "dequantize.cuh"
  12. #include <cuda_fp16.h>
  13. namespace aphrodite {
  14. namespace awq {
  15. template <int N>
  16. __global__ void __launch_bounds__(64)
  17. gemm_forward_4bit_cuda_m16nXk32(int G, int split_k_iters,
  18. half* __restrict__ A, int* __restrict__ B,
  19. half* __restrict__ scaling_factors,
  20. int* __restrict__ zeros, int M, int IC,
  21. int OC, half* __restrict__ C) {
  22. // Only support matrix n = 64 or 128
  23. assert(N == 64 || N == 128);
  24. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
  25. assert(false);
  26. #else
  27. static constexpr uint32_t ZERO = 0x0;
  28. float C_warp[32];
  29. __shared__ half A_shared[16 * (32 + 8)];
  30. __shared__ half B_shared[32 * (N + 8)];
  31. int j_factors1 = ((OC + N - 1) / N);
  32. int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
  33. int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
  34. half A_shared_warp[8];
  35. half B_shared_warp[N / 4];
  36. for (int j_0_4_init = 0; j_0_4_init < N / 32; ++j_0_4_init) {
  37. for (int i = 0; i < 8; ++i) {
  38. C_warp[(j_0_4_init * 8) + i] = 0.0;
  39. }
  40. }
  41. static constexpr int row_stride_warp = 32 * 8 / 32;
  42. static constexpr int row_stride = 2 * 32 * 8 / N;
  43. // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
  44. bool ld_A_flag =
  45. (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp +
  46. threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
  47. // bool wb_C_flag = (threadIdx.x / 4) < M;
  48. half* A_ptr =
  49. A +
  50. (((int)blockIdx_y) / j_factors1 * 16 +
  51. (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) *
  52. IC +
  53. (((int)threadIdx.x) % (32 / 8)) * 8;
  54. int* B_ptr = B + ((int)threadIdx.y) * (OC / 8) * (256 / N) +
  55. (((int)threadIdx.x) / (N / 8)) * (OC / 8) +
  56. (((int)blockIdx_y) % j_factors1) * (N / 8) +
  57. (((int)threadIdx.x) % (N / 8)) * 1;
  58. // Why * 1 in the above line?
  59. half* A_shared_ptr = A_shared +
  60. ((int)threadIdx.y) * row_stride_warp * (32 + 8) +
  61. (((int)threadIdx.x) / (32 / 8)) * (32 + 8) +
  62. (((int)threadIdx.x) % (32 / 8)) * 8;
  63. half* B_shared_ptr = B_shared +
  64. ((int)threadIdx.y) * (row_stride / 2) * (N + 8) +
  65. (((int)threadIdx.x) / (N / 8)) * (N + 8) +
  66. (((int)threadIdx.x) % (N / 8)) * 8;
  67. int* zeros_ptr = zeros + (((int)blockIdx_y) % j_factors1) * (N / 8) +
  68. ((int)threadIdx.x) % (N / 8);
  69. half* scaling_factors_ptr = scaling_factors +
  70. (((int)blockIdx_y) % j_factors1) * N +
  71. (((int)threadIdx.x) % (N / 8)) * 8;
  72. half* C_ptr =
  73. C +
  74. static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
  75. + (((int)blockIdx_y) % j_factors1) * N + ((int)threadIdx.y) * (N / 2) +
  76. (((int)threadIdx.x) % 4) * 2;
  77. // preload s.f. and zeros
  78. int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
  79. if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
  80. for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
  81. int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
  82. __syncthreads();
  83. // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
  84. if (ld_A_flag) {
  85. *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
  86. } else {
  87. *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
  88. }
  89. // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
  90. uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
  91. uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
  92. uint4 B_loaded_scale =
  93. *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
  94. /*
  95. if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 &&
  96. threadIdx.y == 0){ printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x,
  97. B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x,
  98. B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
  99. }
  100. */
  101. // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
  102. int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
  103. for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) {
  104. // B: 32 x 136 (128+8) float16
  105. // each warp: 32 x 4
  106. // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus
  107. // zero -> WB UINT4
  108. // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) *
  109. // 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15)
  110. // * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 *
  111. // 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) *
  112. // 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) *
  113. // 8))); row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
  114. uint32_t B_loaded =
  115. *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
  116. uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
  117. // - zero and * scale
  118. // TODO (Haotian): can save 4 assembly instructions if sormulate as deq =
  119. // q * scale - zero * scale.
  120. asm volatile("sub.f16x2 %0, %1, %2;\n"
  121. : "=r"(B_loaded_fp16.x)
  122. : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
  123. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
  124. : "=r"(B_loaded_fp16.x)
  125. : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
  126. asm volatile("sub.f16x2 %0, %1, %2;\n"
  127. : "=r"(B_loaded_fp16.y)
  128. : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
  129. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
  130. : "=r"(B_loaded_fp16.y)
  131. : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
  132. asm volatile("sub.f16x2 %0, %1, %2;\n"
  133. : "=r"(B_loaded_fp16.z)
  134. : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
  135. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
  136. : "=r"(B_loaded_fp16.z)
  137. : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
  138. asm volatile("sub.f16x2 %0, %1, %2;\n"
  139. : "=r"(B_loaded_fp16.w)
  140. : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
  141. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
  142. : "=r"(B_loaded_fp16.w)
  143. : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
  144. /*
  145. if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 ==
  146. 0 && threadIdx.x == 17 && threadIdx.y == 0){ printf("[x] %X %X %X %X\n",
  147. B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
  148. }
  149. */
  150. // write back
  151. *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) =
  152. B_loaded_fp16;
  153. }
  154. __syncthreads();
  155. for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) {
  156. {
  157. unsigned int addr;
  158. __asm__ __volatile__(
  159. "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
  160. "addr; }\n"
  161. : "=r"(addr)
  162. : "l"((void*)((&(A_shared[(k_0_1 * 16)])) +
  163. (((((int)threadIdx.x) & 15) * 40) +
  164. ((((int)threadIdx.x) >> 4) * 8)))));
  165. __asm__ __volatile__(
  166. "ldmatrix.sync.aligned.m8n8.x4.shared.b16"
  167. "{%0, %1, %2, %3}, [%4];\n"
  168. : "=r"(((unsigned*)(A_shared_warp + 0))[0]),
  169. "=r"(((unsigned*)(A_shared_warp + 0))[1]),
  170. "=r"(((unsigned*)(A_shared_warp + 0))[2]),
  171. "=r"(((unsigned*)(A_shared_warp + 0))[3])
  172. : "r"(addr));
  173. }
  174. for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) {
  175. {
  176. unsigned int addr;
  177. __asm__ __volatile__(
  178. "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
  179. "addr; }\n"
  180. : "=r"(addr)
  181. : "l"((void*)((&(B_shared[(((k_0_1 * (N * 16 + 128)) +
  182. (((int)threadIdx.y) * (N / 2))) +
  183. (ax1_0 * 16))])) +
  184. (((((int)threadIdx.x) & 15) * (N + 8)) +
  185. ((((int)threadIdx.x) >> 4) * 8)))));
  186. __asm__ __volatile__(
  187. "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
  188. "{%0, %1, %2, %3}, [%4];\n"
  189. : "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[0]),
  190. "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[1]),
  191. "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[2]),
  192. "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[3])
  193. : "r"(addr));
  194. }
  195. }
  196. for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) {
  197. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
  198. {
  199. __asm__ __volatile__(
  200. "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
  201. "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
  202. : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
  203. "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
  204. "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
  205. "=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
  206. : "r"(((unsigned*)(A_shared_warp + 0))[0]),
  207. "r"(((unsigned*)(A_shared_warp + 0))[1]),
  208. "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]),
  209. "f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
  210. "f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
  211. "f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
  212. "f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
  213. }
  214. {
  215. __asm__ __volatile__(
  216. "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
  217. "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
  218. : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
  219. "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
  220. "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
  221. "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
  222. : "r"(((unsigned*)(A_shared_warp + 0))[0]),
  223. "r"(((unsigned*)(A_shared_warp + 0))[1]),
  224. "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]),
  225. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
  226. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
  227. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
  228. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
  229. }
  230. {
  231. __asm__ __volatile__(
  232. "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
  233. "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
  234. : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
  235. "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
  236. "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
  237. "=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
  238. : "r"(((unsigned*)(A_shared_warp + 0))[2]),
  239. "r"(((unsigned*)(A_shared_warp + 0))[3]),
  240. "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]),
  241. "f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
  242. "f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
  243. "f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
  244. "f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
  245. }
  246. {
  247. __asm__ __volatile__(
  248. "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
  249. "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
  250. : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
  251. "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
  252. "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
  253. "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
  254. : "r"(((unsigned*)(A_shared_warp + 0))[2]),
  255. "r"(((unsigned*)(A_shared_warp + 0))[3]),
  256. "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]),
  257. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
  258. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
  259. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
  260. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
  261. }
  262. #else
  263. {
  264. __asm__ __volatile__(
  265. "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
  266. "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
  267. "%13};\n"
  268. : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
  269. "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
  270. "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
  271. "=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
  272. : "r"(((unsigned*)(A_shared_warp + 0))[0]),
  273. "r"(((unsigned*)(A_shared_warp + 0))[1]),
  274. "r"(((unsigned*)(A_shared_warp + 0))[2]),
  275. "r"(((unsigned*)(A_shared_warp + 0))[3]),
  276. "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]),
  277. "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]),
  278. "f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
  279. "f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
  280. "f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
  281. "f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
  282. }
  283. {
  284. __asm__ __volatile__(
  285. "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
  286. "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
  287. "%13};\n"
  288. : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
  289. "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
  290. "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
  291. "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
  292. : "r"(((unsigned*)(A_shared_warp + 0))[0]),
  293. "r"(((unsigned*)(A_shared_warp + 0))[1]),
  294. "r"(((unsigned*)(A_shared_warp + 0))[2]),
  295. "r"(((unsigned*)(A_shared_warp + 0))[3]),
  296. "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]),
  297. "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]),
  298. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
  299. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
  300. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
  301. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
  302. }
  303. #endif
  304. }
  305. }
  306. }
  307. // TODO: Shang: Hoist loop invariance.
  308. for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {
  309. for (int local_id = 0; local_id < 8; ++local_id) {
  310. int row_offset = (((int)blockIdx_y) / j_factors1) * 16 +
  311. ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
  312. if (row_offset < M) {
  313. *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 +
  314. local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
  315. }
  316. }
  317. }
  318. #endif
  319. }
  320. __global__ void __launch_bounds__(64)
  321. dequantize_weights(int* __restrict__ B, half* __restrict__ scaling_factors,
  322. int* __restrict__ zeros, half* __restrict__ C, int G,
  323. int in_c, int out_c) {
  324. if (blockIdx.z > 0) {
  325. B = B + blockIdx.z * in_c * out_c / 8;
  326. scaling_factors = scaling_factors + blockIdx.z * in_c * out_c / G;
  327. zeros = zeros + blockIdx.z * in_c * out_c / G / 8;
  328. C = C + blockIdx.z * in_c * out_c;
  329. }
  330. static constexpr uint32_t ZERO = 0x0;
  331. half B_shared[32 * (128 + 8)];
  332. half* B_shared_ptr2 = B_shared;
  333. int N = blockDim.x * gridDim.x; // 2
  334. int col = (blockIdx.x * blockDim.x + threadIdx.x);
  335. int row = blockIdx.y * blockDim.y + threadIdx.y;
  336. int index1 = 8 * col + 8 * row * N;
  337. half* C_ptr2 = C + index1;
  338. int index2 = col + row * N;
  339. int* B_ptr2 = B + index2;
  340. int index3 = col + (int)(row / G) * N;
  341. int* zeros_ptr2 = zeros + index3;
  342. int index4 = 8 * col + (int)(row / G) * N * 8;
  343. half* scaling_factors_ptr2 = scaling_factors + index4;
  344. uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr2);
  345. uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
  346. uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr2);
  347. uint32_t B_loaded = *(uint32_t*)B_ptr2;
  348. uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
  349. asm volatile("sub.f16x2 %0, %1, %2;\n"
  350. : "=r"(B_loaded_fp16.x)
  351. : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
  352. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
  353. : "=r"(B_loaded_fp16.x)
  354. : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
  355. asm volatile("sub.f16x2 %0, %1, %2;\n"
  356. : "=r"(B_loaded_fp16.y)
  357. : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
  358. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
  359. : "=r"(B_loaded_fp16.y)
  360. : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
  361. asm volatile("sub.f16x2 %0, %1, %2;\n"
  362. : "=r"(B_loaded_fp16.z)
  363. : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
  364. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
  365. : "=r"(B_loaded_fp16.z)
  366. : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
  367. asm volatile("sub.f16x2 %0, %1, %2;\n"
  368. : "=r"(B_loaded_fp16.w)
  369. : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
  370. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
  371. : "=r"(B_loaded_fp16.w)
  372. : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
  373. *(uint4*)B_shared_ptr2 = B_loaded_fp16;
  374. for (int i = 0; i < 8; ++i) {
  375. *(C_ptr2 + i) = B_shared[i];
  376. }
  377. }
  378. template <int N>
  379. __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
  380. int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B,
  381. half* __restrict__ scaling_factors, int* __restrict__ zeros,
  382. const float* __restrict__ topk_weights,
  383. const int* __restrict__ sorted_token_ids_ptr,
  384. const int* __restrict__ expert_ids_ptr,
  385. const int* __restrict__ num_tokens_post_padded, const int num_valid_tokens,
  386. const int top_k, const int expert_num, int pad_M, int M, int IC, int OC,
  387. half* __restrict__ C) {
  388. // Only support matrix n = 64 or 128
  389. assert(N == 64 || N == 128);
  390. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
  391. assert(false);
  392. #else
  393. int num_tokens = *num_tokens_post_padded;
  394. int j_factors1 = ((OC + N - 1) / N);
  395. [[maybe_unused]] int blockIdx_x = 0;
  396. int blockIdx_y = blockIdx.x % ((pad_M + 16 - 1) / 16 * j_factors1);
  397. int blockIdx_z = blockIdx.x / ((pad_M + 16 - 1) / 16 * j_factors1);
  398. int block = blockIdx_y / j_factors1;
  399. if (block * 16 >= num_tokens) return;
  400. static constexpr uint32_t ZERO = 0x0;
  401. float C_warp[32];
  402. __shared__ half A_shared[16 * (32 + 8)];
  403. __shared__ half B_shared[32 * (N + 8)];
  404. [[maybe_unused]] __shared__ half scaling_factors_shared[N];
  405. [[maybe_unused]] __shared__ half zeros_shared[N];
  406. half A_shared_warp[8];
  407. half B_shared_warp[N / 4];
  408. for (int j_0_4_init = 0; j_0_4_init < N / 32; ++j_0_4_init) {
  409. for (int i = 0; i < 8; ++i) {
  410. C_warp[(j_0_4_init * 8) + i] = 0.0;
  411. }
  412. }
  413. static constexpr int row_stride_warp = 32 * 8 / 32;
  414. static constexpr int row_stride = 2 * 32 * 8 / N;
  415. [[maybe_unused]] bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N;
  416. // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
  417. int row = (block * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32);
  418. int token_id = sorted_token_ids_ptr[row];
  419. bool ld_A_flag = (token_id < num_valid_tokens);
  420. half* A_ptr = A + token_id / top_k * IC + (((int)threadIdx.x) % (32 / 8)) * 8;
  421. int expert_id = expert_ids_ptr[block];
  422. B = B + OC * IC / 8 * expert_id;
  423. scaling_factors = scaling_factors + OC * IC / G * expert_id;
  424. zeros = zeros + OC * IC / G / 8 * expert_id;
  425. int* B_ptr = B + ((int)threadIdx.y) * (OC / 8) * (256 / N) +
  426. (((int)threadIdx.x) / (N / 8)) * (OC / 8) +
  427. (((int)blockIdx_y) % j_factors1) * (N / 8) +
  428. (((int)threadIdx.x) % (N / 8)) * 1;
  429. // Why * 1 in the above line?
  430. half* A_shared_ptr = A_shared +
  431. ((int)threadIdx.y) * row_stride_warp * (32 + 8) +
  432. (((int)threadIdx.x) / (32 / 8)) * (32 + 8) +
  433. (((int)threadIdx.x) % (32 / 8)) * 8;
  434. half* B_shared_ptr = B_shared +
  435. ((int)threadIdx.y) * (row_stride / 2) * (N + 8) +
  436. (((int)threadIdx.x) / (N / 8)) * (N + 8) +
  437. (((int)threadIdx.x) % (N / 8)) * 8;
  438. int* zeros_ptr = zeros + (((int)blockIdx_y) % j_factors1) * (N / 8) +
  439. ((int)threadIdx.x) % (N / 8);
  440. half* scaling_factors_ptr = scaling_factors +
  441. (((int)blockIdx_y) % j_factors1) * N +
  442. (((int)threadIdx.x) % (N / 8)) * 8;
  443. half* C_ptr = C +
  444. static_cast<long long>(blockIdx_z) * M * OC *
  445. expert_num // blockIdz.x -> split_k dim
  446. + (((int)blockIdx_y) % j_factors1) * N +
  447. ((int)threadIdx.y) * (N / 2) + (((int)threadIdx.x) % 4) * 2;
  448. // preload s.f. and zeros
  449. int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
  450. if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
  451. for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
  452. int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
  453. __syncthreads();
  454. // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
  455. if (ld_A_flag) {
  456. *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
  457. } else {
  458. *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
  459. }
  460. uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
  461. uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
  462. uint4 B_loaded_scale =
  463. *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
  464. int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
  465. for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) {
  466. uint32_t B_loaded =
  467. *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
  468. uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
  469. // TODO (Haotian): can save 4 assembly instructions if sormulate as deq =
  470. // q * scale - zero * scale.
  471. asm volatile("sub.f16x2 %0, %1, %2;\n"
  472. : "=r"(B_loaded_fp16.x)
  473. : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
  474. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
  475. : "=r"(B_loaded_fp16.x)
  476. : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
  477. asm volatile("sub.f16x2 %0, %1, %2;\n"
  478. : "=r"(B_loaded_fp16.y)
  479. : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
  480. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
  481. : "=r"(B_loaded_fp16.y)
  482. : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
  483. asm volatile("sub.f16x2 %0, %1, %2;\n"
  484. : "=r"(B_loaded_fp16.z)
  485. : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
  486. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
  487. : "=r"(B_loaded_fp16.z)
  488. : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
  489. asm volatile("sub.f16x2 %0, %1, %2;\n"
  490. : "=r"(B_loaded_fp16.w)
  491. : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
  492. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
  493. : "=r"(B_loaded_fp16.w)
  494. : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
  495. // write back
  496. *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) =
  497. B_loaded_fp16;
  498. }
  499. __syncthreads();
  500. for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) {
  501. {
  502. unsigned int addr;
  503. __asm__ __volatile__(
  504. "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
  505. "addr; }\n"
  506. : "=r"(addr)
  507. : "l"((void*)((&(A_shared[(k_0_1 * 16)])) +
  508. (((((int)threadIdx.x) & 15) * 40) +
  509. ((((int)threadIdx.x) >> 4) * 8)))));
  510. __asm__ __volatile__(
  511. "ldmatrix.sync.aligned.m8n8.x4.shared.b16"
  512. "{%0, %1, %2, %3}, [%4];\n"
  513. : "=r"(((unsigned*)(A_shared_warp + 0))[0]),
  514. "=r"(((unsigned*)(A_shared_warp + 0))[1]),
  515. "=r"(((unsigned*)(A_shared_warp + 0))[2]),
  516. "=r"(((unsigned*)(A_shared_warp + 0))[3])
  517. : "r"(addr));
  518. }
  519. for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) {
  520. {
  521. unsigned int addr;
  522. __asm__ __volatile__(
  523. "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
  524. "addr; }\n"
  525. : "=r"(addr)
  526. : "l"((void*)((&(B_shared[(((k_0_1 * (N * 16 + 128)) +
  527. (((int)threadIdx.y) * (N / 2))) +
  528. (ax1_0 * 16))])) +
  529. (((((int)threadIdx.x) & 15) * (N + 8)) +
  530. ((((int)threadIdx.x) >> 4) * 8)))));
  531. __asm__ __volatile__(
  532. "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
  533. "{%0, %1, %2, %3}, [%4];\n"
  534. : "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[0]),
  535. "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[1]),
  536. "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[2]),
  537. "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[3])
  538. : "r"(addr));
  539. }
  540. }
  541. for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) {
  542. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
  543. {
  544. __asm__ __volatile__(
  545. "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
  546. "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
  547. : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
  548. "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
  549. "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
  550. "=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
  551. : "r"(((unsigned*)(A_shared_warp + 0))[0]),
  552. "r"(((unsigned*)(A_shared_warp + 0))[1]),
  553. "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]),
  554. "f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
  555. "f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
  556. "f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
  557. "f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
  558. }
  559. {
  560. __asm__ __volatile__(
  561. "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
  562. "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
  563. : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
  564. "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
  565. "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
  566. "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
  567. : "r"(((unsigned*)(A_shared_warp + 0))[0]),
  568. "r"(((unsigned*)(A_shared_warp + 0))[1]),
  569. "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]),
  570. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
  571. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
  572. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
  573. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
  574. }
  575. {
  576. __asm__ __volatile__(
  577. "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
  578. "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
  579. : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
  580. "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
  581. "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
  582. "=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
  583. : "r"(((unsigned*)(A_shared_warp + 0))[2]),
  584. "r"(((unsigned*)(A_shared_warp + 0))[3]),
  585. "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]),
  586. "f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
  587. "f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
  588. "f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
  589. "f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
  590. }
  591. {
  592. __asm__ __volatile__(
  593. "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
  594. "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
  595. : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
  596. "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
  597. "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
  598. "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
  599. : "r"(((unsigned*)(A_shared_warp + 0))[2]),
  600. "r"(((unsigned*)(A_shared_warp + 0))[3]),
  601. "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]),
  602. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
  603. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
  604. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
  605. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
  606. }
  607. #else
  608. {
  609. __asm__ __volatile__(
  610. "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
  611. "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
  612. "%13};\n"
  613. : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
  614. "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
  615. "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
  616. "=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
  617. : "r"(((unsigned*)(A_shared_warp + 0))[0]),
  618. "r"(((unsigned*)(A_shared_warp + 0))[1]),
  619. "r"(((unsigned*)(A_shared_warp + 0))[2]),
  620. "r"(((unsigned*)(A_shared_warp + 0))[3]),
  621. "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]),
  622. "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]),
  623. "f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
  624. "f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
  625. "f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
  626. "f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
  627. }
  628. {
  629. __asm__ __volatile__(
  630. "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
  631. "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
  632. "%13};\n"
  633. : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
  634. "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
  635. "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
  636. "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
  637. : "r"(((unsigned*)(A_shared_warp + 0))[0]),
  638. "r"(((unsigned*)(A_shared_warp + 0))[1]),
  639. "r"(((unsigned*)(A_shared_warp + 0))[2]),
  640. "r"(((unsigned*)(A_shared_warp + 0))[3]),
  641. "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]),
  642. "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]),
  643. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
  644. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
  645. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
  646. "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
  647. }
  648. #endif
  649. }
  650. }
  651. }
  652. // TODO: Shang: Hoist loop invariance.
  653. for (int ax1_0_1 = 0; ax1_0_1 < N / 32; ++ax1_0_1) {
  654. for (int local_id = 0; local_id < 8; ++local_id) {
  655. int row_offset =
  656. block * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
  657. int token_id = sorted_token_ids_ptr[row_offset];
  658. if (token_id < num_valid_tokens) {
  659. float value = C_warp[(ax1_0_1 * 8) + local_id];
  660. if (topk_weights) {
  661. value = value * topk_weights[token_id];
  662. }
  663. *(C_ptr + ax1_0_1 * 16 + token_id * OC + (local_id / 4) * 8 +
  664. local_id % 2) = __float2half(value);
  665. }
  666. }
  667. }
  668. #endif
  669. }
  670. } // namespace awq
  671. } // namespace aphrodite
  672. torch::Tensor awq_dequantize(torch::Tensor _kernel,
  673. torch::Tensor _scaling_factors,
  674. torch::Tensor _zeros, int64_t split_k_iters,
  675. int64_t thx, int64_t thy) {
  676. int in_c = _kernel.dim() == 2 ? _kernel.size(0) : _kernel.size(1);
  677. int qout_c = _kernel.dim() == 2 ? _kernel.size(1) : _kernel.size(2);
  678. int num_experts = _kernel.dim() == 2 ? 1 : _kernel.size(0);
  679. int out_c = qout_c * 8;
  680. int G = in_c / (_kernel.dim() == 2 ? _scaling_factors.size(0)
  681. : _scaling_factors.size(1));
  682. int x_thread = thx;
  683. int y_thread = thy;
  684. int x_blocks = 1;
  685. int y_blocks = 1;
  686. if (thx == 0) {
  687. x_thread = qout_c;
  688. }
  689. if (thy == 0) {
  690. y_thread = in_c;
  691. }
  692. if (thx == 0 && thy == 0) {
  693. x_thread = 8;
  694. y_thread = 8;
  695. x_blocks = (int)(qout_c / 8);
  696. y_blocks = (int)(in_c / 8);
  697. }
  698. const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors));
  699. auto options = torch::TensorOptions()
  700. .dtype(_scaling_factors.dtype())
  701. .device(_scaling_factors.device());
  702. at::Tensor _de_kernel;
  703. if (num_experts == 1) {
  704. _de_kernel = torch::empty({in_c, out_c}, options);
  705. } else {
  706. _de_kernel = torch::empty({num_experts, in_c, out_c}, options);
  707. }
  708. auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
  709. auto de_kernel = reinterpret_cast<half*>(_de_kernel.data_ptr<at::Half>());
  710. auto scaling_factors =
  711. reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
  712. auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
  713. dim3 num_blocks(x_blocks, y_blocks, num_experts);
  714. dim3 threads_per_block(x_thread, y_thread);
  715. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  716. aphrodite::awq::
  717. dequantize_weights<<<num_blocks, threads_per_block, 0, stream>>>(
  718. kernel, scaling_factors, zeros, de_kernel, G, in_c, out_c);
  719. return _de_kernel;
  720. }
  721. // in_feats: M, IC [float16]
  722. // kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
  723. // scaling_factors: IC // G, OC [float16]
  724. // zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
  725. // assume that batch_size < 16 for now
  726. torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
  727. torch::Tensor _scaling_factors, torch::Tensor _zeros,
  728. int64_t split_k_iters) {
  729. int num_in_feats = _in_feats.size(0);
  730. int num_in_channels = _in_feats.size(1);
  731. const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
  732. auto options = torch::TensorOptions()
  733. .dtype(_in_feats.dtype())
  734. .device(_in_feats.device());
  735. at::Tensor _out_feats =
  736. torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options);
  737. int num_out_feats = _out_feats.size(-2);
  738. int num_out_channels = _out_feats.size(-1);
  739. auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
  740. auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
  741. auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
  742. auto scaling_factors =
  743. reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
  744. auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
  745. int group_size = num_in_channels / _scaling_factors.size(0);
  746. if (num_out_channels % 64 != 0)
  747. throw std::invalid_argument("OC is not multiple of cta_N = 64");
  748. if (num_out_channels % 8 != 0)
  749. throw std::invalid_argument("OC is not multiple of pack_num = 8");
  750. if (group_size % 32 != 0)
  751. throw std::invalid_argument("Group size should be a multiple of 32");
  752. if (num_out_channels % group_size != 0)
  753. throw std::invalid_argument("OC is not multiple of Group size");
  754. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  755. if (num_out_channels % 128 == 0) {
  756. int j_factors1 = num_out_channels / 128 / 1;
  757. dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
  758. // threadIdx.x: 32
  759. // threadIdx.y: i_factors[2] * j_factors[2]
  760. dim3 threads_per_block(32, 2);
  761. aphrodite::awq::gemm_forward_4bit_cuda_m16nXk32<128>
  762. <<<num_blocks, threads_per_block, 0, stream>>>(
  763. group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros,
  764. num_in_feats, num_in_channels, num_out_channels, out_feats);
  765. } else if (num_out_channels % 64 == 0) {
  766. int j_factors1 = num_out_channels / 64 / 1;
  767. dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 *
  768. split_k_iters);
  769. // threadIdx.x: 32
  770. // threadIdx.y: i_factors[2] * j_factors[2]
  771. dim3 threads_per_block(32, 2);
  772. aphrodite::awq::gemm_forward_4bit_cuda_m16nXk32<64>
  773. <<<num_blocks, threads_per_block, 0, stream>>>(
  774. group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros,
  775. num_in_feats, num_in_channels, num_out_channels, out_feats);
  776. }
  777. return _out_feats.sum(0);
  778. }
  779. torch::Tensor awq_group_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
  780. torch::Tensor _scaling_factors,
  781. torch::Tensor _zeros, torch::Tensor _topk_weights,
  782. torch::Tensor _sorted_token_ids_ptr,
  783. torch::Tensor _expert_ids_ptr,
  784. torch::Tensor _num_tokens_post_padded,
  785. bool mul_weights, int split_k_iters) {
  786. int num_in_feats = _in_feats.size(0);
  787. int pad_num_in_feats = _sorted_token_ids_ptr.size(0);
  788. int num_in_channels = _in_feats.size(2);
  789. const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
  790. auto options = torch::TensorOptions()
  791. .dtype(_in_feats.dtype())
  792. .device(_in_feats.device());
  793. int num_experts = _topk_weights.size(1);
  794. int top_k = num_experts / _in_feats.size(1);
  795. int group_size = num_in_channels / _scaling_factors.size(1);
  796. at::Tensor _out_feats = torch::empty(
  797. {split_k_iters, num_in_feats, _topk_weights.size(1), _kernel.size(2) * 8},
  798. options);
  799. int num_out_channels = _out_feats.size(-1);
  800. auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
  801. auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
  802. auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
  803. auto scaling_factors =
  804. reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
  805. auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
  806. auto topk_weights = mul_weights
  807. ? reinterpret_cast<float*>(_topk_weights.data_ptr())
  808. : nullptr;
  809. auto sorted_token_ids_ptr =
  810. reinterpret_cast<int*>(_sorted_token_ids_ptr.data_ptr());
  811. auto expert_ids_ptr = reinterpret_cast<int*>(_expert_ids_ptr.data_ptr());
  812. auto num_tokens_post_padded =
  813. reinterpret_cast<int*>(_num_tokens_post_padded.data_ptr());
  814. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  815. if (num_out_channels % 128 == 0) {
  816. int j_factors1 = num_out_channels / 128 / 1;
  817. dim3 num_blocks((pad_num_in_feats + 16 - 1) / 16 * j_factors1 *
  818. split_k_iters);
  819. // threadIdx.x: 32
  820. // threadIdx.y: i_factors[2] * j_factors[2]
  821. dim3 threads_per_block(32, 2);
  822. aphrodite::awq::group_gemm_forward_4bit_cuda_m16nXk32<128>
  823. <<<num_blocks, threads_per_block, 0, stream>>>(
  824. group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros,
  825. topk_weights, sorted_token_ids_ptr, expert_ids_ptr,
  826. num_tokens_post_padded, _topk_weights.numel(), top_k, num_experts,
  827. pad_num_in_feats, num_in_feats, num_in_channels, num_out_channels,
  828. out_feats);
  829. } else if (num_out_channels % 64 == 0) {
  830. int j_factors1 = num_out_channels / 64 / 1;
  831. dim3 num_blocks((pad_num_in_feats + 16 - 1) / 16 * j_factors1 *
  832. split_k_iters);
  833. // threadIdx.x: 32
  834. // threadIdx.y: i_factors[2] * j_factors[2]
  835. dim3 threads_per_block(32, 2);
  836. aphrodite::awq::group_gemm_forward_4bit_cuda_m16nXk32<64>
  837. <<<num_blocks, threads_per_block, 0, stream>>>(
  838. group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros,
  839. topk_weights, sorted_token_ids_ptr, expert_ids_ptr,
  840. num_tokens_post_padded, _topk_weights.numel(), top_k, num_experts,
  841. pad_num_in_feats, num_in_feats, num_in_channels, num_out_channels,
  842. out_feats);
  843. }
  844. return _out_feats.sum(0);
  845. }