1
0

bitsandbytes.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. import torch
  2. from torch.nn.parameter import Parameter
  3. from typing import List, Dict, Any, Optional, TypeVar, NamedTuple
  4. from aphrodite._C import ops
  5. from aphrodite.modeling.layers.linear import (LinearMethodBase,
  6. set_weight_attrs)
  7. from aphrodite.modeling.layers.quantization.base_config import (
  8. QuantizationConfig)
  9. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  10. QKVParallelLinear,
  11. RowParallelLinear)
  12. class BitsandBytesConfig(QuantizationConfig):
  13. """Config class for BitsandBytes.
  14. Reference: https://arxiv.org/abs/2208.07339
  15. """
  16. def __init__(
  17. self,
  18. weight_bits: int,
  19. group_size: int,
  20. zero_point: bool,
  21. from_float: bool,
  22. quant_mode: str, # llm_int8, smoothquant, weight_only
  23. ) -> None:
  24. self.weight_bits = weight_bits
  25. self.group_size = group_size
  26. self.zero_point = zero_point
  27. self.from_float = from_float
  28. self.quant_mode = quant_mode
  29. if quant_mode == "weight_only" and self.weight_bits != 4:
  30. raise ValueError(
  31. "Currently, only 4-bit weight quantization is supported for "
  32. f"BNB weight_only, but got {self.weight_bits} bits.")
  33. if quant_mode in ["llm_int8", "smoothquant"] and self.weight_bits != 8:
  34. raise ValueError(
  35. "Currently, only 8-bit weight quantization is supported for "
  36. "BNB llm_int8 or smoothquant, "
  37. f"but got {self.weight_bits} bits.")
  38. self.pack_factor = 32 // self.weight_bits
  39. def __repr__(self) -> str:
  40. return (f"BitsandBytesConfig(weight_bits={self.weight_bits}, "
  41. f"group_size={self.group_size}, "
  42. f"zero_point={self.zero_point}, "
  43. f"from_float={self.from_float}, "
  44. f"quant_mode={self.quant_mode})")
  45. def get_name(self) -> str:
  46. return "bitsandbytes"
  47. def get_supported_act_dtypes(self) -> List[torch.dtype]:
  48. return [torch.half, torch.bfloat16]
  49. def get_min_capability(self) -> int:
  50. # The BitsandBytes kernel only supports Ampere or newer GPUs.
  51. return 75
  52. def merge_weight(self) -> bool:
  53. return True
  54. def rope_style(self) -> Optional[bool]:
  55. return None
  56. @staticmethod
  57. def get_config_filenames() -> List[str]:
  58. return [
  59. "quant_config.json",
  60. "quantize_config.json",
  61. ]
  62. @classmethod
  63. def from_config(cls, config: Dict[str, Any]) -> "BitsandBytesConfig":
  64. weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
  65. group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
  66. zero_point = cls.get_from_keys(config, ["zero_point"])
  67. try:
  68. from_float = cls.get_from_keys(config, ["from_float"])
  69. except Exception:
  70. from_float = False
  71. try:
  72. quant_mode = cls.get_from_keys(config, ["quant_mode"])
  73. except Exception:
  74. quant_mode = "weight_only"
  75. return cls(weight_bits, group_size, zero_point, from_float, quant_mode)
  76. def get_linear_method(self) -> "BNBLinearMethod":
  77. return BNBLinearMethod(self)
  78. def get_scaled_act_names(self) -> List[str]:
  79. return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
  80. class BNBLinearMethod(LinearMethodBase):
  81. """Linear method for BitsandBytes.
  82. Args:
  83. quant_config: The BitsandBytes quantization config.
  84. """
  85. def __init__(self, quant_config: BitsandBytesConfig):
  86. self.quant_config = quant_config
  87. def create_weights(self, input_size_per_partition: int,
  88. output_partition_sizes: List[int], input_size: int,
  89. output_size: int,
  90. params_dtype: torch.dtype) -> Dict[str, Any]:
  91. if self.quant_config.quant_mode == "weight_only" and \
  92. input_size_per_partition % self.quant_config.group_size != 0:
  93. raise ValueError(
  94. "The input size is not aligned with the quantized "
  95. "weight shape. This can be caused by too large "
  96. "tensor parallel size.")
  97. output_size_per_partition = sum(output_partition_sizes)
  98. if self.quant_config.quant_mode == "weight_only" and \
  99. output_size_per_partition % self.quant_config.pack_factor != 0:
  100. raise ValueError(
  101. "The output size is not aligned with the quantized "
  102. "weight shape. This can be caused by too large "
  103. "tensor parallel size.")
  104. if self.quant_config.quant_mode == "weight_only" and \
  105. not self.quant_config.from_float:
  106. qweight = Parameter(
  107. torch.empty(
  108. input_size_per_partition,
  109. output_size_per_partition // self.quant_config.pack_factor,
  110. dtype=torch.int32,
  111. ),
  112. requires_grad=False,
  113. )
  114. set_weight_attrs(
  115. qweight, {
  116. "input_dim": 0,
  117. "output_dim": 1,
  118. "packed_dim": 1,
  119. "pack_factor": self.quant_config.pack_factor,
  120. })
  121. qzeros = Parameter(
  122. torch.empty(
  123. input_size_per_partition // self.quant_config.group_size,
  124. output_size_per_partition // self.quant_config.pack_factor,
  125. dtype=torch.int32,
  126. ),
  127. requires_grad=False,
  128. )
  129. set_weight_attrs(
  130. qzeros, {
  131. "input_dim": 0,
  132. "output_dim": 1,
  133. "packed_dim": 1,
  134. "pack_factor": self.quant_config.pack_factor,
  135. })
  136. scales = Parameter(
  137. torch.empty(
  138. input_size_per_partition // self.quant_config.group_size,
  139. output_size_per_partition,
  140. dtype=params_dtype,
  141. ),
  142. requires_grad=False,
  143. )
  144. set_weight_attrs(scales, {
  145. "input_dim": 0,
  146. "output_dim": 1,
  147. })
  148. return {
  149. "qweight": qweight,
  150. "qzeros": qzeros,
  151. "scales": scales,
  152. }
  153. else:
  154. weight = Parameter(torch.empty(output_size_per_partition,
  155. input_size_per_partition,
  156. dtype=params_dtype),
  157. requires_grad=False)
  158. set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
  159. return {"weight": weight}
  160. def apply_weights(self,
  161. weights: Dict[str, torch.Tensor],
  162. x: torch.Tensor,
  163. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  164. if self.quant_config.quant_mode == "weight_only":
  165. qweight = weights["qweight"].data
  166. scales_zeros = weights["scales_zeros"].data
  167. pack_factor = self.quant_config.pack_factor
  168. out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
  169. reshaped_x = x.reshape(-1, x.shape[-1])
  170. out = ops.autoquant_s4_f16_gemm(reshaped_x, qweight, scales_zeros)
  171. if bias is not None:
  172. out = out + bias
  173. return out.reshape(out_shape)
  174. else:
  175. weight = weights["weight"]
  176. state = weights["state"]
  177. if weight.CB is not None:
  178. state.CB = weight.CB
  179. state.SCB = weight.SCB
  180. weight.CB = None
  181. weight.SCB = None
  182. import bitsandbytes as bnb
  183. out = bnb.matmul(x, weight, bias=bias, state=state)
  184. if not state.has_fp16_weights and \
  185. state.CB is not None and state.CxB is not None:
  186. # we converted 8-bit row major to turing/ampere format
  187. # in the first inference pass
  188. # we no longer need the row-major weight
  189. del state.CB
  190. weight.data = state.CxB
  191. return out
  192. T = TypeVar("T", bound="torch.nn.Module")
  193. class QParams(NamedTuple):
  194. """A class to hold the quantization parameters."""
  195. scales: torch.Tensor
  196. zero_points: Optional[torch.Tensor]
  197. @torch.no_grad()
  198. def cal_qparams_per_group_minmax(w: torch.Tensor,
  199. n_bits: int = 4,
  200. group_size: int = 128):
  201. """Calculate quantization parameters for each group using min and max
  202. values."""
  203. outc, inc = w.shape
  204. assert inc >= group_size, \
  205. 'Input channels should be greater than or equal to group_size.'
  206. assert inc % group_size == 0, \
  207. 'Input channels should be divisible by group_size.'
  208. w_group_wise = w.reshape(outc, -1, group_size)
  209. w_min = w_group_wise.min(dim=-1, keepdim=True)[0]
  210. w_max = w_group_wise.max(dim=-1, keepdim=True)[0]
  211. q_max = 2**n_bits - 1
  212. q_min = 0
  213. scales = (w_max - w_min)
  214. scales = scales.clamp_(min=1e-5).div_(q_max)
  215. # zero_points = (-w_min / scales).round().clamp(q_min, q_max)
  216. zero_points = (-torch.round(w_min / scales)).clamp_(q_min, q_max)
  217. return QParams(scales=scales, zero_points=zero_points)
  218. def convert_s4(qw: torch.Tensor,
  219. qz: torch.Tensor,
  220. s: torch.Tensor,
  221. group_size: int = 128):
  222. assert qw.is_contiguous()
  223. assert qz.is_contiguous()
  224. assert s.is_contiguous()
  225. _qw = torch.zeros_like(qw)
  226. _sz = torch.zeros_like(s, dtype=torch.int32) # half2
  227. _ws = torch.zeros_like(s)
  228. ops.autoquant_convert_s4_k_m8(_qw, _sz, _ws, qw, s, qz,
  229. qw.size(-1) * 8, qw.size(0), group_size)
  230. return _qw, _sz
  231. def tp_m_s4(x: torch.Tensor, tp: int = 1):
  232. return x.view(x.size(0) // 32, tp, -1, 128).permute(0, 2, 3,
  233. 1).contiguous()
  234. def quant(weight: torch.Tensor,
  235. qparams: Optional[QParams] = None) -> torch.Tensor:
  236. """Perform fake quantization on the given weight tensor.
  237. Args:
  238. weight (torch.Tensor): The weight tensor with shape
  239. (out_features, in_features).
  240. qparams (Optional[QParams]): A namedtuple containing 'scales'
  241. and 'zero_points'.
  242. Returns:
  243. torch.Tensor: The fake quantized weight tensor.
  244. """
  245. if qparams is None:
  246. qparams = cal_qparams_per_group_minmax(weight)
  247. scales = qparams.scales
  248. zero_points = qparams.zero_points
  249. out_c, in_c = weight.shape
  250. # Reshape the weights if using per_group quantization
  251. # per tensor scales shape: [1]
  252. # per channel scales shape: [out_c, 1]
  253. # per group scales shape: [out_c, in_c//group_size, 1]
  254. if len(scales.shape) > 2:
  255. # scales shape: [out_c, in_c//group_size, 1]
  256. weight = weight.reshape(out_c, scales.shape[1], -1)
  257. if zero_points is None:
  258. real_qweight = (weight / scales).round()
  259. else:
  260. real_qweight = ((weight + (scales * zero_points)) / scales).round()
  261. if len(scales.shape) > 2:
  262. real_qweight = real_qweight.reshape(out_c, in_c)
  263. return real_qweight.to(torch.int32)
  264. # core quantization method (simulated quantization)
  265. def quantize_tensor(
  266. weight,
  267. n_bits=4,
  268. group_size=128,
  269. ):
  270. pack_num = 32 // n_bits
  271. pack_order = [0, 2, 4, 6, 1, 3, 5, 7]
  272. org_weight_shape = weight.shape
  273. out_features = org_weight_shape[0]
  274. in_features = org_weight_shape[1]
  275. qparams = cal_qparams_per_group_minmax(weight, n_bits)
  276. i32_w = quant(weight, qparams)
  277. i32_w = i32_w.t().contiguous()
  278. w_pack_oc = out_features // (32 // n_bits)
  279. w_inc = in_features
  280. pack_int_w = torch.zeros((w_inc, w_pack_oc),
  281. dtype=torch.int32,
  282. device=weight.device)
  283. for col in range(pack_int_w.shape[1]):
  284. for i in range(pack_num):
  285. pack_int_w_col = i32_w[:, col * pack_num + pack_order[i]]
  286. pack_int_w[:, col] |= pack_int_w_col << (i * n_bits)
  287. qweight = pack_int_w
  288. scales = qparams.scales.squeeze(-1).t().contiguous()
  289. if qparams.zero_points is not None:
  290. zeros = qparams.zero_points.to(torch.int32)
  291. zeros = zeros.squeeze(-1).t().contiguous()
  292. z_inc = in_features // group_size
  293. z_oc = out_features // (32 // n_bits)
  294. pack_int_zeros = torch.zeros((z_inc, z_oc),
  295. dtype=torch.int32,
  296. device=weight.device)
  297. for col in range(pack_int_zeros.shape[1]):
  298. for i in range(pack_num):
  299. qzero_col = zeros[:, col * pack_num + pack_order[i]]
  300. pack_int_zeros[:, col] |= qzero_col << (i * n_bits)
  301. qzeros = pack_int_zeros
  302. return qweight, scales, qzeros
  303. def replace_quant_params(model,
  304. quant_config,
  305. modules_to_not_convert="lm_head"):
  306. """
  307. modules_to_not_convert (`str`, *optional*, defaults to `lm_head`):
  308. Name of the module to not convert in `Linear8bitLt`.
  309. In practice we keep the `lm_head` in full precision
  310. for numerical stability reasons.
  311. """
  312. if not isinstance(modules_to_not_convert, list):
  313. modules_to_not_convert = [modules_to_not_convert]
  314. for name, module in model.named_children():
  315. if len(list(module.children())) > 0:
  316. replace_quant_params(module, quant_config, modules_to_not_convert)
  317. if isinstance(
  318. module,
  319. (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear)) \
  320. and name not in modules_to_not_convert:
  321. if quant_config.from_float:
  322. module.linear_weights.pop("weight")
  323. param = module._parameters["weight"]
  324. if quant_config.quant_mode in ("llm_int8", "smoothquant"):
  325. import bitsandbytes as bnb
  326. new_value = bnb.nn.Int8Params(param.data,
  327. requires_grad=False,
  328. has_fp16_weights=False)
  329. state = bnb.MatmulLtState()
  330. if quant_config.quant_mode == "smoothquant":
  331. state.threshold = 0.0
  332. else:
  333. state.threshold = 6.0
  334. state.has_fp16_weights = False
  335. state.memory_efficient_backward = False
  336. state.use_pool = True
  337. module._parameters["weight"] = new_value
  338. module.linear_weights["weight"] = new_value
  339. module.linear_weights["state"] = state
  340. set_weight_attrs(
  341. new_value, {
  342. "input_dim": 0,
  343. "output_dim": 1,
  344. "packed_dim": 1,
  345. "pack_factor": quant_config.pack_factor,
  346. })
  347. del param
  348. torch.cuda.empty_cache()
  349. elif quant_config.quant_mode == "weight_only":
  350. data_fp = param.cuda()
  351. _qweight, _scales, _qzeros = quantize_tensor(
  352. data_fp, n_bits=4, group_size=128)
  353. qweight, scales_zeros = convert_s4(_qweight, _qzeros,
  354. _scales)
  355. torch.cuda.synchronize()
  356. param_qweight = Parameter(qweight, requires_grad=False)
  357. param_scales_zeros = Parameter(scales_zeros,
  358. requires_grad=False)
  359. module.register_parameter("qweight", param_qweight)
  360. module.register_parameter("scales_zeros",
  361. param_scales_zeros)
  362. set_weight_attrs(
  363. param_qweight, {
  364. "input_dim": 0,
  365. "output_dim": 1,
  366. "packed_dim": 1,
  367. "pack_factor": quant_config.pack_factor,
  368. })
  369. set_weight_attrs(param_scales_zeros, {
  370. "input_dim": 0,
  371. "output_dim": 1,
  372. })
  373. module.linear_weights["qweight"] = param_qweight
  374. module.linear_weights["scales_zeros"] = param_scales_zeros
  375. del _qzeros
  376. del _scales
  377. del param
  378. delattr(module, "weight")
  379. torch.cuda.empty_cache()
  380. else: # load packed int4 weight
  381. module.linear_weights.pop("qweight")
  382. module.linear_weights.pop("qzeros")
  383. module.linear_weights.pop("scales")
  384. _qweight = module._parameters["qweight"]
  385. _qzeros = module._parameters["qzeros"]
  386. _scales = module._parameters["scales"]
  387. qweight, scales_zeros = convert_s4(_qweight.data, _qzeros.data,
  388. _scales.data)
  389. param_qweight = Parameter(qweight, requires_grad=False)
  390. param_scales_zeros = Parameter(scales_zeros,
  391. requires_grad=False)
  392. del _qweight
  393. del _qzeros
  394. del _scales
  395. delattr(module, "qweight")
  396. delattr(module, "qzeros")
  397. delattr(module, "scales")
  398. module.register_parameter("qweight", param_qweight)
  399. module.register_parameter("scales_zeros", param_scales_zeros)
  400. set_weight_attrs(
  401. param_qweight, {
  402. "input_dim": 0,
  403. "output_dim": 1,
  404. "packed_dim": 1,
  405. "pack_factor": quant_config.pack_factor,
  406. })
  407. set_weight_attrs(param_scales_zeros, {
  408. "input_dim": 0,
  409. "output_dim": 1,
  410. })
  411. module.linear_weights["qweight"] = param_qweight
  412. module.linear_weights["scales_zeros"] = param_scales_zeros
  413. torch.cuda.synchronize()
  414. torch.cuda.empty_cache()