from functools import partial from typing import Optional, Tuple import torch from aphrodite import _custom_ops as ops from aphrodite.modeling.parameter import (BaseAphroditeParameter, permute_param_layout_) from aphrodite.quantization.utils.machete_utils import ( MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape, query_machete_supported_quant_types) from aphrodite.quantization.utils.quant_utils import ( pack_weights_into_int32, unpack_weights_into_int32) from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig class MacheteLinearKernel(MPLinearKernel): @classmethod def get_min_capability(cls) -> int: return 90 @classmethod def can_implement(cls, c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: if c.has_g_idx and\ c.partition_weight_shape[0] != c.full_weight_shape[0]: return False, "Act reordering currently not supported by Machete, "\ "when the input features are partitioned across "\ "devices" if c.zero_points: return False, "Zero points currently not supported by "\ " Compressed Tensors + Machete. (Kernel supports it"\ " but CompressedTensorsWNA16 does not so support has"\ " not been added to MacheteWNA16Kernel yet" if c.weight_type not in query_machete_supported_quant_types( c.zero_points): return False, f"Quant type ({c.weight_type}) not supported by "\ "Machete, supported types are: "\ f"{query_machete_supported_quant_types(c.zero_points)}" if c.group_size not in MACHETE_SUPPORTED_GROUP_SIZES: return False, f"Group size ({c.group_size}) not supported by "\ "Machete, supported group sizes are: "\ f"{MACHETE_SUPPORTED_GROUP_SIZES}" return check_machete_supports_shape(c.partition_weight_shape[0], c.partition_weight_shape[1]) # note assumes that # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} # `weight_scale` is: {input_dim = 0, output_dim = 1} def process_weights_after_loading(self, layer: torch.nn.Module): c = self.config if c.has_g_idx: assert self.w_gidx_name is not None perm = torch.argsort(getattr(layer, self.w_gidx_name))\ .to(torch.int) self.act_perm = lambda x: x[:, perm] # use `ops.permute_cols` if possible if c.act_type in [torch.float16, torch.bfloat16] \ and c.partition_weight_shape[0] % 8 == 0: self.act_perm = partial(ops.permute_cols, perm=perm) def transform_w_q(x): assert isinstance(x, BaseAphroditeParameter) permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) if c.has_g_idx: x_unpacked = unpack_weights_into_int32(x.data, c.weight_type, packed_dim=0) x_perm = x_unpacked[perm, :] x.data = pack_weights_into_int32(x_perm, c.weight_type, packed_dim=0) x.data = ops.machete_prepack_B(x.data.t().contiguous().t(), self.config.weight_type) return x def transform_w_s(x): assert isinstance(x, BaseAphroditeParameter) permute_param_layout_(x, input_dim=0, output_dim=1) x.data = x.data.contiguous() return x # Repack weights and scales for Machete self._transform_param(layer, self.w_q_name, transform_w_q) self._transform_param(layer, self.w_s_name, transform_w_s) def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: c = self.config w_q, w_s, _, _ = self._get_weight_params(layer) x_2d = x.reshape(-1, x.shape[-1]) out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) if c.has_g_idx: x_2d = self.act_perm(x_2d) output = ops.machete_gemm(a=x_2d, b_q=w_q, b_type=c.weight_type, b_zeros=None, b_scales=w_s, b_group_size=c.group_size) if bias is not None: output.add_(bias) # In-place add return output.reshape(out_shape)