quip.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. from torch.nn.parameter import Parameter
  4. from aphrodite._C import ops
  5. from aphrodite.modeling.layers.linear import (LinearMethodBase,
  6. set_weight_attrs)
  7. from aphrodite.modeling.layers.quantization.base_config import (
  8. QuantizationConfig)
  9. from aphrodite.modeling.layers.quantization.quip_utils import (
  10. get_packed_abs_grid,
  11. get_hadK,
  12. matmul_hadUt_cuda,
  13. matmul_hadU_cuda,
  14. )
  15. class QuipConfig(QuantizationConfig):
  16. """Config class for Quip.
  17. Reference: https://cornell-relaxml.github.io/quip-sharp/
  18. """
  19. def __init__(self, codebook: int, use_rand: bool) -> None:
  20. self.codebook = codebook
  21. self.use_rand = use_rand
  22. if self.codebook != "E8P12":
  23. raise ValueError("Currently, only E8P12 is supported for "
  24. f"Quip, but got {self.codebook}.")
  25. def __repr__(self) -> str:
  26. return (f"QuipConfig(codebook={self.codebook}, "
  27. f"rescale_WH={self.rescale_WH})")
  28. @classmethod
  29. def get_name(cls) -> str:
  30. return "quip"
  31. @classmethod
  32. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  33. return [torch.half]
  34. @classmethod
  35. def get_min_capability(cls) -> int:
  36. return 80
  37. @classmethod
  38. def get_config_filenames(cls) -> List[str]:
  39. return ["quantization_config.json"]
  40. @classmethod
  41. def from_config(cls, config: Dict[str, Any]) -> "QuipConfig":
  42. codebook = cls.get_from_keys(config, ["codebook"])
  43. use_rand = cls.get_from_keys(config, ["use_rand"])
  44. return cls(codebook, use_rand)
  45. def get_linear_method(self) -> "QuipLinearMethod":
  46. return QuipLinearMethod(self)
  47. def get_scaled_act_names(self) -> List[str]:
  48. return []
  49. def merge_weight(self) -> bool:
  50. return False
  51. def quant_vocab(self) -> List[bool]:
  52. return [False, False]
  53. def support_fused_moe(self) -> bool:
  54. return False
  55. class QuipLinearMethod(LinearMethodBase):
  56. """Linear method for Quip.
  57. Args:
  58. quant_config: The Quip quantization config.
  59. """
  60. def __init__(self, quant_config: QuipConfig):
  61. self.quant_config = quant_config
  62. self.grid_packed_abs = get_packed_abs_grid().to(device="cuda")
  63. self.pack = 8
  64. self.idx_dtype = torch.int16
  65. def create_weights(
  66. self,
  67. input_size_per_partition: int,
  68. output_partition_sizes: List[int],
  69. input_size: int,
  70. output_size: int,
  71. params_dtype: torch.dtype,
  72. ) -> Dict[str, Any]:
  73. output_size_per_partition = sum(output_partition_sizes)
  74. if (input_size != input_size_per_partition
  75. or output_size != output_size_per_partition):
  76. raise ValueError(
  77. "Currently Quip doesn't support tensor parallel yet")
  78. had_left, K_left, q_in_features = get_hadK(input_size,
  79. self.quant_config.use_rand)
  80. had_right, K_right, q_out_features = get_hadK(
  81. output_size, self.quant_config.use_rand)
  82. weights = {
  83. "K_left": K_left,
  84. "K_right": K_right,
  85. "q_in_features": q_in_features,
  86. "q_out_features": q_out_features,
  87. }
  88. if had_left is not None:
  89. weights["had_left"] = Parameter(
  90. had_left.to(dtype=params_dtype, device="cuda"),
  91. requires_grad=False,
  92. )
  93. set_weight_attrs(weights["had_left"], {"ignore_warning": True})
  94. if had_right is not None:
  95. weights["had_right"] = Parameter(
  96. had_right.to(dtype=params_dtype, device="cuda"),
  97. requires_grad=False,
  98. )
  99. set_weight_attrs(weights["had_right"], {"ignore_warning": True})
  100. Qidxs = Parameter(
  101. torch.empty(q_out_features,
  102. q_in_features // self.pack,
  103. device="cuda",
  104. dtype=self.idx_dtype),
  105. requires_grad=False,
  106. )
  107. set_weight_attrs(Qidxs, {"ignore_warning": True})
  108. Wscale = Parameter(
  109. torch.ones((), dtype=torch.float, device="cuda"),
  110. requires_grad=False,
  111. )
  112. set_weight_attrs(Wscale, {"ignore_warning": True})
  113. SU = Parameter(
  114. torch.ones(
  115. input_size,
  116. device="cuda",
  117. dtype=params_dtype,
  118. ),
  119. requires_grad=False,
  120. )
  121. set_weight_attrs(SU, {"ignore_warning": True})
  122. SV = Parameter(
  123. torch.ones(
  124. output_size,
  125. device="cuda",
  126. dtype=params_dtype,
  127. ),
  128. requires_grad=False,
  129. )
  130. set_weight_attrs(SV, {"ignore_warning": True})
  131. weights.update({
  132. "Qidxs": Qidxs,
  133. "Wscale": Wscale,
  134. "SU": SU,
  135. "SV": SV,
  136. })
  137. return weights
  138. def apply_weights(self,
  139. weights: Dict[str, Any],
  140. x: torch.Tensor,
  141. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  142. # First run
  143. if isinstance(weights["Wscale"], torch.Tensor):
  144. weights["Wscale"] = weights["Wscale"].item()
  145. if "SU" in weights and torch.all(weights["SU"] > 0):
  146. del weights["SU"]
  147. if "SV" in weights and torch.all(weights["SV"] > 0):
  148. del weights["SV"]
  149. reshaped_x = x.reshape(-1, x.shape[-1])
  150. out_dim = weights["Qidxs"].shape[0]
  151. if "SU" in weights:
  152. reshaped_x = reshaped_x * weights["SU"]
  153. reshaped_x = matmul_hadUt_cuda(reshaped_x,
  154. weights.get("had_left",
  155. None), weights["K_left"],
  156. weights["q_in_features"],
  157. weights["Wscale"])
  158. m, n = weights["Qidxs"].shape
  159. if reshaped_x.size(0) < 32:
  160. out = ops.quip_gemv(reshaped_x, weights["Qidxs"],
  161. self.grid_packed_abs)
  162. else:
  163. W_decompressed = torch.empty(m,
  164. n * 8,
  165. dtype=torch.float16,
  166. device=x.device)
  167. ops.quip_decompress(weights["Qidxs"], self.grid_packed_abs,
  168. W_decompressed)
  169. out = reshaped_x @ W_decompressed.T
  170. out = matmul_hadU_cuda(out, weights.get("had_right",
  171. None), weights["K_right"],
  172. weights["q_out_features"])[..., :out_dim]
  173. if "SV" in weights:
  174. out = out * weights["SV"]
  175. out = out.view(*x.shape[:-1], out.shape[-1])
  176. out = out + bias if bias is not None else out
  177. return out
  178. def apply_moe_weights(self, w1: Dict[str,
  179. torch.Tensor], w2: Dict[str,
  180. torch.Tensor],
  181. x: torch.Tensor, gating_output: torch.Tensor,
  182. topk: int, renormalize: bool) -> torch.Tensor:
  183. raise NotImplementedError