quip.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. from torch.nn.parameter import Parameter
  4. from aphrodite._C import ops
  5. from aphrodite.modeling.layers.linear import (LinearMethodBase,
  6. set_weight_attrs)
  7. from aphrodite.modeling.layers.quantization.base_config import (
  8. QuantizationConfig)
  9. from aphrodite.modeling.layers.quantization.quip_utils import (
  10. get_packed_abs_grid,
  11. get_hadK,
  12. matmul_hadUt_cuda,
  13. matmul_hadU_cuda,
  14. )
  15. class QuipConfig(QuantizationConfig):
  16. """Config class for Quip.
  17. Reference: https://cornell-relaxml.github.io/quip-sharp/
  18. """
  19. def __init__(self, codebook: int, use_rand: bool) -> None:
  20. self.codebook = codebook
  21. self.use_rand = use_rand
  22. if self.codebook != "E8P12":
  23. raise ValueError("Currently, only E8P12 is supported for "
  24. f"Quip, but got {self.codebook}.")
  25. def __repr__(self) -> str:
  26. return (f"QuipConfig(codebook={self.codebook}, "
  27. f"rescale_WH={self.rescale_WH})")
  28. @classmethod
  29. def get_name(cls) -> str:
  30. return "quip"
  31. @classmethod
  32. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  33. return [torch.half]
  34. @classmethod
  35. def get_min_capability(cls) -> int:
  36. return 80
  37. @classmethod
  38. def get_config_filenames(cls) -> List[str]:
  39. return ["quantization_config.json"]
  40. @classmethod
  41. def from_config(cls, config: Dict[str, Any]) -> "QuipConfig":
  42. codebook = cls.get_from_keys(config, ["codebook"])
  43. use_rand = cls.get_from_keys(config, ["use_rand"])
  44. return cls(codebook, use_rand)
  45. def get_linear_method(self) -> "QuipLinearMethod":
  46. return QuipLinearMethod(self)
  47. def get_scaled_act_names(self) -> List[str]:
  48. return []
  49. def merge_weight(self) -> bool:
  50. return False
  51. def rope_style(self) -> Optional[bool]:
  52. return None
  53. class QuipLinearMethod(LinearMethodBase):
  54. """Linear method for Quip.
  55. Args:
  56. quant_config: The Quip quantization config.
  57. """
  58. def __init__(self, quant_config: QuipConfig):
  59. self.quant_config = quant_config
  60. self.grid_packed_abs = get_packed_abs_grid().to(device="cuda")
  61. self.pack = 8
  62. self.idx_dtype = torch.int16
  63. def create_weights(
  64. self,
  65. input_size_per_partition: int,
  66. output_partition_sizes: List[int],
  67. input_size: int,
  68. output_size: int,
  69. params_dtype: torch.dtype,
  70. ) -> Dict[str, Any]:
  71. output_size_per_partition = sum(output_partition_sizes)
  72. if (input_size != input_size_per_partition
  73. or output_size != output_size_per_partition):
  74. raise ValueError(
  75. "Currently Quip doesn't support tensor parallel yet")
  76. had_left, K_left, q_in_features = get_hadK(input_size,
  77. self.quant_config.use_rand)
  78. had_right, K_right, q_out_features = get_hadK(
  79. output_size, self.quant_config.use_rand)
  80. weights = {
  81. "K_left": K_left,
  82. "K_right": K_right,
  83. "q_in_features": q_in_features,
  84. "q_out_features": q_out_features,
  85. }
  86. if had_left is not None:
  87. weights["had_left"] = Parameter(
  88. had_left.to(dtype=params_dtype, device="cuda"),
  89. requires_grad=False,
  90. )
  91. set_weight_attrs(weights["had_left"], {"ignore_warning": True})
  92. if had_right is not None:
  93. weights["had_right"] = Parameter(
  94. had_right.to(dtype=params_dtype, device="cuda"),
  95. requires_grad=False,
  96. )
  97. set_weight_attrs(weights["had_right"], {"ignore_warning": True})
  98. Qidxs = Parameter(
  99. torch.empty(q_out_features,
  100. q_in_features // self.pack,
  101. device="cuda",
  102. dtype=self.idx_dtype),
  103. requires_grad=False,
  104. )
  105. set_weight_attrs(Qidxs, {"ignore_warning": True})
  106. Wscale = Parameter(
  107. torch.ones((), dtype=torch.float, device="cuda"),
  108. requires_grad=False,
  109. )
  110. set_weight_attrs(Wscale, {"ignore_warning": True})
  111. SU = Parameter(
  112. torch.ones(
  113. input_size,
  114. device="cuda",
  115. dtype=params_dtype,
  116. ),
  117. requires_grad=False,
  118. )
  119. set_weight_attrs(SU, {"ignore_warning": True})
  120. SV = Parameter(
  121. torch.ones(
  122. output_size,
  123. device="cuda",
  124. dtype=params_dtype,
  125. ),
  126. requires_grad=False,
  127. )
  128. set_weight_attrs(SV, {"ignore_warning": True})
  129. weights.update({
  130. "Qidxs": Qidxs,
  131. "Wscale": Wscale,
  132. "SU": SU,
  133. "SV": SV,
  134. })
  135. return weights
  136. def apply_weights(self,
  137. weights: Dict[str, Any],
  138. x: torch.Tensor,
  139. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  140. # First run
  141. if isinstance(weights["Wscale"], torch.Tensor):
  142. weights["Wscale"] = weights["Wscale"].item()
  143. if "SU" in weights and torch.all(weights["SU"] > 0):
  144. del weights["SU"]
  145. if "SV" in weights and torch.all(weights["SV"] > 0):
  146. del weights["SV"]
  147. reshaped_x = x.reshape(-1, x.shape[-1])
  148. out_dim = weights["Qidxs"].shape[0]
  149. if "SU" in weights:
  150. reshaped_x = reshaped_x * weights["SU"]
  151. reshaped_x = matmul_hadUt_cuda(reshaped_x,
  152. weights.get("had_left",
  153. None), weights["K_left"],
  154. weights["q_in_features"],
  155. weights["Wscale"])
  156. m, n = weights["Qidxs"].shape
  157. if reshaped_x.size(0) < 32:
  158. out = ops.quip_gemv(reshaped_x, weights["Qidxs"],
  159. self.grid_packed_abs)
  160. else:
  161. W_decompressed = torch.empty(m,
  162. n * 8,
  163. dtype=torch.float16,
  164. device=x.device)
  165. ops.quip_decompress(weights["Qidxs"], self.grid_packed_abs,
  166. W_decompressed)
  167. out = reshaped_x @ W_decompressed.T
  168. out = matmul_hadU_cuda(out, weights.get("had_right",
  169. None), weights["K_right"],
  170. weights["q_out_features"])[..., :out_dim]
  171. if "SV" in weights:
  172. out = out * weights["SV"]
  173. out = out.view(*x.shape[:-1], out.shape[-1])
  174. out = out + bias if bias is not None else out
  175. return out