bitsandbytes.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464
  1. from contextlib import suppress
  2. from typing import Any, Dict, List, NamedTuple, Optional, TypeVar
  3. import torch
  4. from torch.nn.parameter import Parameter
  5. from aphrodite.modeling.layers.linear import (ColumnParallelLinear, LinearBase,
  6. LinearMethodBase,
  7. QKVParallelLinear,
  8. RowParallelLinear)
  9. from aphrodite.modeling.utils import set_weight_attrs
  10. from aphrodite.quantization.base_config import QuantizationConfig
  11. HAS_QUANTS = False
  12. with suppress(ImportError):
  13. from aphrodite._quant_C import quant_ops as ops
  14. HAS_QUANTS = True
  15. class BitsandBytesConfig(QuantizationConfig):
  16. """Config class for BitsandBytes.
  17. Reference: https://arxiv.org/abs/2208.07339
  18. """
  19. def __init__(
  20. self,
  21. weight_bits: int,
  22. group_size: int,
  23. zero_point: bool,
  24. from_float: bool,
  25. quant_mode: str, # llm_int8, smoothquant, weight_only
  26. ) -> None:
  27. if not HAS_QUANTS:
  28. raise ImportError("Could not find the quantization kernels.")
  29. self.weight_bits = weight_bits
  30. self.group_size = group_size
  31. self.zero_point = zero_point
  32. self.from_float = from_float
  33. self.quant_mode = quant_mode
  34. if quant_mode == "weight_only" and self.weight_bits != 4:
  35. raise ValueError(
  36. "Currently, only 4-bit weight quantization is supported for "
  37. f"BNB weight_only, but got {self.weight_bits} bits.")
  38. if quant_mode in ["llm_int8", "smoothquant"] and self.weight_bits != 8:
  39. raise ValueError(
  40. "Currently, only 8-bit weight quantization is supported for "
  41. "BNB llm_int8 or smoothquant, "
  42. f"but got {self.weight_bits} bits.")
  43. self.pack_factor = 32 // self.weight_bits
  44. def __repr__(self) -> str:
  45. return (f"BitsandBytesConfig(weight_bits={self.weight_bits}, "
  46. f"group_size={self.group_size}, "
  47. f"zero_point={self.zero_point}, "
  48. f"from_float={self.from_float}, "
  49. f"quant_mode={self.quant_mode})")
  50. def get_name(self) -> str:
  51. return "bitsandbytes"
  52. def get_supported_act_dtypes(self) -> List[torch.dtype]:
  53. return [torch.half, torch.bfloat16]
  54. def get_min_capability(self) -> int:
  55. # The BitsandBytes kernel only supports Ampere or newer GPUs.
  56. return 75
  57. @staticmethod
  58. def get_config_filenames() -> List[str]:
  59. return [
  60. "quant_config.json",
  61. "quantize_config.json",
  62. ]
  63. @classmethod
  64. def from_config(cls, config: Dict[str, Any]) -> "BitsandBytesConfig":
  65. weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
  66. group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
  67. zero_point = cls.get_from_keys(config, ["zero_point"])
  68. try:
  69. from_float = cls.get_from_keys(config, ["from_float"])
  70. except Exception:
  71. from_float = False
  72. try:
  73. quant_mode = cls.get_from_keys(config, ["quant_mode"])
  74. except Exception:
  75. quant_mode = "weight_only"
  76. return cls(weight_bits, group_size, zero_point, from_float, quant_mode)
  77. def get_quant_method(
  78. self, layer: torch.nn.Module) -> Optional["BNBLinearMethod"]:
  79. if isinstance(layer, LinearBase):
  80. return BNBLinearMethod(self)
  81. return None
  82. def get_scaled_act_names(self) -> List[str]:
  83. return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
  84. class BNBLinearMethod(LinearMethodBase):
  85. """Linear method for BitsandBytes.
  86. Args:
  87. quant_config: The BitsandBytes quantization config.
  88. """
  89. def __init__(self, quant_config: BitsandBytesConfig):
  90. if not HAS_QUANTS:
  91. raise ImportError("Could not find the quantization kernels.")
  92. self.quant_config = quant_config
  93. def create_weights(self, layer: torch.nn.Module,
  94. input_size_per_partition: int,
  95. output_partition_sizes: List[int], input_size: int,
  96. output_size: int, params_dtype: torch.dtype,
  97. **extra_weight_attrs):
  98. if self.quant_config.quant_mode == "weight_only" and \
  99. input_size_per_partition % self.quant_config.group_size != 0:
  100. raise ValueError(
  101. "The input size is not aligned with the quantized "
  102. "weight shape. This can be caused by too large "
  103. "tensor parallel size.")
  104. output_size_per_partition = sum(output_partition_sizes)
  105. if self.quant_config.quant_mode == "weight_only" and \
  106. output_size_per_partition % self.quant_config.pack_factor != 0:
  107. raise ValueError(
  108. "The output size is not aligned with the quantized "
  109. "weight shape. This can be caused by too large "
  110. "tensor parallel size.")
  111. if self.quant_config.quant_mode == "weight_only" and \
  112. not self.quant_config.from_float:
  113. qweight = Parameter(
  114. torch.empty(
  115. input_size_per_partition,
  116. output_size_per_partition // self.quant_config.pack_factor,
  117. dtype=torch.int32,
  118. ),
  119. requires_grad=False,
  120. )
  121. set_weight_attrs(
  122. qweight, {
  123. "input_dim": 0,
  124. "output_dim": 1,
  125. "packed_dim": 1,
  126. "pack_factor": self.quant_config.pack_factor,
  127. })
  128. qzeros = Parameter(
  129. torch.empty(
  130. input_size_per_partition // self.quant_config.group_size,
  131. output_size_per_partition // self.quant_config.pack_factor,
  132. dtype=torch.int32,
  133. ),
  134. requires_grad=False,
  135. )
  136. set_weight_attrs(
  137. qzeros, {
  138. "input_dim": 0,
  139. "output_dim": 1,
  140. "packed_dim": 1,
  141. "pack_factor": self.quant_config.pack_factor,
  142. })
  143. scales = Parameter(
  144. torch.empty(
  145. input_size_per_partition // self.quant_config.group_size,
  146. output_size_per_partition,
  147. dtype=params_dtype,
  148. ),
  149. requires_grad=False,
  150. )
  151. set_weight_attrs(scales, {
  152. "input_dim": 0,
  153. "output_dim": 1,
  154. })
  155. layer.register_parameter("qweight", qweight)
  156. set_weight_attrs(qweight, extra_weight_attrs)
  157. layer.register_parameter("qzeros", qzeros)
  158. set_weight_attrs(qzeros, extra_weight_attrs)
  159. layer.register_parameter("scales", scales)
  160. set_weight_attrs(scales, extra_weight_attrs)
  161. else:
  162. weight = Parameter(torch.empty(output_size_per_partition,
  163. input_size_per_partition,
  164. dtype=params_dtype),
  165. requires_grad=False)
  166. set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
  167. layer.register_parameter("weight", weight)
  168. set_weight_attrs(weight, extra_weight_attrs)
  169. def apply(self,
  170. layer: torch.nn.Module,
  171. x: torch.Tensor,
  172. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  173. if self.quant_config.quant_mode == "weight_only":
  174. qweight = layer.qweight
  175. scales_zeros = layer.scales_zeros
  176. pack_factor = self.quant_config.pack_factor
  177. out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
  178. reshaped_x = x.reshape(-1, x.shape[-1])
  179. out = ops.autoquant_s4_f16_gemm(reshaped_x, qweight, scales_zeros)
  180. if bias is not None:
  181. out = out + bias
  182. return out.reshape(out_shape)
  183. else:
  184. weight = layer.weight
  185. state = layer.state
  186. if weight.CB is not None:
  187. state.CB = weight.CB
  188. state.SCB = weight.SCB
  189. weight.CB = None
  190. weight.SCB = None
  191. import bitsandbytes as bnb
  192. out = bnb.matmul(x, weight, bias=bias, state=state)
  193. if not state.has_fp16_weights and \
  194. state.CB is not None and state.CxB is not None:
  195. # we converted 8-bit row major to turing/ampere format
  196. # in the first inference pass
  197. # we no longer need the row-major weight
  198. del state.CB
  199. weight.data = state.CxB
  200. return out
  201. T = TypeVar("T", bound="torch.nn.Module")
  202. class QParams(NamedTuple):
  203. """A class to hold the quantization parameters."""
  204. scales: torch.Tensor
  205. zero_points: Optional[torch.Tensor]
  206. @torch.no_grad()
  207. def cal_qparams_per_group_minmax(w: torch.Tensor,
  208. n_bits: int = 4,
  209. group_size: int = 128):
  210. """Calculate quantization parameters for each group using min and max
  211. values."""
  212. outc, inc = w.shape
  213. assert inc >= group_size, \
  214. 'Input channels should be greater than or equal to group_size.'
  215. assert inc % group_size == 0, \
  216. 'Input channels should be divisible by group_size.'
  217. w_group_wise = w.reshape(outc, -1, group_size)
  218. w_min = w_group_wise.min(dim=-1, keepdim=True)[0]
  219. w_max = w_group_wise.max(dim=-1, keepdim=True)[0]
  220. q_max = 2**n_bits - 1
  221. q_min = 0
  222. scales = (w_max - w_min)
  223. scales = scales.clamp_(min=1e-5).div_(q_max)
  224. # zero_points = (-w_min / scales).round().clamp(q_min, q_max)
  225. zero_points = (-torch.round(w_min / scales)).clamp_(q_min, q_max)
  226. return QParams(scales=scales, zero_points=zero_points)
  227. def convert_s4(qw: torch.Tensor,
  228. qz: torch.Tensor,
  229. s: torch.Tensor,
  230. group_size: int = 128):
  231. assert qw.is_contiguous()
  232. assert qz.is_contiguous()
  233. assert s.is_contiguous()
  234. _qw = torch.zeros_like(qw)
  235. _sz = torch.zeros_like(s, dtype=torch.int32) # half2
  236. _ws = torch.zeros_like(s)
  237. ops.autoquant_convert_s4_k_m8(_qw, _sz, _ws, qw, s, qz,
  238. qw.size(-1) * 8, qw.size(0), group_size)
  239. return _qw, _sz
  240. def tp_m_s4(x: torch.Tensor, tp: int = 1):
  241. return x.view(x.size(0) // 32, tp, -1, 128).permute(0, 2, 3,
  242. 1).contiguous()
  243. def quant(weight: torch.Tensor,
  244. qparams: Optional[QParams] = None) -> torch.Tensor:
  245. """Perform fake quantization on the given weight tensor.
  246. Args:
  247. weight (torch.Tensor): The weight tensor with shape
  248. (out_features, in_features).
  249. qparams (Optional[QParams]): A namedtuple containing 'scales'
  250. and 'zero_points'.
  251. Returns:
  252. torch.Tensor: The fake quantized weight tensor.
  253. """
  254. if qparams is None:
  255. qparams = cal_qparams_per_group_minmax(weight)
  256. scales = qparams.scales
  257. zero_points = qparams.zero_points
  258. out_c, in_c = weight.shape
  259. # Reshape the weights if using per_group quantization
  260. # per tensor scales shape: [1]
  261. # per channel scales shape: [out_c, 1]
  262. # per group scales shape: [out_c, in_c//group_size, 1]
  263. if len(scales.shape) > 2:
  264. # scales shape: [out_c, in_c//group_size, 1]
  265. weight = weight.reshape(out_c, scales.shape[1], -1)
  266. if zero_points is None:
  267. real_qweight = (weight / scales).round()
  268. else:
  269. real_qweight = ((weight + (scales * zero_points)) / scales).round()
  270. if len(scales.shape) > 2:
  271. real_qweight = real_qweight.reshape(out_c, in_c)
  272. return real_qweight.to(torch.int32)
  273. # core quantization method (simulated quantization)
  274. def quantize_tensor(
  275. weight,
  276. n_bits=4,
  277. group_size=128,
  278. ):
  279. pack_num = 32 // n_bits
  280. pack_order = [0, 2, 4, 6, 1, 3, 5, 7]
  281. org_weight_shape = weight.shape
  282. out_features = org_weight_shape[0]
  283. in_features = org_weight_shape[1]
  284. qparams = cal_qparams_per_group_minmax(weight, n_bits)
  285. i32_w = quant(weight, qparams)
  286. i32_w = i32_w.t().contiguous()
  287. w_pack_oc = out_features // (32 // n_bits)
  288. w_inc = in_features
  289. pack_int_w = torch.zeros((w_inc, w_pack_oc),
  290. dtype=torch.int32,
  291. device=weight.device)
  292. for col in range(pack_int_w.shape[1]):
  293. for i in range(pack_num):
  294. pack_int_w_col = i32_w[:, col * pack_num + pack_order[i]]
  295. pack_int_w[:, col] |= pack_int_w_col << (i * n_bits)
  296. qweight = pack_int_w
  297. scales = qparams.scales.squeeze(-1).t().contiguous()
  298. if qparams.zero_points is not None:
  299. zeros = qparams.zero_points.to(torch.int32)
  300. zeros = zeros.squeeze(-1).t().contiguous()
  301. z_inc = in_features // group_size
  302. z_oc = out_features // (32 // n_bits)
  303. pack_int_zeros = torch.zeros((z_inc, z_oc),
  304. dtype=torch.int32,
  305. device=weight.device)
  306. for col in range(pack_int_zeros.shape[1]):
  307. for i in range(pack_num):
  308. qzero_col = zeros[:, col * pack_num + pack_order[i]]
  309. pack_int_zeros[:, col] |= qzero_col << (i * n_bits)
  310. qzeros = pack_int_zeros
  311. return qweight, scales, qzeros
  312. def replace_quant_params(model,
  313. quant_config,
  314. modules_to_not_convert="lm_head"):
  315. """
  316. modules_to_not_convert (`str`, *optional*, defaults to `lm_head`):
  317. Name of the module to not convert in `Linear8bitLt`.
  318. In practice we keep the `lm_head` in full precision
  319. for numerical stability reasons.
  320. """
  321. if not isinstance(modules_to_not_convert, list):
  322. modules_to_not_convert = [modules_to_not_convert]
  323. for name, module in model.named_children():
  324. if len(list(module.children())) > 0:
  325. replace_quant_params(module, quant_config, modules_to_not_convert)
  326. if isinstance(
  327. module,
  328. (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear)) \
  329. and name not in modules_to_not_convert:
  330. if quant_config.from_float:
  331. module.linear_weights.pop("weight")
  332. param = module._parameters["weight"]
  333. if quant_config.quant_mode in ("llm_int8", "smoothquant"):
  334. import bitsandbytes as bnb
  335. new_value = bnb.nn.Int8Params(param.data,
  336. requires_grad=False,
  337. has_fp16_weights=False)
  338. state = bnb.MatmulLtState()
  339. if quant_config.quant_mode == "smoothquant":
  340. state.threshold = 0.0
  341. else:
  342. state.threshold = 6.0
  343. state.has_fp16_weights = False
  344. state.memory_efficient_backward = False
  345. state.use_pool = True
  346. module._parameters["weight"] = new_value
  347. module.linear_weights["weight"] = new_value
  348. module.linear_weights["state"] = state
  349. set_weight_attrs(
  350. new_value, {
  351. "input_dim": 0,
  352. "output_dim": 1,
  353. "packed_dim": 1,
  354. "pack_factor": quant_config.pack_factor,
  355. })
  356. del param
  357. torch.cuda.empty_cache()
  358. elif quant_config.quant_mode == "weight_only":
  359. data_fp = param.cuda()
  360. _qweight, _scales, _qzeros = quantize_tensor(
  361. data_fp, n_bits=4, group_size=128)
  362. qweight, scales_zeros = convert_s4(_qweight, _qzeros,
  363. _scales)
  364. torch.cuda.synchronize()
  365. param_qweight = Parameter(qweight, requires_grad=False)
  366. param_scales_zeros = Parameter(scales_zeros,
  367. requires_grad=False)
  368. module.register_parameter("qweight", param_qweight)
  369. module.register_parameter("scales_zeros",
  370. param_scales_zeros)
  371. set_weight_attrs(
  372. param_qweight, {
  373. "input_dim": 0,
  374. "output_dim": 1,
  375. "packed_dim": 1,
  376. "pack_factor": quant_config.pack_factor,
  377. })
  378. set_weight_attrs(param_scales_zeros, {
  379. "input_dim": 0,
  380. "output_dim": 1,
  381. })
  382. module.linear_weights["qweight"] = param_qweight
  383. module.linear_weights["scales_zeros"] = param_scales_zeros
  384. del _qzeros
  385. del _scales
  386. del param
  387. delattr(module, "weight")
  388. torch.cuda.empty_cache()
  389. else: # load packed int4 weight
  390. module.linear_weights.pop("qweight")
  391. module.linear_weights.pop("qzeros")
  392. module.linear_weights.pop("scales")
  393. _qweight = module._parameters["qweight"]
  394. _qzeros = module._parameters["qzeros"]
  395. _scales = module._parameters["scales"]
  396. qweight, scales_zeros = convert_s4(_qweight.data, _qzeros.data,
  397. _scales.data)
  398. param_qweight = Parameter(qweight, requires_grad=False)
  399. param_scales_zeros = Parameter(scales_zeros,
  400. requires_grad=False)
  401. del _qweight
  402. del _qzeros
  403. del _scales
  404. delattr(module, "qweight")
  405. delattr(module, "qzeros")
  406. delattr(module, "scales")
  407. module.register_parameter("qweight", param_qweight)
  408. module.register_parameter("scales_zeros", param_scales_zeros)
  409. set_weight_attrs(
  410. param_qweight, {
  411. "input_dim": 0,
  412. "output_dim": 1,
  413. "packed_dim": 1,
  414. "pack_factor": quant_config.pack_factor,
  415. })
  416. set_weight_attrs(param_scales_zeros, {
  417. "input_dim": 0,
  418. "output_dim": 1,
  419. })
  420. module.linear_weights["qweight"] = param_qweight
  421. module.linear_weights["scales_zeros"] = param_scales_zeros
  422. torch.cuda.synchronize()
  423. torch.cuda.empty_cache()