aqlm.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. # Supports AQLM compression, see https://github.com/Vahe1994/AQLM
  2. # and https://arxiv.org/pdf/2401.06118.pdf
  3. import math
  4. from contextlib import suppress
  5. from typing import Any, Dict, List, Optional
  6. import torch
  7. import torch.nn.functional as F
  8. from torch.nn.parameter import Parameter
  9. from aphrodite.modeling.layers.linear import LinearMethodBase, set_weight_attrs
  10. from aphrodite.quantization.base_config import \
  11. QuantizationConfig
  12. HAS_QUANTS = False
  13. with suppress(ImportError):
  14. from aphrodite._quant_C import quant_ops as ops
  15. HAS_QUANTS = True
  16. def get_int_dtype(nbits: int) -> torch.dtype:
  17. if nbits <= 8:
  18. return torch.int8
  19. if nbits <= 16:
  20. return torch.int16
  21. if nbits <= 32:
  22. return torch.int32
  23. if nbits <= 64:
  24. return torch.int64
  25. raise ValueError(f"No dtype available for {nbits}-bit codebooks")
  26. @torch.inference_mode()
  27. def unpack_int_data(data: torch.IntTensor, nbits: int) -> torch.IntTensor:
  28. return data.to(torch.int64) % (2**nbits)
  29. def dequantize_weight(codes: torch.Tensor,
  30. codebooks: torch.Tensor,
  31. scales: Optional[torch.Tensor] = None) -> torch.Tensor:
  32. """
  33. Decode float weights from quantization codes. Differentiable.
  34. :param codes: tensor of integer quantization codes, shape
  35. [*dims, num_out_groups, num_in_groups, num_codebooks]
  36. :param codebooks: tensor of vectors for each quantization code,
  37. [num_codebooks, codebook_size, out_group_size, in_group_size]
  38. :param scales: weight will be multiplied by this factor, must be
  39. broadcastble with
  40. [*dims, out_groups, num_in_groups, out_group_size, in_group_size]
  41. :return: reconstructed weight tensor of shape
  42. [*dims, num_in_groups*group_size]
  43. """
  44. num_out_groups, num_in_groups, num_codebooks = codes.shape[-3:]
  45. num_codebooks, codebook_size, out_group_size, in_group_size = \
  46. codebooks.shape
  47. out_features = num_out_groups * out_group_size
  48. in_features = num_in_groups * in_group_size
  49. codebook_offsets = torch.arange(
  50. 0, num_codebooks * codebook_size, codebook_size,
  51. device=codes.device) # shape: [num_codebooks]
  52. reconstructed_weight_flat = F.embedding_bag(
  53. codes.flatten(0, -2) + codebook_offsets,
  54. codebooks.flatten(0, 1).flatten(-2, -1),
  55. mode="sum"
  56. ) # [prod(dims) * num_out_groups * num_in_groups, out_group_size
  57. # * in_group_size]
  58. reconstructed_weight_groupwise = reconstructed_weight_flat.view(
  59. list(codes.shape[:-3]) +
  60. [num_out_groups, num_in_groups, out_group_size, in_group_size])
  61. if scales is not None:
  62. reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul(
  63. scales)
  64. return reconstructed_weight_groupwise.swapaxes(
  65. -3, -2).reshape(list(codes.shape[:-3]) + [out_features, in_features])
  66. def dequantize_gemm(
  67. input: torch.Tensor, # [..., in_features]
  68. codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
  69. codebooks: torch.
  70. Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
  71. scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
  72. bias: Optional[torch.Tensor],
  73. ) -> torch.Tensor:
  74. dequantized_weight = dequantize_weight(
  75. unpack_int_data(codes, codebooks.shape[1].bit_length() - 1),
  76. codebooks,
  77. scales,
  78. )
  79. return F.linear(input, dequantized_weight, bias)
  80. # Generic dequantization, slow but flexible.
  81. def generic_dequantize_gemm(
  82. input: torch.Tensor, # [..., in_features]
  83. codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
  84. codebooks: torch.
  85. Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
  86. scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
  87. output_partition_sizes: torch.IntTensor,
  88. bias: Optional[torch.Tensor],
  89. ) -> torch.Tensor:
  90. output_shape = input.shape[:-1] + (scales.shape[0], )
  91. output = torch.empty(output_shape, dtype=input.dtype, device=input.device)
  92. num_outputs = len(output_partition_sizes)
  93. # break the inputs and codebooks apart then combine the outputs.
  94. # Surprisingly (to me) this is faster than doing 3 de-quants and 1 big
  95. # multiply at the end.
  96. num_codebooks = codebooks.shape[0] // num_outputs
  97. assert (scales.shape[0] == codes.shape[0])
  98. assert (sum(output_partition_sizes) == scales.shape[0])
  99. output_offset = 0
  100. codebooks_offset = 0
  101. for output_size in output_partition_sizes:
  102. shard_output = dequantize_gemm(
  103. input, codes.narrow(0, output_offset, output_size),
  104. codebooks.narrow(0, codebooks_offset, num_codebooks),
  105. scales.narrow(0, output_offset, output_size), None
  106. if bias is None else bias.narrow(0, output_offset, output_size))
  107. output_slice = output.narrow(-1, output_offset, output_size)
  108. assert (output_slice.shape == shard_output.shape)
  109. output_slice.copy_(shard_output)
  110. output_offset += output_size
  111. codebooks_offset += num_codebooks
  112. return output
  113. # Optimized dequnantize/decompression kernels, supports 1x16 and 2x8
  114. # at 6 and 9 times faster than the generic version above, respectively.
  115. def optimized_dequantize_gemm(
  116. input: torch.Tensor, # [..., in_features]
  117. codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
  118. codebooks: torch.
  119. Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
  120. scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
  121. output_partition_sizes: torch.IntTensor,
  122. bias: Optional[torch.Tensor],
  123. ) -> torch.Tensor:
  124. weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
  125. if bias is None:
  126. # scaling the output is fastest, so we do that when possible.
  127. output = F.linear(input, weights, bias)
  128. orig_shape = output.shape
  129. flattened_output = output.view(-1, output.size(-1))
  130. f_scales = scales.view(-1, scales.shape[0])
  131. b_scales = f_scales.expand(flattened_output.shape[0], -1)
  132. flattened_output *= b_scales
  133. return output.view(orig_shape)
  134. else:
  135. b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
  136. -1, weights.shape[1])
  137. weights *= b_scales
  138. return F.linear(input, weights, bias)
  139. class AQLMConfig(QuantizationConfig):
  140. """Config class for AQLM.
  141. Reference: https://github.com/Vahe1994/AQLM
  142. """
  143. def __init__(
  144. self,
  145. in_group_size: int,
  146. nbits_per_codebook: int,
  147. num_codebooks: int,
  148. out_group_size: int,
  149. ) -> None:
  150. self.in_group_size = in_group_size
  151. self.nbits_per_codebook = nbits_per_codebook
  152. self.num_codebooks = num_codebooks
  153. self.out_group_size = out_group_size
  154. # out_group_size > 1 is untested, and probably won't work as-is.
  155. assert (self.out_group_size == 1)
  156. self.pack_factor = (self.in_group_size * self.out_group_size)
  157. def __repr__(self) -> str:
  158. return (f"AQLMConfig(in_group_size={self.in_group_size}, "
  159. f"nbits_per_codebook={self.nbits_per_codebook}, "
  160. f"num_codebooks={self.num_codebooks}, "
  161. f"out_group_size={self.out_group_size})")
  162. @classmethod
  163. def get_name(cls) -> str:
  164. return "aqlm"
  165. @classmethod
  166. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  167. return [torch.half]
  168. @classmethod
  169. def get_min_capability(cls) -> int:
  170. return 70
  171. @classmethod
  172. def get_config_filenames(cls) -> List[str]:
  173. return [] # no extra configs.
  174. @classmethod
  175. def from_config(cls, config: Dict[str, Any]) -> "AQLMConfig":
  176. in_group_size = cls.get_from_keys(config, ["in_group_size"])
  177. nbits_per_codebook = cls.get_from_keys(config, ["nbits_per_codebook"])
  178. num_code_books = cls.get_from_keys(config, ["num_codebooks"])
  179. out_group_size = cls.get_from_keys(config, ["out_group_size"])
  180. return cls(in_group_size, nbits_per_codebook, num_code_books,
  181. out_group_size)
  182. def get_linear_method(self) -> "AQLMLinearMethod":
  183. return AQLMLinearMethod(self)
  184. def get_scaled_act_names(self) -> List[str]:
  185. return []
  186. class AQLMLinearMethod(LinearMethodBase):
  187. """Linear method for AQLM.
  188. Args:
  189. quant_config: The AQLM quantization config.
  190. """
  191. def __init__(self, quant_config: AQLMConfig):
  192. self.quant_config = quant_config
  193. def create_weights(self, layer: torch.nn.Module,
  194. input_size_per_partition: int,
  195. output_partition_sizes: List[int], input_size: int,
  196. output_size: int, params_dtype: torch.dtype,
  197. **extra_weight_attrs):
  198. del output_size # Unused.
  199. del input_size # Unused.
  200. if params_dtype != torch.half:
  201. raise ValueError("Only half is currently supported by aqlm")
  202. if input_size_per_partition % self.quant_config.in_group_size != 0:
  203. raise ValueError(
  204. "The input size is not aligned with the quantized "
  205. "weight shape. This can be caused by too large "
  206. "tensor parallel size.")
  207. output_size_per_partition = sum(output_partition_sizes)
  208. if output_size_per_partition % self.quant_config.out_group_size != 0:
  209. raise ValueError(
  210. "The output size is not aligned with the quantized "
  211. "weight shape. This can be caused by too large "
  212. "tensor parallel size.")
  213. codes = Parameter(
  214. torch.empty(
  215. # There could actually be two pack factors, one along input and
  216. # one along output, but we don't currently support
  217. # out_group_size, and only the one along output needs to be
  218. # marked with "packed_dim" in order for QKVLinear to work.
  219. output_size_per_partition,
  220. input_size_per_partition // self.quant_config.pack_factor,
  221. self.quant_config.num_codebooks,
  222. dtype=get_int_dtype(self.quant_config.nbits_per_codebook),
  223. ),
  224. requires_grad=False,
  225. )
  226. set_weight_attrs(
  227. codes,
  228. {
  229. "input_dim": 1,
  230. "output_dim": 0,
  231. "packed_dim": 1,
  232. "pack_factor": self.quant_config.pack_factor,
  233. },
  234. )
  235. codebooks = Parameter(
  236. torch.empty(
  237. self.quant_config.num_codebooks * len(output_partition_sizes),
  238. 2**self.quant_config.nbits_per_codebook,
  239. self.quant_config.out_group_size,
  240. self.quant_config.in_group_size,
  241. dtype=params_dtype,
  242. ),
  243. requires_grad=False,
  244. )
  245. set_weight_attrs(
  246. codebooks,
  247. {
  248. # metadata indicates fixed size concatenated along dim 0
  249. "is_metadata":
  250. True,
  251. "output_partition_sizes":
  252. torch.tensor(output_partition_sizes, device='cpu'),
  253. },
  254. )
  255. scales = Parameter(
  256. torch.empty(
  257. (
  258. output_size_per_partition //
  259. self.quant_config.out_group_size,
  260. 1,
  261. 1,
  262. 1,
  263. ),
  264. dtype=params_dtype,
  265. ),
  266. requires_grad=False,
  267. )
  268. set_weight_attrs(
  269. scales,
  270. {
  271. "output_dim": 0,
  272. "packed_dim": 0,
  273. "pack_factor": self.quant_config.out_group_size
  274. },
  275. )
  276. layer.register_parameter("codes", codes)
  277. set_weight_attrs(codes, extra_weight_attrs)
  278. layer.register_parameter("codebooks", codebooks)
  279. set_weight_attrs(codebooks, extra_weight_attrs)
  280. layer.register_parameter("scales", scales)
  281. set_weight_attrs(scales, extra_weight_attrs)
  282. def apply_weights(
  283. self,
  284. layer: torch.nn.Module,
  285. x: torch.Tensor,
  286. bias: Optional[torch.Tensor] = None,
  287. ) -> torch.Tensor:
  288. codebooks = layer.codebooks
  289. codes = layer.codes
  290. scales = layer.scales
  291. output_partition_sizes = getattr(codebooks, "output_partition_sizes",
  292. None)
  293. nbooks = codes.shape[2]
  294. ingroups = codebooks.shape[3]
  295. outgroups = codebooks.shape[2]
  296. bits = codebooks.shape[1]
  297. # We support these formats with dedicated gemm and decompression
  298. # kernels.
  299. if ingroups == 8 and outgroups == 1 and (
  300. (bits == 256 and nbooks == 2) or (bits == 65536 and nbooks == 1)):
  301. # thresholds determined by timings on an A6000, one GPU
  302. use_gemv = math.prod(x.shape[:-1]) <= 6
  303. return ops.aqlm_gemm(
  304. x,
  305. codes,
  306. codebooks,
  307. scales,
  308. output_partition_sizes,
  309. bias,
  310. ) if use_gemv else optimized_dequantize_gemm(
  311. x,
  312. codes,
  313. codebooks,
  314. scales,
  315. output_partition_sizes,
  316. bias,
  317. )
  318. # fall back all unoptimized formats
  319. return generic_dequantize_gemm(
  320. x,
  321. codes,
  322. codebooks,
  323. scales,
  324. output_partition_sizes,
  325. bias,
  326. )