bitsandbytes.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468
  1. import torch
  2. from torch.nn.parameter import Parameter
  3. from typing import List, Dict, Any, Optional, TypeVar, NamedTuple
  4. from aphrodite._C import ops
  5. from aphrodite.modeling.layers.linear import (LinearMethodBase,
  6. set_weight_attrs)
  7. from aphrodite.modeling.layers.quantization.base_config import (
  8. QuantizationConfig)
  9. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  10. QKVParallelLinear,
  11. RowParallelLinear)
  12. class BitsandBytesConfig(QuantizationConfig):
  13. """Config class for BitsandBytes.
  14. Reference: https://arxiv.org/abs/2208.07339
  15. """
  16. def __init__(
  17. self,
  18. weight_bits: int,
  19. group_size: int,
  20. zero_point: bool,
  21. from_float: bool,
  22. quant_mode: str, # llm_int8, smoothquant, weight_only
  23. ) -> None:
  24. self.weight_bits = weight_bits
  25. self.group_size = group_size
  26. self.zero_point = zero_point
  27. self.from_float = from_float
  28. self.quant_mode = quant_mode
  29. if quant_mode == "weight_only" and self.weight_bits != 4:
  30. raise ValueError(
  31. "Currently, only 4-bit weight quantization is supported for "
  32. f"BNB weight_only, but got {self.weight_bits} bits.")
  33. if quant_mode in ["llm_int8", "smoothquant"] and self.weight_bits != 8:
  34. raise ValueError(
  35. "Currently, only 8-bit weight quantization is supported for "
  36. "BNB llm_int8 or smoothquant, "
  37. f"but got {self.weight_bits} bits.")
  38. self.pack_factor = 32 // self.weight_bits
  39. def __repr__(self) -> str:
  40. return (f"BitsandBytesConfig(weight_bits={self.weight_bits}, "
  41. f"group_size={self.group_size}, "
  42. f"zero_point={self.zero_point}, "
  43. f"from_float={self.from_float}, "
  44. f"quant_mode={self.quant_mode})")
  45. def get_name(self) -> str:
  46. return "bitsandbytes"
  47. def get_supported_act_dtypes(self) -> List[torch.dtype]:
  48. return [torch.half, torch.bfloat16]
  49. def get_min_capability(self) -> int:
  50. # The BitsandBytes kernel only supports Ampere or newer GPUs.
  51. return 75
  52. def merge_weight(self) -> bool:
  53. return True
  54. def rope_style(self) -> Optional[bool]:
  55. return None
  56. def quant_vocab(self) -> List[bool]:
  57. return [False, False]
  58. def support_fused_moe(self) -> bool:
  59. return False
  60. @staticmethod
  61. def get_config_filenames() -> List[str]:
  62. return [
  63. "quant_config.json",
  64. "quantize_config.json",
  65. ]
  66. @classmethod
  67. def from_config(cls, config: Dict[str, Any]) -> "BitsandBytesConfig":
  68. weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
  69. group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
  70. zero_point = cls.get_from_keys(config, ["zero_point"])
  71. try:
  72. from_float = cls.get_from_keys(config, ["from_float"])
  73. except Exception:
  74. from_float = False
  75. try:
  76. quant_mode = cls.get_from_keys(config, ["quant_mode"])
  77. except Exception:
  78. quant_mode = "weight_only"
  79. return cls(weight_bits, group_size, zero_point, from_float, quant_mode)
  80. def get_linear_method(self) -> "BNBLinearMethod":
  81. return BNBLinearMethod(self)
  82. def get_scaled_act_names(self) -> List[str]:
  83. return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
  84. class BNBLinearMethod(LinearMethodBase):
  85. """Linear method for BitsandBytes.
  86. Args:
  87. quant_config: The BitsandBytes quantization config.
  88. """
  89. def __init__(self, quant_config: BitsandBytesConfig):
  90. self.quant_config = quant_config
  91. def create_weights(self, input_size_per_partition: int,
  92. output_partition_sizes: List[int], input_size: int,
  93. output_size: int,
  94. params_dtype: torch.dtype) -> Dict[str, Any]:
  95. if self.quant_config.quant_mode == "weight_only" and \
  96. input_size_per_partition % self.quant_config.group_size != 0:
  97. raise ValueError(
  98. "The input size is not aligned with the quantized "
  99. "weight shape. This can be caused by too large "
  100. "tensor parallel size.")
  101. output_size_per_partition = sum(output_partition_sizes)
  102. if self.quant_config.quant_mode == "weight_only" and \
  103. output_size_per_partition % self.quant_config.pack_factor != 0:
  104. raise ValueError(
  105. "The output size is not aligned with the quantized "
  106. "weight shape. This can be caused by too large "
  107. "tensor parallel size.")
  108. if self.quant_config.quant_mode == "weight_only" and \
  109. not self.quant_config.from_float:
  110. qweight = Parameter(
  111. torch.empty(
  112. input_size_per_partition,
  113. output_size_per_partition // self.quant_config.pack_factor,
  114. dtype=torch.int32,
  115. ),
  116. requires_grad=False,
  117. )
  118. set_weight_attrs(
  119. qweight, {
  120. "input_dim": 0,
  121. "output_dim": 1,
  122. "packed_dim": 1,
  123. "pack_factor": self.quant_config.pack_factor,
  124. })
  125. qzeros = Parameter(
  126. torch.empty(
  127. input_size_per_partition // self.quant_config.group_size,
  128. output_size_per_partition // self.quant_config.pack_factor,
  129. dtype=torch.int32,
  130. ),
  131. requires_grad=False,
  132. )
  133. set_weight_attrs(
  134. qzeros, {
  135. "input_dim": 0,
  136. "output_dim": 1,
  137. "packed_dim": 1,
  138. "pack_factor": self.quant_config.pack_factor,
  139. })
  140. scales = Parameter(
  141. torch.empty(
  142. input_size_per_partition // self.quant_config.group_size,
  143. output_size_per_partition,
  144. dtype=params_dtype,
  145. ),
  146. requires_grad=False,
  147. )
  148. set_weight_attrs(scales, {
  149. "input_dim": 0,
  150. "output_dim": 1,
  151. })
  152. return {
  153. "qweight": qweight,
  154. "qzeros": qzeros,
  155. "scales": scales,
  156. }
  157. else:
  158. weight = Parameter(torch.empty(output_size_per_partition,
  159. input_size_per_partition,
  160. dtype=params_dtype),
  161. requires_grad=False)
  162. set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
  163. return {"weight": weight}
  164. def apply_weights(self,
  165. weights: Dict[str, torch.Tensor],
  166. x: torch.Tensor,
  167. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  168. if self.quant_config.quant_mode == "weight_only":
  169. qweight = weights["qweight"].data
  170. scales_zeros = weights["scales_zeros"].data
  171. pack_factor = self.quant_config.pack_factor
  172. out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
  173. reshaped_x = x.reshape(-1, x.shape[-1])
  174. out = ops.autoquant_s4_f16_gemm(reshaped_x, qweight, scales_zeros)
  175. if bias is not None:
  176. out = out + bias
  177. return out.reshape(out_shape)
  178. else:
  179. weight = weights["weight"]
  180. state = weights["state"]
  181. if weight.CB is not None:
  182. state.CB = weight.CB
  183. state.SCB = weight.SCB
  184. weight.CB = None
  185. weight.SCB = None
  186. import bitsandbytes as bnb
  187. out = bnb.matmul(x, weight, bias=bias, state=state)
  188. if not state.has_fp16_weights and \
  189. state.CB is not None and state.CxB is not None:
  190. # we converted 8-bit row major to turing/ampere format
  191. # in the first inference pass
  192. # we no longer need the row-major weight
  193. del state.CB
  194. weight.data = state.CxB
  195. return out
  196. def apply_moe_weights(self, w1: Dict[str,
  197. torch.Tensor], w2: Dict[str,
  198. torch.Tensor],
  199. x: torch.Tensor, gating_output: torch.Tensor,
  200. topk: int, renormalize: bool) -> torch.Tensor:
  201. raise NotImplementedError
  202. T = TypeVar("T", bound="torch.nn.Module")
  203. class QParams(NamedTuple):
  204. """A class to hold the quantization parameters."""
  205. scales: torch.Tensor
  206. zero_points: Optional[torch.Tensor]
  207. @torch.no_grad()
  208. def cal_qparams_per_group_minmax(w: torch.Tensor,
  209. n_bits: int = 4,
  210. group_size: int = 128):
  211. """Calculate quantization parameters for each group using min and max
  212. values."""
  213. outc, inc = w.shape
  214. assert inc >= group_size, \
  215. 'Input channels should be greater than or equal to group_size.'
  216. assert inc % group_size == 0, \
  217. 'Input channels should be divisible by group_size.'
  218. w_group_wise = w.reshape(outc, -1, group_size)
  219. w_min = w_group_wise.min(dim=-1, keepdim=True)[0]
  220. w_max = w_group_wise.max(dim=-1, keepdim=True)[0]
  221. q_max = 2**n_bits - 1
  222. q_min = 0
  223. scales = (w_max - w_min)
  224. scales = scales.clamp_(min=1e-5).div_(q_max)
  225. # zero_points = (-w_min / scales).round().clamp(q_min, q_max)
  226. zero_points = (-torch.round(w_min / scales)).clamp_(q_min, q_max)
  227. return QParams(scales=scales, zero_points=zero_points)
  228. def convert_s4(qw: torch.Tensor,
  229. qz: torch.Tensor,
  230. s: torch.Tensor,
  231. group_size: int = 128):
  232. assert qw.is_contiguous()
  233. assert qz.is_contiguous()
  234. assert s.is_contiguous()
  235. _qw = torch.zeros_like(qw)
  236. _sz = torch.zeros_like(s, dtype=torch.int32) # half2
  237. _ws = torch.zeros_like(s)
  238. ops.autoquant_convert_s4_k_m8(_qw, _sz, _ws, qw, s, qz,
  239. qw.size(-1) * 8, qw.size(0), group_size)
  240. return _qw, _sz
  241. def tp_m_s4(x: torch.Tensor, tp: int = 1):
  242. return x.view(x.size(0) // 32, tp, -1, 128).permute(0, 2, 3,
  243. 1).contiguous()
  244. def quant(weight: torch.Tensor,
  245. qparams: Optional[QParams] = None) -> torch.Tensor:
  246. """Perform fake quantization on the given weight tensor.
  247. Args:
  248. weight (torch.Tensor): The weight tensor with shape
  249. (out_features, in_features).
  250. qparams (Optional[QParams]): A namedtuple containing 'scales'
  251. and 'zero_points'.
  252. Returns:
  253. torch.Tensor: The fake quantized weight tensor.
  254. """
  255. if qparams is None:
  256. qparams = cal_qparams_per_group_minmax(weight)
  257. scales = qparams.scales
  258. zero_points = qparams.zero_points
  259. out_c, in_c = weight.shape
  260. # Reshape the weights if using per_group quantization
  261. # per tensor scales shape: [1]
  262. # per channel scales shape: [out_c, 1]
  263. # per group scales shape: [out_c, in_c//group_size, 1]
  264. if len(scales.shape) > 2:
  265. # scales shape: [out_c, in_c//group_size, 1]
  266. weight = weight.reshape(out_c, scales.shape[1], -1)
  267. if zero_points is None:
  268. real_qweight = (weight / scales).round()
  269. else:
  270. real_qweight = ((weight + (scales * zero_points)) / scales).round()
  271. if len(scales.shape) > 2:
  272. real_qweight = real_qweight.reshape(out_c, in_c)
  273. return real_qweight.to(torch.int32)
  274. # core quantization method (simulated quantization)
  275. def quantize_tensor(
  276. weight,
  277. n_bits=4,
  278. group_size=128,
  279. ):
  280. pack_num = 32 // n_bits
  281. pack_order = [0, 2, 4, 6, 1, 3, 5, 7]
  282. org_weight_shape = weight.shape
  283. out_features = org_weight_shape[0]
  284. in_features = org_weight_shape[1]
  285. qparams = cal_qparams_per_group_minmax(weight, n_bits)
  286. i32_w = quant(weight, qparams)
  287. i32_w = i32_w.t().contiguous()
  288. w_pack_oc = out_features // (32 // n_bits)
  289. w_inc = in_features
  290. pack_int_w = torch.zeros((w_inc, w_pack_oc),
  291. dtype=torch.int32,
  292. device=weight.device)
  293. for col in range(pack_int_w.shape[1]):
  294. for i in range(pack_num):
  295. pack_int_w_col = i32_w[:, col * pack_num + pack_order[i]]
  296. pack_int_w[:, col] |= pack_int_w_col << (i * n_bits)
  297. qweight = pack_int_w
  298. scales = qparams.scales.squeeze(-1).t().contiguous()
  299. if qparams.zero_points is not None:
  300. zeros = qparams.zero_points.to(torch.int32)
  301. zeros = zeros.squeeze(-1).t().contiguous()
  302. z_inc = in_features // group_size
  303. z_oc = out_features // (32 // n_bits)
  304. pack_int_zeros = torch.zeros((z_inc, z_oc),
  305. dtype=torch.int32,
  306. device=weight.device)
  307. for col in range(pack_int_zeros.shape[1]):
  308. for i in range(pack_num):
  309. qzero_col = zeros[:, col * pack_num + pack_order[i]]
  310. pack_int_zeros[:, col] |= qzero_col << (i * n_bits)
  311. qzeros = pack_int_zeros
  312. return qweight, scales, qzeros
  313. def replace_quant_params(model,
  314. quant_config,
  315. modules_to_not_convert="lm_head"):
  316. """
  317. modules_to_not_convert (`str`, *optional*, defaults to `lm_head`):
  318. Name of the module to not convert in `Linear8bitLt`.
  319. In practice we keep the `lm_head` in full precision
  320. for numerical stability reasons.
  321. """
  322. if not isinstance(modules_to_not_convert, list):
  323. modules_to_not_convert = [modules_to_not_convert]
  324. for name, module in model.named_children():
  325. if len(list(module.children())) > 0:
  326. replace_quant_params(module, quant_config, modules_to_not_convert)
  327. if isinstance(
  328. module,
  329. (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear)) \
  330. and name not in modules_to_not_convert:
  331. if quant_config.from_float:
  332. module.linear_weights.pop("weight")
  333. param = module._parameters["weight"]
  334. if quant_config.quant_mode in ("llm_int8", "smoothquant"):
  335. import bitsandbytes as bnb
  336. new_value = bnb.nn.Int8Params(param.data,
  337. requires_grad=False,
  338. has_fp16_weights=False)
  339. state = bnb.MatmulLtState()
  340. if quant_config.quant_mode == "smoothquant":
  341. state.threshold = 0.0
  342. else:
  343. state.threshold = 6.0
  344. state.has_fp16_weights = False
  345. state.memory_efficient_backward = False
  346. state.use_pool = True
  347. module._parameters["weight"] = new_value
  348. module.linear_weights["weight"] = new_value
  349. module.linear_weights["state"] = state
  350. set_weight_attrs(
  351. new_value, {
  352. "input_dim": 0,
  353. "output_dim": 1,
  354. "packed_dim": 1,
  355. "pack_factor": quant_config.pack_factor,
  356. })
  357. del param
  358. torch.cuda.empty_cache()
  359. elif quant_config.quant_mode == "weight_only":
  360. data_fp = param.cuda()
  361. _qweight, _scales, _qzeros = quantize_tensor(
  362. data_fp, n_bits=4, group_size=128)
  363. qweight, scales_zeros = convert_s4(_qweight, _qzeros,
  364. _scales)
  365. torch.cuda.synchronize()
  366. param_qweight = Parameter(qweight, requires_grad=False)
  367. param_scales_zeros = Parameter(scales_zeros,
  368. requires_grad=False)
  369. module.register_parameter("qweight", param_qweight)
  370. module.register_parameter("scales_zeros",
  371. param_scales_zeros)
  372. set_weight_attrs(
  373. param_qweight, {
  374. "input_dim": 0,
  375. "output_dim": 1,
  376. "packed_dim": 1,
  377. "pack_factor": quant_config.pack_factor,
  378. })
  379. set_weight_attrs(param_scales_zeros, {
  380. "input_dim": 0,
  381. "output_dim": 1,
  382. })
  383. module.linear_weights["qweight"] = param_qweight
  384. module.linear_weights["scales_zeros"] = param_scales_zeros
  385. del _qzeros
  386. del _scales
  387. del param
  388. delattr(module, "weight")
  389. torch.cuda.empty_cache()
  390. else: # load packed int4 weight
  391. module.linear_weights.pop("qweight")
  392. module.linear_weights.pop("qzeros")
  393. module.linear_weights.pop("scales")
  394. _qweight = module._parameters["qweight"]
  395. _qzeros = module._parameters["qzeros"]
  396. _scales = module._parameters["scales"]
  397. qweight, scales_zeros = convert_s4(_qweight.data, _qzeros.data,
  398. _scales.data)
  399. param_qweight = Parameter(qweight, requires_grad=False)
  400. param_scales_zeros = Parameter(scales_zeros,
  401. requires_grad=False)
  402. del _qweight
  403. del _qzeros
  404. del _scales
  405. delattr(module, "qweight")
  406. delattr(module, "qzeros")
  407. delattr(module, "scales")
  408. module.register_parameter("qweight", param_qweight)
  409. module.register_parameter("scales_zeros", param_scales_zeros)
  410. set_weight_attrs(
  411. param_qweight, {
  412. "input_dim": 0,
  413. "output_dim": 1,
  414. "packed_dim": 1,
  415. "pack_factor": quant_config.pack_factor,
  416. })
  417. set_weight_attrs(param_scales_zeros, {
  418. "input_dim": 0,
  419. "output_dim": 1,
  420. })
  421. module.linear_weights["qweight"] = param_qweight
  422. module.linear_weights["scales_zeros"] = param_scales_zeros
  423. torch.cuda.synchronize()
  424. torch.cuda.empty_cache()