/* Adapted from https://github.com/mit-han-lab/llm-awq @article{lin2023awq, title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023} } */ #include #include #include "dequantize.cuh" #include namespace aphrodite { namespace awq { template __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C) { // Only support matrix n = 64 or 128 assert(N == 64 || N == 128); #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 assert(false); #else static constexpr uint32_t ZERO = 0x0; float C_warp[32]; __shared__ half A_shared[16 * (32 + 8)]; __shared__ half B_shared[32 * (N + 8)]; int j_factors1 = ((OC + N - 1) / N); int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1); int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1); half A_shared_warp[8]; half B_shared_warp[N / 4]; for (int j_0_4_init = 0; j_0_4_init < N / 32; ++j_0_4_init) { for (int i = 0; i < 8; ++i) { C_warp[(j_0_4_init * 8) + i] = 0.0; } } static constexpr int row_stride_warp = 32 * 8 / 32; static constexpr int row_stride = 2 * 32 * 8 / N; // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id // bool wb_C_flag = (threadIdx.x / 4) < M; half* A_ptr = A + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC + (((int)threadIdx.x) % (32 / 8)) * 8; int* B_ptr = B + ((int)threadIdx.y) * (OC / 8) * (256 / N) + (((int)threadIdx.x) / (N / 8)) * (OC / 8) + (((int)blockIdx_y) % j_factors1) * (N / 8) + (((int)threadIdx.x) % (N / 8)) * 1; // Why * 1 in the above line? half* A_shared_ptr = A_shared + ((int)threadIdx.y) * row_stride_warp * (32 + 8) + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) + (((int)threadIdx.x) % (32 / 8)) * 8; half* B_shared_ptr = B_shared + ((int)threadIdx.y) * (row_stride / 2) * (N + 8) + (((int)threadIdx.x) / (N / 8)) * (N + 8) + (((int)threadIdx.x) % (N / 8)) * 8; int* zeros_ptr = zeros + (((int)blockIdx_y) % j_factors1) * (N / 8) + ((int)threadIdx.x) % (N / 8); half* scaling_factors_ptr = scaling_factors + (((int)blockIdx_y) % j_factors1) * N + (((int)threadIdx.x) % (N / 8)) * 8; half* C_ptr = C + static_cast(blockIdx_z) * M * OC // blockIdz.x -> split_k dim + (((int)blockIdx_y) % j_factors1) * N + ((int)threadIdx.y) * (N / 2) + (((int)threadIdx.x) % 4) * 2; // preload s.f. and zeros int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters; if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1; for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) { int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z; __syncthreads(); // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 if (ld_A_flag) { *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32)); } else { *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0); } // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) { uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8)); uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); /* if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){ printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w); } */ // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0); int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8); for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) { // B: 32 x 136 (128+8) float16 // each warp: 32 x 4 // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus // zero -> WB UINT4 // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * // 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) // * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * // 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * // 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * // 8))); row stride in shared memory: (NWARPS * 32 * 8 / cta_N) uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); // - zero and * scale // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = // q * scale - zero * scale. asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); /* if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){ printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w); } */ // write back *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = B_loaded_fp16; } __syncthreads(); for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) { { unsigned int addr; asm("{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, " "addr; }\n" : "=r"(addr) : "l"((void*)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))); asm("ldmatrix.sync.aligned.m8n8.x4.shared.b16" "{%0, %1, %2, %3}, [%4];\n" : "=r"(((unsigned*)(A_shared_warp + 0))[0]), "=r"(((unsigned*)(A_shared_warp + 0))[1]), "=r"(((unsigned*)(A_shared_warp + 0))[2]), "=r"(((unsigned*)(A_shared_warp + 0))[3]) : "r"(addr)); } for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) { { unsigned int addr; asm("{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, " "addr; }\n" : "=r"(addr) : "l"((void*)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + (((int)threadIdx.y) * (N / 2))) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * (N + 8)) + ((((int)threadIdx.x) >> 4) * 8))))); asm("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" "{%0, %1, %2, %3}, [%4];\n" : "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[3]) : "r"(addr)); } } for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 { asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) : "r"(((unsigned*)(A_shared_warp + 0))[0]), "r"(((unsigned*)(A_shared_warp + 0))[1]), "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); } { asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) : "r"(((unsigned*)(A_shared_warp + 0))[0]), "r"(((unsigned*)(A_shared_warp + 0))[1]), "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); } { asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) : "r"(((unsigned*)(A_shared_warp + 0))[2]), "r"(((unsigned*)(A_shared_warp + 0))[3]), "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); } { asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) : "r"(((unsigned*)(A_shared_warp + 0))[2]), "r"(((unsigned*)(A_shared_warp + 0))[3]), "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); } #else { asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, " "%13};\n" : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) : "r"(((unsigned*)(A_shared_warp + 0))[0]), "r"(((unsigned*)(A_shared_warp + 0))[1]), "r"(((unsigned*)(A_shared_warp + 0))[2]), "r"(((unsigned*)(A_shared_warp + 0))[3]), "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); } { asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, " "%13};\n" : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) : "r"(((unsigned*)(A_shared_warp + 0))[0]), "r"(((unsigned*)(A_shared_warp + 0))[1]), "r"(((unsigned*)(A_shared_warp + 0))[2]), "r"(((unsigned*)(A_shared_warp + 0))[3]), "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); } #endif } } } // TODO: Shang: Hoist loop invariance. for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) { for (int local_id = 0; local_id < 8; ++local_id) { int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; if (row_offset < M) { *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]); } } } #endif } __global__ void __launch_bounds__(64) dequantize_weights(int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, half* __restrict__ C, int G, int in_c, int out_c) { if (blockIdx.z > 0) { B = B + blockIdx.z * in_c * out_c / 8; scaling_factors = scaling_factors + blockIdx.z * in_c * out_c / G; zeros = zeros + blockIdx.z * in_c * out_c / G / 8; C = C + blockIdx.z * in_c * out_c; } static constexpr uint32_t ZERO = 0x0; half B_shared[32 * (128 + 8)]; half* B_shared_ptr2 = B_shared; int N = blockDim.x * gridDim.x; // 2 int col = (blockIdx.x * blockDim.x + threadIdx.x); int row = blockIdx.y * blockDim.y + threadIdx.y; int index1 = 8 * col + 8 * row * N; half* C_ptr2 = C + index1; int index2 = col + row * N; int* B_ptr2 = B + index2; int index3 = col + (int)(row / G) * N; int* zeros_ptr2 = zeros + index3; int index4 = 8 * col + (int)(row / G) * N * 8; half* scaling_factors_ptr2 = scaling_factors + index4; uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr2); uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr2); uint32_t B_loaded = *(uint32_t*)B_ptr2; uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); *(uint4*)B_shared_ptr2 = B_loaded_fp16; for (int i = 0; i < 8; ++i) { *(C_ptr2 + i) = B_shared[i]; } } template __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32( int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, const float* __restrict__ topk_weights, const int* __restrict__ sorted_token_ids_ptr, const int* __restrict__ expert_ids_ptr, const int* __restrict__ num_tokens_post_padded, const int num_valid_tokens, const int top_k, const int expert_num, int pad_M, int M, int IC, int OC, half* __restrict__ C) { // Only support matrix n = 64 or 128 assert(N == 64 || N == 128); #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 assert(false); #else int num_tokens = *num_tokens_post_padded; int j_factors1 = ((OC + N - 1) / N); [[maybe_unused]] int blockIdx_x = 0; int blockIdx_y = blockIdx.x % ((pad_M + 16 - 1) / 16 * j_factors1); int blockIdx_z = blockIdx.x / ((pad_M + 16 - 1) / 16 * j_factors1); int block = blockIdx_y / j_factors1; if (block * 16 >= num_tokens) return; static constexpr uint32_t ZERO = 0x0; float C_warp[32]; __shared__ half A_shared[16 * (32 + 8)]; __shared__ half B_shared[32 * (N + 8)]; [[maybe_unused]] __shared__ half scaling_factors_shared[N]; [[maybe_unused]] __shared__ half zeros_shared[N]; half A_shared_warp[8]; half B_shared_warp[N / 4]; for (int j_0_4_init = 0; j_0_4_init < N / 32; ++j_0_4_init) { for (int i = 0; i < 8; ++i) { C_warp[(j_0_4_init * 8) + i] = 0.0; } } static constexpr int row_stride_warp = 32 * 8 / 32; static constexpr int row_stride = 2 * 32 * 8 / N; [[maybe_unused]] bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N; // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 int row = (block * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32); int token_id = sorted_token_ids_ptr[row]; bool ld_A_flag = (token_id < num_valid_tokens); half* A_ptr = A + token_id / top_k * IC + (((int)threadIdx.x) % (32 / 8)) * 8; int expert_id = expert_ids_ptr[block]; B = B + OC * IC / 8 * expert_id; scaling_factors = scaling_factors + OC * IC / G * expert_id; zeros = zeros + OC * IC / G / 8 * expert_id; int* B_ptr = B + ((int)threadIdx.y) * (OC / 8) * (256 / N) + (((int)threadIdx.x) / (N / 8)) * (OC / 8) + (((int)blockIdx_y) % j_factors1) * (N / 8) + (((int)threadIdx.x) % (N / 8)) * 1; // Why * 1 in the above line? half* A_shared_ptr = A_shared + ((int)threadIdx.y) * row_stride_warp * (32 + 8) + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) + (((int)threadIdx.x) % (32 / 8)) * 8; half* B_shared_ptr = B_shared + ((int)threadIdx.y) * (row_stride / 2) * (N + 8) + (((int)threadIdx.x) / (N / 8)) * (N + 8) + (((int)threadIdx.x) % (N / 8)) * 8; int* zeros_ptr = zeros + (((int)blockIdx_y) % j_factors1) * (N / 8) + ((int)threadIdx.x) % (N / 8); half* scaling_factors_ptr = scaling_factors + (((int)blockIdx_y) % j_factors1) * N + (((int)threadIdx.x) % (N / 8)) * 8; half* C_ptr = C + static_cast(blockIdx_z) * M * OC * expert_num // blockIdz.x -> split_k dim + (((int)blockIdx_y) % j_factors1) * N + ((int)threadIdx.y) * (N / 2) + (((int)threadIdx.x) % 4) * 2; // preload s.f. and zeros int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters; if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1; for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) { int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z; __syncthreads(); // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 if (ld_A_flag) { *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32)); } else { *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0); } uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8)); uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8); for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) { uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = // q * scale - zero * scale. asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); // write back *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = B_loaded_fp16; } __syncthreads(); for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) { { unsigned int addr; asm("{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, " "addr; }\n" : "=r"(addr) : "l"((void*)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))); asm("ldmatrix.sync.aligned.m8n8.x4.shared.b16" "{%0, %1, %2, %3}, [%4];\n" : "=r"(((unsigned*)(A_shared_warp + 0))[0]), "=r"(((unsigned*)(A_shared_warp + 0))[1]), "=r"(((unsigned*)(A_shared_warp + 0))[2]), "=r"(((unsigned*)(A_shared_warp + 0))[3]) : "r"(addr)); } for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) { { unsigned int addr; asm("{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, " "addr; }\n" : "=r"(addr) : "l"((void*)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + (((int)threadIdx.y) * (N / 2))) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * (N + 8)) + ((((int)threadIdx.x) >> 4) * 8))))); asm("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" "{%0, %1, %2, %3}, [%4];\n" : "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[3]) : "r"(addr)); } } for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 { asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) : "r"(((unsigned*)(A_shared_warp + 0))[0]), "r"(((unsigned*)(A_shared_warp + 0))[1]), "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); } { asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) : "r"(((unsigned*)(A_shared_warp + 0))[0]), "r"(((unsigned*)(A_shared_warp + 0))[1]), "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); } { asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) : "r"(((unsigned*)(A_shared_warp + 0))[2]), "r"(((unsigned*)(A_shared_warp + 0))[3]), "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); } { asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) : "r"(((unsigned*)(A_shared_warp + 0))[2]), "r"(((unsigned*)(A_shared_warp + 0))[3]), "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); } #else { asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, " "%13};\n" : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) : "r"(((unsigned*)(A_shared_warp + 0))[0]), "r"(((unsigned*)(A_shared_warp + 0))[1]), "r"(((unsigned*)(A_shared_warp + 0))[2]), "r"(((unsigned*)(A_shared_warp + 0))[3]), "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); } { asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, " "%13};\n" : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) : "r"(((unsigned*)(A_shared_warp + 0))[0]), "r"(((unsigned*)(A_shared_warp + 0))[1]), "r"(((unsigned*)(A_shared_warp + 0))[2]), "r"(((unsigned*)(A_shared_warp + 0))[3]), "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); } #endif } } } // TODO: Shang: Hoist loop invariance. for (int ax1_0_1 = 0; ax1_0_1 < N / 32; ++ax1_0_1) { for (int local_id = 0; local_id < 8; ++local_id) { int row_offset = block * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; int token_id = sorted_token_ids_ptr[row_offset]; if (token_id < num_valid_tokens) { float value = C_warp[(ax1_0_1 * 8) + local_id]; if (topk_weights) { value = value * topk_weights[token_id]; } *(C_ptr + ax1_0_1 * 16 + token_id * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(value); } } } #endif } } // namespace awq } // namespace aphrodite torch::Tensor awq_dequantize(torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros, int64_t split_k_iters, int64_t thx, int64_t thy) { int in_c = _kernel.dim() == 2 ? _kernel.size(0) : _kernel.size(1); int qout_c = _kernel.dim() == 2 ? _kernel.size(1) : _kernel.size(2); int num_experts = _kernel.dim() == 2 ? 1 : _kernel.size(0); int out_c = qout_c * 8; int G = in_c / (_kernel.dim() == 2 ? _scaling_factors.size(0) : _scaling_factors.size(1)); int x_thread = thx; int y_thread = thy; int x_blocks = 1; int y_blocks = 1; if (thx == 0) { x_thread = qout_c; } if (thy == 0) { y_thread = in_c; } if (thx == 0 && thy == 0) { x_thread = 8; y_thread = 8; x_blocks = (int)(qout_c / 8); y_blocks = (int)(in_c / 8); } const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors)); auto options = torch::TensorOptions() .dtype(_scaling_factors.dtype()) .device(_scaling_factors.device()); at::Tensor _de_kernel; if (num_experts == 1) { _de_kernel = torch::empty({in_c, out_c}, options); } else { _de_kernel = torch::empty({num_experts, in_c, out_c}, options); } auto kernel = reinterpret_cast(_kernel.data_ptr()); auto de_kernel = reinterpret_cast(_de_kernel.data_ptr()); auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); auto zeros = reinterpret_cast(_zeros.data_ptr()); dim3 num_blocks(x_blocks, y_blocks, num_experts); dim3 threads_per_block(x_thread, y_thread); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); aphrodite::awq:: dequantize_weights<<>>( kernel, scaling_factors, zeros, de_kernel, G, in_c, out_c); return _de_kernel; } // in_feats: M, IC [float16] // kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b] // scaling_factors: IC // G, OC [float16] // zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b] // assume that batch_size < 16 for now torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros, int64_t split_k_iters) { int num_in_feats = _in_feats.size(0); int num_in_channels = _in_feats.size(1); const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); auto options = torch::TensorOptions() .dtype(_in_feats.dtype()) .device(_in_feats.device()); at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options); int num_out_feats = _out_feats.size(-2); int num_out_channels = _out_feats.size(-1); auto in_feats = reinterpret_cast(_in_feats.data_ptr()); auto kernel = reinterpret_cast(_kernel.data_ptr()); auto out_feats = reinterpret_cast(_out_feats.data_ptr()); auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); auto zeros = reinterpret_cast(_zeros.data_ptr()); int group_size = num_in_channels / _scaling_factors.size(0); if (num_out_channels % 64 != 0) throw std::invalid_argument("OC is not multiple of cta_N = 64"); if (num_out_channels % 8 != 0) throw std::invalid_argument("OC is not multiple of pack_num = 8"); if (group_size % 32 != 0) throw std::invalid_argument("Group size should be a multiple of 32"); if (num_out_channels % group_size != 0) throw std::invalid_argument("OC is not multiple of Group size"); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if (num_out_channels % 128 == 0) { int j_factors1 = num_out_channels / 128 / 1; dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); // threadIdx.x: 32 // threadIdx.y: i_factors[2] * j_factors[2] dim3 threads_per_block(32, 2); aphrodite::awq::gemm_forward_4bit_cuda_m16nXk32<128> <<>>( group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats); } else if (num_out_channels % 64 == 0) { int j_factors1 = num_out_channels / 64 / 1; dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); // threadIdx.x: 32 // threadIdx.y: i_factors[2] * j_factors[2] dim3 threads_per_block(32, 2); aphrodite::awq::gemm_forward_4bit_cuda_m16nXk32<64> <<>>( group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats); } return _out_feats.sum(0); } torch::Tensor awq_group_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros, torch::Tensor _topk_weights, torch::Tensor _sorted_token_ids_ptr, torch::Tensor _expert_ids_ptr, torch::Tensor _num_tokens_post_padded, bool mul_weights, int split_k_iters) { int num_in_feats = _in_feats.size(0); int pad_num_in_feats = _sorted_token_ids_ptr.size(0); int num_in_channels = _in_feats.size(2); const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); auto options = torch::TensorOptions() .dtype(_in_feats.dtype()) .device(_in_feats.device()); int num_experts = _topk_weights.size(1); int top_k = num_experts / _in_feats.size(1); int group_size = num_in_channels / _scaling_factors.size(1); at::Tensor _out_feats = torch::empty( {split_k_iters, num_in_feats, _topk_weights.size(1), _kernel.size(2) * 8}, options); int num_out_channels = _out_feats.size(-1); auto in_feats = reinterpret_cast(_in_feats.data_ptr()); auto kernel = reinterpret_cast(_kernel.data_ptr()); auto out_feats = reinterpret_cast(_out_feats.data_ptr()); auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); auto zeros = reinterpret_cast(_zeros.data_ptr()); auto topk_weights = mul_weights ? reinterpret_cast(_topk_weights.data_ptr()) : nullptr; auto sorted_token_ids_ptr = reinterpret_cast(_sorted_token_ids_ptr.data_ptr()); auto expert_ids_ptr = reinterpret_cast(_expert_ids_ptr.data_ptr()); auto num_tokens_post_padded = reinterpret_cast(_num_tokens_post_padded.data_ptr()); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if (num_out_channels % 128 == 0) { int j_factors1 = num_out_channels / 128 / 1; dim3 num_blocks((pad_num_in_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); // threadIdx.x: 32 // threadIdx.y: i_factors[2] * j_factors[2] dim3 threads_per_block(32, 2); aphrodite::awq::group_gemm_forward_4bit_cuda_m16nXk32<128> <<>>( group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, topk_weights, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded, _topk_weights.numel(), top_k, num_experts, pad_num_in_feats, num_in_feats, num_in_channels, num_out_channels, out_feats); } else if (num_out_channels % 64 == 0) { int j_factors1 = num_out_channels / 64 / 1; dim3 num_blocks((pad_num_in_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); // threadIdx.x: 32 // threadIdx.y: i_factors[2] * j_factors[2] dim3 threads_per_block(32, 2); aphrodite::awq::group_gemm_forward_4bit_cuda_m16nXk32<64> <<>>( group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, topk_weights, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded, _topk_weights.numel(), top_k, num_experts, pad_num_in_feats, num_in_feats, num_in_channels, num_out_channels, out_feats); } return _out_feats.sum(0); }