123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393 |
- """This file is used for /tests and /benchmarks"""
- from typing import List
- import numpy
- import torch
- from aphrodite.quantization.qqq import MARLIN_QQQ_SUPPORTED_NUM_BITS
- from aphrodite.scalar_type import ScalarType, scalar_types
- SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
- SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
- # NOTE: this is a hack. We should update each model to register the
- # stacked params and get it from there instead in a future PR.
- # fused_name: List[shard_name]
- FUSED_LAYER_NAME_MAPPING = {
- "qkv_proj": ["q_proj", "k_proj", "v_proj"],
- "gate_up_proj": ["gate_proj", "up_proj"]
- }
- def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
- # prefix: model.layers.0.self_attn.q_proj
- # proj_name: q_proj
- proj_name = prefix.split(".")[-1]
- if proj_name in FUSED_LAYER_NAME_MAPPING:
- shard_prefixes = [
- prefix.replace(proj_name, shard_proj_name)
- for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name]
- ]
- is_skipped = None
- for shard_prefix in shard_prefixes:
- is_shard_skipped = shard_prefix in ignored_layers
- if is_skipped is None:
- is_skipped = is_shard_skipped
- elif is_shard_skipped != is_skipped:
- raise ValueError(
- f"Detected some but not all shards of {prefix} "
- "are quantized. All shards of fused layers "
- "to have the same precision.")
- else:
- is_skipped = prefix in ignored_layers
- assert is_skipped is not None
- return is_skipped
- def get_pack_factor(num_bits):
- assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
- return 32 // num_bits
- def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
- assert q_w.shape == w_ref.shape
- orig_device = q_w.device
- k_size, _ = q_w.shape
- g_idx = torch.zeros((k_size, ), dtype=torch.int32)
- for i in range(k_size):
- g_idx[i] = i // group_size
- # Simulate act_order by doing a random permutation on K
- rand_perm = torch.randperm(k_size)
- g_idx = g_idx[rand_perm].contiguous()
- q_w = q_w[rand_perm, :].contiguous()
- w_ref = w_ref[rand_perm, :].contiguous()
- return (
- w_ref.to(device=orig_device),
- q_w.to(device=orig_device),
- g_idx.to(device=orig_device),
- rand_perm.to(device=orig_device),
- )
- def quantize_weights(w: torch.Tensor,
- quant_type: ScalarType,
- group_size: int,
- zero_points: bool = False):
- assert quant_type.is_integer(), \
- "Floating point quantization may work but has not been tested"
- orig_device = w.device
- orig_type = w.dtype
- size_k, size_n = w.shape
- assert w.is_floating_point(), "w must be float"
- if group_size == -1:
- group_size = size_k
- assert group_size <= size_k
- # Reshape to [groupsize, -1]
- if group_size < size_k:
- w = w.reshape((-1, group_size, size_n))
- w = w.permute(1, 0, 2)
- w = w.reshape((group_size, -1))
- # Compute scale for each group
- max_val = torch.max(w, 0, keepdim=True).values
- min_val = torch.min(w, 0, keepdim=True).values
- max_q_val = quant_type.max()
- min_q_val = quant_type.min()
- if zero_points:
- assert not quant_type.is_signed() and quant_type.max() > 0
- w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
- maybe_w_zp = torch.round(torch.abs(min_val / w_s)) \
- .clamp(min_q_val, max_q_val).int()
- else:
- # If the bias is such that there are no possible negative/positive
- # values, set the max value to inf to avoid divide by 0
- w_s = torch.max(
- abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
- abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)))
- maybe_w_zp = None
- # Quantize
- w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
- w_q = torch.clamp(w_q, min_q_val, max_q_val)
- # Compute ref (dequantized)
- w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
- if quant_type.has_bias():
- w_q += quant_type.bias
- # Restore original shapes
- if group_size < size_k:
- def reshape_w(w):
- w = w.reshape((group_size, -1, size_n))
- w = w.permute(1, 0, 2)
- w = w.reshape((size_k, size_n)).contiguous()
- return w
- w_q = reshape_w(w_q)
- w_ref = reshape_w(w_ref)
- w_s = w_s.reshape((-1, size_n)).contiguous()
- if zero_points:
- maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
- maybe_w_zp = maybe_w_zp.to(device=orig_device)
- return (
- w_ref.to(device=orig_device),
- w_q.to(device=orig_device),
- w_s.to(device=orig_device),
- maybe_w_zp,
- )
- def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType,
- group_size: int, act_order: bool):
- size_k, _ = w.shape
- assert w.is_floating_point(), "w must be float"
- assert quant_type in SUPPORTED_GPTQ_QUANT_TYPES, \
- f"Unsupported gptq type = {quant_type}"
- assert group_size in SUPPORTED_GROUP_SIZES + [
- size_k
- ], f"Unsupported groupsize = {group_size}"
- w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size)
- # Apply act_order
- g_idx = torch.empty(0, dtype=torch.int, device=w.device)
- rand_perm = torch.empty(0, dtype=torch.int, device=w.device)
- if act_order:
- assert (
- group_size < size_k
- ), "For act_order, groupsize = {} must be less than size_k = {}".format(
- group_size, size_k)
- w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size)
- return w_ref, w_q, w_s, g_idx, rand_perm
- # QQQ employs different quant schemes for per-group and
- # per-channel quantization.
- def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int):
- orig_device = w.device
- size_k, size_n = w.shape
- assert w.is_floating_point(), "w must be float"
- assert num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS, \
- f"Unsupported num_bits = {num_bits}"
- assert group_size in SUPPORTED_GROUP_SIZES + [
- size_k
- ], f"Unsupported groupsize = {group_size}"
- if group_size == -1:
- group_size = size_k
- assert group_size <= size_k
- if group_size < size_k:
- # Reshape to [groupsize, -1]
- w = w.reshape((-1, group_size, size_n))
- w = w.permute(1, 0, 2)
- w = w.reshape((group_size, -1))
- max_q_val = 2**num_bits - 1
- half_q_val = (max_q_val + 1) // 2
- # Compute scale for each group
- s_group = torch.max(torch.abs(w), 0, keepdim=True)[0]
- s_group *= 2 / max_q_val # 2 => symmetric
- # Quantize
- q_w = torch.round(w / s_group).int()
- q_w += half_q_val
- q_w = torch.clamp(q_w, 0, max_q_val)
- # Compute ref (dequantized)
- w_ref = (q_w - half_q_val).half() * s_group
- # Restore original shapes
- def reshape_w(w):
- w = w.reshape((group_size, -1, size_n))
- w = w.permute(1, 0, 2)
- w = w.reshape((size_k, size_n)).contiguous()
- return w
- q_w = reshape_w(q_w)
- w_ref = reshape_w(w_ref)
- # Compute int8 quantization scale for each channel
- s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0]
- s_channel /= 127.0
- t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8)
- w_ref = t_int8.half() * s_channel
- s_channel = s_channel.reshape(1, -1).to(dtype=torch.float)
- # Fuse scales
- s_group = (s_group.reshape(-1, size_n).contiguous() /
- s_channel).to(dtype=torch.half)
- else:
- max_q_val = 2**(num_bits - 1) - 1
- # Compute scale for each channel
- s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0]
- s_channel /= max_q_val
- # Quantize
- q_w = torch.round(w / s_channel).int()
- q_w = torch.clamp(q_w, -max_q_val, max_q_val)
- # Compute ref (dequantized)
- w_ref = q_w.half() * s_channel
- s_group = torch.tensor([], dtype=torch.half)
- # div 2 ** (8 - self.bits)) to offset right shift in unpacking
- s_channel /= (2**(8 - num_bits))
- s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float)
- return (
- w_ref.to(device=orig_device),
- q_w.to(device=orig_device),
- s_group.to(device=orig_device),
- s_channel.to(device=orig_device),
- )
- def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):
- orig_device = q_w.device
- sort_indices = torch.argsort(g_idx).to(
- dtype=torch.int32) # Sort based on g_idx
- g_idx = g_idx[sort_indices].contiguous()
- q_w = q_w[sort_indices, :].contiguous()
- return (
- q_w.to(device=orig_device),
- g_idx.to(device=orig_device),
- sort_indices.to(device=orig_device),
- )
- def pack_rows(
- q_w: torch.Tensor,
- num_bits: int,
- size_k: int,
- size_n: int,
- ):
- assert q_w.shape == (size_k, size_n)
- pack_factor = get_pack_factor(num_bits)
- assert size_k % pack_factor == 0
- orig_device = q_w.device
- q_w = q_w.cpu().numpy().astype(numpy.uint32)
- q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32)
- for i in range(pack_factor):
- q_res |= q_w[i::pack_factor, :] << num_bits * i
- q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
- return q_res
- def pack_cols(
- q_w: torch.Tensor,
- num_bits: int,
- size_k: int,
- size_n: int,
- ):
- assert q_w.shape == (size_k, size_n)
- pack_factor = get_pack_factor(num_bits)
- assert size_n % pack_factor == 0
- orig_device = q_w.device
- q_w = q_w.cpu().numpy().astype(numpy.uint32)
- q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
- for i in range(pack_factor):
- q_res |= q_w[:, i::pack_factor] << num_bits * i
- q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
- q_res = q_res.contiguous()
- return q_res
- def unpack_cols(
- packed_q_w: torch.Tensor,
- num_bits: int,
- size_k: int,
- size_n: int,
- ):
- pack_factor = get_pack_factor(num_bits)
- assert size_n % pack_factor == 0
- assert packed_q_w.shape == (
- size_k, size_n // pack_factor
- ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
- packed_q_w.shape, size_k, size_n, pack_factor)
- orig_device = packed_q_w.device
- packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
- q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
- mask = (1 << num_bits) - 1
- for i in range(pack_factor):
- vals = packed_q_w_cpu & mask
- packed_q_w_cpu >>= num_bits
- q_res[:, i::pack_factor] = vals
- q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
- q_res = q_res.contiguous()
- return q_res
- def gptq_pack(
- q_w: torch.Tensor,
- num_bits: int,
- size_k: int,
- size_n: int,
- ):
- return pack_rows(q_w, num_bits, size_k, size_n)
- def awq_pack(
- q_w: torch.Tensor,
- num_bits: int,
- size_k: int,
- size_n: int,
- ):
- assert q_w.shape == (size_k, size_n)
- # Interleave column dim (for the dequantize code) and pack it to int32
- if num_bits == 4:
- interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
- elif num_bits == 8:
- interleave = numpy.array([0, 2, 1, 3])
- else:
- raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
- q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel()
- q_w = q_w.reshape((-1, size_n)).contiguous()
- return pack_cols(q_w, num_bits, size_k, size_n)
|