1
0

aqlm.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. # Supports AQLM compression, see https://github.com/Vahe1994/AQLM
  2. # and https://arxiv.org/pdf/2401.06118.pdf
  3. import math
  4. from typing import Any, Dict, List, Optional
  5. import torch
  6. import torch.nn.functional as F
  7. from torch.nn.parameter import Parameter
  8. from aphrodite import _custom_ops as ops
  9. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  10. from aphrodite.modeling.utils import set_weight_attrs
  11. from aphrodite.quantization.base_config import QuantizationConfig
  12. def get_int_dtype(nbits: int) -> torch.dtype:
  13. if nbits <= 8:
  14. return torch.int8
  15. if nbits <= 16:
  16. return torch.int16
  17. if nbits <= 32:
  18. return torch.int32
  19. if nbits <= 64:
  20. return torch.int64
  21. raise ValueError(f"No dtype available for {nbits}-bit codebooks")
  22. @torch.inference_mode()
  23. def unpack_int_data(data: torch.IntTensor, nbits: int) -> torch.IntTensor:
  24. return data.to(torch.int64) % (2**nbits)
  25. def dequantize_weight(codes: torch.Tensor,
  26. codebooks: torch.Tensor,
  27. scales: Optional[torch.Tensor] = None) -> torch.Tensor:
  28. """
  29. Decode float weights from quantization codes. Differentiable.
  30. :param codes: tensor of integer quantization codes, shape
  31. [*dims, num_out_groups, num_in_groups, num_codebooks]
  32. :param codebooks: tensor of vectors for each quantization code,
  33. [num_codebooks, codebook_size, out_group_size, in_group_size]
  34. :param scales: weight will be multiplied by this factor, must be
  35. broadcastble with
  36. [*dims, out_groups, num_in_groups, out_group_size, in_group_size]
  37. :return: reconstructed weight tensor of shape
  38. [*dims, num_in_groups*group_size]
  39. """
  40. num_out_groups, num_in_groups, num_codebooks = codes.shape[-3:]
  41. num_codebooks, codebook_size, out_group_size, in_group_size = \
  42. codebooks.shape
  43. out_features = num_out_groups * out_group_size
  44. in_features = num_in_groups * in_group_size
  45. codebook_offsets = torch.arange(
  46. 0, num_codebooks * codebook_size, codebook_size,
  47. device=codes.device) # shape: [num_codebooks]
  48. reconstructed_weight_flat = F.embedding_bag(
  49. codes.flatten(0, -2) + codebook_offsets,
  50. codebooks.flatten(0, 1).flatten(-2, -1),
  51. mode="sum"
  52. ) # [prod(dims) * num_out_groups * num_in_groups, out_group_size
  53. # * in_group_size]
  54. reconstructed_weight_groupwise = reconstructed_weight_flat.view(
  55. list(codes.shape[:-3]) +
  56. [num_out_groups, num_in_groups, out_group_size, in_group_size])
  57. if scales is not None:
  58. reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul(
  59. scales)
  60. return reconstructed_weight_groupwise.swapaxes(
  61. -3, -2).reshape(list(codes.shape[:-3]) + [out_features, in_features])
  62. def dequantize_gemm(
  63. input: torch.Tensor, # [..., in_features]
  64. codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
  65. codebooks: torch.
  66. Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
  67. scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
  68. bias: Optional[torch.Tensor],
  69. ) -> torch.Tensor:
  70. dequantized_weight = dequantize_weight(
  71. unpack_int_data(codes, codebooks.shape[1].bit_length() - 1),
  72. codebooks,
  73. scales,
  74. )
  75. return F.linear(input, dequantized_weight, bias)
  76. # Generic dequantization, slow but flexible.
  77. def generic_dequantize_gemm(
  78. input: torch.Tensor, # [..., in_features]
  79. codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
  80. codebooks: torch.
  81. Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
  82. scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
  83. output_partition_sizes: List[int],
  84. bias: Optional[torch.Tensor],
  85. ) -> torch.Tensor:
  86. output_shape = input.shape[:-1] + (scales.shape[0], )
  87. output = torch.empty(output_shape, dtype=input.dtype, device=input.device)
  88. num_outputs = len(output_partition_sizes)
  89. # break the inputs and codebooks apart then combine the outputs.
  90. # Surprisingly (to me) this is faster than doing 3 de-quants and 1 big
  91. # multiply at the end.
  92. num_codebooks = codebooks.shape[0] // num_outputs
  93. assert (scales.shape[0] == codes.shape[0])
  94. assert (sum(output_partition_sizes) == scales.shape[0])
  95. output_offset = 0
  96. codebooks_offset = 0
  97. for output_size in output_partition_sizes:
  98. shard_output = dequantize_gemm(
  99. input, codes.narrow(0, output_offset, output_size),
  100. codebooks.narrow(0, codebooks_offset, num_codebooks),
  101. scales.narrow(0, output_offset, output_size), None
  102. if bias is None else bias.narrow(0, output_offset, output_size))
  103. output_slice = output.narrow(-1, output_offset, output_size)
  104. assert (output_slice.shape == shard_output.shape)
  105. output_slice.copy_(shard_output)
  106. output_offset += output_size
  107. codebooks_offset += num_codebooks
  108. return output
  109. # Optimized dequnantize/decompression kernels, supports 1x16 and 2x8
  110. # at 6 and 9 times faster than the generic version above, respectively.
  111. def optimized_dequantize_gemm(
  112. input: torch.Tensor, # [..., in_features]
  113. codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
  114. codebooks: torch.
  115. Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
  116. scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
  117. output_partition_sizes: List[int],
  118. bias: Optional[torch.Tensor],
  119. ) -> torch.Tensor:
  120. weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
  121. if bias is None:
  122. # scaling the output is fastest, so we do that when possible.
  123. output = F.linear(input, weights, bias)
  124. orig_shape = output.shape
  125. flattened_output = output.view(-1, output.size(-1))
  126. f_scales = scales.view(-1, scales.shape[0])
  127. b_scales = f_scales.expand(flattened_output.shape[0], -1)
  128. flattened_output *= b_scales
  129. return output.view(orig_shape)
  130. else:
  131. b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
  132. -1, weights.shape[1])
  133. weights *= b_scales
  134. return F.linear(input, weights, bias)
  135. class AQLMConfig(QuantizationConfig):
  136. """Config class for AQLM.
  137. Reference: https://github.com/Vahe1994/AQLM
  138. """
  139. def __init__(
  140. self,
  141. in_group_size: int,
  142. nbits_per_codebook: int,
  143. num_codebooks: int,
  144. out_group_size: int,
  145. ) -> None:
  146. self.in_group_size = in_group_size
  147. self.nbits_per_codebook = nbits_per_codebook
  148. self.num_codebooks = num_codebooks
  149. self.out_group_size = out_group_size
  150. # out_group_size > 1 is untested, and probably won't work as-is.
  151. assert (self.out_group_size == 1)
  152. self.pack_factor = (self.in_group_size * self.out_group_size)
  153. def __repr__(self) -> str:
  154. return (f"AQLMConfig(in_group_size={self.in_group_size}, "
  155. f"nbits_per_codebook={self.nbits_per_codebook}, "
  156. f"num_codebooks={self.num_codebooks}, "
  157. f"out_group_size={self.out_group_size})")
  158. @classmethod
  159. def get_name(cls) -> str:
  160. return "aqlm"
  161. @classmethod
  162. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  163. return [torch.half]
  164. @classmethod
  165. def get_min_capability(cls) -> int:
  166. return 60
  167. @classmethod
  168. def get_config_filenames(cls) -> List[str]:
  169. return [] # no extra configs.
  170. @classmethod
  171. def from_config(cls, config: Dict[str, Any]) -> "AQLMConfig":
  172. in_group_size = cls.get_from_keys(config, ["in_group_size"])
  173. nbits_per_codebook = cls.get_from_keys(config, ["nbits_per_codebook"])
  174. num_code_books = cls.get_from_keys(config, ["num_codebooks"])
  175. out_group_size = cls.get_from_keys(config, ["out_group_size"])
  176. return cls(in_group_size, nbits_per_codebook, num_code_books,
  177. out_group_size)
  178. def get_quant_method(self, layer: torch.nn.Module,
  179. prefix: str) -> Optional["AQLMLinearMethod"]:
  180. if isinstance(layer, LinearBase):
  181. return AQLMLinearMethod(self)
  182. return None
  183. def get_scaled_act_names(self) -> List[str]:
  184. return []
  185. class AQLMLinearMethod(LinearMethodBase):
  186. """Linear method for AQLM.
  187. Args:
  188. quant_config: The AQLM quantization config.
  189. """
  190. def __init__(self, quant_config: AQLMConfig):
  191. self.quant_config = quant_config
  192. def create_weights(self, layer: torch.nn.Module,
  193. input_size_per_partition: int,
  194. output_partition_sizes: List[int], input_size: int,
  195. output_size: int, params_dtype: torch.dtype,
  196. **extra_weight_attrs):
  197. del output_size # Unused.
  198. del input_size # Unused.
  199. if params_dtype != torch.half:
  200. raise ValueError("Only half is currently supported by aqlm")
  201. if input_size_per_partition % self.quant_config.in_group_size != 0:
  202. raise ValueError(
  203. "The input size is not aligned with the quantized "
  204. "weight shape. This can be caused by too large "
  205. "tensor parallel size.")
  206. output_size_per_partition = sum(output_partition_sizes)
  207. if output_size_per_partition % self.quant_config.out_group_size != 0:
  208. raise ValueError(
  209. "The output size is not aligned with the quantized "
  210. "weight shape. This can be caused by too large "
  211. "tensor parallel size.")
  212. codes = Parameter(
  213. torch.empty(
  214. # There could actually be two pack factors, one along input and
  215. # one along output, but we don't currently support
  216. # out_group_size, and only the one along output needs to be
  217. # marked with "packed_dim" in order for QKVLinear to work.
  218. output_size_per_partition,
  219. input_size_per_partition // self.quant_config.pack_factor,
  220. self.quant_config.num_codebooks,
  221. dtype=get_int_dtype(self.quant_config.nbits_per_codebook),
  222. ),
  223. requires_grad=False,
  224. )
  225. set_weight_attrs(
  226. codes,
  227. {
  228. "input_dim": 1,
  229. "output_dim": 0,
  230. "packed_dim": 1,
  231. "pack_factor": self.quant_config.pack_factor,
  232. },
  233. )
  234. codebooks = Parameter(
  235. torch.empty(
  236. self.quant_config.num_codebooks * len(output_partition_sizes),
  237. 2**self.quant_config.nbits_per_codebook,
  238. self.quant_config.out_group_size,
  239. self.quant_config.in_group_size,
  240. dtype=params_dtype,
  241. ),
  242. requires_grad=False,
  243. )
  244. set_weight_attrs(
  245. codebooks,
  246. {
  247. # metadata indicates fixed size concatenated along dim 0
  248. "is_metadata": True,
  249. "output_partition_sizes": output_partition_sizes
  250. },
  251. )
  252. scales = Parameter(
  253. torch.empty(
  254. (
  255. output_size_per_partition //
  256. self.quant_config.out_group_size,
  257. 1,
  258. 1,
  259. 1,
  260. ),
  261. dtype=params_dtype,
  262. ),
  263. requires_grad=False,
  264. )
  265. set_weight_attrs(
  266. scales,
  267. {
  268. "output_dim": 0,
  269. "packed_dim": 0,
  270. "pack_factor": self.quant_config.out_group_size
  271. },
  272. )
  273. layer.register_parameter("codes", codes)
  274. set_weight_attrs(codes, extra_weight_attrs)
  275. layer.register_parameter("codebooks", codebooks)
  276. set_weight_attrs(codebooks, extra_weight_attrs)
  277. layer.register_parameter("scales", scales)
  278. set_weight_attrs(scales, extra_weight_attrs)
  279. def apply(
  280. self,
  281. layer: torch.nn.Module,
  282. x: torch.Tensor,
  283. bias: Optional[torch.Tensor] = None,
  284. ) -> torch.Tensor:
  285. codebooks = layer.codebooks
  286. codes = layer.codes
  287. scales = layer.scales
  288. output_partition_sizes = getattr(codebooks, "output_partition_sizes",
  289. [])
  290. nbooks = codes.shape[2]
  291. ingroups = codebooks.shape[3]
  292. outgroups = codebooks.shape[2]
  293. bits = codebooks.shape[1]
  294. # We support these formats with dedicated gemm and decompression
  295. # kernels.
  296. if ingroups == 8 and outgroups == 1 and (
  297. (bits == 256 and nbooks == 2) or (bits == 65536 and nbooks == 1)):
  298. # thresholds determined by timings on an A6000, one GPU
  299. use_gemv = math.prod(x.shape[:-1]) <= 6
  300. return ops.aqlm_gemm(
  301. x,
  302. codes,
  303. codebooks,
  304. scales,
  305. output_partition_sizes,
  306. bias,
  307. ) if use_gemv else optimized_dequantize_gemm(
  308. x,
  309. codes,
  310. codebooks,
  311. scales,
  312. output_partition_sizes,
  313. bias,
  314. )
  315. # fall back all unoptimized formats
  316. return generic_dequantize_gemm(
  317. x,
  318. codes,
  319. codebooks,
  320. scales,
  321. output_partition_sizes,
  322. bias,
  323. )