bitsandbytes.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477
  1. import torch
  2. from torch.nn.parameter import Parameter
  3. from typing import List, Dict, Any, Optional, TypeVar, NamedTuple
  4. from contextlib import suppress
  5. from aphrodite.modeling.layers.linear import (LinearMethodBase,
  6. set_weight_attrs)
  7. from aphrodite.quantization.base_config import (QuantizationConfig)
  8. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  9. QKVParallelLinear,
  10. RowParallelLinear)
  11. HAS_QUANTS = False
  12. with suppress(ImportError):
  13. from aphrodite._quant_C import quant_ops as ops
  14. HAS_QUANTS = True
  15. class BitsandBytesConfig(QuantizationConfig):
  16. """Config class for BitsandBytes.
  17. Reference: https://arxiv.org/abs/2208.07339
  18. """
  19. def __init__(
  20. self,
  21. weight_bits: int,
  22. group_size: int,
  23. zero_point: bool,
  24. from_float: bool,
  25. quant_mode: str, # llm_int8, smoothquant, weight_only
  26. ) -> None:
  27. if not HAS_QUANTS:
  28. raise ImportError("Could not find the quantization kernels.")
  29. self.weight_bits = weight_bits
  30. self.group_size = group_size
  31. self.zero_point = zero_point
  32. self.from_float = from_float
  33. self.quant_mode = quant_mode
  34. if quant_mode == "weight_only" and self.weight_bits != 4:
  35. raise ValueError(
  36. "Currently, only 4-bit weight quantization is supported for "
  37. f"BNB weight_only, but got {self.weight_bits} bits.")
  38. if quant_mode in ["llm_int8", "smoothquant"] and self.weight_bits != 8:
  39. raise ValueError(
  40. "Currently, only 8-bit weight quantization is supported for "
  41. "BNB llm_int8 or smoothquant, "
  42. f"but got {self.weight_bits} bits.")
  43. self.pack_factor = 32 // self.weight_bits
  44. def __repr__(self) -> str:
  45. return (f"BitsandBytesConfig(weight_bits={self.weight_bits}, "
  46. f"group_size={self.group_size}, "
  47. f"zero_point={self.zero_point}, "
  48. f"from_float={self.from_float}, "
  49. f"quant_mode={self.quant_mode})")
  50. def get_name(self) -> str:
  51. return "bitsandbytes"
  52. def get_supported_act_dtypes(self) -> List[torch.dtype]:
  53. return [torch.half, torch.bfloat16]
  54. def get_min_capability(self) -> int:
  55. # The BitsandBytes kernel only supports Ampere or newer GPUs.
  56. return 75
  57. def merge_weight(self) -> bool:
  58. return True
  59. def rope_style(self) -> Optional[bool]:
  60. return None
  61. def quant_vocab(self) -> List[bool]:
  62. return [False, False]
  63. def support_fused_moe(self) -> bool:
  64. return False
  65. @staticmethod
  66. def get_config_filenames() -> List[str]:
  67. return [
  68. "quant_config.json",
  69. "quantize_config.json",
  70. ]
  71. @classmethod
  72. def from_config(cls, config: Dict[str, Any]) -> "BitsandBytesConfig":
  73. weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
  74. group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
  75. zero_point = cls.get_from_keys(config, ["zero_point"])
  76. try:
  77. from_float = cls.get_from_keys(config, ["from_float"])
  78. except Exception:
  79. from_float = False
  80. try:
  81. quant_mode = cls.get_from_keys(config, ["quant_mode"])
  82. except Exception:
  83. quant_mode = "weight_only"
  84. return cls(weight_bits, group_size, zero_point, from_float, quant_mode)
  85. def get_linear_method(self) -> "BNBLinearMethod":
  86. return BNBLinearMethod(self)
  87. def get_scaled_act_names(self) -> List[str]:
  88. return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
  89. class BNBLinearMethod(LinearMethodBase):
  90. """Linear method for BitsandBytes.
  91. Args:
  92. quant_config: The BitsandBytes quantization config.
  93. """
  94. def __init__(self, quant_config: BitsandBytesConfig):
  95. self.quant_config = quant_config
  96. def create_weights(self, layer: torch.nn.Module,
  97. input_size_per_partition: int,
  98. output_partition_sizes: List[int], input_size: int,
  99. output_size: int, params_dtype: torch.dtype,
  100. **extra_weight_attrs):
  101. if self.quant_config.quant_mode == "weight_only" and \
  102. input_size_per_partition % self.quant_config.group_size != 0:
  103. raise ValueError(
  104. "The input size is not aligned with the quantized "
  105. "weight shape. This can be caused by too large "
  106. "tensor parallel size.")
  107. output_size_per_partition = sum(output_partition_sizes)
  108. if self.quant_config.quant_mode == "weight_only" and \
  109. output_size_per_partition % self.quant_config.pack_factor != 0:
  110. raise ValueError(
  111. "The output size is not aligned with the quantized "
  112. "weight shape. This can be caused by too large "
  113. "tensor parallel size.")
  114. if self.quant_config.quant_mode == "weight_only" and \
  115. not self.quant_config.from_float:
  116. qweight = Parameter(
  117. torch.empty(
  118. input_size_per_partition,
  119. output_size_per_partition // self.quant_config.pack_factor,
  120. dtype=torch.int32,
  121. ),
  122. requires_grad=False,
  123. )
  124. set_weight_attrs(
  125. qweight, {
  126. "input_dim": 0,
  127. "output_dim": 1,
  128. "packed_dim": 1,
  129. "pack_factor": self.quant_config.pack_factor,
  130. })
  131. qzeros = Parameter(
  132. torch.empty(
  133. input_size_per_partition // self.quant_config.group_size,
  134. output_size_per_partition // self.quant_config.pack_factor,
  135. dtype=torch.int32,
  136. ),
  137. requires_grad=False,
  138. )
  139. set_weight_attrs(
  140. qzeros, {
  141. "input_dim": 0,
  142. "output_dim": 1,
  143. "packed_dim": 1,
  144. "pack_factor": self.quant_config.pack_factor,
  145. })
  146. scales = Parameter(
  147. torch.empty(
  148. input_size_per_partition // self.quant_config.group_size,
  149. output_size_per_partition,
  150. dtype=params_dtype,
  151. ),
  152. requires_grad=False,
  153. )
  154. set_weight_attrs(scales, {
  155. "input_dim": 0,
  156. "output_dim": 1,
  157. })
  158. layer.register_parameter("qweight", qweight)
  159. set_weight_attrs(qweight, extra_weight_attrs)
  160. layer.register_parameter("qzeros", qzeros)
  161. set_weight_attrs(qzeros, extra_weight_attrs)
  162. layer.register_parameter("scales", scales)
  163. set_weight_attrs(scales, extra_weight_attrs)
  164. else:
  165. weight = Parameter(torch.empty(output_size_per_partition,
  166. input_size_per_partition,
  167. dtype=params_dtype),
  168. requires_grad=False)
  169. set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
  170. layer.register_parameter("weight", weight)
  171. set_weight_attrs(weight, extra_weight_attrs)
  172. def apply_weights(self,
  173. layer: torch.nn.Module,
  174. x: torch.Tensor,
  175. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  176. if self.quant_config.quant_mode == "weight_only":
  177. qweight = layer.qweight.data
  178. scales_zeros = layer.scales_zeros.data
  179. pack_factor = self.quant_config.pack_factor
  180. out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
  181. reshaped_x = x.reshape(-1, x.shape[-1])
  182. out = ops.autoquant_s4_f16_gemm(reshaped_x, qweight, scales_zeros)
  183. if bias is not None:
  184. out = out + bias
  185. return out.reshape(out_shape)
  186. else:
  187. weight = layer.weights
  188. state = layer.state
  189. if weight.CB is not None:
  190. state.CB = weight.CB
  191. state.SCB = weight.SCB
  192. weight.CB = None
  193. weight.SCB = None
  194. import bitsandbytes as bnb
  195. out = bnb.matmul(x, weight, bias=bias, state=state)
  196. if not state.has_fp16_weights and \
  197. state.CB is not None and state.CxB is not None:
  198. # we converted 8-bit row major to turing/ampere format
  199. # in the first inference pass
  200. # we no longer need the row-major weight
  201. del state.CB
  202. weight.data = state.CxB
  203. return out
  204. def apply_moe_weights(self, w1: Dict[str,
  205. torch.Tensor], w2: Dict[str,
  206. torch.Tensor],
  207. x: torch.Tensor, gating_output: torch.Tensor,
  208. topk: int, renormalize: bool) -> torch.Tensor:
  209. raise NotImplementedError
  210. T = TypeVar("T", bound="torch.nn.Module")
  211. class QParams(NamedTuple):
  212. """A class to hold the quantization parameters."""
  213. scales: torch.Tensor
  214. zero_points: Optional[torch.Tensor]
  215. @torch.no_grad()
  216. def cal_qparams_per_group_minmax(w: torch.Tensor,
  217. n_bits: int = 4,
  218. group_size: int = 128):
  219. """Calculate quantization parameters for each group using min and max
  220. values."""
  221. outc, inc = w.shape
  222. assert inc >= group_size, \
  223. 'Input channels should be greater than or equal to group_size.'
  224. assert inc % group_size == 0, \
  225. 'Input channels should be divisible by group_size.'
  226. w_group_wise = w.reshape(outc, -1, group_size)
  227. w_min = w_group_wise.min(dim=-1, keepdim=True)[0]
  228. w_max = w_group_wise.max(dim=-1, keepdim=True)[0]
  229. q_max = 2**n_bits - 1
  230. q_min = 0
  231. scales = (w_max - w_min)
  232. scales = scales.clamp_(min=1e-5).div_(q_max)
  233. # zero_points = (-w_min / scales).round().clamp(q_min, q_max)
  234. zero_points = (-torch.round(w_min / scales)).clamp_(q_min, q_max)
  235. return QParams(scales=scales, zero_points=zero_points)
  236. def convert_s4(qw: torch.Tensor,
  237. qz: torch.Tensor,
  238. s: torch.Tensor,
  239. group_size: int = 128):
  240. assert qw.is_contiguous()
  241. assert qz.is_contiguous()
  242. assert s.is_contiguous()
  243. _qw = torch.zeros_like(qw)
  244. _sz = torch.zeros_like(s, dtype=torch.int32) # half2
  245. _ws = torch.zeros_like(s)
  246. ops.autoquant_convert_s4_k_m8(_qw, _sz, _ws, qw, s, qz,
  247. qw.size(-1) * 8, qw.size(0), group_size)
  248. return _qw, _sz
  249. def tp_m_s4(x: torch.Tensor, tp: int = 1):
  250. return x.view(x.size(0) // 32, tp, -1, 128).permute(0, 2, 3,
  251. 1).contiguous()
  252. def quant(weight: torch.Tensor,
  253. qparams: Optional[QParams] = None) -> torch.Tensor:
  254. """Perform fake quantization on the given weight tensor.
  255. Args:
  256. weight (torch.Tensor): The weight tensor with shape
  257. (out_features, in_features).
  258. qparams (Optional[QParams]): A namedtuple containing 'scales'
  259. and 'zero_points'.
  260. Returns:
  261. torch.Tensor: The fake quantized weight tensor.
  262. """
  263. if qparams is None:
  264. qparams = cal_qparams_per_group_minmax(weight)
  265. scales = qparams.scales
  266. zero_points = qparams.zero_points
  267. out_c, in_c = weight.shape
  268. # Reshape the weights if using per_group quantization
  269. # per tensor scales shape: [1]
  270. # per channel scales shape: [out_c, 1]
  271. # per group scales shape: [out_c, in_c//group_size, 1]
  272. if len(scales.shape) > 2:
  273. # scales shape: [out_c, in_c//group_size, 1]
  274. weight = weight.reshape(out_c, scales.shape[1], -1)
  275. if zero_points is None:
  276. real_qweight = (weight / scales).round()
  277. else:
  278. real_qweight = ((weight + (scales * zero_points)) / scales).round()
  279. if len(scales.shape) > 2:
  280. real_qweight = real_qweight.reshape(out_c, in_c)
  281. return real_qweight.to(torch.int32)
  282. # core quantization method (simulated quantization)
  283. def quantize_tensor(
  284. weight,
  285. n_bits=4,
  286. group_size=128,
  287. ):
  288. pack_num = 32 // n_bits
  289. pack_order = [0, 2, 4, 6, 1, 3, 5, 7]
  290. org_weight_shape = weight.shape
  291. out_features = org_weight_shape[0]
  292. in_features = org_weight_shape[1]
  293. qparams = cal_qparams_per_group_minmax(weight, n_bits)
  294. i32_w = quant(weight, qparams)
  295. i32_w = i32_w.t().contiguous()
  296. w_pack_oc = out_features // (32 // n_bits)
  297. w_inc = in_features
  298. pack_int_w = torch.zeros((w_inc, w_pack_oc),
  299. dtype=torch.int32,
  300. device=weight.device)
  301. for col in range(pack_int_w.shape[1]):
  302. for i in range(pack_num):
  303. pack_int_w_col = i32_w[:, col * pack_num + pack_order[i]]
  304. pack_int_w[:, col] |= pack_int_w_col << (i * n_bits)
  305. qweight = pack_int_w
  306. scales = qparams.scales.squeeze(-1).t().contiguous()
  307. if qparams.zero_points is not None:
  308. zeros = qparams.zero_points.to(torch.int32)
  309. zeros = zeros.squeeze(-1).t().contiguous()
  310. z_inc = in_features // group_size
  311. z_oc = out_features // (32 // n_bits)
  312. pack_int_zeros = torch.zeros((z_inc, z_oc),
  313. dtype=torch.int32,
  314. device=weight.device)
  315. for col in range(pack_int_zeros.shape[1]):
  316. for i in range(pack_num):
  317. qzero_col = zeros[:, col * pack_num + pack_order[i]]
  318. pack_int_zeros[:, col] |= qzero_col << (i * n_bits)
  319. qzeros = pack_int_zeros
  320. return qweight, scales, qzeros
  321. def replace_quant_params(model,
  322. quant_config,
  323. modules_to_not_convert="lm_head"):
  324. """
  325. modules_to_not_convert (`str`, *optional*, defaults to `lm_head`):
  326. Name of the module to not convert in `Linear8bitLt`.
  327. In practice we keep the `lm_head` in full precision
  328. for numerical stability reasons.
  329. """
  330. if not isinstance(modules_to_not_convert, list):
  331. modules_to_not_convert = [modules_to_not_convert]
  332. for name, module in model.named_children():
  333. if len(list(module.children())) > 0:
  334. replace_quant_params(module, quant_config, modules_to_not_convert)
  335. if isinstance(
  336. module,
  337. (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear)) \
  338. and name not in modules_to_not_convert:
  339. if quant_config.from_float:
  340. module.linear_weights.pop("weight")
  341. param = module._parameters["weight"]
  342. if quant_config.quant_mode in ("llm_int8", "smoothquant"):
  343. import bitsandbytes as bnb
  344. new_value = bnb.nn.Int8Params(param.data,
  345. requires_grad=False,
  346. has_fp16_weights=False)
  347. state = bnb.MatmulLtState()
  348. if quant_config.quant_mode == "smoothquant":
  349. state.threshold = 0.0
  350. else:
  351. state.threshold = 6.0
  352. state.has_fp16_weights = False
  353. state.memory_efficient_backward = False
  354. state.use_pool = True
  355. module._parameters["weight"] = new_value
  356. module.linear_weights["weight"] = new_value
  357. module.linear_weights["state"] = state
  358. set_weight_attrs(
  359. new_value, {
  360. "input_dim": 0,
  361. "output_dim": 1,
  362. "packed_dim": 1,
  363. "pack_factor": quant_config.pack_factor,
  364. })
  365. del param
  366. torch.cuda.empty_cache()
  367. elif quant_config.quant_mode == "weight_only":
  368. data_fp = param.cuda()
  369. _qweight, _scales, _qzeros = quantize_tensor(
  370. data_fp, n_bits=4, group_size=128)
  371. qweight, scales_zeros = convert_s4(_qweight, _qzeros,
  372. _scales)
  373. torch.cuda.synchronize()
  374. param_qweight = Parameter(qweight, requires_grad=False)
  375. param_scales_zeros = Parameter(scales_zeros,
  376. requires_grad=False)
  377. module.register_parameter("qweight", param_qweight)
  378. module.register_parameter("scales_zeros",
  379. param_scales_zeros)
  380. set_weight_attrs(
  381. param_qweight, {
  382. "input_dim": 0,
  383. "output_dim": 1,
  384. "packed_dim": 1,
  385. "pack_factor": quant_config.pack_factor,
  386. })
  387. set_weight_attrs(param_scales_zeros, {
  388. "input_dim": 0,
  389. "output_dim": 1,
  390. })
  391. module.linear_weights["qweight"] = param_qweight
  392. module.linear_weights["scales_zeros"] = param_scales_zeros
  393. del _qzeros
  394. del _scales
  395. del param
  396. delattr(module, "weight")
  397. torch.cuda.empty_cache()
  398. else: # load packed int4 weight
  399. module.linear_weights.pop("qweight")
  400. module.linear_weights.pop("qzeros")
  401. module.linear_weights.pop("scales")
  402. _qweight = module._parameters["qweight"]
  403. _qzeros = module._parameters["qzeros"]
  404. _scales = module._parameters["scales"]
  405. qweight, scales_zeros = convert_s4(_qweight.data, _qzeros.data,
  406. _scales.data)
  407. param_qweight = Parameter(qweight, requires_grad=False)
  408. param_scales_zeros = Parameter(scales_zeros,
  409. requires_grad=False)
  410. del _qweight
  411. del _qzeros
  412. del _scales
  413. delattr(module, "qweight")
  414. delattr(module, "qzeros")
  415. delattr(module, "scales")
  416. module.register_parameter("qweight", param_qweight)
  417. module.register_parameter("scales_zeros", param_scales_zeros)
  418. set_weight_attrs(
  419. param_qweight, {
  420. "input_dim": 0,
  421. "output_dim": 1,
  422. "packed_dim": 1,
  423. "pack_factor": quant_config.pack_factor,
  424. })
  425. set_weight_attrs(param_scales_zeros, {
  426. "input_dim": 0,
  427. "output_dim": 1,
  428. })
  429. module.linear_weights["qweight"] = param_qweight
  430. module.linear_weights["scales_zeros"] = param_scales_zeros
  431. torch.cuda.synchronize()
  432. torch.cuda.empty_cache()