123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153 |
- from typing import Any, Dict, List, Optional
- import torch
- from torch.nn.parameter import Parameter
- from aphrodite.modeling.layers.linear import (LinearBase, LinearMethodBase,
- set_weight_attrs)
- from aphrodite.quantization.base_config import QuantizationConfig
- class BitsAndBytesConfig(QuantizationConfig):
- """Config class for BitsAndBytes Quantization.
- Reference: https://arxiv.org/abs/2305.14314
- """
- def __init__(self, ) -> None:
- pass
- def __repr__(self) -> str:
- return "BitsAndBytesConfig"
- @classmethod
- def get_name(self) -> str:
- return "bitsandbytes"
- @classmethod
- def get_supported_act_dtypes(self) -> List[torch.dtype]:
- return [torch.float32, torch.float16, torch.bfloat16]
- @classmethod
- def get_min_capability(cls) -> int:
- return 70
- @staticmethod
- def get_config_filenames() -> List[str]:
- return [
- "adapter_config.json",
- ]
- @classmethod
- def from_config(cls, config: Dict[str, Any]) -> "BitsAndBytesConfig":
- return cls()
- def get_quant_method(self, layer: torch.nn.Module,
- prefix: str) -> Optional["BitsAndBytesLinearMethod"]:
- if isinstance(layer, LinearBase):
- return BitsAndBytesLinearMethod(self)
- return None
- def get_scaled_act_names(self) -> List[str]:
- return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
- class BitsAndBytesLinearMethod(LinearMethodBase):
- """Linear method for BitsAndBytes.
- Args:
- quant_config: The BitsAndBytes quantization config.
- """
- def __init__(self, quant_config: BitsAndBytesConfig):
- try:
- import bitsandbytes
- if bitsandbytes.__version__ < "0.42.0":
- raise ImportError("bitsandbytes version is wrong. Please "
- "install bitsandbytes>=0.42.0.")
- except ImportError as err:
- raise ImportError("Please install bitsandbytes>=0.42.0 via "
- "`pip install bitsandbytes>=0.42.0` to use "
- "bitsandbytes quantizer.") from err
- self.quant_config = quant_config
- def create_weights(self, layer: torch.nn.Module,
- input_size_per_partition: int,
- output_partition_sizes: List[int], input_size: int,
- output_size: int, params_dtype: torch.dtype,
- **extra_weight_attrs):
- quant_ratio = 0
- if params_dtype.is_floating_point:
- quant_ratio = torch.finfo(params_dtype).bits // torch.iinfo(
- torch.uint8).bits
- else:
- quant_ratio = torch.iinfo(params_dtype).bits // torch.iinfo(
- torch.uint8).bits
- if input_size_per_partition * sum(
- output_partition_sizes) % quant_ratio != 0:
- raise ValueError(
- "The input size is not aligned with the quantized "
- "weight shape. ")
- qweight = Parameter(
- torch.empty(
- input_size_per_partition * sum(output_partition_sizes) //
- quant_ratio,
- 1,
- dtype=torch.uint8,
- ),
- requires_grad=False,
- )
- set_weight_attrs(
- qweight,
- {
- "input_dim": 0,
- # In bitsandbytes, a tensor of shape [n,m] is quantized to
- #[n*m/pack_ratio, 1],so the output_dim is 0
- "output_dim": 0,
- "pack_factor": quant_ratio,
- "use_bitsandbytes": True,
- })
- layer.register_parameter("qweight", qweight)
- set_weight_attrs(qweight, extra_weight_attrs)
- def apply(self,
- layer: torch.nn.Module,
- x: torch.Tensor,
- bias: Optional[torch.Tensor] = None) -> torch.Tensor:
- # only load the bitsandbytes module when needed
- from bitsandbytes import matmul_4bit
- original_type = x.dtype
- bf_x = x.to(torch.bfloat16)
- qweight = layer.qweight
- quant_states = qweight.bnb_quant_state
- offsets = qweight.bnb_shard_offsets
- out_dim_0 = x.shape[0]
- out_dim_1 = sum(
- [quant_state[1].shape[0] for quant_state in quant_states.items()])
- out = torch.empty(out_dim_0,
- out_dim_1,
- dtype=torch.bfloat16,
- device=x.device)
- current_index = 0
- for i in range(len(quant_states)):
- output_size = quant_states[i].shape[0]
- # It is more efficient to use out kwarg like
- # matmul_4bit(..., out = ...). Infeasible now due to the bug
- # https://github.com/TimDettmers/bitsandbytes/issues/1235.
- # Need to change after the bug is fixed.
- out[:, current_index:current_index + output_size] = matmul_4bit(
- bf_x, qweight[offsets[i]:offsets[i + 1]].t(), quant_states[i])
- current_index += output_size
- out = out.to(original_type)
- if bias is not None:
- out += bias
- return out
|