bitsandbytes.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. from aphrodite.modeling.layers.linear import (LinearBase, LinearMethodBase,
  4. set_weight_attrs)
  5. from aphrodite.quantization.base_config import QuantizationConfig
  6. class BitsAndBytesConfig(QuantizationConfig):
  7. """Config class for BitsAndBytes Quantization.
  8. Reference: https://arxiv.org/abs/2305.14314
  9. """
  10. def __init__(
  11. self,
  12. load_in_8bit: bool = False,
  13. load_in_4bit: bool = True,
  14. bnb_4bit_compute_dtype: str = "float32",
  15. bnb_4bit_quant_type: str = "fp4",
  16. bnb_4bit_use_double_quant: bool = False,
  17. llm_int8_enable_fp32_cpu_offload: bool = False,
  18. llm_int8_has_fp16_weight: bool = False,
  19. llm_int8_skip_modules: Optional[Any] = None,
  20. llm_int8_threshold: float = 0.0,
  21. ) -> None:
  22. self.load_in_8bit = load_in_8bit
  23. self.load_in_4bit = load_in_4bit
  24. self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype
  25. self.bnb_4bit_quant_type = bnb_4bit_quant_type
  26. self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant
  27. self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload
  28. self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight
  29. self.llm_int8_skip_modules = llm_int8_skip_modules
  30. self.llm_int8_threshold = llm_int8_threshold
  31. def __repr__(self) -> str:
  32. return "BitsAndBytesConfig"
  33. @classmethod
  34. def get_name(self) -> str:
  35. return "bitsandbytes"
  36. @classmethod
  37. def get_supported_act_dtypes(self) -> List[torch.dtype]:
  38. return [torch.float32, torch.float16, torch.bfloat16]
  39. @classmethod
  40. def get_min_capability(cls) -> int:
  41. return 70
  42. @staticmethod
  43. def get_config_filenames() -> List[str]:
  44. return [
  45. "adapter_config.json",
  46. ]
  47. @classmethod
  48. def from_config(cls, config: Dict[str, Any]) -> "BitsAndBytesConfig":
  49. def get_safe_value(config, keys, default_value=None):
  50. try:
  51. value = cls.get_from_keys(config, keys)
  52. return value if value is not None else default_value
  53. except ValueError:
  54. return default_value
  55. load_in_8bit = get_safe_value(config, ["load_in_8bit"],
  56. default_value=False)
  57. load_in_4bit = get_safe_value(config, ["load_in_4bit"],
  58. default_value=True)
  59. bnb_4bit_compute_dtype = get_safe_value(config,
  60. ["bnb_4bit_compute_dtype"],
  61. default_value="float32")
  62. bnb_4bit_quant_type = get_safe_value(config, ["bnb_4bit_quant_type"],
  63. default_value="fp4")
  64. bnb_4bit_use_double_quant = get_safe_value(
  65. config, ["bnb_4bit_use_double_quant"], default_value=False)
  66. llm_int8_enable_fp32_cpu_offload = get_safe_value(
  67. config, ["llm_int8_enable_fp32_cpu_offload"], default_value=False)
  68. llm_int8_has_fp16_weight = get_safe_value(config,
  69. ["llm_int8_has_fp16_weight"],
  70. default_value=False)
  71. llm_int8_skip_modules = get_safe_value(config,
  72. ["llm_int8_skip_modules"],
  73. default_value=[])
  74. llm_int8_threshold = get_safe_value(config, ["llm_int8_threshold"],
  75. default_value=0.0)
  76. return cls(
  77. load_in_8bit=load_in_8bit,
  78. load_in_4bit=load_in_4bit,
  79. bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
  80. bnb_4bit_quant_type=bnb_4bit_quant_type,
  81. bnb_4bit_use_double_quant=bnb_4bit_use_double_quant,
  82. llm_int8_enable_fp32_cpu_offload=llm_int8_enable_fp32_cpu_offload,
  83. llm_int8_has_fp16_weight=llm_int8_has_fp16_weight,
  84. llm_int8_skip_modules=llm_int8_skip_modules,
  85. llm_int8_threshold=llm_int8_threshold)
  86. def get_quant_method(self, layer: torch.nn.Module,
  87. prefix: str) -> Optional["BitsAndBytesLinearMethod"]:
  88. if isinstance(layer, LinearBase):
  89. return BitsAndBytesLinearMethod(self)
  90. return None
  91. def get_scaled_act_names(self) -> List[str]:
  92. return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
  93. class BitsAndBytesLinearMethod(LinearMethodBase):
  94. """Linear method for BitsAndBytes.
  95. Args:
  96. quant_config: The BitsAndBytes quantization config.
  97. """
  98. def __init__(self, quant_config: BitsAndBytesConfig):
  99. try:
  100. import bitsandbytes
  101. if bitsandbytes.__version__ < "0.42.0":
  102. raise ImportError("bitsandbytes version is wrong. Please "
  103. "install bitsandbytes>=0.42.0.")
  104. except ImportError as err:
  105. raise ImportError("Please install bitsandbytes>=0.42.0 via "
  106. "`pip install bitsandbytes>=0.42.0` to use "
  107. "bitsandbytes quantizer.") from err
  108. self.quant_config = quant_config
  109. def create_weights(self, layer: torch.nn.Module,
  110. input_size_per_partition: int,
  111. output_partition_sizes: List[int], input_size: int,
  112. output_size: int, params_dtype: torch.dtype,
  113. **extra_weight_attrs):
  114. from bitsandbytes.nn import Int8Params
  115. def calculate_quant_ratio(dtype):
  116. if dtype.is_floating_point:
  117. return torch.finfo(dtype).bits // torch.iinfo(torch.uint8).bits
  118. else:
  119. return torch.iinfo(dtype).bits // torch.iinfo(torch.uint8).bits
  120. def create_qweight_for_8bit():
  121. qweight = Int8Params(
  122. data=torch.empty(sum(output_partition_sizes),
  123. input_size_per_partition,
  124. dtype=torch.int8),
  125. has_fp16_weights=self.quant_config.llm_int8_has_fp16_weight,
  126. requires_grad=False)
  127. set_weight_attrs(
  128. qweight, {
  129. "input_dim": 0,
  130. "output_dim": 0,
  131. "pack_factor": 1,
  132. "use_bitsandbytes_8bit": True,
  133. "generation": 0
  134. })
  135. return qweight
  136. def create_qweight_for_4bit():
  137. quant_ratio = calculate_quant_ratio(params_dtype)
  138. total_size = input_size_per_partition * sum(output_partition_sizes)
  139. if total_size % quant_ratio != 0:
  140. raise ValueError(
  141. "The input size is not aligned with the quantized "
  142. "weight shape.")
  143. qweight = torch.nn.Parameter(torch.empty(total_size // quant_ratio,
  144. 1,
  145. dtype=torch.uint8),
  146. requires_grad=False)
  147. set_weight_attrs(
  148. qweight, {
  149. "input_dim": 0,
  150. "output_dim": 0,
  151. "pack_factor": quant_ratio,
  152. "use_bitsandbytes_4bit": True
  153. })
  154. return qweight
  155. if self.quant_config.load_in_8bit:
  156. qweight = create_qweight_for_8bit()
  157. else:
  158. qweight = create_qweight_for_4bit()
  159. layer.register_parameter("qweight", qweight)
  160. set_weight_attrs(qweight, extra_weight_attrs)
  161. def apply(self,
  162. layer: torch.nn.Module,
  163. x: torch.Tensor,
  164. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  165. if self.quant_config.load_in_8bit:
  166. return self._apply_8bit_weight(layer, x, bias)
  167. else:
  168. return self._apply_4bit_weight(layer, x, bias)
  169. def _apply_8bit_weight(
  170. self,
  171. layer: torch.nn.Module,
  172. x: torch.Tensor,
  173. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  174. # only load the bitsandbytes module when needed
  175. from bitsandbytes import MatmulLtState, matmul
  176. original_type = x.dtype
  177. bf_x = x.to(torch.bfloat16)
  178. qweight = layer.qweight
  179. offsets = qweight.bnb_shard_offsets
  180. quant_states = qweight.bnb_quant_state
  181. matmul_states = qweight.matmul_state
  182. generation = qweight.generation
  183. out_dim_0 = x.shape[0]
  184. out_dim_1 = sum(
  185. [quant_state[1].shape[0] for quant_state in quant_states.items()])
  186. out = torch.empty(out_dim_0,
  187. out_dim_1,
  188. dtype=torch.float16,
  189. device=x.device)
  190. current_index = 0
  191. for i in range(len(quant_states)):
  192. output_size = quant_states[i].shape[0]
  193. # in profile_run or the first generation of inference,
  194. # create new matmul_states
  195. if generation == 0 or generation == 1:
  196. matmul_states[i] = MatmulLtState()
  197. matmul_states[i].CB = qweight[offsets[i]:offsets[i + 1]]
  198. matmul_states[i].SCB = quant_states[i]
  199. matmul_states[i].threshold = (
  200. self.quant_config.llm_int8_threshold)
  201. matmul_states[i].has_fp16_weights = (
  202. self.quant_config.llm_int8_has_fp16_weight)
  203. matmul_states[i].is_training = False
  204. if matmul_states[i].threshold > 0.0 and not matmul_states[
  205. i].has_fp16_weights:
  206. matmul_states[i].use_pool = True
  207. new_x = bf_x.unsqueeze(0)
  208. out[:, current_index:current_index + output_size] = matmul(
  209. new_x,
  210. qweight[offsets[i]:offsets[i + 1]],
  211. state=matmul_states[i])
  212. current_index += output_size
  213. # only update the matmul_states if it is not profile_run
  214. if (generation > 0
  215. and not self.quant_config.llm_int8_has_fp16_weight
  216. and matmul_states[i].CB is not None
  217. and matmul_states[i].CxB is not None):
  218. del matmul_states[i].CB
  219. qweight[offsets[i]:offsets[i + 1]] = matmul_states[i].CxB
  220. out = out.to(original_type)
  221. if bias is not None:
  222. out += bias
  223. qweight.generation += 1
  224. return out
  225. def _apply_4bit_weight(
  226. self,
  227. layer: torch.nn.Module,
  228. x: torch.Tensor,
  229. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  230. # only load the bitsandbytes module when needed
  231. from bitsandbytes import matmul_4bit
  232. original_type = x.dtype
  233. bf_x = x.to(torch.bfloat16)
  234. qweight = layer.qweight
  235. quant_states = qweight.bnb_quant_state
  236. offsets = qweight.bnb_shard_offsets
  237. out_dim_0 = x.shape[0]
  238. out_dim_1 = sum(
  239. [quant_state[1].shape[0] for quant_state in quant_states.items()])
  240. out = torch.empty(out_dim_0,
  241. out_dim_1,
  242. dtype=torch.bfloat16,
  243. device=x.device)
  244. current_index = 0
  245. for i in range(len(quant_states)):
  246. output_size = quant_states[i].shape[0]
  247. # It is more efficient to use out kwarg like
  248. # matmul_4bit(..., out = ...). Infeasible now due to the bug
  249. # https://github.com/TimDettmers/bitsandbytes/issues/1235.
  250. # Need to change after the bug is fixed.
  251. out[:, current_index:current_index + output_size] = matmul_4bit(
  252. bf_x, qweight[offsets[i]:offsets[i + 1]].t(), quant_states[i])
  253. current_index += output_size
  254. out = out.to(original_type)
  255. if bias is not None:
  256. out += bias
  257. return out