123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306 |
- from typing import List, Optional, Tuple
- import numpy
- import torch
- from aphrodite import _custom_ops as ops
- from aphrodite.platforms import current_platform
- from aphrodite.scalar_type import ScalarType, scalar_types
- from .quant_utils import pack_cols, unpack_cols
- GPTQ_MARLIN_TILE = 16
- GPTQ_MARLIN_MIN_THREAD_N = 64
- GPTQ_MARLIN_MIN_THREAD_K = 128
- GPTQ_MARLIN_MAX_PARALLEL = 16
- MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
- # In case there is a performance issue with Marlin, the variable below can be
- # changed to False, which allows Marlin to perform global reductions in fp16
- # precision (instead of fp32), and therefore, save on some memory movements.
- USE_FP32_REDUCE_DEFAULT = True
- # For binary size and compile time, we don't support the same types for with and
- # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
- # TODO: we may want to move this into the C++ so its closer to the actual impl
- def query_marlin_supported_quant_types(has_zp: bool,
- device_capability: Optional[int] = None
- ):
- if device_capability is None:
- major, minor = current_platform.get_device_capability()
- device_capability = major * 10 + minor
- if device_capability < 80:
- return []
- if has_zp:
- # AWQ style, unsigned + runtime zero-point
- return [scalar_types.uint4, scalar_types.uint8]
- else:
- # GPTQ style, unsigned + symmetric bias
- # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able
- # to add `scalar_types.float8_e4m3fn` here
- return [scalar_types.uint4b8, scalar_types.uint8b128]
- def _check_marlin_supported(
- quant_type: ScalarType,
- group_size: Optional[int],
- has_zp: bool,
- device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
- if device_capability is None:
- major, minor = current_platform.get_device_capability()
- device_capability = major * 10 + minor
- supported_types = query_marlin_supported_quant_types(
- has_zp, device_capability)
- if quant_type not in supported_types:
- return (False, f"Marlin does not support weight_bits = {quant_type}. "
- f"Only types = {supported_types} "
- f"are supported (for group_size = {group_size}, "
- f"device_capability = {device_capability}, zp = {has_zp}).")
- if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES):
- return (False, f"Marlin does not support group_size = {group_size}. "
- f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
- "are supported.")
- return True, None
- def check_marlin_supported(quant_type: ScalarType,
- group_size: int,
- has_zp: bool = False,
- device_capability: Optional[int] = None) -> bool:
- cond, _ = _check_marlin_supported(quant_type, group_size, has_zp,
- device_capability)
- return cond
- def verify_marlin_supported(quant_type: ScalarType,
- group_size: int,
- has_zp: bool = False) -> None:
- cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp)
- if not cond:
- assert err_msg is not None
- raise ValueError(err_msg)
- def verify_marlin_supports_shape(output_size_per_partition: int,
- input_size_per_partition: int,
- input_size: int, group_size: int) -> None:
- # Validate output_size_per_partition
- if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0:
- raise ValueError(f"Weight output_size_per_partition = "
- f"{output_size_per_partition} is not divisible by "
- f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. "
- "Consider reducing tensor_parallel_size or running "
- "with --quantization gptq.")
- # Validate input_size_per_partition
- if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0:
- raise ValueError(f"Weight input_size_per_partition = "
- f"{input_size_per_partition} is not divisible "
- f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. "
- "Consider reducing tensor_parallel_size or running "
- "with --quantization gptq.")
- if (group_size < input_size
- and input_size_per_partition % group_size != 0):
- raise ValueError(
- f"Weight input_size_per_partition = {input_size_per_partition}"
- f" is not divisible by group_size = {group_size}."
- "Consider reducing tensor_parallel_size or running "
- "with --quantization gptq.")
- def marlin_make_workspace(output_size_per_partition: int,
- device: torch.device) -> torch.Tensor:
- max_workspace_size = (output_size_per_partition //
- GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL
- return torch.zeros(max_workspace_size,
- dtype=torch.int,
- device=device,
- requires_grad=False)
- def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
- return (not act_order) or (act_order and not is_row_parallel)
- def marlin_repeat_scales_on_all_ranks(act_order: bool, group_size: int,
- is_row_parallel: bool) -> bool:
- # Need to repeat scales on every rank if act_ordering or
- # channelwise and RowParallelLinear
- is_channelwise = group_size == -1
- return act_order or (is_channelwise and is_row_parallel)
- def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
- return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
- requires_grad=False)
- def marlin_sort_g_idx(
- g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
- g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
- return g_idx[g_idx_sort_indices], g_idx_sort_indices
- def get_scale_perms():
- scale_perm: List[int] = []
- for i in range(8):
- scale_perm.extend([i + 8 * j for j in range(8)])
- scale_perm_single: List[int] = []
- for i in range(4):
- scale_perm_single.extend(
- [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
- return scale_perm, scale_perm_single
- def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
- group_size: int) -> torch.Tensor:
- scale_perm, scale_perm_single = get_scale_perms()
- if group_size < size_k and group_size != -1:
- s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
- else:
- s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
- s = s.reshape((-1, size_n)).contiguous()
- return s
- def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int,
- num_bits: int) -> torch.Tensor:
- # Permute zero-points in a similar way to scales, but do not use the
- # "single" permutation, since zero-points are applied on every MMA
- scale_perm, _ = get_scale_perms()
- zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm]
- # Interleave column dim (for the dequantize code) and pack it to int32
- if num_bits == 4:
- interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
- elif num_bits == 8:
- interleave = numpy.array([0, 2, 1, 3])
- else:
- raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
- zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
- zp = zp.reshape((-1, size_n)).contiguous()
- zp = pack_cols(zp, num_bits, size_k, size_n)
- return zp
- def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
- size_n: int, num_bits: int) -> torch.Tensor:
- # AWQ zero-points are quantized and packed on the column dim.
- # In addition, the values are permuted based on dequantizer.
- # Here we undo both of these, and then apply marlin permutation
- # and pack it back.
- q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n)
- # Undo interleaving (use argsort(..) to get inverse perm)
- if num_bits == 4:
- undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7]))
- elif num_bits == 8:
- undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3]))
- else:
- raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
- q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
- q_zp = q_zp.reshape((-1, size_n)).contiguous()
- marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
- return marlin_zp
- # Newly generated tensors need to replace existing tensors that are
- # already registered as parameters by Aphrodite (and won't be freed)
- def replace_tensor(layer: torch.nn.Module, name: str,
- new_t: torch.Tensor) -> None:
- # It is important to use resize_() here since it ensures
- # the same buffer is reused
- getattr(layer, name).resize_(new_t.shape)
- getattr(layer, name).copy_(new_t)
- del new_t
- def apply_gptq_marlin_linear(
- input: torch.Tensor,
- weight: torch.Tensor,
- weight_scale: torch.Tensor,
- weight_zp: torch.Tensor,
- g_idx: torch.Tensor,
- g_idx_sort_indices: torch.Tensor,
- workspace: torch.Tensor,
- wtype: ScalarType,
- output_size_per_partition: int,
- input_size_per_partition: int,
- is_k_full: bool,
- bias: Optional[torch.Tensor] = None,
- use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
- reshaped_x = input.reshape(-1, input.shape[-1])
- out_shape = input.shape[:-1] + (output_size_per_partition, )
- output = ops.gptq_marlin_gemm(reshaped_x,
- weight,
- weight_scale,
- weight_zp,
- g_idx,
- g_idx_sort_indices,
- workspace,
- wtype,
- size_m=reshaped_x.shape[0],
- size_n=output_size_per_partition,
- size_k=input_size_per_partition,
- is_k_full=is_k_full,
- has_zp=False,
- use_fp32_reduce=use_fp32_reduce)
- if bias is not None:
- output.add_(bias) # In-place add
- return output.reshape(out_shape)
- def apply_awq_marlin_linear(
- input: torch.Tensor,
- weight: torch.Tensor,
- weight_scale: torch.Tensor,
- weight_zp: torch.Tensor,
- g_idx: torch.Tensor,
- g_idx_sort_indices: torch.Tensor,
- workspace: torch.Tensor,
- quant_type: ScalarType,
- output_size_per_partition: int,
- input_size_per_partition: int,
- bias: Optional[torch.Tensor] = None,
- use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
- reshaped_x = input.reshape(-1, input.shape[-1])
- out_shape = input.shape[:-1] + (output_size_per_partition, )
- output = ops.gptq_marlin_gemm(reshaped_x,
- weight,
- weight_scale,
- weight_zp,
- g_idx,
- g_idx_sort_indices,
- workspace,
- quant_type,
- size_m=reshaped_x.shape[0],
- size_n=output_size_per_partition,
- size_k=input_size_per_partition,
- is_k_full=True,
- has_zp=True,
- use_fp32_reduce=use_fp32_reduce)
- if bias is not None:
- output.add_(bias) # In-place add
- return output.reshape(out_shape)
|