123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160 |
- from typing import Callable, List, Optional, Set
- import torch
- from loguru import logger
- from aphrodite.modeling.parameter import (BaseAphroditeParameter,
- ChannelQuantScaleParameter,
- GroupQuantScaleParameter,
- PackedAphroditeParameter,
- RowAphroditeParameter)
- from aphrodite.quantization.compressed_tensors.schemes import (
- CompressedTensorsScheme)
- from aphrodite.quantization.compressed_tensors.utils import ActivationOrdering
- from aphrodite.quantization.kernels import (MPLinearLayerConfig,
- choose_mp_linear_kernel)
- from aphrodite.quantization.utils.marlin_utils import (
- marlin_repeat_scales_on_all_ranks)
- from aphrodite.scalar_type import scalar_types
- __all__ = ["CompressedTensorsWNA16"]
- WNA16_SUPPORTED_TYPES_MAP = {
- 4: scalar_types.uint4b8,
- 8: scalar_types.uint8b128
- }
- WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
- class CompressedTensorsWNA16(CompressedTensorsScheme):
- _kernel_backends_being_used: Set[str] = set()
- def __init__(self,
- strategy: str,
- num_bits: int,
- group_size: Optional[int] = None,
- actorder: Optional[ActivationOrdering] = None):
- self.pack_factor = 32 // num_bits
- self.strategy = strategy
- self.group_size = -1 if group_size is None else group_size
- self.has_g_idx = actorder == ActivationOrdering.GROUP
- if self.group_size == -1 and self.strategy != "channel":
- raise ValueError("Marlin kernels require group quantization or "
- "channelwise quantization, but found no group "
- "size and strategy is not channelwise.")
- if num_bits not in WNA16_SUPPORTED_TYPES_MAP:
- raise ValueError(
- f"Unsupported num_bits = {num_bits}. "
- f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}")
- self.quant_type = WNA16_SUPPORTED_TYPES_MAP[num_bits]
- @classmethod
- def get_min_capability(cls) -> int:
- # ampere and up
- return 80
- def create_weights(self, layer: torch.nn.Module, output_size: int,
- input_size: int, output_partition_sizes: List[int],
- input_size_per_partition: int,
- params_dtype: torch.dtype, weight_loader: Callable,
- **kwargs):
- output_size_per_partition = sum(output_partition_sizes)
- mp_linear_kernel_config = MPLinearLayerConfig(
- full_weight_shape=(input_size, output_size),
- partition_weight_shape=\
- (input_size_per_partition, output_size_per_partition),
- weight_type=self.quant_type,
- act_type=params_dtype,
- group_size=self.group_size,
- zero_points=False,
- has_g_idx=self.has_g_idx
- )
- kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
- if kernel_type.__name__ not in self._kernel_backends_being_used:
- logger.info(
- f"Using {kernel_type.__name__} for CompressedTensorsWNA16")
- self._kernel_backends_being_used.add(kernel_type.__name__)
- # If group_size is -1, we are in channelwise case.
- group_size = self.group_size if self.group_size != -1 else input_size
- row_parallel = (input_size != input_size_per_partition)
- partition_scales = not marlin_repeat_scales_on_all_ranks(
- self.has_g_idx, self.group_size, row_parallel)
- scales_and_zp_size = input_size // group_size
- if partition_scales:
- assert input_size_per_partition % group_size == 0
- scales_and_zp_size = input_size_per_partition // group_size
- weight = PackedAphroditeParameter(input_dim=1,
- output_dim=0,
- weight_loader=weight_loader,
- packed_factor=self.pack_factor,
- packed_dim=1,
- data=torch.empty(
- output_size_per_partition,
- input_size_per_partition //
- self.pack_factor,
- dtype=torch.int32,
- ))
- weight_scale_args = {
- "weight_loader":
- weight_loader,
- "data":
- torch.empty(
- output_size_per_partition,
- scales_and_zp_size,
- dtype=params_dtype,
- )
- }
- if not partition_scales:
- weight_scale = ChannelQuantScaleParameter(output_dim=0,
- **weight_scale_args)
- else:
- weight_scale = GroupQuantScaleParameter(output_dim=0,
- input_dim=1,
- **weight_scale_args)
- # A 2D array defining the original shape of the weights
- # before packing
- weight_shape = BaseAphroditeParameter(data=torch.empty(2,
- dtype=torch.int64),
- weight_loader=weight_loader)
- layer.register_parameter("weight_packed", weight)
- layer.register_parameter("weight_scale", weight_scale)
- layer.register_parameter("weight_shape", weight_shape)
- # group index (for activation reordering)
- if self.has_g_idx:
- weight_g_idx = RowAphroditeParameter(data=torch.empty(
- input_size_per_partition,
- dtype=torch.int32,
- ),
- input_dim=0,
- weight_loader=weight_loader)
- layer.register_parameter("weight_g_idx", weight_g_idx)
- self.kernel = kernel_type(mp_linear_kernel_config,
- w_q_param_name="weight_packed",
- w_s_param_name="weight_scale",
- w_zp_param_name=None,
- w_gidx_param_name="weight_g_idx")
- # Checkpoints are serialized in compressed-tensors format, which is
- # different from the format the kernel may want. Handle repacking here.
- def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
- self.kernel.process_weights_after_loading(layer)
- def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
- bias: Optional[torch.Tensor]) -> torch.Tensor:
- return self.kernel.apply_weights(layer, x, bias)
|