"""This file is used for /tests and /benchmarks""" from typing import List import numpy import torch from aphrodite.quantization.qqq import MARLIN_QQQ_SUPPORTED_NUM_BITS from aphrodite.scalar_type import ScalarType, scalar_types SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] # NOTE: this is a hack. We should update each model to register the # stacked params and get it from there instead in a future PR. # fused_name: List[shard_name] FUSED_LAYER_NAME_MAPPING = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"] } def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool: # prefix: model.layers.0.self_attn.q_proj # proj_name: q_proj proj_name = prefix.split(".")[-1] if proj_name in FUSED_LAYER_NAME_MAPPING: shard_prefixes = [ prefix.replace(proj_name, shard_proj_name) for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name] ] is_skipped = None for shard_prefix in shard_prefixes: is_shard_skipped = shard_prefix in ignored_layers if is_skipped is None: is_skipped = is_shard_skipped elif is_shard_skipped != is_skipped: raise ValueError( f"Detected some but not all shards of {prefix} " "are quantized. All shards of fused layers " "to have the same precision.") else: is_skipped = prefix in ignored_layers assert is_skipped is not None return is_skipped def get_pack_factor(num_bits): assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" return 32 // num_bits def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int): assert q_w.shape == w_ref.shape orig_device = q_w.device k_size, _ = q_w.shape g_idx = torch.zeros((k_size, ), dtype=torch.int32) for i in range(k_size): g_idx[i] = i // group_size # Simulate act_order by doing a random permutation on K rand_perm = torch.randperm(k_size) g_idx = g_idx[rand_perm].contiguous() q_w = q_w[rand_perm, :].contiguous() w_ref = w_ref[rand_perm, :].contiguous() return ( w_ref.to(device=orig_device), q_w.to(device=orig_device), g_idx.to(device=orig_device), rand_perm.to(device=orig_device), ) def quantize_weights(w: torch.Tensor, quant_type: ScalarType, group_size: int, zero_points: bool = False): assert quant_type.is_integer(), \ "Floating point quantization may work but has not been tested" orig_device = w.device orig_type = w.dtype size_k, size_n = w.shape assert w.is_floating_point(), "w must be float" if group_size == -1: group_size = size_k assert group_size <= size_k # Reshape to [groupsize, -1] if group_size < size_k: w = w.reshape((-1, group_size, size_n)) w = w.permute(1, 0, 2) w = w.reshape((group_size, -1)) # Compute scale for each group max_val = torch.max(w, 0, keepdim=True).values min_val = torch.min(w, 0, keepdim=True).values max_q_val = quant_type.max() min_q_val = quant_type.min() if zero_points: assert not quant_type.is_signed() and quant_type.max() > 0 w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() maybe_w_zp = torch.round(torch.abs(min_val / w_s)) \ .clamp(min_q_val, max_q_val).int() else: # If the bias is such that there are no possible negative/positive # values, set the max value to inf to avoid divide by 0 w_s = torch.max( abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), abs(min_val / (min_q_val if min_q_val != 0 else torch.inf))) maybe_w_zp = None # Quantize w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) w_q = torch.clamp(w_q, min_q_val, max_q_val) # Compute ref (dequantized) w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s if quant_type.has_bias(): w_q += quant_type.bias # Restore original shapes if group_size < size_k: def reshape_w(w): w = w.reshape((group_size, -1, size_n)) w = w.permute(1, 0, 2) w = w.reshape((size_k, size_n)).contiguous() return w w_q = reshape_w(w_q) w_ref = reshape_w(w_ref) w_s = w_s.reshape((-1, size_n)).contiguous() if zero_points: maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() maybe_w_zp = maybe_w_zp.to(device=orig_device) return ( w_ref.to(device=orig_device), w_q.to(device=orig_device), w_s.to(device=orig_device), maybe_w_zp, ) def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType, group_size: int, act_order: bool): size_k, _ = w.shape assert w.is_floating_point(), "w must be float" assert quant_type in SUPPORTED_GPTQ_QUANT_TYPES, \ f"Unsupported gptq type = {quant_type}" assert group_size in SUPPORTED_GROUP_SIZES + [ size_k ], f"Unsupported groupsize = {group_size}" w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size) # Apply act_order g_idx = torch.empty(0, dtype=torch.int, device=w.device) rand_perm = torch.empty(0, dtype=torch.int, device=w.device) if act_order: assert ( group_size < size_k ), "For act_order, groupsize = {} must be less than size_k = {}".format( group_size, size_k) w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size) return w_ref, w_q, w_s, g_idx, rand_perm # QQQ employs different quant schemes for per-group and # per-channel quantization. def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int): orig_device = w.device size_k, size_n = w.shape assert w.is_floating_point(), "w must be float" assert num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS, \ f"Unsupported num_bits = {num_bits}" assert group_size in SUPPORTED_GROUP_SIZES + [ size_k ], f"Unsupported groupsize = {group_size}" if group_size == -1: group_size = size_k assert group_size <= size_k if group_size < size_k: # Reshape to [groupsize, -1] w = w.reshape((-1, group_size, size_n)) w = w.permute(1, 0, 2) w = w.reshape((group_size, -1)) max_q_val = 2**num_bits - 1 half_q_val = (max_q_val + 1) // 2 # Compute scale for each group s_group = torch.max(torch.abs(w), 0, keepdim=True)[0] s_group *= 2 / max_q_val # 2 => symmetric # Quantize q_w = torch.round(w / s_group).int() q_w += half_q_val q_w = torch.clamp(q_w, 0, max_q_val) # Compute ref (dequantized) w_ref = (q_w - half_q_val).half() * s_group # Restore original shapes def reshape_w(w): w = w.reshape((group_size, -1, size_n)) w = w.permute(1, 0, 2) w = w.reshape((size_k, size_n)).contiguous() return w q_w = reshape_w(q_w) w_ref = reshape_w(w_ref) # Compute int8 quantization scale for each channel s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0] s_channel /= 127.0 t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8) w_ref = t_int8.half() * s_channel s_channel = s_channel.reshape(1, -1).to(dtype=torch.float) # Fuse scales s_group = (s_group.reshape(-1, size_n).contiguous() / s_channel).to(dtype=torch.half) else: max_q_val = 2**(num_bits - 1) - 1 # Compute scale for each channel s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0] s_channel /= max_q_val # Quantize q_w = torch.round(w / s_channel).int() q_w = torch.clamp(q_w, -max_q_val, max_q_val) # Compute ref (dequantized) w_ref = q_w.half() * s_channel s_group = torch.tensor([], dtype=torch.half) # div 2 ** (8 - self.bits)) to offset right shift in unpacking s_channel /= (2**(8 - num_bits)) s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float) return ( w_ref.to(device=orig_device), q_w.to(device=orig_device), s_group.to(device=orig_device), s_channel.to(device=orig_device), ) def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): orig_device = q_w.device sort_indices = torch.argsort(g_idx).to( dtype=torch.int32) # Sort based on g_idx g_idx = g_idx[sort_indices].contiguous() q_w = q_w[sort_indices, :].contiguous() return ( q_w.to(device=orig_device), g_idx.to(device=orig_device), sort_indices.to(device=orig_device), ) def pack_rows( q_w: torch.Tensor, num_bits: int, size_k: int, size_n: int, ): assert q_w.shape == (size_k, size_n) pack_factor = get_pack_factor(num_bits) assert size_k % pack_factor == 0 orig_device = q_w.device q_w = q_w.cpu().numpy().astype(numpy.uint32) q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32) for i in range(pack_factor): q_res |= q_w[i::pack_factor, :] << num_bits * i q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) return q_res def pack_cols( q_w: torch.Tensor, num_bits: int, size_k: int, size_n: int, ): assert q_w.shape == (size_k, size_n) pack_factor = get_pack_factor(num_bits) assert size_n % pack_factor == 0 orig_device = q_w.device q_w = q_w.cpu().numpy().astype(numpy.uint32) q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) for i in range(pack_factor): q_res |= q_w[:, i::pack_factor] << num_bits * i q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) q_res = q_res.contiguous() return q_res def unpack_cols( packed_q_w: torch.Tensor, num_bits: int, size_k: int, size_n: int, ): pack_factor = get_pack_factor(num_bits) assert size_n % pack_factor == 0 assert packed_q_w.shape == ( size_k, size_n // pack_factor ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( packed_q_w.shape, size_k, size_n, pack_factor) orig_device = packed_q_w.device packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32) mask = (1 << num_bits) - 1 for i in range(pack_factor): vals = packed_q_w_cpu & mask packed_q_w_cpu >>= num_bits q_res[:, i::pack_factor] = vals q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) q_res = q_res.contiguous() return q_res def gptq_pack( q_w: torch.Tensor, num_bits: int, size_k: int, size_n: int, ): return pack_rows(q_w, num_bits, size_k, size_n) def awq_pack( q_w: torch.Tensor, num_bits: int, size_k: int, size_n: int, ): assert q_w.shape == (size_k, size_n) # Interleave column dim (for the dequantize code) and pack it to int32 if num_bits == 4: interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) elif num_bits == 8: interleave = numpy.array([0, 2, 1, 3]) else: raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel() q_w = q_w.reshape((-1, size_n)).contiguous() return pack_cols(q_w, num_bits, size_k, size_n)