quip.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. from contextlib import suppress
  2. from typing import Any, Dict, List, Optional
  3. import torch
  4. from torch.nn.parameter import Parameter
  5. from aphrodite.modeling.layers.linear import LinearMethodBase, set_weight_attrs
  6. from aphrodite.quantization.base_config import QuantizationConfig
  7. from aphrodite.quantization.quip_utils import (get_hadK, get_packed_abs_grid,
  8. matmul_hadU_cuda,
  9. matmul_hadUt_cuda)
  10. HAS_QUANTS = False
  11. with suppress(ImportError):
  12. from aphrodite._quant_C import quant_ops as ops
  13. HAS_QUANTS = True
  14. class QuipConfig(QuantizationConfig):
  15. """Config class for Quip.
  16. Reference: https://cornell-relaxml.github.io/quip-sharp/
  17. """
  18. def __init__(self, codebook: int, use_rand: bool) -> None:
  19. if not HAS_QUANTS:
  20. raise ImportError("Could not find the quantization kernels.")
  21. self.codebook = codebook
  22. self.use_rand = use_rand
  23. if self.codebook != "E8P12":
  24. raise ValueError("Currently, only E8P12 is supported for "
  25. f"Quip, but got {self.codebook}.")
  26. def __repr__(self) -> str:
  27. return (f"QuipConfig(codebook={self.codebook}, "
  28. f"rescale_WH={self.rescale_WH})")
  29. @classmethod
  30. def get_name(cls) -> str:
  31. return "quip"
  32. @classmethod
  33. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  34. return [torch.half]
  35. @classmethod
  36. def get_min_capability(cls) -> int:
  37. return 80
  38. @classmethod
  39. def get_config_filenames(cls) -> List[str]:
  40. return ["quantization_config.json"]
  41. @classmethod
  42. def from_config(cls, config: Dict[str, Any]) -> "QuipConfig":
  43. codebook = cls.get_from_keys(config, ["codebook"])
  44. use_rand = cls.get_from_keys(config, ["use_rand"])
  45. return cls(codebook, use_rand)
  46. def get_linear_method(self) -> "QuipLinearMethod":
  47. return QuipLinearMethod(self)
  48. def get_scaled_act_names(self) -> List[str]:
  49. return []
  50. class QuipLinearMethod(LinearMethodBase):
  51. """Linear method for Quip.
  52. Args:
  53. quant_config: The Quip quantization config.
  54. """
  55. def __init__(self, quant_config: QuipConfig):
  56. self.quant_config = quant_config
  57. self.grid_packed_abs = get_packed_abs_grid().to(device="cuda")
  58. self.pack = 8
  59. self.idx_dtype = torch.int16
  60. def create_weights(
  61. self,
  62. layer: torch.nn.Module,
  63. input_size_per_partition: int,
  64. output_partition_sizes: List[int],
  65. input_size: int,
  66. output_size: int,
  67. params_dtype: torch.dtype,
  68. **extra_weight_attrs,
  69. ):
  70. output_size_per_partition = sum(output_partition_sizes)
  71. if (input_size != input_size_per_partition
  72. or output_size != output_size_per_partition):
  73. raise ValueError(
  74. "Currently Quip doesn't support tensor parallel yet")
  75. had_left, K_left, q_in_features = get_hadK(input_size,
  76. self.quant_config.use_rand)
  77. had_right, K_right, q_out_features = get_hadK(
  78. output_size, self.quant_config.use_rand)
  79. if had_left is not None:
  80. layer.register_parameter(
  81. "had_left",
  82. Parameter(
  83. had_left.to(dtype=params_dtype, device="cuda"),
  84. requires_grad=False,
  85. ))
  86. set_weight_attrs(layer.had_left, extra_weight_attrs)
  87. if had_right is not None:
  88. layer.register_parameter(
  89. "had_right",
  90. Parameter(
  91. had_right.to(dtype=params_dtype, device="cuda"),
  92. requires_grad=False,
  93. ))
  94. set_weight_attrs(layer.had_right, extra_weight_attrs)
  95. layer.register_parameter(
  96. "Qidxs",
  97. Parameter(
  98. torch.empty(q_out_features,
  99. q_in_features // self.pack,
  100. device="cuda",
  101. dtype=self.idx_dtype),
  102. requires_grad=False,
  103. ))
  104. set_weight_attrs(layer.Qidxs, extra_weight_attrs)
  105. layer.register_parameter(
  106. "Wscale",
  107. Parameter(
  108. torch.ones((), dtype=torch.float, device="cuda"),
  109. requires_grad=False,
  110. ))
  111. set_weight_attrs(layer.Wscale, extra_weight_attrs)
  112. layer.register_parameter(
  113. "SU",
  114. Parameter(
  115. torch.ones(
  116. input_size,
  117. device="cuda",
  118. dtype=params_dtype,
  119. ),
  120. requires_grad=False,
  121. ))
  122. set_weight_attrs(layer.SU, extra_weight_attrs)
  123. layer.register_parameter(
  124. "SV",
  125. Parameter(
  126. torch.ones(
  127. output_size,
  128. device="cuda",
  129. dtype=params_dtype,
  130. ),
  131. requires_grad=False,
  132. ))
  133. set_weight_attrs(layer.SV, extra_weight_attrs)
  134. def apply_weights(self,
  135. layer: torch.nn.Module,
  136. x: torch.Tensor,
  137. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  138. # First run
  139. if isinstance(layer.Wscale, torch.Tensor):
  140. layer.Wscale = layer.Wscale.item()
  141. if "SU" in layer and torch.all(layer.SU > 0):
  142. del layer.SU
  143. if "SV" in layer and torch.all(layer.SV > 0):
  144. del layer.SV
  145. reshaped_x = x.reshape(-1, x.shape[-1])
  146. out_dim = layer.Qidxs.shape[0]
  147. if "SU" in layer:
  148. reshaped_x = reshaped_x * layer.SU
  149. reshaped_x = matmul_hadUt_cuda(reshaped_x, layer.get("had_left", None),
  150. layer.K_left, layer.q_in_features,
  151. layer.Wscale)
  152. m, n = layer.Qidxs.shape
  153. if reshaped_x.size(0) < 32:
  154. out = ops.quip_gemv(reshaped_x, layer.Qidxs, self.grid_packed_abs)
  155. else:
  156. W_decompressed = torch.empty(m,
  157. n * 8,
  158. dtype=torch.float16,
  159. device=x.device)
  160. ops.quip_decompress(layer.Qidxs, self.grid_packed_abs,
  161. W_decompressed)
  162. out = reshaped_x @ W_decompressed.T
  163. out = matmul_hadU_cuda(out, layer.get("had_right",
  164. None), layer.K_right,
  165. layer.q_out_features)[..., :out_dim]
  166. if "SV" in layer:
  167. out = out * layer.SV
  168. out = out.view(*x.shape[:-1], out.shape[-1])
  169. out = out.add_(bias) if bias is not None else out
  170. return out