bitsandbytes.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. from contextlib import suppress
  2. from typing import Any, Dict, List, NamedTuple, Optional, TypeVar
  3. import torch
  4. from torch.nn.parameter import Parameter
  5. from aphrodite.modeling.layers.linear import (ColumnParallelLinear, LinearBase,
  6. LinearMethodBase,
  7. QKVParallelLinear,
  8. RowParallelLinear)
  9. from aphrodite.modeling.utils import set_weight_attrs
  10. from aphrodite.quantization.base_config import QuantizationConfig
  11. HAS_QUANTS = False
  12. with suppress(ImportError):
  13. from aphrodite._quant_C import quant_ops as ops
  14. HAS_QUANTS = True
  15. class BitsandBytesConfig(QuantizationConfig):
  16. """Config class for BitsandBytes.
  17. Reference: https://arxiv.org/abs/2208.07339
  18. """
  19. def __init__(
  20. self,
  21. weight_bits: int,
  22. group_size: int,
  23. zero_point: bool,
  24. from_float: bool,
  25. quant_mode: str, # llm_int8, smoothquant, weight_only
  26. ) -> None:
  27. if not HAS_QUANTS:
  28. raise ImportError("Could not find the quantization kernels.")
  29. self.weight_bits = weight_bits
  30. self.group_size = group_size
  31. self.zero_point = zero_point
  32. self.from_float = from_float
  33. self.quant_mode = quant_mode
  34. if quant_mode == "weight_only" and self.weight_bits != 4:
  35. raise ValueError(
  36. "Currently, only 4-bit weight quantization is supported for "
  37. f"BNB weight_only, but got {self.weight_bits} bits.")
  38. if quant_mode in ["llm_int8", "smoothquant"] and self.weight_bits != 8:
  39. raise ValueError(
  40. "Currently, only 8-bit weight quantization is supported for "
  41. "BNB llm_int8 or smoothquant, "
  42. f"but got {self.weight_bits} bits.")
  43. self.pack_factor = 32 // self.weight_bits
  44. def __repr__(self) -> str:
  45. return (f"BitsandBytesConfig(weight_bits={self.weight_bits}, "
  46. f"group_size={self.group_size}, "
  47. f"zero_point={self.zero_point}, "
  48. f"from_float={self.from_float}, "
  49. f"quant_mode={self.quant_mode})")
  50. def get_name(self) -> str:
  51. return "bitsandbytes"
  52. def get_supported_act_dtypes(self) -> List[torch.dtype]:
  53. return [torch.half, torch.bfloat16]
  54. def get_min_capability(self) -> int:
  55. # The BitsandBytes kernel only supports Ampere or newer GPUs.
  56. return 75
  57. @staticmethod
  58. def get_config_filenames() -> List[str]:
  59. return [
  60. "quant_config.json",
  61. "quantize_config.json",
  62. ]
  63. @classmethod
  64. def from_config(cls, config: Dict[str, Any]) -> "BitsandBytesConfig":
  65. weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
  66. group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
  67. zero_point = cls.get_from_keys(config, ["zero_point"])
  68. try:
  69. from_float = cls.get_from_keys(config, ["from_float"])
  70. except Exception:
  71. from_float = False
  72. try:
  73. quant_mode = cls.get_from_keys(config, ["quant_mode"])
  74. except Exception:
  75. quant_mode = "weight_only"
  76. return cls(weight_bits, group_size, zero_point, from_float, quant_mode)
  77. def get_quant_method(
  78. self, layer: torch.nn.Module) -> Optional["BNBLinearMethod"]:
  79. if isinstance(layer, LinearBase):
  80. return BNBLinearMethod(self)
  81. return None
  82. def get_scaled_act_names(self) -> List[str]:
  83. return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
  84. class BNBLinearMethod(LinearMethodBase):
  85. """Linear method for BitsandBytes.
  86. Args:
  87. quant_config: The BitsandBytes quantization config.
  88. """
  89. def __init__(self, quant_config: BitsandBytesConfig):
  90. self.quant_config = quant_config
  91. def create_weights(self, layer: torch.nn.Module,
  92. input_size_per_partition: int,
  93. output_partition_sizes: List[int], input_size: int,
  94. output_size: int, params_dtype: torch.dtype,
  95. **extra_weight_attrs):
  96. if self.quant_config.quant_mode == "weight_only" and \
  97. input_size_per_partition % self.quant_config.group_size != 0:
  98. raise ValueError(
  99. "The input size is not aligned with the quantized "
  100. "weight shape. This can be caused by too large "
  101. "tensor parallel size.")
  102. output_size_per_partition = sum(output_partition_sizes)
  103. if self.quant_config.quant_mode == "weight_only" and \
  104. output_size_per_partition % self.quant_config.pack_factor != 0:
  105. raise ValueError(
  106. "The output size is not aligned with the quantized "
  107. "weight shape. This can be caused by too large "
  108. "tensor parallel size.")
  109. if self.quant_config.quant_mode == "weight_only" and \
  110. not self.quant_config.from_float:
  111. qweight = Parameter(
  112. torch.empty(
  113. input_size_per_partition,
  114. output_size_per_partition // self.quant_config.pack_factor,
  115. dtype=torch.int32,
  116. ),
  117. requires_grad=False,
  118. )
  119. set_weight_attrs(
  120. qweight, {
  121. "input_dim": 0,
  122. "output_dim": 1,
  123. "packed_dim": 1,
  124. "pack_factor": self.quant_config.pack_factor,
  125. })
  126. qzeros = Parameter(
  127. torch.empty(
  128. input_size_per_partition // self.quant_config.group_size,
  129. output_size_per_partition // self.quant_config.pack_factor,
  130. dtype=torch.int32,
  131. ),
  132. requires_grad=False,
  133. )
  134. set_weight_attrs(
  135. qzeros, {
  136. "input_dim": 0,
  137. "output_dim": 1,
  138. "packed_dim": 1,
  139. "pack_factor": self.quant_config.pack_factor,
  140. })
  141. scales = Parameter(
  142. torch.empty(
  143. input_size_per_partition // self.quant_config.group_size,
  144. output_size_per_partition,
  145. dtype=params_dtype,
  146. ),
  147. requires_grad=False,
  148. )
  149. set_weight_attrs(scales, {
  150. "input_dim": 0,
  151. "output_dim": 1,
  152. })
  153. layer.register_parameter("qweight", qweight)
  154. set_weight_attrs(qweight, extra_weight_attrs)
  155. layer.register_parameter("qzeros", qzeros)
  156. set_weight_attrs(qzeros, extra_weight_attrs)
  157. layer.register_parameter("scales", scales)
  158. set_weight_attrs(scales, extra_weight_attrs)
  159. else:
  160. weight = Parameter(torch.empty(output_size_per_partition,
  161. input_size_per_partition,
  162. dtype=params_dtype),
  163. requires_grad=False)
  164. set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
  165. layer.register_parameter("weight", weight)
  166. set_weight_attrs(weight, extra_weight_attrs)
  167. def apply(self,
  168. layer: torch.nn.Module,
  169. x: torch.Tensor,
  170. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  171. if self.quant_config.quant_mode == "weight_only":
  172. qweight = layer.qweight
  173. scales_zeros = layer.scales_zeros
  174. pack_factor = self.quant_config.pack_factor
  175. out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
  176. reshaped_x = x.reshape(-1, x.shape[-1])
  177. out = ops.autoquant_s4_f16_gemm(reshaped_x, qweight, scales_zeros)
  178. if bias is not None:
  179. out = out + bias
  180. return out.reshape(out_shape)
  181. else:
  182. weight = layer.weight
  183. state = layer.state
  184. if weight.CB is not None:
  185. state.CB = weight.CB
  186. state.SCB = weight.SCB
  187. weight.CB = None
  188. weight.SCB = None
  189. import bitsandbytes as bnb
  190. out = bnb.matmul(x, weight, bias=bias, state=state)
  191. if not state.has_fp16_weights and \
  192. state.CB is not None and state.CxB is not None:
  193. # we converted 8-bit row major to turing/ampere format
  194. # in the first inference pass
  195. # we no longer need the row-major weight
  196. del state.CB
  197. weight.data = state.CxB
  198. return out
  199. T = TypeVar("T", bound="torch.nn.Module")
  200. class QParams(NamedTuple):
  201. """A class to hold the quantization parameters."""
  202. scales: torch.Tensor
  203. zero_points: Optional[torch.Tensor]
  204. @torch.no_grad()
  205. def cal_qparams_per_group_minmax(w: torch.Tensor,
  206. n_bits: int = 4,
  207. group_size: int = 128):
  208. """Calculate quantization parameters for each group using min and max
  209. values."""
  210. outc, inc = w.shape
  211. assert inc >= group_size, \
  212. 'Input channels should be greater than or equal to group_size.'
  213. assert inc % group_size == 0, \
  214. 'Input channels should be divisible by group_size.'
  215. w_group_wise = w.reshape(outc, -1, group_size)
  216. w_min = w_group_wise.min(dim=-1, keepdim=True)[0]
  217. w_max = w_group_wise.max(dim=-1, keepdim=True)[0]
  218. q_max = 2**n_bits - 1
  219. q_min = 0
  220. scales = (w_max - w_min)
  221. scales = scales.clamp_(min=1e-5).div_(q_max)
  222. # zero_points = (-w_min / scales).round().clamp(q_min, q_max)
  223. zero_points = (-torch.round(w_min / scales)).clamp_(q_min, q_max)
  224. return QParams(scales=scales, zero_points=zero_points)
  225. def convert_s4(qw: torch.Tensor,
  226. qz: torch.Tensor,
  227. s: torch.Tensor,
  228. group_size: int = 128):
  229. assert qw.is_contiguous()
  230. assert qz.is_contiguous()
  231. assert s.is_contiguous()
  232. _qw = torch.zeros_like(qw)
  233. _sz = torch.zeros_like(s, dtype=torch.int32) # half2
  234. _ws = torch.zeros_like(s)
  235. ops.autoquant_convert_s4_k_m8(_qw, _sz, _ws, qw, s, qz,
  236. qw.size(-1) * 8, qw.size(0), group_size)
  237. return _qw, _sz
  238. def tp_m_s4(x: torch.Tensor, tp: int = 1):
  239. return x.view(x.size(0) // 32, tp, -1, 128).permute(0, 2, 3,
  240. 1).contiguous()
  241. def quant(weight: torch.Tensor,
  242. qparams: Optional[QParams] = None) -> torch.Tensor:
  243. """Perform fake quantization on the given weight tensor.
  244. Args:
  245. weight (torch.Tensor): The weight tensor with shape
  246. (out_features, in_features).
  247. qparams (Optional[QParams]): A namedtuple containing 'scales'
  248. and 'zero_points'.
  249. Returns:
  250. torch.Tensor: The fake quantized weight tensor.
  251. """
  252. if qparams is None:
  253. qparams = cal_qparams_per_group_minmax(weight)
  254. scales = qparams.scales
  255. zero_points = qparams.zero_points
  256. out_c, in_c = weight.shape
  257. # Reshape the weights if using per_group quantization
  258. # per tensor scales shape: [1]
  259. # per channel scales shape: [out_c, 1]
  260. # per group scales shape: [out_c, in_c//group_size, 1]
  261. if len(scales.shape) > 2:
  262. # scales shape: [out_c, in_c//group_size, 1]
  263. weight = weight.reshape(out_c, scales.shape[1], -1)
  264. if zero_points is None:
  265. real_qweight = (weight / scales).round()
  266. else:
  267. real_qweight = ((weight + (scales * zero_points)) / scales).round()
  268. if len(scales.shape) > 2:
  269. real_qweight = real_qweight.reshape(out_c, in_c)
  270. return real_qweight.to(torch.int32)
  271. # core quantization method (simulated quantization)
  272. def quantize_tensor(
  273. weight,
  274. n_bits=4,
  275. group_size=128,
  276. ):
  277. pack_num = 32 // n_bits
  278. pack_order = [0, 2, 4, 6, 1, 3, 5, 7]
  279. org_weight_shape = weight.shape
  280. out_features = org_weight_shape[0]
  281. in_features = org_weight_shape[1]
  282. qparams = cal_qparams_per_group_minmax(weight, n_bits)
  283. i32_w = quant(weight, qparams)
  284. i32_w = i32_w.t().contiguous()
  285. w_pack_oc = out_features // (32 // n_bits)
  286. w_inc = in_features
  287. pack_int_w = torch.zeros((w_inc, w_pack_oc),
  288. dtype=torch.int32,
  289. device=weight.device)
  290. for col in range(pack_int_w.shape[1]):
  291. for i in range(pack_num):
  292. pack_int_w_col = i32_w[:, col * pack_num + pack_order[i]]
  293. pack_int_w[:, col] |= pack_int_w_col << (i * n_bits)
  294. qweight = pack_int_w
  295. scales = qparams.scales.squeeze(-1).t().contiguous()
  296. if qparams.zero_points is not None:
  297. zeros = qparams.zero_points.to(torch.int32)
  298. zeros = zeros.squeeze(-1).t().contiguous()
  299. z_inc = in_features // group_size
  300. z_oc = out_features // (32 // n_bits)
  301. pack_int_zeros = torch.zeros((z_inc, z_oc),
  302. dtype=torch.int32,
  303. device=weight.device)
  304. for col in range(pack_int_zeros.shape[1]):
  305. for i in range(pack_num):
  306. qzero_col = zeros[:, col * pack_num + pack_order[i]]
  307. pack_int_zeros[:, col] |= qzero_col << (i * n_bits)
  308. qzeros = pack_int_zeros
  309. return qweight, scales, qzeros
  310. def replace_quant_params(model,
  311. quant_config,
  312. modules_to_not_convert="lm_head"):
  313. """
  314. modules_to_not_convert (`str`, *optional*, defaults to `lm_head`):
  315. Name of the module to not convert in `Linear8bitLt`.
  316. In practice we keep the `lm_head` in full precision
  317. for numerical stability reasons.
  318. """
  319. if not isinstance(modules_to_not_convert, list):
  320. modules_to_not_convert = [modules_to_not_convert]
  321. for name, module in model.named_children():
  322. if len(list(module.children())) > 0:
  323. replace_quant_params(module, quant_config, modules_to_not_convert)
  324. if isinstance(
  325. module,
  326. (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear)) \
  327. and name not in modules_to_not_convert:
  328. if quant_config.from_float:
  329. module.linear_weights.pop("weight")
  330. param = module._parameters["weight"]
  331. if quant_config.quant_mode in ("llm_int8", "smoothquant"):
  332. import bitsandbytes as bnb
  333. new_value = bnb.nn.Int8Params(param.data,
  334. requires_grad=False,
  335. has_fp16_weights=False)
  336. state = bnb.MatmulLtState()
  337. if quant_config.quant_mode == "smoothquant":
  338. state.threshold = 0.0
  339. else:
  340. state.threshold = 6.0
  341. state.has_fp16_weights = False
  342. state.memory_efficient_backward = False
  343. state.use_pool = True
  344. module._parameters["weight"] = new_value
  345. module.linear_weights["weight"] = new_value
  346. module.linear_weights["state"] = state
  347. set_weight_attrs(
  348. new_value, {
  349. "input_dim": 0,
  350. "output_dim": 1,
  351. "packed_dim": 1,
  352. "pack_factor": quant_config.pack_factor,
  353. })
  354. del param
  355. torch.cuda.empty_cache()
  356. elif quant_config.quant_mode == "weight_only":
  357. data_fp = param.cuda()
  358. _qweight, _scales, _qzeros = quantize_tensor(
  359. data_fp, n_bits=4, group_size=128)
  360. qweight, scales_zeros = convert_s4(_qweight, _qzeros,
  361. _scales)
  362. torch.cuda.synchronize()
  363. param_qweight = Parameter(qweight, requires_grad=False)
  364. param_scales_zeros = Parameter(scales_zeros,
  365. requires_grad=False)
  366. module.register_parameter("qweight", param_qweight)
  367. module.register_parameter("scales_zeros",
  368. param_scales_zeros)
  369. set_weight_attrs(
  370. param_qweight, {
  371. "input_dim": 0,
  372. "output_dim": 1,
  373. "packed_dim": 1,
  374. "pack_factor": quant_config.pack_factor,
  375. })
  376. set_weight_attrs(param_scales_zeros, {
  377. "input_dim": 0,
  378. "output_dim": 1,
  379. })
  380. module.linear_weights["qweight"] = param_qweight
  381. module.linear_weights["scales_zeros"] = param_scales_zeros
  382. del _qzeros
  383. del _scales
  384. del param
  385. delattr(module, "weight")
  386. torch.cuda.empty_cache()
  387. else: # load packed int4 weight
  388. module.linear_weights.pop("qweight")
  389. module.linear_weights.pop("qzeros")
  390. module.linear_weights.pop("scales")
  391. _qweight = module._parameters["qweight"]
  392. _qzeros = module._parameters["qzeros"]
  393. _scales = module._parameters["scales"]
  394. qweight, scales_zeros = convert_s4(_qweight.data, _qzeros.data,
  395. _scales.data)
  396. param_qweight = Parameter(qweight, requires_grad=False)
  397. param_scales_zeros = Parameter(scales_zeros,
  398. requires_grad=False)
  399. del _qweight
  400. del _qzeros
  401. del _scales
  402. delattr(module, "qweight")
  403. delattr(module, "qzeros")
  404. delattr(module, "scales")
  405. module.register_parameter("qweight", param_qweight)
  406. module.register_parameter("scales_zeros", param_scales_zeros)
  407. set_weight_attrs(
  408. param_qweight, {
  409. "input_dim": 0,
  410. "output_dim": 1,
  411. "packed_dim": 1,
  412. "pack_factor": quant_config.pack_factor,
  413. })
  414. set_weight_attrs(param_scales_zeros, {
  415. "input_dim": 0,
  416. "output_dim": 1,
  417. })
  418. module.linear_weights["qweight"] = param_qweight
  419. module.linear_weights["scales_zeros"] = param_scales_zeros
  420. torch.cuda.synchronize()
  421. torch.cuda.empty_cache()