1
0

quip.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. from contextlib import suppress
  2. from typing import Any, Dict, List, Optional
  3. import torch
  4. from torch.nn.parameter import Parameter
  5. from aphrodite.modeling.layers.linear import LinearMethodBase, LinearBase
  6. from aphrodite.quantization.base_config import QuantizationConfig
  7. from aphrodite.quantization.quip_utils import (get_hadK, get_packed_abs_grid,
  8. matmul_hadU_cuda,
  9. matmul_hadUt_cuda)
  10. from aphrodite.modeling.utils import set_weight_attrs
  11. HAS_QUANTS = False
  12. with suppress(ImportError):
  13. from aphrodite._quant_C import quant_ops as ops
  14. HAS_QUANTS = True
  15. class QuipConfig(QuantizationConfig):
  16. """Config class for Quip.
  17. Reference: https://cornell-relaxml.github.io/quip-sharp/
  18. """
  19. def __init__(self, codebook: int, use_rand: bool) -> None:
  20. if not HAS_QUANTS:
  21. raise ImportError("Could not find the quantization kernels.")
  22. self.codebook = codebook
  23. self.use_rand = use_rand
  24. if self.codebook != "E8P12":
  25. raise ValueError("Currently, only E8P12 is supported for "
  26. f"Quip, but got {self.codebook}.")
  27. def __repr__(self) -> str:
  28. return (f"QuipConfig(codebook={self.codebook}, "
  29. f"rescale_WH={self.rescale_WH})")
  30. @classmethod
  31. def get_name(cls) -> str:
  32. return "quip"
  33. @classmethod
  34. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  35. return [torch.half]
  36. @classmethod
  37. def get_min_capability(cls) -> int:
  38. return 80
  39. @classmethod
  40. def get_config_filenames(cls) -> List[str]:
  41. return ["quantization_config.json"]
  42. @classmethod
  43. def from_config(cls, config: Dict[str, Any]) -> "QuipConfig":
  44. codebook = cls.get_from_keys(config, ["codebook"])
  45. use_rand = cls.get_from_keys(config, ["use_rand"])
  46. return cls(codebook, use_rand)
  47. def get_quant_method(
  48. self, layer: torch.nn.Module) -> Optional["QuipLinearMethod"]:
  49. if isinstance(layer, LinearBase):
  50. return QuipLinearMethod(self)
  51. return None
  52. def get_scaled_act_names(self) -> List[str]:
  53. return []
  54. class QuipLinearMethod(LinearMethodBase):
  55. """Linear method for Quip.
  56. Args:
  57. quant_config: The Quip quantization config.
  58. """
  59. def __init__(self, quant_config: QuipConfig):
  60. self.quant_config = quant_config
  61. self.grid_packed_abs = get_packed_abs_grid().to(device="cuda")
  62. self.pack = 8
  63. self.idx_dtype = torch.int16
  64. def create_weights(
  65. self,
  66. layer: torch.nn.Module,
  67. input_size_per_partition: int,
  68. output_partition_sizes: List[int],
  69. input_size: int,
  70. output_size: int,
  71. params_dtype: torch.dtype,
  72. **extra_weight_attrs,
  73. ):
  74. output_size_per_partition = sum(output_partition_sizes)
  75. if (input_size != input_size_per_partition
  76. or output_size != output_size_per_partition):
  77. raise ValueError(
  78. "Currently Quip doesn't support tensor parallel yet")
  79. had_left, K_left, q_in_features = get_hadK(input_size,
  80. self.quant_config.use_rand)
  81. had_right, K_right, q_out_features = get_hadK(
  82. output_size, self.quant_config.use_rand)
  83. if had_left is not None:
  84. layer.register_parameter(
  85. "had_left",
  86. Parameter(
  87. had_left.to(dtype=params_dtype, device="cuda"),
  88. requires_grad=False,
  89. ))
  90. set_weight_attrs(layer.had_left, extra_weight_attrs)
  91. if had_right is not None:
  92. layer.register_parameter(
  93. "had_right",
  94. Parameter(
  95. had_right.to(dtype=params_dtype, device="cuda"),
  96. requires_grad=False,
  97. ))
  98. set_weight_attrs(layer.had_right, extra_weight_attrs)
  99. layer.register_parameter(
  100. "Qidxs",
  101. Parameter(
  102. torch.empty(q_out_features,
  103. q_in_features // self.pack,
  104. device="cuda",
  105. dtype=self.idx_dtype),
  106. requires_grad=False,
  107. ))
  108. set_weight_attrs(layer.Qidxs, extra_weight_attrs)
  109. layer.register_parameter(
  110. "Wscale",
  111. Parameter(
  112. torch.ones((), dtype=torch.float, device="cuda"),
  113. requires_grad=False,
  114. ))
  115. set_weight_attrs(layer.Wscale, extra_weight_attrs)
  116. layer.register_parameter(
  117. "SU",
  118. Parameter(
  119. torch.ones(
  120. input_size,
  121. device="cuda",
  122. dtype=params_dtype,
  123. ),
  124. requires_grad=False,
  125. ))
  126. set_weight_attrs(layer.SU, extra_weight_attrs)
  127. layer.register_parameter(
  128. "SV",
  129. Parameter(
  130. torch.ones(
  131. output_size,
  132. device="cuda",
  133. dtype=params_dtype,
  134. ),
  135. requires_grad=False,
  136. ))
  137. set_weight_attrs(layer.SV, extra_weight_attrs)
  138. def apply(self,
  139. layer: torch.nn.Module,
  140. x: torch.Tensor,
  141. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  142. # First run
  143. if isinstance(layer.Wscale, torch.Tensor):
  144. layer.Wscale = layer.Wscale.item()
  145. if "SU" in layer and torch.all(layer.SU > 0):
  146. del layer.SU
  147. if "SV" in layer and torch.all(layer.SV > 0):
  148. del layer.SV
  149. reshaped_x = x.reshape(-1, x.shape[-1])
  150. out_dim = layer.Qidxs.shape[0]
  151. if "SU" in layer:
  152. reshaped_x = reshaped_x * layer.SU
  153. reshaped_x = matmul_hadUt_cuda(reshaped_x, layer.get("had_left", None),
  154. layer.K_left, layer.q_in_features,
  155. layer.Wscale)
  156. m, n = layer.Qidxs.shape
  157. if reshaped_x.size(0) < 32:
  158. out = ops.quip_gemv(reshaped_x, layer.Qidxs, self.grid_packed_abs)
  159. else:
  160. W_decompressed = torch.empty(m,
  161. n * 8,
  162. dtype=torch.float16,
  163. device=x.device)
  164. ops.quip_decompress(layer.Qidxs, self.grid_packed_abs,
  165. W_decompressed)
  166. out = reshaped_x @ W_decompressed.T
  167. out = matmul_hadU_cuda(out, layer.get("had_right",
  168. None), layer.K_right,
  169. layer.q_out_features)[..., :out_dim]
  170. if "SV" in layer:
  171. out = out * layer.SV
  172. out = out.view(*x.shape[:-1], out.shape[-1])
  173. out = out.add_(bias) if bias is not None else out
  174. return out