autoquant.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. from typing import Any, Dict, List, NamedTuple, Optional, TypeVar
  2. import torch
  3. from torch.nn.parameter import Parameter
  4. from aphrodite import _custom_ops as ops
  5. from aphrodite.modeling.layers.linear import (ColumnParallelLinear, LinearBase,
  6. LinearMethodBase,
  7. QKVParallelLinear,
  8. RowParallelLinear)
  9. from aphrodite.modeling.utils import set_weight_attrs
  10. from aphrodite.quantization.base_config import QuantizationConfig
  11. class AutoQuantConfig(QuantizationConfig):
  12. """Config class for AutoQuant.
  13. Reference: https://arxiv.org/abs/2208.07339
  14. """
  15. def __init__(
  16. self,
  17. weight_bits: int,
  18. group_size: int,
  19. zero_point: bool,
  20. from_float: bool,
  21. quant_mode: str, # llm_int8, smoothquant, weight_only
  22. ) -> None:
  23. self.weight_bits = weight_bits
  24. self.group_size = group_size
  25. self.zero_point = zero_point
  26. self.from_float = from_float
  27. self.quant_mode = quant_mode
  28. if quant_mode == "weight_only" and self.weight_bits != 4:
  29. raise ValueError(
  30. "Currently, only 4-bit weight quantization is supported for "
  31. f"AutoQuant weight_only, but got {self.weight_bits} bits.")
  32. if quant_mode in ["llm_int8", "smoothquant"] and self.weight_bits != 8:
  33. raise ValueError(
  34. "Currently, only 8-bit weight quantization is supported for "
  35. "AutoQuant llm_int8 or smoothquant, "
  36. f"but got {self.weight_bits} bits.")
  37. self.pack_factor = 32 // self.weight_bits
  38. def __repr__(self) -> str:
  39. return (f"AutoQuantConfig(weight_bits={self.weight_bits}, "
  40. f"group_size={self.group_size}, "
  41. f"zero_point={self.zero_point}, "
  42. f"from_float={self.from_float}, "
  43. f"quant_mode={self.quant_mode})")
  44. def get_name(self) -> str:
  45. return "autoquant"
  46. def get_supported_act_dtypes(self) -> List[torch.dtype]:
  47. return [torch.half, torch.bfloat16]
  48. def get_min_capability(self) -> int:
  49. # The AutoQuant kernel only supports Ampere or newer GPUs.
  50. return 75
  51. @staticmethod
  52. def get_config_filenames() -> List[str]:
  53. return [
  54. "quant_config.json",
  55. "quantize_config.json",
  56. ]
  57. @classmethod
  58. def from_config(cls, config: Dict[str, Any]) -> "AutoQuantConfig":
  59. weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
  60. group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
  61. zero_point = cls.get_from_keys(config, ["zero_point"])
  62. try:
  63. from_float = cls.get_from_keys(config, ["from_float"])
  64. except Exception:
  65. from_float = False
  66. try:
  67. quant_mode = cls.get_from_keys(config, ["quant_mode"])
  68. except Exception:
  69. quant_mode = "weight_only"
  70. return cls(weight_bits, group_size, zero_point, from_float, quant_mode)
  71. def get_quant_method(self, layer: torch.nn.Module,
  72. prefix: str) -> Optional["AutoQuantLinearMethod"]:
  73. if isinstance(layer, LinearBase):
  74. return AutoQuantLinearMethod(self)
  75. return None
  76. def get_scaled_act_names(self) -> List[str]:
  77. return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
  78. class AutoQuantLinearMethod(LinearMethodBase):
  79. """Linear method for AutoQuant.
  80. Args:
  81. quant_config: The AutoQuant quantization config.
  82. """
  83. def __init__(self, quant_config: AutoQuantConfig):
  84. self.quant_config = quant_config
  85. def create_weights(self, layer: torch.nn.Module,
  86. input_size_per_partition: int,
  87. output_partition_sizes: List[int], input_size: int,
  88. output_size: int, params_dtype: torch.dtype,
  89. **extra_weight_attrs):
  90. if self.quant_config.quant_mode == "weight_only" and \
  91. input_size_per_partition % self.quant_config.group_size != 0:
  92. raise ValueError(
  93. "The input size is not aligned with the quantized "
  94. "weight shape. This can be caused by too large "
  95. "tensor parallel size.")
  96. output_size_per_partition = sum(output_partition_sizes)
  97. if self.quant_config.quant_mode == "weight_only" and \
  98. output_size_per_partition % self.quant_config.pack_factor != 0:
  99. raise ValueError(
  100. "The output size is not aligned with the quantized "
  101. "weight shape. This can be caused by too large "
  102. "tensor parallel size.")
  103. if self.quant_config.quant_mode == "weight_only" and \
  104. not self.quant_config.from_float:
  105. qweight = Parameter(
  106. torch.empty(
  107. input_size_per_partition,
  108. output_size_per_partition // self.quant_config.pack_factor,
  109. dtype=torch.int32,
  110. ),
  111. requires_grad=False,
  112. )
  113. set_weight_attrs(
  114. qweight, {
  115. "input_dim": 0,
  116. "output_dim": 1,
  117. "packed_dim": 1,
  118. "pack_factor": self.quant_config.pack_factor,
  119. })
  120. qzeros = Parameter(
  121. torch.empty(
  122. input_size_per_partition // self.quant_config.group_size,
  123. output_size_per_partition // self.quant_config.pack_factor,
  124. dtype=torch.int32,
  125. ),
  126. requires_grad=False,
  127. )
  128. set_weight_attrs(
  129. qzeros, {
  130. "input_dim": 0,
  131. "output_dim": 1,
  132. "packed_dim": 1,
  133. "pack_factor": self.quant_config.pack_factor,
  134. })
  135. scales = Parameter(
  136. torch.empty(
  137. input_size_per_partition // self.quant_config.group_size,
  138. output_size_per_partition,
  139. dtype=params_dtype,
  140. ),
  141. requires_grad=False,
  142. )
  143. set_weight_attrs(scales, {
  144. "input_dim": 0,
  145. "output_dim": 1,
  146. })
  147. layer.register_parameter("qweight", qweight)
  148. set_weight_attrs(qweight, extra_weight_attrs)
  149. layer.register_parameter("qzeros", qzeros)
  150. set_weight_attrs(qzeros, extra_weight_attrs)
  151. layer.register_parameter("scales", scales)
  152. set_weight_attrs(scales, extra_weight_attrs)
  153. else:
  154. weight = Parameter(torch.empty(output_size_per_partition,
  155. input_size_per_partition,
  156. dtype=params_dtype),
  157. requires_grad=False)
  158. set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
  159. layer.register_parameter("weight", weight)
  160. set_weight_attrs(weight, extra_weight_attrs)
  161. def apply(self,
  162. layer: torch.nn.Module,
  163. x: torch.Tensor,
  164. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  165. if self.quant_config.quant_mode == "weight_only":
  166. qweight = layer.qweight
  167. scales_zeros = layer.scales_zeros
  168. pack_factor = self.quant_config.pack_factor
  169. out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
  170. reshaped_x = x.reshape(-1, x.shape[-1])
  171. out = ops.autoquant_s4_f16_gemm(reshaped_x, qweight, scales_zeros)
  172. if bias is not None:
  173. out = out + bias
  174. return out.reshape(out_shape)
  175. else:
  176. weight = layer.weight
  177. state = layer.state
  178. if weight.CB is not None:
  179. state.CB = weight.CB
  180. state.SCB = weight.SCB
  181. weight.CB = None
  182. weight.SCB = None
  183. import bitsandbytes as bnb
  184. out = bnb.matmul(x, weight, bias=bias, state=state)
  185. if not state.has_fp16_weights and \
  186. state.CB is not None and state.CxB is not None:
  187. # we converted 8-bit row major to turing/ampere format
  188. # in the first inference pass
  189. # we no longer need the row-major weight
  190. del state.CB
  191. weight.data = state.CxB
  192. return out
  193. T = TypeVar("T", bound="torch.nn.Module")
  194. class QParams(NamedTuple):
  195. """A class to hold the quantization parameters."""
  196. scales: torch.Tensor
  197. zero_points: Optional[torch.Tensor]
  198. @torch.no_grad()
  199. def cal_qparams_per_group_minmax(w: torch.Tensor,
  200. n_bits: int = 4,
  201. group_size: int = 128):
  202. """Calculate quantization parameters for each group using min and max
  203. values."""
  204. outc, inc = w.shape
  205. assert inc >= group_size, \
  206. 'Input channels should be greater than or equal to group_size.'
  207. assert inc % group_size == 0, \
  208. 'Input channels should be divisible by group_size.'
  209. w_group_wise = w.reshape(outc, -1, group_size)
  210. w_min = w_group_wise.min(dim=-1, keepdim=True)[0]
  211. w_max = w_group_wise.max(dim=-1, keepdim=True)[0]
  212. q_max = 2**n_bits - 1
  213. q_min = 0
  214. scales = (w_max - w_min)
  215. scales = scales.clamp_(min=1e-5).div_(q_max)
  216. # zero_points = (-w_min / scales).round().clamp(q_min, q_max)
  217. zero_points = (-torch.round(w_min / scales)).clamp_(q_min, q_max)
  218. return QParams(scales=scales, zero_points=zero_points)
  219. def convert_s4(qw: torch.Tensor,
  220. qz: torch.Tensor,
  221. s: torch.Tensor,
  222. group_size: int = 128):
  223. assert qw.is_contiguous()
  224. assert qz.is_contiguous()
  225. assert s.is_contiguous()
  226. _qw = torch.zeros_like(qw)
  227. _sz = torch.zeros_like(s, dtype=torch.int32) # half2
  228. _ws = torch.zeros_like(s)
  229. ops.autoquant_convert_s4_k_m8(_qw, _sz, _ws, qw, s, qz,
  230. qw.size(-1) * 8, qw.size(0), group_size)
  231. return _qw, _sz
  232. def tp_m_s4(x: torch.Tensor, tp: int = 1):
  233. return x.view(x.size(0) // 32, tp, -1, 128).permute(0, 2, 3,
  234. 1).contiguous()
  235. def quant(weight: torch.Tensor,
  236. qparams: Optional[QParams] = None) -> torch.Tensor:
  237. """Perform fake quantization on the given weight tensor.
  238. Args:
  239. weight (torch.Tensor): The weight tensor with shape
  240. (out_features, in_features).
  241. qparams (Optional[QParams]): A namedtuple containing 'scales'
  242. and 'zero_points'.
  243. Returns:
  244. torch.Tensor: The fake quantized weight tensor.
  245. """
  246. if qparams is None:
  247. qparams = cal_qparams_per_group_minmax(weight)
  248. scales = qparams.scales
  249. zero_points = qparams.zero_points
  250. out_c, in_c = weight.shape
  251. # Reshape the weights if using per_group quantization
  252. # per tensor scales shape: [1]
  253. # per channel scales shape: [out_c, 1]
  254. # per group scales shape: [out_c, in_c//group_size, 1]
  255. if len(scales.shape) > 2:
  256. # scales shape: [out_c, in_c//group_size, 1]
  257. weight = weight.reshape(out_c, scales.shape[1], -1)
  258. if zero_points is None:
  259. real_qweight = (weight / scales).round()
  260. else:
  261. real_qweight = ((weight + (scales * zero_points)) / scales).round()
  262. if len(scales.shape) > 2:
  263. real_qweight = real_qweight.reshape(out_c, in_c)
  264. return real_qweight.to(torch.int32)
  265. # core quantization method (simulated quantization)
  266. def quantize_tensor(
  267. weight,
  268. n_bits=4,
  269. group_size=128,
  270. ):
  271. pack_num = 32 // n_bits
  272. pack_order = [0, 2, 4, 6, 1, 3, 5, 7]
  273. org_weight_shape = weight.shape
  274. out_features = org_weight_shape[0]
  275. in_features = org_weight_shape[1]
  276. qparams = cal_qparams_per_group_minmax(weight, n_bits)
  277. i32_w = quant(weight, qparams)
  278. i32_w = i32_w.t().contiguous()
  279. w_pack_oc = out_features // (32 // n_bits)
  280. w_inc = in_features
  281. pack_int_w = torch.zeros((w_inc, w_pack_oc),
  282. dtype=torch.int32,
  283. device=weight.device)
  284. for col in range(pack_int_w.shape[1]):
  285. for i in range(pack_num):
  286. pack_int_w_col = i32_w[:, col * pack_num + pack_order[i]]
  287. pack_int_w[:, col] |= pack_int_w_col << (i * n_bits)
  288. qweight = pack_int_w
  289. scales = qparams.scales.squeeze(-1).t().contiguous()
  290. if qparams.zero_points is not None:
  291. zeros = qparams.zero_points.to(torch.int32)
  292. zeros = zeros.squeeze(-1).t().contiguous()
  293. z_inc = in_features // group_size
  294. z_oc = out_features // (32 // n_bits)
  295. pack_int_zeros = torch.zeros((z_inc, z_oc),
  296. dtype=torch.int32,
  297. device=weight.device)
  298. for col in range(pack_int_zeros.shape[1]):
  299. for i in range(pack_num):
  300. qzero_col = zeros[:, col * pack_num + pack_order[i]]
  301. pack_int_zeros[:, col] |= qzero_col << (i * n_bits)
  302. qzeros = pack_int_zeros
  303. return qweight, scales, qzeros
  304. def replace_quant_params(model,
  305. quant_config,
  306. modules_to_not_convert="lm_head"):
  307. """
  308. modules_to_not_convert (`str`, *optional*, defaults to `lm_head`):
  309. Name of the module to not convert in `Linear8bitLt`.
  310. In practice we keep the `lm_head` in full precision
  311. for numerical stability reasons.
  312. """
  313. if not isinstance(modules_to_not_convert, list):
  314. modules_to_not_convert = [modules_to_not_convert]
  315. for name, module in model.named_children():
  316. if len(list(module.children())) > 0:
  317. replace_quant_params(module, quant_config, modules_to_not_convert)
  318. if isinstance(
  319. module,
  320. (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear)) \
  321. and name not in modules_to_not_convert:
  322. if quant_config.from_float:
  323. module.linear_weights.pop("weight")
  324. param = module._parameters["weight"]
  325. if quant_config.quant_mode in ("llm_int8", "smoothquant"):
  326. import bitsandbytes as bnb
  327. new_value = bnb.nn.Int8Params(param.data,
  328. requires_grad=False,
  329. has_fp16_weights=False)
  330. state = bnb.MatmulLtState()
  331. if quant_config.quant_mode == "smoothquant":
  332. state.threshold = 0.0
  333. else:
  334. state.threshold = 6.0
  335. state.has_fp16_weights = False
  336. state.memory_efficient_backward = False
  337. state.use_pool = True
  338. module._parameters["weight"] = new_value
  339. module.linear_weights["weight"] = new_value
  340. module.linear_weights["state"] = state
  341. set_weight_attrs(
  342. new_value, {
  343. "input_dim": 0,
  344. "output_dim": 1,
  345. "packed_dim": 1,
  346. "pack_factor": quant_config.pack_factor,
  347. })
  348. del param
  349. torch.cuda.empty_cache()
  350. elif quant_config.quant_mode == "weight_only":
  351. data_fp = param.cuda()
  352. _qweight, _scales, _qzeros = quantize_tensor(
  353. data_fp, n_bits=4, group_size=128)
  354. qweight, scales_zeros = convert_s4(_qweight, _qzeros,
  355. _scales)
  356. torch.cuda.synchronize()
  357. param_qweight = Parameter(qweight, requires_grad=False)
  358. param_scales_zeros = Parameter(scales_zeros,
  359. requires_grad=False)
  360. module.register_parameter("qweight", param_qweight)
  361. module.register_parameter("scales_zeros",
  362. param_scales_zeros)
  363. set_weight_attrs(
  364. param_qweight, {
  365. "input_dim": 0,
  366. "output_dim": 1,
  367. "packed_dim": 1,
  368. "pack_factor": quant_config.pack_factor,
  369. })
  370. set_weight_attrs(param_scales_zeros, {
  371. "input_dim": 0,
  372. "output_dim": 1,
  373. })
  374. module.linear_weights["qweight"] = param_qweight
  375. module.linear_weights["scales_zeros"] = param_scales_zeros
  376. del _qzeros
  377. del _scales
  378. del param
  379. delattr(module, "weight")
  380. torch.cuda.empty_cache()
  381. else: # load packed int4 weight
  382. module.linear_weights.pop("qweight")
  383. module.linear_weights.pop("qzeros")
  384. module.linear_weights.pop("scales")
  385. _qweight = module._parameters["qweight"]
  386. _qzeros = module._parameters["qzeros"]
  387. _scales = module._parameters["scales"]
  388. qweight, scales_zeros = convert_s4(_qweight.data, _qzeros.data,
  389. _scales.data)
  390. param_qweight = Parameter(qweight, requires_grad=False)
  391. param_scales_zeros = Parameter(scales_zeros,
  392. requires_grad=False)
  393. del _qweight
  394. del _qzeros
  395. del _scales
  396. delattr(module, "qweight")
  397. delattr(module, "qzeros")
  398. delattr(module, "scales")
  399. module.register_parameter("qweight", param_qweight)
  400. module.register_parameter("scales_zeros", param_scales_zeros)
  401. set_weight_attrs(
  402. param_qweight, {
  403. "input_dim": 0,
  404. "output_dim": 1,
  405. "packed_dim": 1,
  406. "pack_factor": quant_config.pack_factor,
  407. })
  408. set_weight_attrs(param_scales_zeros, {
  409. "input_dim": 0,
  410. "output_dim": 1,
  411. })
  412. module.linear_weights["qweight"] = param_qweight
  413. module.linear_weights["scales_zeros"] = param_scales_zeros
  414. torch.cuda.synchronize()
  415. torch.cuda.empty_cache()