bitsandbytes.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458
  1. import torch
  2. from torch.nn.parameter import Parameter
  3. from typing import List, Dict, Any, Optional, TypeVar, NamedTuple
  4. from contextlib import suppress
  5. from aphrodite.modeling.layers.linear import (LinearMethodBase,
  6. set_weight_attrs)
  7. from aphrodite.quantization.base_config import (QuantizationConfig)
  8. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  9. QKVParallelLinear,
  10. RowParallelLinear)
  11. HAS_QUANTS = False
  12. with suppress(ImportError):
  13. from aphrodite._quant_C import quant_ops as ops
  14. HAS_QUANTS = True
  15. class BitsandBytesConfig(QuantizationConfig):
  16. """Config class for BitsandBytes.
  17. Reference: https://arxiv.org/abs/2208.07339
  18. """
  19. def __init__(
  20. self,
  21. weight_bits: int,
  22. group_size: int,
  23. zero_point: bool,
  24. from_float: bool,
  25. quant_mode: str, # llm_int8, smoothquant, weight_only
  26. ) -> None:
  27. if not HAS_QUANTS:
  28. raise ImportError("Could not find the quantization kernels.")
  29. self.weight_bits = weight_bits
  30. self.group_size = group_size
  31. self.zero_point = zero_point
  32. self.from_float = from_float
  33. self.quant_mode = quant_mode
  34. if quant_mode == "weight_only" and self.weight_bits != 4:
  35. raise ValueError(
  36. "Currently, only 4-bit weight quantization is supported for "
  37. f"BNB weight_only, but got {self.weight_bits} bits.")
  38. if quant_mode in ["llm_int8", "smoothquant"] and self.weight_bits != 8:
  39. raise ValueError(
  40. "Currently, only 8-bit weight quantization is supported for "
  41. "BNB llm_int8 or smoothquant, "
  42. f"but got {self.weight_bits} bits.")
  43. self.pack_factor = 32 // self.weight_bits
  44. def __repr__(self) -> str:
  45. return (f"BitsandBytesConfig(weight_bits={self.weight_bits}, "
  46. f"group_size={self.group_size}, "
  47. f"zero_point={self.zero_point}, "
  48. f"from_float={self.from_float}, "
  49. f"quant_mode={self.quant_mode})")
  50. def get_name(self) -> str:
  51. return "bitsandbytes"
  52. def get_supported_act_dtypes(self) -> List[torch.dtype]:
  53. return [torch.half, torch.bfloat16]
  54. def get_min_capability(self) -> int:
  55. # The BitsandBytes kernel only supports Ampere or newer GPUs.
  56. return 75
  57. @staticmethod
  58. def get_config_filenames() -> List[str]:
  59. return [
  60. "quant_config.json",
  61. "quantize_config.json",
  62. ]
  63. @classmethod
  64. def from_config(cls, config: Dict[str, Any]) -> "BitsandBytesConfig":
  65. weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
  66. group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
  67. zero_point = cls.get_from_keys(config, ["zero_point"])
  68. try:
  69. from_float = cls.get_from_keys(config, ["from_float"])
  70. except Exception:
  71. from_float = False
  72. try:
  73. quant_mode = cls.get_from_keys(config, ["quant_mode"])
  74. except Exception:
  75. quant_mode = "weight_only"
  76. return cls(weight_bits, group_size, zero_point, from_float, quant_mode)
  77. def get_linear_method(self) -> "BNBLinearMethod":
  78. return BNBLinearMethod(self)
  79. def get_scaled_act_names(self) -> List[str]:
  80. return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
  81. class BNBLinearMethod(LinearMethodBase):
  82. """Linear method for BitsandBytes.
  83. Args:
  84. quant_config: The BitsandBytes quantization config.
  85. """
  86. def __init__(self, quant_config: BitsandBytesConfig):
  87. self.quant_config = quant_config
  88. def create_weights(self, layer: torch.nn.Module,
  89. input_size_per_partition: int,
  90. output_partition_sizes: List[int], input_size: int,
  91. output_size: int, params_dtype: torch.dtype,
  92. **extra_weight_attrs):
  93. if self.quant_config.quant_mode == "weight_only" and \
  94. input_size_per_partition % self.quant_config.group_size != 0:
  95. raise ValueError(
  96. "The input size is not aligned with the quantized "
  97. "weight shape. This can be caused by too large "
  98. "tensor parallel size.")
  99. output_size_per_partition = sum(output_partition_sizes)
  100. if self.quant_config.quant_mode == "weight_only" and \
  101. output_size_per_partition % self.quant_config.pack_factor != 0:
  102. raise ValueError(
  103. "The output size is not aligned with the quantized "
  104. "weight shape. This can be caused by too large "
  105. "tensor parallel size.")
  106. if self.quant_config.quant_mode == "weight_only" and \
  107. not self.quant_config.from_float:
  108. qweight = Parameter(
  109. torch.empty(
  110. input_size_per_partition,
  111. output_size_per_partition // self.quant_config.pack_factor,
  112. dtype=torch.int32,
  113. ),
  114. requires_grad=False,
  115. )
  116. set_weight_attrs(
  117. qweight, {
  118. "input_dim": 0,
  119. "output_dim": 1,
  120. "packed_dim": 1,
  121. "pack_factor": self.quant_config.pack_factor,
  122. })
  123. qzeros = Parameter(
  124. torch.empty(
  125. input_size_per_partition // self.quant_config.group_size,
  126. output_size_per_partition // self.quant_config.pack_factor,
  127. dtype=torch.int32,
  128. ),
  129. requires_grad=False,
  130. )
  131. set_weight_attrs(
  132. qzeros, {
  133. "input_dim": 0,
  134. "output_dim": 1,
  135. "packed_dim": 1,
  136. "pack_factor": self.quant_config.pack_factor,
  137. })
  138. scales = Parameter(
  139. torch.empty(
  140. input_size_per_partition // self.quant_config.group_size,
  141. output_size_per_partition,
  142. dtype=params_dtype,
  143. ),
  144. requires_grad=False,
  145. )
  146. set_weight_attrs(scales, {
  147. "input_dim": 0,
  148. "output_dim": 1,
  149. })
  150. layer.register_parameter("qweight", qweight)
  151. set_weight_attrs(qweight, extra_weight_attrs)
  152. layer.register_parameter("qzeros", qzeros)
  153. set_weight_attrs(qzeros, extra_weight_attrs)
  154. layer.register_parameter("scales", scales)
  155. set_weight_attrs(scales, extra_weight_attrs)
  156. else:
  157. weight = Parameter(torch.empty(output_size_per_partition,
  158. input_size_per_partition,
  159. dtype=params_dtype),
  160. requires_grad=False)
  161. set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
  162. layer.register_parameter("weight", weight)
  163. set_weight_attrs(weight, extra_weight_attrs)
  164. def apply_weights(self,
  165. layer: torch.nn.Module,
  166. x: torch.Tensor,
  167. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  168. if self.quant_config.quant_mode == "weight_only":
  169. qweight = layer.qweight
  170. scales_zeros = layer.scales_zeros
  171. pack_factor = self.quant_config.pack_factor
  172. out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
  173. reshaped_x = x.reshape(-1, x.shape[-1])
  174. out = ops.autoquant_s4_f16_gemm(reshaped_x, qweight, scales_zeros)
  175. if bias is not None:
  176. out = out + bias
  177. return out.reshape(out_shape)
  178. else:
  179. weight = layer.weight
  180. state = layer.state
  181. if weight.CB is not None:
  182. state.CB = weight.CB
  183. state.SCB = weight.SCB
  184. weight.CB = None
  185. weight.SCB = None
  186. import bitsandbytes as bnb
  187. out = bnb.matmul(x, weight, bias=bias, state=state)
  188. if not state.has_fp16_weights and \
  189. state.CB is not None and state.CxB is not None:
  190. # we converted 8-bit row major to turing/ampere format
  191. # in the first inference pass
  192. # we no longer need the row-major weight
  193. del state.CB
  194. weight.data = state.CxB
  195. return out
  196. T = TypeVar("T", bound="torch.nn.Module")
  197. class QParams(NamedTuple):
  198. """A class to hold the quantization parameters."""
  199. scales: torch.Tensor
  200. zero_points: Optional[torch.Tensor]
  201. @torch.no_grad()
  202. def cal_qparams_per_group_minmax(w: torch.Tensor,
  203. n_bits: int = 4,
  204. group_size: int = 128):
  205. """Calculate quantization parameters for each group using min and max
  206. values."""
  207. outc, inc = w.shape
  208. assert inc >= group_size, \
  209. 'Input channels should be greater than or equal to group_size.'
  210. assert inc % group_size == 0, \
  211. 'Input channels should be divisible by group_size.'
  212. w_group_wise = w.reshape(outc, -1, group_size)
  213. w_min = w_group_wise.min(dim=-1, keepdim=True)[0]
  214. w_max = w_group_wise.max(dim=-1, keepdim=True)[0]
  215. q_max = 2**n_bits - 1
  216. q_min = 0
  217. scales = (w_max - w_min)
  218. scales = scales.clamp_(min=1e-5).div_(q_max)
  219. # zero_points = (-w_min / scales).round().clamp(q_min, q_max)
  220. zero_points = (-torch.round(w_min / scales)).clamp_(q_min, q_max)
  221. return QParams(scales=scales, zero_points=zero_points)
  222. def convert_s4(qw: torch.Tensor,
  223. qz: torch.Tensor,
  224. s: torch.Tensor,
  225. group_size: int = 128):
  226. assert qw.is_contiguous()
  227. assert qz.is_contiguous()
  228. assert s.is_contiguous()
  229. _qw = torch.zeros_like(qw)
  230. _sz = torch.zeros_like(s, dtype=torch.int32) # half2
  231. _ws = torch.zeros_like(s)
  232. ops.autoquant_convert_s4_k_m8(_qw, _sz, _ws, qw, s, qz,
  233. qw.size(-1) * 8, qw.size(0), group_size)
  234. return _qw, _sz
  235. def tp_m_s4(x: torch.Tensor, tp: int = 1):
  236. return x.view(x.size(0) // 32, tp, -1, 128).permute(0, 2, 3,
  237. 1).contiguous()
  238. def quant(weight: torch.Tensor,
  239. qparams: Optional[QParams] = None) -> torch.Tensor:
  240. """Perform fake quantization on the given weight tensor.
  241. Args:
  242. weight (torch.Tensor): The weight tensor with shape
  243. (out_features, in_features).
  244. qparams (Optional[QParams]): A namedtuple containing 'scales'
  245. and 'zero_points'.
  246. Returns:
  247. torch.Tensor: The fake quantized weight tensor.
  248. """
  249. if qparams is None:
  250. qparams = cal_qparams_per_group_minmax(weight)
  251. scales = qparams.scales
  252. zero_points = qparams.zero_points
  253. out_c, in_c = weight.shape
  254. # Reshape the weights if using per_group quantization
  255. # per tensor scales shape: [1]
  256. # per channel scales shape: [out_c, 1]
  257. # per group scales shape: [out_c, in_c//group_size, 1]
  258. if len(scales.shape) > 2:
  259. # scales shape: [out_c, in_c//group_size, 1]
  260. weight = weight.reshape(out_c, scales.shape[1], -1)
  261. if zero_points is None:
  262. real_qweight = (weight / scales).round()
  263. else:
  264. real_qweight = ((weight + (scales * zero_points)) / scales).round()
  265. if len(scales.shape) > 2:
  266. real_qweight = real_qweight.reshape(out_c, in_c)
  267. return real_qweight.to(torch.int32)
  268. # core quantization method (simulated quantization)
  269. def quantize_tensor(
  270. weight,
  271. n_bits=4,
  272. group_size=128,
  273. ):
  274. pack_num = 32 // n_bits
  275. pack_order = [0, 2, 4, 6, 1, 3, 5, 7]
  276. org_weight_shape = weight.shape
  277. out_features = org_weight_shape[0]
  278. in_features = org_weight_shape[1]
  279. qparams = cal_qparams_per_group_minmax(weight, n_bits)
  280. i32_w = quant(weight, qparams)
  281. i32_w = i32_w.t().contiguous()
  282. w_pack_oc = out_features // (32 // n_bits)
  283. w_inc = in_features
  284. pack_int_w = torch.zeros((w_inc, w_pack_oc),
  285. dtype=torch.int32,
  286. device=weight.device)
  287. for col in range(pack_int_w.shape[1]):
  288. for i in range(pack_num):
  289. pack_int_w_col = i32_w[:, col * pack_num + pack_order[i]]
  290. pack_int_w[:, col] |= pack_int_w_col << (i * n_bits)
  291. qweight = pack_int_w
  292. scales = qparams.scales.squeeze(-1).t().contiguous()
  293. if qparams.zero_points is not None:
  294. zeros = qparams.zero_points.to(torch.int32)
  295. zeros = zeros.squeeze(-1).t().contiguous()
  296. z_inc = in_features // group_size
  297. z_oc = out_features // (32 // n_bits)
  298. pack_int_zeros = torch.zeros((z_inc, z_oc),
  299. dtype=torch.int32,
  300. device=weight.device)
  301. for col in range(pack_int_zeros.shape[1]):
  302. for i in range(pack_num):
  303. qzero_col = zeros[:, col * pack_num + pack_order[i]]
  304. pack_int_zeros[:, col] |= qzero_col << (i * n_bits)
  305. qzeros = pack_int_zeros
  306. return qweight, scales, qzeros
  307. def replace_quant_params(model,
  308. quant_config,
  309. modules_to_not_convert="lm_head"):
  310. """
  311. modules_to_not_convert (`str`, *optional*, defaults to `lm_head`):
  312. Name of the module to not convert in `Linear8bitLt`.
  313. In practice we keep the `lm_head` in full precision
  314. for numerical stability reasons.
  315. """
  316. if not isinstance(modules_to_not_convert, list):
  317. modules_to_not_convert = [modules_to_not_convert]
  318. for name, module in model.named_children():
  319. if len(list(module.children())) > 0:
  320. replace_quant_params(module, quant_config, modules_to_not_convert)
  321. if isinstance(
  322. module,
  323. (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear)) \
  324. and name not in modules_to_not_convert:
  325. if quant_config.from_float:
  326. module.linear_weights.pop("weight")
  327. param = module._parameters["weight"]
  328. if quant_config.quant_mode in ("llm_int8", "smoothquant"):
  329. import bitsandbytes as bnb
  330. new_value = bnb.nn.Int8Params(param.data,
  331. requires_grad=False,
  332. has_fp16_weights=False)
  333. state = bnb.MatmulLtState()
  334. if quant_config.quant_mode == "smoothquant":
  335. state.threshold = 0.0
  336. else:
  337. state.threshold = 6.0
  338. state.has_fp16_weights = False
  339. state.memory_efficient_backward = False
  340. state.use_pool = True
  341. module._parameters["weight"] = new_value
  342. module.linear_weights["weight"] = new_value
  343. module.linear_weights["state"] = state
  344. set_weight_attrs(
  345. new_value, {
  346. "input_dim": 0,
  347. "output_dim": 1,
  348. "packed_dim": 1,
  349. "pack_factor": quant_config.pack_factor,
  350. })
  351. del param
  352. torch.cuda.empty_cache()
  353. elif quant_config.quant_mode == "weight_only":
  354. data_fp = param.cuda()
  355. _qweight, _scales, _qzeros = quantize_tensor(
  356. data_fp, n_bits=4, group_size=128)
  357. qweight, scales_zeros = convert_s4(_qweight, _qzeros,
  358. _scales)
  359. torch.cuda.synchronize()
  360. param_qweight = Parameter(qweight, requires_grad=False)
  361. param_scales_zeros = Parameter(scales_zeros,
  362. requires_grad=False)
  363. module.register_parameter("qweight", param_qweight)
  364. module.register_parameter("scales_zeros",
  365. param_scales_zeros)
  366. set_weight_attrs(
  367. param_qweight, {
  368. "input_dim": 0,
  369. "output_dim": 1,
  370. "packed_dim": 1,
  371. "pack_factor": quant_config.pack_factor,
  372. })
  373. set_weight_attrs(param_scales_zeros, {
  374. "input_dim": 0,
  375. "output_dim": 1,
  376. })
  377. module.linear_weights["qweight"] = param_qweight
  378. module.linear_weights["scales_zeros"] = param_scales_zeros
  379. del _qzeros
  380. del _scales
  381. del param
  382. delattr(module, "weight")
  383. torch.cuda.empty_cache()
  384. else: # load packed int4 weight
  385. module.linear_weights.pop("qweight")
  386. module.linear_weights.pop("qzeros")
  387. module.linear_weights.pop("scales")
  388. _qweight = module._parameters["qweight"]
  389. _qzeros = module._parameters["qzeros"]
  390. _scales = module._parameters["scales"]
  391. qweight, scales_zeros = convert_s4(_qweight.data, _qzeros.data,
  392. _scales.data)
  393. param_qweight = Parameter(qweight, requires_grad=False)
  394. param_scales_zeros = Parameter(scales_zeros,
  395. requires_grad=False)
  396. del _qweight
  397. del _qzeros
  398. del _scales
  399. delattr(module, "qweight")
  400. delattr(module, "qzeros")
  401. delattr(module, "scales")
  402. module.register_parameter("qweight", param_qweight)
  403. module.register_parameter("scales_zeros", param_scales_zeros)
  404. set_weight_attrs(
  405. param_qweight, {
  406. "input_dim": 0,
  407. "output_dim": 1,
  408. "packed_dim": 1,
  409. "pack_factor": quant_config.pack_factor,
  410. })
  411. set_weight_attrs(param_scales_zeros, {
  412. "input_dim": 0,
  413. "output_dim": 1,
  414. })
  415. module.linear_weights["qweight"] = param_qweight
  416. module.linear_weights["scales_zeros"] = param_scales_zeros
  417. torch.cuda.synchronize()
  418. torch.cuda.empty_cache()