from typing import Optional, Tuple import torch from aphrodite import _custom_ops as ops from aphrodite.modeling.parameter import (BaseAphroditeParameter, permute_param_layout_) from aphrodite.quantization.utils.marlin_utils import ( MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear, check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx, query_marlin_supported_quant_types) from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig class MarlinLinearKernel(MPLinearKernel): @classmethod def get_min_capability(cls) -> int: return 80 @classmethod def can_implement(cls, c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: if c.zero_points: return False, "Zero points currently not supported by "\ " MarlinLinearKernel. Will be added when AWQMarlin "\ "is migrated over to using MPLinearKernel backend" quant_types = query_marlin_supported_quant_types(c.zero_points) if c.weight_type not in quant_types: return False, f"Quant type ({c.weight_type}) not supported by"\ f" Marlin, supported types are: {quant_types}" if c.group_size not in MARLIN_SUPPORTED_GROUP_SIZES: return False, f"Group size ({c.group_size}) not supported by "\ "Marlin, supported group sizes are: "\ f"{MARLIN_SUPPORTED_GROUP_SIZES}" return check_marlin_supports_shape(c.partition_weight_shape[0], c.partition_weight_shape[1], c.full_weight_shape[1], c.group_size) # note assumes that # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} # `weight_scale` is: {input_dim = 0, output_dim = 1} def process_weights_after_loading(self, layer: torch.nn.Module) -> None: device = getattr(layer, self.w_q_name).device c = self.config row_parallel = (c.partition_weight_shape[0] != c.full_weight_shape[0]) self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel) # Allocate marlin workspace. self.workspace = marlin_make_workspace(c.partition_weight_shape[1], device) # Default names since marlin requires empty parameters for these, # TODO: remove this requirement from marlin (allow optional tensors) if self.w_gidx_name is None: self.w_gidx_name = "g_idx" if self.w_zp_name is None: self.w_zp_name = "w_zp" if c.has_g_idx: g_idx, g_idx_sort_indices = marlin_sort_g_idx( getattr(layer, self.w_gidx_name)) self._transform_param(layer, self.w_gidx_name, lambda _: g_idx) layer.g_idx_sort_indices = g_idx_sort_indices else: setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device)) layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) if c.zero_points: pass # TODO (lucas): add the following when AWQMarlin is migrated over to # using MPLinearKernel backend # self._transform_param(layer, self.w_zp_name, lambda x: \ # marlin_zero_points( # x, # size_k=c.partition_weight_shape[0], # size_n=c.partition_weight_shape[1], # num_bits=c.weight_type.size_bits)) else: setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device)) def transform_w_q(x): assert isinstance(x, BaseAphroditeParameter) permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) x.data = ops.gptq_marlin_repack(x.data.contiguous(), perm=layer.g_idx_sort_indices, size_k=c.partition_weight_shape[0], size_n=c.partition_weight_shape[1], num_bits=c.weight_type.size_bits) return x def transform_w_s(x): assert isinstance(x, BaseAphroditeParameter) permute_param_layout_(x, input_dim=0, output_dim=1) x.data = marlin_permute_scales(x.data.contiguous(), size_k=c.partition_weight_shape[0], size_n=c.partition_weight_shape[1], group_size=c.group_size) return x self._transform_param(layer, self.w_q_name, transform_w_q) self._transform_param(layer, self.w_s_name, transform_w_s) def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: c = self.config w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer) # `process_weights_after_loading` will ensure w_zp and w_gidx are not # None for marlin return apply_gptq_marlin_linear( input=x, weight=w_q, weight_scale=w_s, weight_zp=w_zp, # type: ignore g_idx=w_gidx, # type: ignore g_idx_sort_indices=layer.g_idx_sort_indices, workspace=self.workspace, wtype=c.weight_type, input_size_per_partition=c.partition_weight_shape[0], output_size_per_partition=c.partition_weight_shape[1], is_k_full=self.is_k_full, bias=bias)