quip.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. from torch.nn.parameter import Parameter
  4. from aphrodite import _custom_ops as ops
  5. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  6. from aphrodite.modeling.utils import set_weight_attrs
  7. from aphrodite.quantization.base_config import QuantizationConfig
  8. from aphrodite.quantization.quip_utils import (get_hadK, get_packed_abs_grid,
  9. matmul_hadU_cuda,
  10. matmul_hadUt_cuda)
  11. class QuipConfig(QuantizationConfig):
  12. """Config class for Quip.
  13. Reference: https://cornell-relaxml.github.io/quip-sharp/
  14. """
  15. def __init__(self, codebook: int, use_rand: bool) -> None:
  16. self.codebook = codebook
  17. self.use_rand = use_rand
  18. if self.codebook != "E8P12":
  19. raise ValueError("Currently, only E8P12 is supported for "
  20. f"Quip, but got {self.codebook}.")
  21. def __repr__(self) -> str:
  22. return (f"QuipConfig(codebook={self.codebook}, "
  23. f"rescale_WH={self.rescale_WH})")
  24. @classmethod
  25. def get_name(cls) -> str:
  26. return "quip"
  27. @classmethod
  28. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  29. return [torch.half]
  30. @classmethod
  31. def get_min_capability(cls) -> int:
  32. return 80
  33. @classmethod
  34. def get_config_filenames(cls) -> List[str]:
  35. return ["quantization_config.json"]
  36. @classmethod
  37. def from_config(cls, config: Dict[str, Any]) -> "QuipConfig":
  38. codebook = cls.get_from_keys(config, ["codebook"])
  39. use_rand = cls.get_from_keys(config, ["use_rand"])
  40. return cls(codebook, use_rand)
  41. def get_quant_method(self, layer: torch.nn.Module,
  42. prefix: str) -> Optional["QuipLinearMethod"]:
  43. if isinstance(layer, LinearBase):
  44. return QuipLinearMethod(self)
  45. return None
  46. def get_scaled_act_names(self) -> List[str]:
  47. return []
  48. class QuipLinearMethod(LinearMethodBase):
  49. """Linear method for Quip.
  50. Args:
  51. quant_config: The Quip quantization config.
  52. """
  53. def __init__(self, quant_config: QuipConfig):
  54. self.quant_config = quant_config
  55. self.grid_packed_abs = get_packed_abs_grid().to(device="cuda")
  56. self.pack = 8
  57. self.idx_dtype = torch.int16
  58. def create_weights(
  59. self,
  60. layer: torch.nn.Module,
  61. input_size_per_partition: int,
  62. output_partition_sizes: List[int],
  63. input_size: int,
  64. output_size: int,
  65. params_dtype: torch.dtype,
  66. **extra_weight_attrs,
  67. ):
  68. output_size_per_partition = sum(output_partition_sizes)
  69. if (input_size != input_size_per_partition
  70. or output_size != output_size_per_partition):
  71. raise ValueError(
  72. "Currently Quip doesn't support tensor parallel yet")
  73. had_left, K_left, q_in_features = get_hadK(input_size,
  74. self.quant_config.use_rand)
  75. had_right, K_right, q_out_features = get_hadK(
  76. output_size, self.quant_config.use_rand)
  77. if had_left is not None:
  78. layer.register_parameter(
  79. "had_left",
  80. Parameter(
  81. had_left.to(dtype=params_dtype, device="cuda"),
  82. requires_grad=False,
  83. ))
  84. set_weight_attrs(layer.had_left, extra_weight_attrs)
  85. if had_right is not None:
  86. layer.register_parameter(
  87. "had_right",
  88. Parameter(
  89. had_right.to(dtype=params_dtype, device="cuda"),
  90. requires_grad=False,
  91. ))
  92. set_weight_attrs(layer.had_right, extra_weight_attrs)
  93. layer.register_parameter(
  94. "Qidxs",
  95. Parameter(
  96. torch.empty(q_out_features,
  97. q_in_features // self.pack,
  98. device="cuda",
  99. dtype=self.idx_dtype),
  100. requires_grad=False,
  101. ))
  102. set_weight_attrs(layer.Qidxs, extra_weight_attrs)
  103. layer.register_parameter(
  104. "Wscale",
  105. Parameter(
  106. torch.ones((), dtype=torch.float, device="cuda"),
  107. requires_grad=False,
  108. ))
  109. set_weight_attrs(layer.Wscale, extra_weight_attrs)
  110. layer.register_parameter(
  111. "SU",
  112. Parameter(
  113. torch.ones(
  114. input_size,
  115. device="cuda",
  116. dtype=params_dtype,
  117. ),
  118. requires_grad=False,
  119. ))
  120. set_weight_attrs(layer.SU, extra_weight_attrs)
  121. layer.register_parameter(
  122. "SV",
  123. Parameter(
  124. torch.ones(
  125. output_size,
  126. device="cuda",
  127. dtype=params_dtype,
  128. ),
  129. requires_grad=False,
  130. ))
  131. set_weight_attrs(layer.SV, extra_weight_attrs)
  132. def apply(self,
  133. layer: torch.nn.Module,
  134. x: torch.Tensor,
  135. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  136. # First run
  137. if isinstance(layer.Wscale, torch.Tensor):
  138. layer.Wscale = layer.Wscale.item()
  139. if "SU" in layer and torch.all(layer.SU > 0):
  140. del layer.SU
  141. if "SV" in layer and torch.all(layer.SV > 0):
  142. del layer.SV
  143. reshaped_x = x.reshape(-1, x.shape[-1])
  144. out_dim = layer.Qidxs.shape[0]
  145. if "SU" in layer:
  146. reshaped_x = reshaped_x * layer.SU
  147. reshaped_x = matmul_hadUt_cuda(reshaped_x, layer.get("had_left", None),
  148. layer.K_left, layer.q_in_features,
  149. layer.Wscale)
  150. m, n = layer.Qidxs.shape
  151. if reshaped_x.size(0) < 32:
  152. out = ops.quip_gemv(reshaped_x, layer.Qidxs, self.grid_packed_abs)
  153. else:
  154. W_decompressed = torch.empty(m,
  155. n * 8,
  156. dtype=torch.float16,
  157. device=x.device)
  158. ops.quip_decompress(layer.Qidxs, self.grid_packed_abs,
  159. W_decompressed)
  160. out = reshaped_x @ W_decompressed.T
  161. out = matmul_hadU_cuda(out, layer.get("had_right",
  162. None), layer.K_right,
  163. layer.q_out_features)[..., :out_dim]
  164. if "SV" in layer:
  165. out = out * layer.SV
  166. out = out.view(*x.shape[:-1], out.shape[-1])
  167. out = out.add_(bias) if bias is not None else out
  168. return out