import torch import triton import triton.language as tl AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] @triton.jit def awq_dequantize_kernel( qweight_ptr, # quantized matrix scales_ptr, # scales, per group zeros_ptr, # zeros, per group group_size, # Should always be one of the supported group sizes result_ptr, # Output matrix num_cols, # input num cols in qweight num_rows, # input num rows in qweight BLOCK_SIZE_X: tl.constexpr, BLOCK_SIZE_Y: tl.constexpr): # Setup the pids. pid_x = tl.program_id(axis=0) pid_y = tl.program_id(axis=1) # Compute offsets and masks for qweight_ptr. offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y) offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X * 8) // 8 offsets = num_cols * offsets_y[:, None] + offsets_x[None, :] masks_y = offsets_y < num_rows masks_x = offsets_x < num_cols masks = masks_y[:, None] & masks_x[None, :] # Compute offsets and masks for result output ptr. result_offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y) result_offsets_x = pid_x * BLOCK_SIZE_X * 8 + tl.arange( 0, BLOCK_SIZE_X * 8) result_offsets = (8 * num_cols * result_offsets_y[:, None] + result_offsets_x[None, :]) result_masks_y = result_offsets_y < num_rows result_masks_x = result_offsets_x < num_cols * 8 result_masks = result_masks_y[:, None] & result_masks_x[None, :] # Load the weights. iweights = tl.load(qweight_ptr + offsets, masks) # Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7] # that will map given indices to the correct order. reverse_awq_order_tensor = ((tl.arange(0, 2) * 4)[None, :] + tl.arange(0, 4)[:, None]).reshape(8) # Use this to compute a set of shifts that can be used to unpack and # reorder the values in iweights and zeros. shifts = reverse_awq_order_tensor * 4 shifts = tl.broadcast_to(shifts[None, :], (BLOCK_SIZE_Y * BLOCK_SIZE_X, 8)) shifts = tl.reshape(shifts, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8)) # Unpack and reorder: shift out the correct 4-bit value and mask. iweights = (iweights >> shifts) & 0xF # Compute zero offsets and masks. zero_offsets_y = (pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, BLOCK_SIZE_Y) // group_size) zero_offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X * 8) // 8 zero_offsets = num_cols * zero_offsets_y[:, None] + zero_offsets_x[None, :] zero_masks_y = zero_offsets_y < num_rows // group_size zero_masks_x = zero_offsets_x < num_cols zero_masks = zero_masks_y[:, None] & zero_masks_x[None, :] # Load the zeros. zeros = tl.load(zeros_ptr + zero_offsets, zero_masks) # Unpack and reorder: shift out the correct 4-bit value and mask. zeros = (zeros >> shifts) & 0xF # Compute scale offsets and masks. scale_offsets_y = (pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, BLOCK_SIZE_Y) // group_size) scale_offsets_x = (pid_x * BLOCK_SIZE_X * 8 + tl.arange(0, BLOCK_SIZE_X * 8)) scale_offsets = (num_cols * 8 * scale_offsets_y[:, None] + scale_offsets_x[None, :]) scale_masks_y = scale_offsets_y < num_rows // group_size scale_masks_x = scale_offsets_x < num_cols * 8 scale_masks = scale_masks_y[:, None] & scale_masks_x[None, :] # Load the scales. scales = tl.load(scales_ptr + scale_offsets, scale_masks) # Dequantize. iweights = (iweights - zeros) * scales iweights = iweights.to(result_ptr.type.element_ty) # Finally, store. tl.store(result_ptr + result_offsets, iweights, result_masks) @triton.jit def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K, group_size, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, SPLIT_K: tl.constexpr): pid = tl.program_id(axis=0) pid_z = tl.program_id(1) # NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead. # num_pid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) pid_m = pid // num_pid_n pid_n = pid % num_pid_n accumulator_dtype = c_ptr.type.element_ty # NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead. # accumulator = tl.arange(0, BLOCK_SIZE_N) # accumulator = tl.broadcast_to(accumulator[None, :], # (BLOCK_SIZE_M, BLOCK_SIZE_N)) # accumulator = accumulator & 0x0 # accumulator = accumulator.to(accumulator_dtype) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=accumulator_dtype) # Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7] # that will map given indices to the correct order. reverse_awq_order_tensor = ((tl.arange(0, 2) * 4)[None, :] + tl.arange(0, 4)[:, None]).reshape(8) # Create the necessary shifts to use to unpack. shifts = reverse_awq_order_tensor * 4 shifts = tl.broadcast_to(shifts[None, :], (BLOCK_SIZE_K * (BLOCK_SIZE_N // 8), 8)) shifts = tl.reshape(shifts, (BLOCK_SIZE_K, BLOCK_SIZE_N)) # Offsets and masks. offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) masks_am = offsets_am < M offsets_bn = (pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N) // 8) masks_bn = offsets_bn < N // 8 offsets_zn = (pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N) // 8) masks_zn = offsets_zn < N // 8 offsets_sn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) masks_sn = offsets_sn < N offsets_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) offsets_a = K * offsets_am[:, None] + offsets_k[None, :] offsets_b = (N // 8) * offsets_k[:, None] + offsets_bn[None, :] a_ptrs = a_ptr + offsets_a b_ptrs = b_ptr + offsets_b # NOTE: Use this in TRITON_INTERPRET=1 mode instead of tl.cdiv # block_offset = BLOCK_SIZE_K * SPLIT_K # for k in range(0, (K + block_offset - 1) // (block_offset)): for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): masks_k = offsets_k < K masks_a = masks_am[:, None] & masks_k[None, :] a = tl.load(a_ptrs, mask=masks_a) masks_b = masks_k[:, None] & masks_bn[None, :] b = tl.load(b_ptrs, mask=masks_b) # Dequantize b. offsets_szk = ( (BLOCK_SIZE_K * SPLIT_K * k + pid_z * BLOCK_SIZE_K) // group_size + tl.arange(0, BLOCK_SIZE_K) // group_size) offsets_z = (N // 8) * offsets_szk[:, None] + offsets_zn[None, :] masks_zk = offsets_szk < K // group_size masks_z = masks_zk[:, None] & masks_zn[None, :] zeros_ptrs = zeros_ptr + offsets_z zeros = tl.load(zeros_ptrs, mask=masks_z) offsets_s = N * offsets_szk[:, None] + offsets_sn[None, :] masks_sk = offsets_szk < K // group_size masks_s = masks_sk[:, None] & masks_sn[None, :] scales_ptrs = scales_ptr + offsets_s scales = tl.load(scales_ptrs, mask=masks_s) b = (b >> shifts) & 0xF zeros = (zeros >> shifts) & 0xF b = (b - zeros) * scales b = b.to(c_ptr.type.element_ty) # Accumulate results. accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype) offsets_k += BLOCK_SIZE_K * SPLIT_K a_ptrs += BLOCK_SIZE_K * SPLIT_K b_ptrs += BLOCK_SIZE_K * SPLIT_K * (N // 8) c = accumulator.to(c_ptr.type.element_ty) offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + N * offs_cm[:, None] + offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) if SPLIT_K == 1: tl.store(c_ptrs, c, mask=c_mask) else: tl.atomic_add(c_ptrs, c, mask=c_mask) # qweights - [K , M // 8], int32 # scales - [K // G, M ], float16 # zeros - [K // G, M // 8], int32 def awq_dequantize_triton(qweight: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, block_size_x: int = 32, block_size_y: int = 32) -> torch.Tensor: K = qweight.shape[0] M = scales.shape[1] group_size = qweight.shape[0] // scales.shape[0] assert K > 0 and M > 0 assert scales.shape[0] == K // group_size and scales.shape[1] == M assert zeros.shape[0] == K // group_size and zeros.shape[1] == M // 8 assert group_size <= K assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K # Result tensor: # number of rows = same as input tensor # number of cols = 8 x input tensor num cols result = torch.empty(qweight.shape[0], qweight.shape[1] * 8, device=qweight.device, dtype=scales.dtype) Y = qweight.shape[0] # num rows X = qweight.shape[1] # num cols grid = lambda META: ( triton.cdiv(X, META['BLOCK_SIZE_X']), triton.cdiv(Y, META['BLOCK_SIZE_Y']), ) awq_dequantize_kernel[grid](qweight, scales, zeros, group_size, result, X, Y, BLOCK_SIZE_X=block_size_x, BLOCK_SIZE_Y=block_size_y) return result # input - [M, K] # qweight - [K, N // 8] # qzeros - [K // G, N // 8] # scales - [K // G, N] # split_k_iters - parallelism along K-dimension, int, power of 2. def awq_gemm_triton(input: torch.Tensor, qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor, split_k_iters: int, block_size_m: int = 32, block_size_n: int = 32, block_size_k: int = 32) -> torch.Tensor: M, K = input.shape N = qweight.shape[1] * 8 group_size = qweight.shape[0] // qzeros.shape[0] assert N > 0 and K > 0 and M > 0 assert qweight.shape[0] == K and qweight.shape[1] == N // 8 assert qzeros.shape[0] == K // group_size and qzeros.shape[1] == N // 8 assert scales.shape[0] == K // group_size and scales.shape[1] == N assert split_k_iters & (split_k_iters - 1) == 0 and split_k_iters != 0 assert split_k_iters <= 32 assert group_size <= K assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K grid = lambda META: ( triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( N, META['BLOCK_SIZE_N']), split_k_iters, ) result = torch.zeros((M, N), dtype=scales.dtype, device=input.device) # A = input, B = qweight, C = result # A = M x K, B = K x N, C = M x N awq_gemm_kernel[grid](input, qweight, result, qzeros, scales, M, N, K, group_size, BLOCK_SIZE_M=block_size_m, BLOCK_SIZE_N=block_size_n, BLOCK_SIZE_K=block_size_k, SPLIT_K=split_k_iters) return result