gemm_kernels.cu 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. /*
  2. Adapted from https://github.com/mit-han-lab/llm-awq
  3. @article{lin2023awq,
  4. title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
  5. author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
  6. journal={arXiv},
  7. year={2023}
  8. }
  9. */
  10. #include <torch/extension.h>
  11. #include <c10/cuda/CUDAGuard.h>
  12. #include "dequantize.cuh"
  13. #include <cuda_fp16.h>
  14. // Pack two half values.
  15. static inline __device__ __host__ unsigned
  16. __pack_half2(const half x, const half y) {
  17. unsigned v0 = *((unsigned short *)&x);
  18. unsigned v1 = *((unsigned short *)&y);
  19. return (v1 << 16) | v0;
  20. }
  21. __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
  22. {
  23. static constexpr uint32_t ZERO = 0x0;
  24. float C_warp[32];
  25. __shared__ half A_shared[16 * (32 + 8)];
  26. __shared__ half B_shared[32 * (128 + 8)];
  27. __shared__ half scaling_factors_shared[128];
  28. __shared__ half zeros_shared[128];
  29. int j_factors1 = ((OC + 128 - 1) / 128);
  30. int blockIdx_x = 0;
  31. int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
  32. int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
  33. half A_shared_warp[8];
  34. half B_shared_warp[32];
  35. for (int j_0_4_init = 0; j_0_4_init < 4; ++j_0_4_init) {
  36. for (int i = 0; i < 8; ++i) {
  37. C_warp[(j_0_4_init * 8) + i] = 0.0;
  38. }
  39. }
  40. static constexpr int row_stride_warp = 32 * 8 / 32;
  41. static constexpr int row_stride = 2 * 32 * 8 / 128;
  42. bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 128;
  43. // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
  44. bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
  45. // bool wb_C_flag = (threadIdx.x / 4) < M;
  46. half* A_ptr = A
  47. + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
  48. + (((int)threadIdx.x) % (32 / 8)) * 8;
  49. int* B_ptr = B
  50. + ((int)threadIdx.y) * (OC / 8) * 2
  51. + (((int)threadIdx.x) / (128 / 8)) * (OC / 8)
  52. + (((int)blockIdx_y) % j_factors1) * (128 / 8)
  53. + (((int)threadIdx.x) % (128 / 8)) * 1;
  54. // Why * 1 in the above line?
  55. half* A_shared_ptr = A_shared
  56. + ((int)threadIdx.y) * row_stride_warp * (32 + 8)
  57. + (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
  58. + (((int)threadIdx.x) % (32 / 8) ) * 8;
  59. half* B_shared_ptr = B_shared
  60. + ((int)threadIdx.y) * (row_stride / 2) * (128 + 8)
  61. + (((int)threadIdx.x) / (128 / 8)) * (128 + 8)
  62. + (((int)threadIdx.x) % (128 / 8)) * 8;
  63. int* zeros_ptr = zeros
  64. + (((int)blockIdx_y) % j_factors1) * (128 / 8)
  65. + ((int)threadIdx.x) % (128 / 8);
  66. half* scaling_factors_ptr = scaling_factors
  67. + (((int)blockIdx_y) % j_factors1) * (128)
  68. + (((int)threadIdx.x) % (128 / 8)) * 8;
  69. half* C_ptr = C
  70. + blockIdx_z * M * OC // blockIdz.x -> split_k dim
  71. + (((int)blockIdx_y) % j_factors1) * 128
  72. + ((int)threadIdx.y) * 64
  73. + (((int)threadIdx.x) % 4) * 2;
  74. // preload s.f. and zeros
  75. int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
  76. if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
  77. for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
  78. int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
  79. __syncthreads();
  80. // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
  81. if (ld_A_flag)
  82. {
  83. *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
  84. }
  85. else
  86. {
  87. *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
  88. }
  89. // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
  90. uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
  91. uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
  92. uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
  93. /*
  94. if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
  95. printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
  96. }
  97. */
  98. // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
  99. int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
  100. for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 8; ++ax0_ax1_fused_0) {
  101. // B: 32 x 136 (128+8) float16
  102. // each warp: 32 x 4
  103. // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
  104. // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
  105. // row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
  106. uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
  107. uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
  108. //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
  109. // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
  110. // - zero and * scale
  111. // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
  112. asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
  113. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
  114. asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
  115. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
  116. asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
  117. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
  118. asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
  119. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
  120. /*
  121. if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
  122. printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
  123. }
  124. */
  125. // write back
  126. *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (128 + 8)) = B_loaded_fp16;
  127. }
  128. __syncthreads();
  129. for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) {
  130. {
  131. unsigned int addr;
  132. __asm__ __volatile__(
  133. "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
  134. : "=r"(addr)
  135. : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
  136. );
  137. __asm__ __volatile__(
  138. "ldmatrix.sync.aligned.m8n8.x4.shared.b16"
  139. "{%0, %1, %2, %3}, [%4];\n"
  140. : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
  141. : "r"(addr)
  142. );
  143. }
  144. for (int ax1_0 = 0; ax1_0 < 4; ++ax1_0) {
  145. {
  146. unsigned int addr;
  147. __asm__ __volatile__(
  148. "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
  149. : "=r"(addr)
  150. : "l"((void *)((&(B_shared[(((k_0_1 * 2176) + (((int)threadIdx.y) * 64)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 136) + ((((int)threadIdx.x) >> 4) * 8))))
  151. );
  152. __asm__ __volatile__(
  153. "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
  154. "{%0, %1, %2, %3}, [%4];\n"
  155. : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
  156. : "r"(addr)
  157. );
  158. }
  159. }
  160. for (int j_0_4 = 0; j_0_4 < 4; ++j_0_4) {
  161. {
  162. __asm__ __volatile__(
  163. "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
  164. "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
  165. : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
  166. : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
  167. }
  168. {
  169. __asm__ __volatile__(
  170. "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
  171. "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
  172. : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
  173. : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
  174. }
  175. }
  176. }
  177. }
  178. // TODO: Shang: Hoist loop invariance.
  179. for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {
  180. for (int local_id = 0; local_id < 8; ++local_id) {
  181. int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
  182. if (row_offset < M)
  183. {
  184. *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
  185. }
  186. }
  187. }
  188. }
  189. __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
  190. {
  191. static constexpr uint32_t ZERO = 0x0;
  192. float C_warp[32];
  193. __shared__ half A_shared[16 * (32 + 8)];
  194. __shared__ half B_shared[32 * (64 + 8)];
  195. __shared__ half scaling_factors_shared[64];
  196. __shared__ half zeros_shared[64];
  197. int j_factors1 = ((OC + 64 - 1) / 64);
  198. int blockIdx_x = 0;
  199. int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
  200. int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
  201. half A_shared_warp[8];
  202. half B_shared_warp[16];
  203. for (int j_0_4_init = 0; j_0_4_init < 2; ++j_0_4_init) {
  204. for (int i = 0; i < 8; ++i) {
  205. C_warp[(j_0_4_init * 8) + i] = 0.0;
  206. }
  207. }
  208. static constexpr int row_stride_warp = 32 * 8 / 32;
  209. static constexpr int row_stride = 2 * 32 * 8 / 64;
  210. bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 64;
  211. // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
  212. bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
  213. // bool wb_C_flag = (threadIdx.x / 4) < M;
  214. half* A_ptr = A
  215. + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
  216. + (((int)threadIdx.x) % (32 / 8)) * 8;
  217. int* B_ptr = B
  218. + ((int)threadIdx.y) * (OC / 8) * 4
  219. + (((int)threadIdx.x) / (64 / 8)) * (OC / 8)
  220. + (((int)blockIdx_y) % j_factors1) * (64 / 8)
  221. + (((int)threadIdx.x) % (64 / 8)) * 1;
  222. // Why * 1 in the above line?
  223. half* A_shared_ptr = A_shared
  224. + ((int)threadIdx.y) * row_stride_warp * (32 + 8)
  225. + (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
  226. + (((int)threadIdx.x) % (32 / 8) ) * 8;
  227. half* B_shared_ptr = B_shared
  228. + ((int)threadIdx.y) * (row_stride / 2) * (64 + 8)
  229. + (((int)threadIdx.x) / (64 / 8)) * (64 + 8)
  230. + (((int)threadIdx.x) % (64 / 8)) * 8;
  231. int* zeros_ptr = zeros
  232. + (((int)blockIdx_y) % j_factors1) * (64 / 8)
  233. + ((int)threadIdx.x) % (64 / 8);
  234. half* scaling_factors_ptr = scaling_factors
  235. + (((int)blockIdx_y) % j_factors1) * (64)
  236. + (((int)threadIdx.x) % (64 / 8)) * 8;
  237. half* C_ptr = C
  238. + blockIdx_z * M * OC // blockIdz.x -> split_k dim
  239. + (((int)blockIdx_y) % j_factors1) * 64
  240. + ((int)threadIdx.y) * 32
  241. + (((int)threadIdx.x) % 4) * 2;
  242. // preload s.f. and zeros
  243. int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
  244. if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
  245. for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
  246. int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
  247. __syncthreads();
  248. // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
  249. if (ld_A_flag)
  250. {
  251. *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
  252. }
  253. else
  254. {
  255. *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
  256. }
  257. // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
  258. uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
  259. uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
  260. uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
  261. /*
  262. if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
  263. printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
  264. }
  265. */
  266. // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
  267. int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
  268. for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0) {
  269. // B: 32 x 136 (128+8) float16
  270. // each warp: 32 x 4
  271. // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
  272. // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
  273. // row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
  274. uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
  275. uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
  276. //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
  277. // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
  278. // - zero and * scale
  279. // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
  280. asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
  281. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
  282. asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
  283. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
  284. asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
  285. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
  286. asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
  287. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
  288. /*
  289. if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
  290. printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
  291. }
  292. */
  293. // write back
  294. *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (64 + 8)) = B_loaded_fp16;
  295. }
  296. __syncthreads();
  297. for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1)
  298. {
  299. {
  300. unsigned int addr;
  301. __asm__ __volatile__(
  302. "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
  303. : "=r"(addr)
  304. : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
  305. );
  306. __asm__ __volatile__(
  307. "ldmatrix.sync.aligned.m8n8.x4.shared.b16"
  308. "{%0, %1, %2, %3}, [%4];\n"
  309. : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
  310. : "r"(addr)
  311. );
  312. }
  313. for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0)
  314. {
  315. {
  316. unsigned int addr;
  317. __asm__ __volatile__(
  318. "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
  319. : "=r"(addr)
  320. : "l"((void *)((&(B_shared[(((k_0_1 * 1152) + (((int)threadIdx.y) * 32)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 72) + ((((int)threadIdx.x) >> 4) * 8))))
  321. );
  322. __asm__ __volatile__(
  323. "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
  324. "{%0, %1, %2, %3}, [%4];\n"
  325. : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
  326. : "r"(addr)
  327. );
  328. }
  329. }
  330. for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4)
  331. {
  332. {
  333. __asm__ __volatile__(
  334. "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
  335. "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
  336. : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
  337. : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
  338. }
  339. {
  340. __asm__ __volatile__(
  341. "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
  342. "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
  343. : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
  344. : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
  345. }
  346. }
  347. }
  348. }
  349. // TODO: Shang: Hoist loop invariance.
  350. for (int ax1_0_1 = 0; ax1_0_1 < 2; ++ax1_0_1) {
  351. for (int local_id = 0; local_id < 8; ++local_id) {
  352. int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
  353. if (row_offset < M)
  354. {
  355. *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
  356. }
  357. }
  358. }
  359. }
  360. // in_feats: M, IC [float16]
  361. // kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
  362. // scaling_factors: IC // G, OC [float16]
  363. // zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
  364. // assume that batch_size < 16 for now
  365. torch::Tensor awq_gemm(
  366. torch::Tensor _in_feats,
  367. torch::Tensor _kernel,
  368. torch::Tensor _scaling_factors,
  369. torch::Tensor _zeros,
  370. int split_k_iters)
  371. {
  372. int num_in_feats = _in_feats.size(0);
  373. int num_in_channels = _in_feats.size(1);
  374. const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
  375. auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
  376. at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options);
  377. int num_out_feats = _out_feats.size(-2);
  378. int num_out_channels = _out_feats.size(-1);
  379. auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
  380. auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
  381. auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
  382. auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
  383. auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
  384. int group_size = num_in_channels / _scaling_factors.size(0);
  385. if (num_out_channels % 64 != 0)
  386. throw std::invalid_argument("OC is not multiple of cta_N = 64");
  387. if (num_out_channels % 8 != 0)
  388. throw std::invalid_argument("OC is not multiple of pack_num = 8");
  389. if (group_size % 32 != 0)
  390. throw std::invalid_argument("Group size should be a multiple of 32");
  391. if (num_out_channels % group_size != 0)
  392. throw std::invalid_argument("OC is not multiple of Group size");
  393. if (num_out_channels % 128 == 0)
  394. {
  395. int j_factors1 = num_out_channels / 128 / 1;
  396. dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
  397. // threadIdx.x: 32
  398. // threadIdx.y: i_factors[2] * j_factors[2]
  399. dim3 threads_per_block(32, 2);
  400. gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block>>>(
  401. group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
  402. }
  403. else if (num_out_channels % 64 == 0)
  404. {
  405. int j_factors1 = num_out_channels / 64 / 1;
  406. dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
  407. // threadIdx.x: 32
  408. // threadIdx.y: i_factors[2] * j_factors[2]
  409. dim3 threads_per_block(32, 2);
  410. gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block>>>(
  411. group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
  412. }
  413. return _out_feats.sum(0);
  414. }