quip.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. from typing import Any, Dict, List, Optional
  2. from contextlib import suppress
  3. import torch
  4. from torch.nn.parameter import Parameter
  5. from aphrodite.modeling.layers.linear import (LinearMethodBase,
  6. set_weight_attrs)
  7. from aphrodite.quantization.base_config import (QuantizationConfig)
  8. from aphrodite.quantization.quip_utils import (
  9. get_packed_abs_grid,
  10. get_hadK,
  11. matmul_hadUt_cuda,
  12. matmul_hadU_cuda,
  13. )
  14. HAS_QUANTS = False
  15. with suppress(ImportError):
  16. from aphrodite._quant_C import quant_ops as ops
  17. HAS_QUANTS = True
  18. class QuipConfig(QuantizationConfig):
  19. """Config class for Quip.
  20. Reference: https://cornell-relaxml.github.io/quip-sharp/
  21. """
  22. def __init__(self, codebook: int, use_rand: bool) -> None:
  23. if not HAS_QUANTS:
  24. raise ImportError("Could not find the quantization kernels.")
  25. self.codebook = codebook
  26. self.use_rand = use_rand
  27. if self.codebook != "E8P12":
  28. raise ValueError("Currently, only E8P12 is supported for "
  29. f"Quip, but got {self.codebook}.")
  30. def __repr__(self) -> str:
  31. return (f"QuipConfig(codebook={self.codebook}, "
  32. f"rescale_WH={self.rescale_WH})")
  33. @classmethod
  34. def get_name(cls) -> str:
  35. return "quip"
  36. @classmethod
  37. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  38. return [torch.half]
  39. @classmethod
  40. def get_min_capability(cls) -> int:
  41. return 80
  42. @classmethod
  43. def get_config_filenames(cls) -> List[str]:
  44. return ["quantization_config.json"]
  45. @classmethod
  46. def from_config(cls, config: Dict[str, Any]) -> "QuipConfig":
  47. codebook = cls.get_from_keys(config, ["codebook"])
  48. use_rand = cls.get_from_keys(config, ["use_rand"])
  49. return cls(codebook, use_rand)
  50. def get_linear_method(self) -> "QuipLinearMethod":
  51. return QuipLinearMethod(self)
  52. def get_scaled_act_names(self) -> List[str]:
  53. return []
  54. def merge_weight(self) -> bool:
  55. return False
  56. def quant_vocab(self) -> List[bool]:
  57. return [False, False]
  58. def support_fused_moe(self) -> bool:
  59. return False
  60. def rope_style(self) -> Optional[bool]:
  61. return None
  62. class QuipLinearMethod(LinearMethodBase):
  63. """Linear method for Quip.
  64. Args:
  65. quant_config: The Quip quantization config.
  66. """
  67. def __init__(self, quant_config: QuipConfig):
  68. self.quant_config = quant_config
  69. self.grid_packed_abs = get_packed_abs_grid().to(device="cuda")
  70. self.pack = 8
  71. self.idx_dtype = torch.int16
  72. def create_weights(
  73. self,
  74. input_size_per_partition: int,
  75. output_partition_sizes: List[int],
  76. input_size: int,
  77. output_size: int,
  78. params_dtype: torch.dtype,
  79. ) -> Dict[str, Any]:
  80. output_size_per_partition = sum(output_partition_sizes)
  81. if (input_size != input_size_per_partition
  82. or output_size != output_size_per_partition):
  83. raise ValueError(
  84. "Currently Quip doesn't support tensor parallel yet")
  85. had_left, K_left, q_in_features = get_hadK(input_size,
  86. self.quant_config.use_rand)
  87. had_right, K_right, q_out_features = get_hadK(
  88. output_size, self.quant_config.use_rand)
  89. weights = {
  90. "K_left": K_left,
  91. "K_right": K_right,
  92. "q_in_features": q_in_features,
  93. "q_out_features": q_out_features,
  94. }
  95. if had_left is not None:
  96. weights["had_left"] = Parameter(
  97. had_left.to(dtype=params_dtype, device="cuda"),
  98. requires_grad=False,
  99. )
  100. set_weight_attrs(weights["had_left"], {"ignore_warning": True})
  101. if had_right is not None:
  102. weights["had_right"] = Parameter(
  103. had_right.to(dtype=params_dtype, device="cuda"),
  104. requires_grad=False,
  105. )
  106. set_weight_attrs(weights["had_right"], {"ignore_warning": True})
  107. Qidxs = Parameter(
  108. torch.empty(q_out_features,
  109. q_in_features // self.pack,
  110. device="cuda",
  111. dtype=self.idx_dtype),
  112. requires_grad=False,
  113. )
  114. set_weight_attrs(Qidxs, {"ignore_warning": True})
  115. Wscale = Parameter(
  116. torch.ones((), dtype=torch.float, device="cuda"),
  117. requires_grad=False,
  118. )
  119. set_weight_attrs(Wscale, {"ignore_warning": True})
  120. SU = Parameter(
  121. torch.ones(
  122. input_size,
  123. device="cuda",
  124. dtype=params_dtype,
  125. ),
  126. requires_grad=False,
  127. )
  128. set_weight_attrs(SU, {"ignore_warning": True})
  129. SV = Parameter(
  130. torch.ones(
  131. output_size,
  132. device="cuda",
  133. dtype=params_dtype,
  134. ),
  135. requires_grad=False,
  136. )
  137. set_weight_attrs(SV, {"ignore_warning": True})
  138. weights.update({
  139. "Qidxs": Qidxs,
  140. "Wscale": Wscale,
  141. "SU": SU,
  142. "SV": SV,
  143. })
  144. return weights
  145. def apply_weights(self,
  146. weights: Dict[str, Any],
  147. x: torch.Tensor,
  148. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  149. # First run
  150. if isinstance(weights["Wscale"], torch.Tensor):
  151. weights["Wscale"] = weights["Wscale"].item()
  152. if "SU" in weights and torch.all(weights["SU"] > 0):
  153. del weights["SU"]
  154. if "SV" in weights and torch.all(weights["SV"] > 0):
  155. del weights["SV"]
  156. reshaped_x = x.reshape(-1, x.shape[-1])
  157. out_dim = weights["Qidxs"].shape[0]
  158. if "SU" in weights:
  159. reshaped_x = reshaped_x * weights["SU"]
  160. reshaped_x = matmul_hadUt_cuda(reshaped_x,
  161. weights.get("had_left",
  162. None), weights["K_left"],
  163. weights["q_in_features"],
  164. weights["Wscale"])
  165. m, n = weights["Qidxs"].shape
  166. if reshaped_x.size(0) < 32:
  167. out = ops.quip_gemv(reshaped_x, weights["Qidxs"],
  168. self.grid_packed_abs)
  169. else:
  170. W_decompressed = torch.empty(m,
  171. n * 8,
  172. dtype=torch.float16,
  173. device=x.device)
  174. ops.quip_decompress(weights["Qidxs"], self.grid_packed_abs,
  175. W_decompressed)
  176. out = reshaped_x @ W_decompressed.T
  177. out = matmul_hadU_cuda(out, weights.get("had_right",
  178. None), weights["K_right"],
  179. weights["q_out_features"])[..., :out_dim]
  180. if "SV" in weights:
  181. out = out * weights["SV"]
  182. out = out.view(*x.shape[:-1], out.shape[-1])
  183. out = out + bias if bias is not None else out
  184. return out
  185. def apply_moe_weights(self, w1: Dict[str,
  186. torch.Tensor], w2: Dict[str,
  187. torch.Tensor],
  188. x: torch.Tensor, gating_output: torch.Tensor,
  189. topk: int, renormalize: bool) -> torch.Tensor:
  190. raise NotImplementedError