aqlm.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. # Copyright (c) 2024 Neural Magic, All Rights Reserved.
  2. # AQLM quantization technique, as described in the paper:
  3. # https://arxiv.org/pdf/2401.06118.pdf
  4. from typing import Any, Dict, List, Optional
  5. from contextlib import suppress
  6. import math
  7. import torch
  8. from torch.nn.parameter import Parameter
  9. import torch.nn.functional as F
  10. from aphrodite.modeling.layers.linear import LinearMethodBase, set_weight_attrs
  11. from aphrodite.quantization.base_config import (QuantizationConfig)
  12. HAS_QUANTS = False
  13. with suppress(ImportError):
  14. from aphrodite._quant_C import quant_ops as ops
  15. HAS_QUANTS = True
  16. def get_int_dtype(nbits: int) -> torch.dtype:
  17. if nbits <= 8:
  18. return torch.int8
  19. if nbits <= 16:
  20. return torch.int16
  21. if nbits <= 32:
  22. return torch.int32
  23. if nbits <= 64:
  24. return torch.int64
  25. raise ValueError(f"No dtype available for {nbits}-bit codebooks")
  26. @torch.inference_mode()
  27. def unpack_int_data(data: torch.IntTensor, nbits: int) -> torch.IntTensor:
  28. return data.to(torch.int64) % (2**nbits)
  29. def dequantize_weight(codes: torch.Tensor,
  30. codebooks: torch.Tensor,
  31. scales: Optional[torch.Tensor] = None) -> torch.Tensor:
  32. """
  33. Decode float weights from quantization codes. Differentiable.
  34. :param codes: tensor of integer quantization codes, shape [*dims,
  35. num_out_groups, num_in_groups, num_codebooks]
  36. :param codebooks: tensor of vectors for each quantization code,
  37. [num_codebooks, codebook_size, out_group_size, in_group_size]
  38. :param scales: weight will be multiplied by this factor, must be
  39. broadcastble with [*dims, out_groups, num_in_groups, out_group_size,
  40. in_group_size]
  41. :return: reconstructed weight tensor of shape [*dims,
  42. num_in_groups*group_size]
  43. """
  44. num_out_groups, num_in_groups, num_codebooks = codes.shape[-3:]
  45. (num_codebooks, codebook_size, out_group_size,
  46. in_group_size) = codebooks.shape
  47. out_features = num_out_groups * out_group_size
  48. in_features = num_in_groups * in_group_size
  49. codebook_offsets = torch.arange(
  50. 0, num_codebooks * codebook_size, codebook_size,
  51. device=codes.device) # shape: [num_codebooks]
  52. reconstructed_weight_flat = F.embedding_bag(
  53. codes.flatten(0, -2) + codebook_offsets,
  54. codebooks.flatten(0, 1).flatten(-2, -1),
  55. mode="sum"
  56. ) # [prod(dims) * num_out_groups * num_in_groups, out_group_size *
  57. # in_group_size]
  58. reconstructed_weight_groupwise = reconstructed_weight_flat.view(
  59. list(codes.shape[:-3]) +
  60. [num_out_groups, num_in_groups, out_group_size, in_group_size])
  61. if scales is not None:
  62. reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul(
  63. scales)
  64. return reconstructed_weight_groupwise.swapaxes(
  65. -3, -2).reshape(list(codes.shape[:-3]) + [out_features, in_features])
  66. def dequantize_gemm(
  67. input: torch.Tensor, # [..., in_features]
  68. codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
  69. codebooks: torch.
  70. Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
  71. scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
  72. bias: Optional[torch.Tensor],
  73. ) -> torch.Tensor:
  74. dequantized_weight = dequantize_weight(
  75. unpack_int_data(codes, codebooks.shape[1].bit_length() - 1),
  76. codebooks,
  77. scales,
  78. )
  79. return F.linear(input, dequantized_weight, bias)
  80. def dequantize_partioned_gemm(
  81. input: torch.Tensor, # [..., in_features]
  82. codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
  83. codebooks: torch.
  84. Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
  85. scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
  86. output_partition_sizes: torch.IntTensor,
  87. bias: Optional[torch.Tensor],
  88. ) -> torch.Tensor:
  89. output_shape = input.shape[:-1] + (scales.shape[0], )
  90. output = torch.empty(output_shape, dtype=input.dtype, device=input.device)
  91. num_outputs = len(output_partition_sizes)
  92. # break the inputs and codebooks apart then combine the outputs.
  93. # Surprisingly (to me) this is faster than doing 3 de-quants and 1 big
  94. # multiply at the end.
  95. num_codebooks = codebooks.shape[0] // num_outputs
  96. assert (scales.shape[0] == codes.shape[0])
  97. assert (sum(output_partition_sizes) == scales.shape[0])
  98. output_offset = 0
  99. codebooks_offset = 0
  100. for output_size in output_partition_sizes:
  101. shard_output = dequantize_gemm(
  102. input, codes.narrow(0, output_offset, output_size),
  103. codebooks.narrow(0, codebooks_offset, num_codebooks),
  104. scales.narrow(0, output_offset, output_size), None
  105. if bias is None else bias.narrow(0, output_offset, output_size))
  106. output_slice = output.narrow(-1, output_offset, output_size)
  107. assert (output_slice.shape == shard_output.shape)
  108. output_slice.copy_(shard_output)
  109. output_offset += output_size
  110. codebooks_offset += num_codebooks
  111. return output
  112. class AQLMConfig(QuantizationConfig):
  113. """Config class for AQLM.
  114. Reference: https://github.com/Vahe1994/AQLM
  115. """
  116. def __init__(
  117. self,
  118. in_group_size: int,
  119. nbits_per_codebook: int,
  120. num_codebooks: int,
  121. out_group_size: int,
  122. ) -> None:
  123. if not HAS_QUANTS:
  124. raise ImportError("Could not find the quantization kernels.")
  125. self.in_group_size = in_group_size
  126. self.nbits_per_codebook = nbits_per_codebook
  127. self.num_codebooks = num_codebooks
  128. self.out_group_size = out_group_size
  129. # out_group_size > 1 is untested, and probably won't work as-is.
  130. assert self.out_group_size == 1
  131. self.pack_factor = (self.in_group_size * self.out_group_size)
  132. def __repr__(self) -> str:
  133. return (f"AQLMConfig(in_group_size={self.in_group_size}, "
  134. f"nbits_per_codebook={self.nbits_per_codebook}, "
  135. f"num_codebooks={self.num_codebooks}, "
  136. f"out_group_size={self.out_group_size})")
  137. @classmethod
  138. def get_name(cls) -> str:
  139. return "aqlm"
  140. @classmethod
  141. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  142. return [torch.half]
  143. @classmethod
  144. def get_min_capability(cls) -> int:
  145. return 70
  146. @classmethod
  147. def get_config_filenames(cls) -> List[str]:
  148. return [] # no extra configs.
  149. @classmethod
  150. def from_config(cls, config: Dict[str, Any]) -> "AQLMConfig":
  151. in_group_size = cls.get_from_keys(config, ["in_group_size"])
  152. nbits_per_codebook = cls.get_from_keys(config, ["nbits_per_codebook"])
  153. num_code_books = cls.get_from_keys(config, ["num_codebooks"])
  154. out_group_size = cls.get_from_keys(config, ["out_group_size"])
  155. return cls(in_group_size, nbits_per_codebook, num_code_books,
  156. out_group_size)
  157. def get_linear_method(self) -> "AQLMLinearMethod":
  158. return AQLMLinearMethod(self)
  159. def get_scaled_act_names(self) -> List[str]:
  160. return []
  161. def merge_weight(self) -> bool:
  162. return True
  163. def rope_style(self) -> Optional[bool]:
  164. return None
  165. def quant_vocab(self) -> List[bool]:
  166. return [False, False]
  167. def support_fused_moe(self) -> bool:
  168. return False
  169. class AQLMLinearMethod(LinearMethodBase):
  170. """Linear method for AQLM.
  171. Args:
  172. quant_config: The AQLM quantization config.
  173. """
  174. def __init__(self, quant_config: AQLMConfig):
  175. self.quant_config = quant_config
  176. def create_weights(
  177. self,
  178. layer: torch.nn.Module,
  179. input_size_per_partition: int,
  180. output_partition_sizes: List[int],
  181. input_size: int,
  182. output_size: int,
  183. params_dtype: torch.dtype,
  184. **extra_weight_attrs,
  185. ):
  186. del output_size # Unused.
  187. del input_size # Unused.
  188. if params_dtype != torch.half:
  189. raise ValueError("Only half is currently supported by aqlm")
  190. if input_size_per_partition % self.quant_config.in_group_size != 0:
  191. raise ValueError(
  192. "The input size is not aligned with the quantized "
  193. "weight shape. This can be caused by too large "
  194. "tensor parallel size.")
  195. output_size_per_partition = sum(output_partition_sizes)
  196. if output_size_per_partition % self.quant_config.out_group_size != 0:
  197. raise ValueError(
  198. "The output size is not aligned with the quantized "
  199. "weight shape. This can be caused by too large "
  200. "tensor parallel size.")
  201. codes = Parameter(
  202. torch.empty(
  203. # There could actually be two pack factors, one along input
  204. # and one along output,
  205. # but we don't currently support out_group_size,
  206. # and only the one along output needs to be marked with
  207. # "packed_dim".
  208. # in order for QKVLinear to work.
  209. output_size_per_partition,
  210. input_size_per_partition // self.quant_config.pack_factor,
  211. self.quant_config.num_codebooks,
  212. dtype=get_int_dtype(self.quant_config.nbits_per_codebook),
  213. ),
  214. requires_grad=False,
  215. )
  216. set_weight_attrs(
  217. codes,
  218. {
  219. "input_dim": 1,
  220. "output_dim": 0,
  221. "packed_dim": 1,
  222. "pack_factor": self.quant_config.pack_factor,
  223. },
  224. )
  225. codebooks = Parameter(
  226. torch.empty(
  227. self.quant_config.num_codebooks * len(output_partition_sizes),
  228. 2**self.quant_config.nbits_per_codebook,
  229. self.quant_config.out_group_size,
  230. self.quant_config.in_group_size,
  231. dtype=params_dtype,
  232. ),
  233. requires_grad=False,
  234. )
  235. set_weight_attrs(
  236. codebooks,
  237. {
  238. # metadata indicates fixed size concatenated along dim 0
  239. "is_metadata":
  240. True,
  241. "output_partition_sizes":
  242. torch.tensor(output_partition_sizes, device='cpu'),
  243. },
  244. )
  245. scales = Parameter(
  246. torch.empty(
  247. (
  248. output_size_per_partition //
  249. self.quant_config.out_group_size,
  250. 1,
  251. 1,
  252. 1,
  253. ),
  254. dtype=params_dtype,
  255. ),
  256. requires_grad=False,
  257. )
  258. set_weight_attrs(
  259. scales,
  260. {
  261. "output_dim": 0,
  262. "packed_dim": 0,
  263. "pack_factor": self.quant_config.out_group_size
  264. },
  265. )
  266. layer.register_parameter("codes", codes)
  267. set_weight_attrs(codes, extra_weight_attrs)
  268. layer.register_parameter("codebooks", codebooks)
  269. set_weight_attrs(codebooks, extra_weight_attrs)
  270. layer.register_parameter("scales", scales)
  271. set_weight_attrs(scales, extra_weight_attrs)
  272. def apply_weights(
  273. self,
  274. layer: torch.nn.Module,
  275. x: torch.Tensor,
  276. bias: Optional[torch.Tensor] = None,
  277. ) -> torch.Tensor:
  278. codebooks = layer.codebooks
  279. codes = layer.codes
  280. scales = layer.scales
  281. output_partition_sizes = getattr(codebooks, "output_partition_sizes",
  282. None)
  283. use_gemv = math.prod(
  284. x.shape[:-1]) <= 32 or output_partition_sizes is None
  285. output = ops.aqlm_gemm(
  286. x,
  287. codes,
  288. codebooks,
  289. scales,
  290. output_partition_sizes,
  291. bias,
  292. ) if use_gemv else dequantize_partioned_gemm(
  293. x,
  294. codes,
  295. codebooks,
  296. scales,
  297. output_partition_sizes,
  298. bias,
  299. )
  300. return output
  301. def apply_moe_weights(self, w1: Dict[str,
  302. torch.Tensor], w2: Dict[str,
  303. torch.Tensor],
  304. x: torch.Tensor, gating_output: torch.Tensor,
  305. topk: int, renormalize: bool) -> torch.Tensor:
  306. raise NotImplementedError