bitsandbytes.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. from torch.nn.parameter import Parameter
  4. from aphrodite.modeling.layers.linear import (LinearBase, LinearMethodBase,
  5. set_weight_attrs)
  6. from aphrodite.quantization.base_config import QuantizationConfig
  7. class BitsAndBytesConfig(QuantizationConfig):
  8. """Config class for BitsAndBytes Quantization.
  9. Reference: https://arxiv.org/abs/2305.14314
  10. """
  11. def __init__(self, ) -> None:
  12. pass
  13. def __repr__(self) -> str:
  14. return "BitsAndBytesConfig"
  15. @classmethod
  16. def get_name(self) -> str:
  17. return "bitsandbytes"
  18. @classmethod
  19. def get_supported_act_dtypes(self) -> List[torch.dtype]:
  20. return [torch.float32, torch.float16, torch.bfloat16]
  21. @classmethod
  22. def get_min_capability(cls) -> int:
  23. return 70
  24. @staticmethod
  25. def get_config_filenames() -> List[str]:
  26. return [
  27. "adapter_config.json",
  28. ]
  29. @classmethod
  30. def from_config(cls, config: Dict[str, Any]) -> "BitsAndBytesConfig":
  31. return cls()
  32. def get_quant_method(self, layer: torch.nn.Module,
  33. prefix: str) -> Optional["BitsAndBytesLinearMethod"]:
  34. if isinstance(layer, LinearBase):
  35. return BitsAndBytesLinearMethod(self)
  36. return None
  37. def get_scaled_act_names(self) -> List[str]:
  38. return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
  39. class BitsAndBytesLinearMethod(LinearMethodBase):
  40. """Linear method for BitsAndBytes.
  41. Args:
  42. quant_config: The BitsAndBytes quantization config.
  43. """
  44. def __init__(self, quant_config: BitsAndBytesConfig):
  45. try:
  46. import bitsandbytes
  47. if bitsandbytes.__version__ < "0.42.0":
  48. raise ImportError("bitsandbytes version is wrong. Please "
  49. "install bitsandbytes>=0.42.0.")
  50. except ImportError as err:
  51. raise ImportError("Please install bitsandbytes>=0.42.0 via "
  52. "`pip install bitsandbytes>=0.42.0` to use "
  53. "bitsandbytes quantizer.") from err
  54. self.quant_config = quant_config
  55. def create_weights(self, layer: torch.nn.Module,
  56. input_size_per_partition: int,
  57. output_partition_sizes: List[int], input_size: int,
  58. output_size: int, params_dtype: torch.dtype,
  59. **extra_weight_attrs):
  60. quant_ratio = 0
  61. if params_dtype.is_floating_point:
  62. quant_ratio = torch.finfo(params_dtype).bits // torch.iinfo(
  63. torch.uint8).bits
  64. else:
  65. quant_ratio = torch.iinfo(params_dtype).bits // torch.iinfo(
  66. torch.uint8).bits
  67. if input_size_per_partition * sum(
  68. output_partition_sizes) % quant_ratio != 0:
  69. raise ValueError(
  70. "The input size is not aligned with the quantized "
  71. "weight shape. ")
  72. qweight = Parameter(
  73. torch.empty(
  74. input_size_per_partition * sum(output_partition_sizes) //
  75. quant_ratio,
  76. 1,
  77. dtype=torch.uint8,
  78. ),
  79. requires_grad=False,
  80. )
  81. set_weight_attrs(
  82. qweight,
  83. {
  84. "input_dim": 0,
  85. # In bitsandbytes, a tensor of shape [n,m] is quantized to
  86. #[n*m/pack_ratio, 1],so the output_dim is 0
  87. "output_dim": 0,
  88. "pack_factor": quant_ratio,
  89. "use_bitsandbytes": True,
  90. })
  91. layer.register_parameter("qweight", qweight)
  92. set_weight_attrs(qweight, extra_weight_attrs)
  93. def apply(self,
  94. layer: torch.nn.Module,
  95. x: torch.Tensor,
  96. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  97. # only load the bitsandbytes module when needed
  98. from bitsandbytes import matmul_4bit
  99. original_type = x.dtype
  100. bf_x = x.to(torch.bfloat16)
  101. qweight = layer.qweight
  102. quant_states = qweight.bnb_quant_state
  103. offsets = qweight.bnb_shard_offsets
  104. out_dim_0 = x.shape[0]
  105. out_dim_1 = sum(
  106. [quant_state[1].shape[0] for quant_state in quant_states.items()])
  107. out = torch.empty(out_dim_0,
  108. out_dim_1,
  109. dtype=torch.bfloat16,
  110. device=x.device)
  111. current_index = 0
  112. for i in range(len(quant_states)):
  113. output_size = quant_states[i].shape[0]
  114. # It is more efficient to use out kwarg like
  115. # matmul_4bit(..., out = ...). Infeasible now due to the bug
  116. # https://github.com/TimDettmers/bitsandbytes/issues/1235.
  117. # Need to change after the bug is fixed.
  118. out[:, current_index:current_index + output_size] = matmul_4bit(
  119. bf_x, qweight[offsets[i]:offsets[i + 1]].t(), quant_states[i])
  120. current_index += output_size
  121. out = out.to(original_type)
  122. if bias is not None:
  123. out += bias
  124. return out