123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132 |
- from typing import Optional, Tuple
- import torch
- from aphrodite import _custom_ops as ops
- from aphrodite.modeling.parameter import (BaseAphroditeParameter,
- permute_param_layout_)
- from aphrodite.quantization.utils.marlin_utils import (
- MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear,
- check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx,
- marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx,
- query_marlin_supported_quant_types)
- from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
- class MarlinLinearKernel(MPLinearKernel):
- @classmethod
- def get_min_capability(cls) -> int:
- return 80
- @classmethod
- def can_implement(cls,
- c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
- if c.zero_points:
- return False, "Zero points currently not supported by "\
- " MarlinLinearKernel. Will be added when AWQMarlin "\
- "is migrated over to using MPLinearKernel backend"
- quant_types = query_marlin_supported_quant_types(c.zero_points)
- if c.weight_type not in quant_types:
- return False, f"Quant type ({c.weight_type}) not supported by"\
- f" Marlin, supported types are: {quant_types}"
- if c.group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
- return False, f"Group size ({c.group_size}) not supported by "\
- "Marlin, supported group sizes are: "\
- f"{MARLIN_SUPPORTED_GROUP_SIZES}"
- return check_marlin_supports_shape(c.partition_weight_shape[0],
- c.partition_weight_shape[1],
- c.full_weight_shape[1],
- c.group_size)
- # note assumes that
- # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
- # `weight_scale` is: {input_dim = 0, output_dim = 1}
- def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
- device = getattr(layer, self.w_q_name).device
- c = self.config
- row_parallel = (c.partition_weight_shape[0] != c.full_weight_shape[0])
- self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
- # Allocate marlin workspace.
- self.workspace = marlin_make_workspace(c.partition_weight_shape[1],
- device)
- # Default names since marlin requires empty parameters for these,
- # TODO: remove this requirement from marlin (allow optional tensors)
- if self.w_gidx_name is None:
- self.w_gidx_name = "g_idx"
- if self.w_zp_name is None:
- self.w_zp_name = "w_zp"
- if c.has_g_idx:
- g_idx, g_idx_sort_indices = marlin_sort_g_idx(
- getattr(layer, self.w_gidx_name))
- self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
- layer.g_idx_sort_indices = g_idx_sort_indices
- else:
- setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
- layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
- if c.zero_points:
- pass
- # TODO (lucas): add the following when AWQMarlin is migrated over to
- # using MPLinearKernel backend
- # self._transform_param(layer, self.w_zp_name, lambda x: \
- # marlin_zero_points(
- # x,
- # size_k=c.partition_weight_shape[0],
- # size_n=c.partition_weight_shape[1],
- # num_bits=c.weight_type.size_bits))
- else:
- setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
- def transform_w_q(x):
- assert isinstance(x, BaseAphroditeParameter)
- permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
- x.data = ops.gptq_marlin_repack(x.data.contiguous(),
- perm=layer.g_idx_sort_indices,
- size_k=c.partition_weight_shape[0],
- size_n=c.partition_weight_shape[1],
- num_bits=c.weight_type.size_bits)
- return x
- def transform_w_s(x):
- assert isinstance(x, BaseAphroditeParameter)
- permute_param_layout_(x, input_dim=0, output_dim=1)
- x.data = marlin_permute_scales(x.data.contiguous(),
- size_k=c.partition_weight_shape[0],
- size_n=c.partition_weight_shape[1],
- group_size=c.group_size)
- return x
- self._transform_param(layer, self.w_q_name, transform_w_q)
- self._transform_param(layer, self.w_s_name, transform_w_s)
- def apply_weights(self,
- layer: torch.nn.Module,
- x: torch.Tensor,
- bias: Optional[torch.Tensor] = None) -> torch.Tensor:
- c = self.config
- w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer)
- # `process_weights_after_loading` will ensure w_zp and w_gidx are not
- # None for marlin
- return apply_gptq_marlin_linear(
- input=x,
- weight=w_q,
- weight_scale=w_s,
- weight_zp=w_zp, # type: ignore
- g_idx=w_gidx, # type: ignore
- g_idx_sort_indices=layer.g_idx_sort_indices,
- workspace=self.workspace,
- wtype=c.weight_type,
- input_size_per_partition=c.partition_weight_shape[0],
- output_size_per_partition=c.partition_weight_shape[1],
- is_k_full=self.is_k_full,
- bias=bias)
|