bitsandbytes.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. import torch
  2. from torch.nn.parameter import Parameter
  3. from typing import List, Dict, Any, Optional, TypeVar, NamedTuple
  4. from contextlib import suppress
  5. from aphrodite.modeling.layers.linear import (LinearMethodBase,
  6. set_weight_attrs)
  7. from aphrodite.quantization.base_config import (QuantizationConfig)
  8. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  9. QKVParallelLinear,
  10. RowParallelLinear)
  11. HAS_QUANTS = False
  12. with suppress(ImportError):
  13. from aphrodite._quant_C import quant_ops as ops
  14. HAS_QUANTS = True
  15. class BitsandBytesConfig(QuantizationConfig):
  16. """Config class for BitsandBytes.
  17. Reference: https://arxiv.org/abs/2208.07339
  18. """
  19. def __init__(
  20. self,
  21. weight_bits: int,
  22. group_size: int,
  23. zero_point: bool,
  24. from_float: bool,
  25. quant_mode: str, # llm_int8, smoothquant, weight_only
  26. ) -> None:
  27. if not HAS_QUANTS:
  28. raise ImportError("Could not find the quantization kernels.")
  29. self.weight_bits = weight_bits
  30. self.group_size = group_size
  31. self.zero_point = zero_point
  32. self.from_float = from_float
  33. self.quant_mode = quant_mode
  34. if quant_mode == "weight_only" and self.weight_bits != 4:
  35. raise ValueError(
  36. "Currently, only 4-bit weight quantization is supported for "
  37. f"BNB weight_only, but got {self.weight_bits} bits.")
  38. if quant_mode in ["llm_int8", "smoothquant"] and self.weight_bits != 8:
  39. raise ValueError(
  40. "Currently, only 8-bit weight quantization is supported for "
  41. "BNB llm_int8 or smoothquant, "
  42. f"but got {self.weight_bits} bits.")
  43. self.pack_factor = 32 // self.weight_bits
  44. def __repr__(self) -> str:
  45. return (f"BitsandBytesConfig(weight_bits={self.weight_bits}, "
  46. f"group_size={self.group_size}, "
  47. f"zero_point={self.zero_point}, "
  48. f"from_float={self.from_float}, "
  49. f"quant_mode={self.quant_mode})")
  50. def get_name(self) -> str:
  51. return "bitsandbytes"
  52. def get_supported_act_dtypes(self) -> List[torch.dtype]:
  53. return [torch.half, torch.bfloat16]
  54. def get_min_capability(self) -> int:
  55. # The BitsandBytes kernel only supports Ampere or newer GPUs.
  56. return 75
  57. def merge_weight(self) -> bool:
  58. return True
  59. def rope_style(self) -> Optional[bool]:
  60. return None
  61. def quant_vocab(self) -> List[bool]:
  62. return [False, False]
  63. def support_fused_moe(self) -> bool:
  64. return False
  65. @staticmethod
  66. def get_config_filenames() -> List[str]:
  67. return [
  68. "quant_config.json",
  69. "quantize_config.json",
  70. ]
  71. @classmethod
  72. def from_config(cls, config: Dict[str, Any]) -> "BitsandBytesConfig":
  73. weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
  74. group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
  75. zero_point = cls.get_from_keys(config, ["zero_point"])
  76. try:
  77. from_float = cls.get_from_keys(config, ["from_float"])
  78. except Exception:
  79. from_float = False
  80. try:
  81. quant_mode = cls.get_from_keys(config, ["quant_mode"])
  82. except Exception:
  83. quant_mode = "weight_only"
  84. return cls(weight_bits, group_size, zero_point, from_float, quant_mode)
  85. def get_linear_method(self) -> "BNBLinearMethod":
  86. return BNBLinearMethod(self)
  87. def get_scaled_act_names(self) -> List[str]:
  88. return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
  89. class BNBLinearMethod(LinearMethodBase):
  90. """Linear method for BitsandBytes.
  91. Args:
  92. quant_config: The BitsandBytes quantization config.
  93. """
  94. def __init__(self, quant_config: BitsandBytesConfig):
  95. self.quant_config = quant_config
  96. def create_weights(self, input_size_per_partition: int,
  97. output_partition_sizes: List[int], input_size: int,
  98. output_size: int,
  99. params_dtype: torch.dtype) -> Dict[str, Any]:
  100. if self.quant_config.quant_mode == "weight_only" and \
  101. input_size_per_partition % self.quant_config.group_size != 0:
  102. raise ValueError(
  103. "The input size is not aligned with the quantized "
  104. "weight shape. This can be caused by too large "
  105. "tensor parallel size.")
  106. output_size_per_partition = sum(output_partition_sizes)
  107. if self.quant_config.quant_mode == "weight_only" and \
  108. output_size_per_partition % self.quant_config.pack_factor != 0:
  109. raise ValueError(
  110. "The output size is not aligned with the quantized "
  111. "weight shape. This can be caused by too large "
  112. "tensor parallel size.")
  113. if self.quant_config.quant_mode == "weight_only" and \
  114. not self.quant_config.from_float:
  115. qweight = Parameter(
  116. torch.empty(
  117. input_size_per_partition,
  118. output_size_per_partition // self.quant_config.pack_factor,
  119. dtype=torch.int32,
  120. ),
  121. requires_grad=False,
  122. )
  123. set_weight_attrs(
  124. qweight, {
  125. "input_dim": 0,
  126. "output_dim": 1,
  127. "packed_dim": 1,
  128. "pack_factor": self.quant_config.pack_factor,
  129. })
  130. qzeros = Parameter(
  131. torch.empty(
  132. input_size_per_partition // self.quant_config.group_size,
  133. output_size_per_partition // self.quant_config.pack_factor,
  134. dtype=torch.int32,
  135. ),
  136. requires_grad=False,
  137. )
  138. set_weight_attrs(
  139. qzeros, {
  140. "input_dim": 0,
  141. "output_dim": 1,
  142. "packed_dim": 1,
  143. "pack_factor": self.quant_config.pack_factor,
  144. })
  145. scales = Parameter(
  146. torch.empty(
  147. input_size_per_partition // self.quant_config.group_size,
  148. output_size_per_partition,
  149. dtype=params_dtype,
  150. ),
  151. requires_grad=False,
  152. )
  153. set_weight_attrs(scales, {
  154. "input_dim": 0,
  155. "output_dim": 1,
  156. })
  157. return {
  158. "qweight": qweight,
  159. "qzeros": qzeros,
  160. "scales": scales,
  161. }
  162. else:
  163. weight = Parameter(torch.empty(output_size_per_partition,
  164. input_size_per_partition,
  165. dtype=params_dtype),
  166. requires_grad=False)
  167. set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
  168. return {"weight": weight}
  169. def apply_weights(self,
  170. weights: Dict[str, torch.Tensor],
  171. x: torch.Tensor,
  172. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  173. if self.quant_config.quant_mode == "weight_only":
  174. qweight = weights["qweight"].data
  175. scales_zeros = weights["scales_zeros"].data
  176. pack_factor = self.quant_config.pack_factor
  177. out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
  178. reshaped_x = x.reshape(-1, x.shape[-1])
  179. out = ops.autoquant_s4_f16_gemm(reshaped_x, qweight, scales_zeros)
  180. if bias is not None:
  181. out = out + bias
  182. return out.reshape(out_shape)
  183. else:
  184. weight = weights["weight"]
  185. state = weights["state"]
  186. if weight.CB is not None:
  187. state.CB = weight.CB
  188. state.SCB = weight.SCB
  189. weight.CB = None
  190. weight.SCB = None
  191. import bitsandbytes as bnb
  192. out = bnb.matmul(x, weight, bias=bias, state=state)
  193. if not state.has_fp16_weights and \
  194. state.CB is not None and state.CxB is not None:
  195. # we converted 8-bit row major to turing/ampere format
  196. # in the first inference pass
  197. # we no longer need the row-major weight
  198. del state.CB
  199. weight.data = state.CxB
  200. return out
  201. def apply_moe_weights(self, w1: Dict[str,
  202. torch.Tensor], w2: Dict[str,
  203. torch.Tensor],
  204. x: torch.Tensor, gating_output: torch.Tensor,
  205. topk: int, renormalize: bool) -> torch.Tensor:
  206. raise NotImplementedError
  207. T = TypeVar("T", bound="torch.nn.Module")
  208. class QParams(NamedTuple):
  209. """A class to hold the quantization parameters."""
  210. scales: torch.Tensor
  211. zero_points: Optional[torch.Tensor]
  212. @torch.no_grad()
  213. def cal_qparams_per_group_minmax(w: torch.Tensor,
  214. n_bits: int = 4,
  215. group_size: int = 128):
  216. """Calculate quantization parameters for each group using min and max
  217. values."""
  218. outc, inc = w.shape
  219. assert inc >= group_size, \
  220. 'Input channels should be greater than or equal to group_size.'
  221. assert inc % group_size == 0, \
  222. 'Input channels should be divisible by group_size.'
  223. w_group_wise = w.reshape(outc, -1, group_size)
  224. w_min = w_group_wise.min(dim=-1, keepdim=True)[0]
  225. w_max = w_group_wise.max(dim=-1, keepdim=True)[0]
  226. q_max = 2**n_bits - 1
  227. q_min = 0
  228. scales = (w_max - w_min)
  229. scales = scales.clamp_(min=1e-5).div_(q_max)
  230. # zero_points = (-w_min / scales).round().clamp(q_min, q_max)
  231. zero_points = (-torch.round(w_min / scales)).clamp_(q_min, q_max)
  232. return QParams(scales=scales, zero_points=zero_points)
  233. def convert_s4(qw: torch.Tensor,
  234. qz: torch.Tensor,
  235. s: torch.Tensor,
  236. group_size: int = 128):
  237. assert qw.is_contiguous()
  238. assert qz.is_contiguous()
  239. assert s.is_contiguous()
  240. _qw = torch.zeros_like(qw)
  241. _sz = torch.zeros_like(s, dtype=torch.int32) # half2
  242. _ws = torch.zeros_like(s)
  243. ops.autoquant_convert_s4_k_m8(_qw, _sz, _ws, qw, s, qz,
  244. qw.size(-1) * 8, qw.size(0), group_size)
  245. return _qw, _sz
  246. def tp_m_s4(x: torch.Tensor, tp: int = 1):
  247. return x.view(x.size(0) // 32, tp, -1, 128).permute(0, 2, 3,
  248. 1).contiguous()
  249. def quant(weight: torch.Tensor,
  250. qparams: Optional[QParams] = None) -> torch.Tensor:
  251. """Perform fake quantization on the given weight tensor.
  252. Args:
  253. weight (torch.Tensor): The weight tensor with shape
  254. (out_features, in_features).
  255. qparams (Optional[QParams]): A namedtuple containing 'scales'
  256. and 'zero_points'.
  257. Returns:
  258. torch.Tensor: The fake quantized weight tensor.
  259. """
  260. if qparams is None:
  261. qparams = cal_qparams_per_group_minmax(weight)
  262. scales = qparams.scales
  263. zero_points = qparams.zero_points
  264. out_c, in_c = weight.shape
  265. # Reshape the weights if using per_group quantization
  266. # per tensor scales shape: [1]
  267. # per channel scales shape: [out_c, 1]
  268. # per group scales shape: [out_c, in_c//group_size, 1]
  269. if len(scales.shape) > 2:
  270. # scales shape: [out_c, in_c//group_size, 1]
  271. weight = weight.reshape(out_c, scales.shape[1], -1)
  272. if zero_points is None:
  273. real_qweight = (weight / scales).round()
  274. else:
  275. real_qweight = ((weight + (scales * zero_points)) / scales).round()
  276. if len(scales.shape) > 2:
  277. real_qweight = real_qweight.reshape(out_c, in_c)
  278. return real_qweight.to(torch.int32)
  279. # core quantization method (simulated quantization)
  280. def quantize_tensor(
  281. weight,
  282. n_bits=4,
  283. group_size=128,
  284. ):
  285. pack_num = 32 // n_bits
  286. pack_order = [0, 2, 4, 6, 1, 3, 5, 7]
  287. org_weight_shape = weight.shape
  288. out_features = org_weight_shape[0]
  289. in_features = org_weight_shape[1]
  290. qparams = cal_qparams_per_group_minmax(weight, n_bits)
  291. i32_w = quant(weight, qparams)
  292. i32_w = i32_w.t().contiguous()
  293. w_pack_oc = out_features // (32 // n_bits)
  294. w_inc = in_features
  295. pack_int_w = torch.zeros((w_inc, w_pack_oc),
  296. dtype=torch.int32,
  297. device=weight.device)
  298. for col in range(pack_int_w.shape[1]):
  299. for i in range(pack_num):
  300. pack_int_w_col = i32_w[:, col * pack_num + pack_order[i]]
  301. pack_int_w[:, col] |= pack_int_w_col << (i * n_bits)
  302. qweight = pack_int_w
  303. scales = qparams.scales.squeeze(-1).t().contiguous()
  304. if qparams.zero_points is not None:
  305. zeros = qparams.zero_points.to(torch.int32)
  306. zeros = zeros.squeeze(-1).t().contiguous()
  307. z_inc = in_features // group_size
  308. z_oc = out_features // (32 // n_bits)
  309. pack_int_zeros = torch.zeros((z_inc, z_oc),
  310. dtype=torch.int32,
  311. device=weight.device)
  312. for col in range(pack_int_zeros.shape[1]):
  313. for i in range(pack_num):
  314. qzero_col = zeros[:, col * pack_num + pack_order[i]]
  315. pack_int_zeros[:, col] |= qzero_col << (i * n_bits)
  316. qzeros = pack_int_zeros
  317. return qweight, scales, qzeros
  318. def replace_quant_params(model,
  319. quant_config,
  320. modules_to_not_convert="lm_head"):
  321. """
  322. modules_to_not_convert (`str`, *optional*, defaults to `lm_head`):
  323. Name of the module to not convert in `Linear8bitLt`.
  324. In practice we keep the `lm_head` in full precision
  325. for numerical stability reasons.
  326. """
  327. if not isinstance(modules_to_not_convert, list):
  328. modules_to_not_convert = [modules_to_not_convert]
  329. for name, module in model.named_children():
  330. if len(list(module.children())) > 0:
  331. replace_quant_params(module, quant_config, modules_to_not_convert)
  332. if isinstance(
  333. module,
  334. (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear)) \
  335. and name not in modules_to_not_convert:
  336. if quant_config.from_float:
  337. module.linear_weights.pop("weight")
  338. param = module._parameters["weight"]
  339. if quant_config.quant_mode in ("llm_int8", "smoothquant"):
  340. import bitsandbytes as bnb
  341. new_value = bnb.nn.Int8Params(param.data,
  342. requires_grad=False,
  343. has_fp16_weights=False)
  344. state = bnb.MatmulLtState()
  345. if quant_config.quant_mode == "smoothquant":
  346. state.threshold = 0.0
  347. else:
  348. state.threshold = 6.0
  349. state.has_fp16_weights = False
  350. state.memory_efficient_backward = False
  351. state.use_pool = True
  352. module._parameters["weight"] = new_value
  353. module.linear_weights["weight"] = new_value
  354. module.linear_weights["state"] = state
  355. set_weight_attrs(
  356. new_value, {
  357. "input_dim": 0,
  358. "output_dim": 1,
  359. "packed_dim": 1,
  360. "pack_factor": quant_config.pack_factor,
  361. })
  362. del param
  363. torch.cuda.empty_cache()
  364. elif quant_config.quant_mode == "weight_only":
  365. data_fp = param.cuda()
  366. _qweight, _scales, _qzeros = quantize_tensor(
  367. data_fp, n_bits=4, group_size=128)
  368. qweight, scales_zeros = convert_s4(_qweight, _qzeros,
  369. _scales)
  370. torch.cuda.synchronize()
  371. param_qweight = Parameter(qweight, requires_grad=False)
  372. param_scales_zeros = Parameter(scales_zeros,
  373. requires_grad=False)
  374. module.register_parameter("qweight", param_qweight)
  375. module.register_parameter("scales_zeros",
  376. param_scales_zeros)
  377. set_weight_attrs(
  378. param_qweight, {
  379. "input_dim": 0,
  380. "output_dim": 1,
  381. "packed_dim": 1,
  382. "pack_factor": quant_config.pack_factor,
  383. })
  384. set_weight_attrs(param_scales_zeros, {
  385. "input_dim": 0,
  386. "output_dim": 1,
  387. })
  388. module.linear_weights["qweight"] = param_qweight
  389. module.linear_weights["scales_zeros"] = param_scales_zeros
  390. del _qzeros
  391. del _scales
  392. del param
  393. delattr(module, "weight")
  394. torch.cuda.empty_cache()
  395. else: # load packed int4 weight
  396. module.linear_weights.pop("qweight")
  397. module.linear_weights.pop("qzeros")
  398. module.linear_weights.pop("scales")
  399. _qweight = module._parameters["qweight"]
  400. _qzeros = module._parameters["qzeros"]
  401. _scales = module._parameters["scales"]
  402. qweight, scales_zeros = convert_s4(_qweight.data, _qzeros.data,
  403. _scales.data)
  404. param_qweight = Parameter(qweight, requires_grad=False)
  405. param_scales_zeros = Parameter(scales_zeros,
  406. requires_grad=False)
  407. del _qweight
  408. del _qzeros
  409. del _scales
  410. delattr(module, "qweight")
  411. delattr(module, "qzeros")
  412. delattr(module, "scales")
  413. module.register_parameter("qweight", param_qweight)
  414. module.register_parameter("scales_zeros", param_scales_zeros)
  415. set_weight_attrs(
  416. param_qweight, {
  417. "input_dim": 0,
  418. "output_dim": 1,
  419. "packed_dim": 1,
  420. "pack_factor": quant_config.pack_factor,
  421. })
  422. set_weight_attrs(param_scales_zeros, {
  423. "input_dim": 0,
  424. "output_dim": 1,
  425. })
  426. module.linear_weights["qweight"] = param_qweight
  427. module.linear_weights["scales_zeros"] = param_scales_zeros
  428. torch.cuda.synchronize()
  429. torch.cuda.empty_cache()