aqlm.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. # Copyright (c) 2024 Neural Magic, All Rights Reserved.
  2. # AQLM quantization technique, as described in the paper:
  3. # https://arxiv.org/pdf/2401.06118.pdf
  4. from typing import Any, Dict, List, Optional
  5. import math
  6. import torch
  7. from torch.nn.parameter import Parameter
  8. import torch.nn.functional as F
  9. from aphrodite._C import ops
  10. from aphrodite.modeling.layers.linear import LinearMethodBase, set_weight_attrs
  11. from aphrodite.modeling.layers.quantization.base_config import (
  12. QuantizationConfig)
  13. def get_int_dtype(nbits: int) -> torch.dtype:
  14. if nbits <= 8:
  15. return torch.int8
  16. if nbits <= 16:
  17. return torch.int16
  18. if nbits <= 32:
  19. return torch.int32
  20. if nbits <= 64:
  21. return torch.int64
  22. raise ValueError(f"No dtype available for {nbits}-bit codebooks")
  23. @torch.inference_mode()
  24. def unpack_int_data(data: torch.IntTensor, nbits: int) -> torch.IntTensor:
  25. return data.to(torch.int64) % (2**nbits)
  26. def dequantize_weight(codes: torch.Tensor,
  27. codebooks: torch.Tensor,
  28. scales: Optional[torch.Tensor] = None) -> torch.Tensor:
  29. """
  30. Decode float weights from quantization codes. Differentiable.
  31. :param codes: tensor of integer quantization codes, shape [*dims,
  32. num_out_groups, num_in_groups, num_codebooks]
  33. :param codebooks: tensor of vectors for each quantization code,
  34. [num_codebooks, codebook_size, out_group_size, in_group_size]
  35. :param scales: weight will be multiplied by this factor, must be
  36. broadcastble with [*dims, out_groups, num_in_groups, out_group_size,
  37. in_group_size]
  38. :return: reconstructed weight tensor of shape [*dims,
  39. num_in_groups*group_size]
  40. """
  41. num_out_groups, num_in_groups, num_codebooks = codes.shape[-3:]
  42. (num_codebooks, codebook_size, out_group_size,
  43. in_group_size) = codebooks.shape
  44. out_features = num_out_groups * out_group_size
  45. in_features = num_in_groups * in_group_size
  46. codebook_offsets = torch.arange(
  47. 0, num_codebooks * codebook_size, codebook_size,
  48. device=codes.device) # shape: [num_codebooks]
  49. reconstructed_weight_flat = F.embedding_bag(
  50. codes.flatten(0, -2) + codebook_offsets,
  51. codebooks.flatten(0, 1).flatten(-2, -1),
  52. mode="sum"
  53. ) # [prod(dims) * num_out_groups * num_in_groups, out_group_size *
  54. # in_group_size]
  55. reconstructed_weight_groupwise = reconstructed_weight_flat.view(
  56. list(codes.shape[:-3]) +
  57. [num_out_groups, num_in_groups, out_group_size, in_group_size])
  58. if scales is not None:
  59. reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul(
  60. scales)
  61. return reconstructed_weight_groupwise.swapaxes(
  62. -3, -2).reshape(list(codes.shape[:-3]) + [out_features, in_features])
  63. def dequantize_gemm(
  64. input: torch.Tensor, # [..., in_features]
  65. codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
  66. codebooks: torch.
  67. Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
  68. scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
  69. bias: Optional[torch.Tensor],
  70. ) -> torch.Tensor:
  71. dequantized_weight = dequantize_weight(
  72. unpack_int_data(codes, codebooks.shape[1].bit_length() - 1),
  73. codebooks,
  74. scales,
  75. )
  76. return F.linear(input, dequantized_weight, bias)
  77. def dequantize_partioned_gemm(
  78. input: torch.Tensor, # [..., in_features]
  79. codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
  80. codebooks: torch.
  81. Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
  82. scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
  83. output_partition_sizes: torch.IntTensor,
  84. bias: Optional[torch.Tensor],
  85. ) -> torch.Tensor:
  86. output_shape = input.shape[:-1] + (scales.shape[0], )
  87. output = torch.empty(output_shape, dtype=input.dtype, device=input.device)
  88. num_outputs = len(output_partition_sizes)
  89. # break the inputs and codebooks apart then combine the outputs.
  90. # Surprisingly (to me) this is faster than doing 3 de-quants and 1 big
  91. # multiply at the end.
  92. num_codebooks = codebooks.shape[0] // num_outputs
  93. assert (scales.shape[0] == codes.shape[0])
  94. assert (sum(output_partition_sizes) == scales.shape[0])
  95. output_offset = 0
  96. codebooks_offset = 0
  97. for output_size in output_partition_sizes:
  98. shard_output = dequantize_gemm(
  99. input, codes.narrow(0, output_offset, output_size),
  100. codebooks.narrow(0, codebooks_offset, num_codebooks),
  101. scales.narrow(0, output_offset, output_size), None
  102. if bias is None else bias.narrow(0, output_offset, output_size))
  103. output_slice = output.narrow(-1, output_offset, output_size)
  104. assert (output_slice.shape == shard_output.shape)
  105. output_slice.copy_(shard_output)
  106. output_offset += output_size
  107. codebooks_offset += num_codebooks
  108. return output
  109. class AQLMConfig(QuantizationConfig):
  110. """Config class for AQLM.
  111. Reference: https://github.com/Vahe1994/AQLM
  112. """
  113. def __init__(
  114. self,
  115. in_group_size: int,
  116. nbits_per_codebook: int,
  117. num_codebooks: int,
  118. out_group_size: int,
  119. ) -> None:
  120. self.in_group_size = in_group_size
  121. self.nbits_per_codebook = nbits_per_codebook
  122. self.num_codebooks = num_codebooks
  123. self.out_group_size = out_group_size
  124. # out_group_size > 1 is untested, and probably won't work as-is.
  125. assert self.out_group_size == 1
  126. self.pack_factor = (self.in_group_size * self.out_group_size)
  127. def __repr__(self) -> str:
  128. return (f"AQLMConfig(in_group_size={self.in_group_size}, "
  129. f"nbits_per_codebook={self.nbits_per_codebook}, "
  130. f"num_codebooks={self.num_codebooks}, "
  131. f"out_group_size={self.out_group_size})")
  132. @classmethod
  133. def get_name(cls) -> str:
  134. return "aqlm"
  135. @classmethod
  136. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  137. return [torch.half]
  138. @classmethod
  139. def get_min_capability(cls) -> int:
  140. return 70
  141. @classmethod
  142. def get_config_filenames(cls) -> List[str]:
  143. return [] # no extra configs.
  144. @classmethod
  145. def from_config(cls, config: Dict[str, Any]) -> "AQLMConfig":
  146. in_group_size = cls.get_from_keys(config, ["in_group_size"])
  147. nbits_per_codebook = cls.get_from_keys(config, ["nbits_per_codebook"])
  148. num_code_books = cls.get_from_keys(config, ["num_codebooks"])
  149. out_group_size = cls.get_from_keys(config, ["out_group_size"])
  150. return cls(in_group_size, nbits_per_codebook, num_code_books,
  151. out_group_size)
  152. def get_linear_method(self) -> "AQLMLinearMethod":
  153. return AQLMLinearMethod(self)
  154. def get_scaled_act_names(self) -> List[str]:
  155. return []
  156. def merge_weight(self) -> bool:
  157. return True
  158. def rope_style(self) -> Optional[bool]:
  159. return None
  160. class AQLMLinearMethod(LinearMethodBase):
  161. """Linear method for AQLM.
  162. Args:
  163. quant_config: The AQLM quantization config.
  164. """
  165. def __init__(self, quant_config: AQLMConfig):
  166. self.quant_config = quant_config
  167. def create_weights(
  168. self,
  169. input_size_per_partition: int,
  170. output_partition_sizes: List[int],
  171. input_size: int,
  172. output_size: int,
  173. params_dtype: torch.dtype,
  174. ) -> Dict[str, Any]:
  175. del output_size # Unused.
  176. del input_size # Unused.
  177. if params_dtype != torch.half:
  178. raise ValueError("Only half is currently supported by aqlm")
  179. if input_size_per_partition % self.quant_config.in_group_size != 0:
  180. raise ValueError(
  181. "The input size is not aligned with the quantized "
  182. "weight shape. This can be caused by too large "
  183. "tensor parallel size.")
  184. output_size_per_partition = sum(output_partition_sizes)
  185. if output_size_per_partition % self.quant_config.out_group_size != 0:
  186. raise ValueError(
  187. "The output size is not aligned with the quantized "
  188. "weight shape. This can be caused by too large "
  189. "tensor parallel size.")
  190. codes = Parameter(
  191. torch.empty(
  192. # There could actually be two pack factors, one along input
  193. # and one along output,
  194. # but we don't currently support out_group_size,
  195. # and only the one along output needs to be marked with
  196. # "packed_dim".
  197. # in order for QKVLinear to work.
  198. output_size_per_partition,
  199. input_size_per_partition // self.quant_config.pack_factor,
  200. self.quant_config.num_codebooks,
  201. dtype=get_int_dtype(self.quant_config.nbits_per_codebook),
  202. ),
  203. requires_grad=False,
  204. )
  205. set_weight_attrs(
  206. codes,
  207. {
  208. "input_dim": 1,
  209. "output_dim": 0,
  210. "packed_dim": 1,
  211. "pack_factor": self.quant_config.pack_factor,
  212. },
  213. )
  214. codebooks = Parameter(
  215. torch.empty(
  216. self.quant_config.num_codebooks * len(output_partition_sizes),
  217. 2**self.quant_config.nbits_per_codebook,
  218. self.quant_config.out_group_size,
  219. self.quant_config.in_group_size,
  220. dtype=params_dtype,
  221. ),
  222. requires_grad=False,
  223. )
  224. set_weight_attrs(
  225. codebooks,
  226. {
  227. # metadata indicates fixed size concatenated along dim 0
  228. "is_metadata":
  229. True,
  230. "output_partition_sizes":
  231. torch.tensor(output_partition_sizes, device='cpu'),
  232. },
  233. )
  234. scales = Parameter(
  235. torch.empty(
  236. (
  237. output_size_per_partition //
  238. self.quant_config.out_group_size,
  239. 1,
  240. 1,
  241. 1,
  242. ),
  243. dtype=params_dtype,
  244. ),
  245. requires_grad=False,
  246. )
  247. set_weight_attrs(
  248. scales,
  249. {
  250. "output_dim": 0,
  251. "packed_dim": 0,
  252. "pack_factor": self.quant_config.out_group_size
  253. },
  254. )
  255. return {
  256. "codes": codes,
  257. "codebooks": codebooks,
  258. "scales": scales,
  259. }
  260. def apply_weights(
  261. self,
  262. weights: Dict[str, Any],
  263. x: torch.Tensor,
  264. bias: Optional[torch.Tensor] = None,
  265. ) -> torch.Tensor:
  266. codebooks = weights["codebooks"]
  267. codes = weights["codes"]
  268. scales = weights["scales"]
  269. output_partition_sizes = getattr(codebooks, "output_partition_sizes",
  270. None)
  271. use_gemv = math.prod(
  272. x.shape[:-1]) <= 32 or output_partition_sizes is None
  273. output = ops.aqlm_gemm(
  274. x,
  275. codes,
  276. codebooks,
  277. scales,
  278. output_partition_sizes,
  279. bias,
  280. ) if use_gemv else dequantize_partioned_gemm(
  281. x,
  282. codes,
  283. codebooks,
  284. scales,
  285. output_partition_sizes,
  286. bias,
  287. )
  288. return output