exl2.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. from aphrodite import _custom_ops as ops
  4. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  5. from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
  6. from aphrodite.modeling.utils import set_weight_attrs
  7. from aphrodite.quantization.base_config import QuantizationConfig
  8. def make_group_map(q_groups, num_qrows):
  9. gr = q_groups.tolist()
  10. group_map = []
  11. num_groups = len(gr) // 2
  12. for i in range(num_groups):
  13. bits = gr[i * 2]
  14. if i < num_groups - 1:
  15. qrows = gr[i * 2 + 3] - gr[i * 2 + 1]
  16. else:
  17. qrows = num_qrows - gr[i * 2 + 1]
  18. rows = qrows * 32 // bits
  19. for j in range(rows):
  20. group_map += [i]
  21. group_map += [rows - j]
  22. return torch.tensor(group_map, dtype=torch.short, device=q_groups.device)
  23. class Exl2Config(QuantizationConfig):
  24. """Config class for Exl2."""
  25. def __init__(self, lm_head_quantized: bool):
  26. self.lm_head_quantized = lm_head_quantized
  27. def __repr__(self) -> str:
  28. return f"Exl2Config(lm_head_quantized={self.lm_head_quantized})"
  29. @classmethod
  30. def get_name(cls) -> str:
  31. return "exl2"
  32. @classmethod
  33. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  34. return [torch.half]
  35. @classmethod
  36. # Need to figure it out
  37. def get_min_capability(cls) -> int:
  38. return 60
  39. @classmethod
  40. def get_config_filenames(cls) -> List[str]:
  41. return []
  42. @classmethod
  43. def from_config(cls, config: Dict[str, Any]) -> "Exl2Config":
  44. lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
  45. return cls(lm_head_quantized)
  46. def get_quant_method(self, layer: torch.nn.Module,
  47. prefix: str) -> Optional["Exl2LinearMethod"]:
  48. if (isinstance(layer, LinearBase) or
  49. (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
  50. return Exl2LinearMethod(self)
  51. return None
  52. def get_scaled_act_names(self) -> List[str]:
  53. return []
  54. def merge_weight(self) -> bool:
  55. return False
  56. def quant_vocab(self) -> List[bool]:
  57. return [False, True]
  58. def support_fused_moe(self) -> bool:
  59. return False
  60. def rope_style(self) -> Optional[bool]:
  61. return None
  62. class Exl2LinearMethod(LinearMethodBase):
  63. """Linear method for Exl2.
  64. Args:
  65. quant_config: The Exl2 quantization config.
  66. """
  67. def __init__(self, quant_config: Exl2Config):
  68. self.quant_config = quant_config
  69. def create_weights(self, layer: torch.nn.Module,
  70. input_size_per_partition: int,
  71. output_partition_sizes: List[int], input_size: int,
  72. output_size: int, params_dtype: torch.dtype,
  73. **extra_weight_attr):
  74. # The shape of weight is unknown until load state dict
  75. # q_groups, q_invperm, q_scale, q_scale_max, q_weight, q_groups
  76. layer.exllama_state = 0
  77. # qweight = torch.nn.parameter.UninitializedParameter(
  78. # requires_grad=False)
  79. qweight = Parameter(
  80. torch.empty(
  81. input_size_per_partition // self.quant_config.pack_factor,
  82. output_size_per_partition,
  83. dtype=torch.int32,
  84. ),
  85. requires_grad=False,
  86. )
  87. set_weight_attrs(qweight, {"output_dim": 1, "ignore_warning": True})
  88. layer.register_parameter("q_weight", qweight)
  89. # qscale = torch.nn.parameter.UninitializedParameter(requires_grad=False)
  90. qscale = Parameter(
  91. torch.empty(
  92. input_size_per_partition // self.quant_config.pack_factor,
  93. output_size_per_partition,
  94. dtype=torch.int32,
  95. ),
  96. requires_grad=False,
  97. )
  98. set_weight_attrs(
  99. qscale, {
  100. "output_dim": 1,
  101. "packed_dim": 1,
  102. "pack_factor": 8,
  103. "ignore_warning": True
  104. })
  105. layer.register_parameter("q_scale", qscale)
  106. for name in ["q_groups", "q_invperm", "q_scale_max"]:
  107. fake_weight = torch.nn.parameter.UninitializedParameter(
  108. requires_grad=False)
  109. set_weight_attrs(fake_weight, {"ignore_warning": True})
  110. layer.register_parameter(name, fake_weight)
  111. def apply(self,
  112. layer: torch.nn.Module,
  113. x: torch.Tensor,
  114. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  115. out_shape = x.shape[:-1] + (layer.q_weight.shape[-1], )
  116. reshaped_x = x.reshape(-1, x.shape[-1])
  117. if layer.exllama_state == 0:
  118. layer.q_scale_max /= 256
  119. layer.q_invperm = layer.q_invperm.short()
  120. if not hasattr(layer, 'q_perm'):
  121. layer.q_perm = torch.argsort(layer.q_invperm).to(torch.short)
  122. if not hasattr(layer, 'q_group_map'):
  123. layer.q_group_map = make_group_map(layer.q_groups,
  124. layer.q_weight.shape[0])
  125. layer.q_matrix = ops.make_q_matrix(
  126. layer.q_weight,
  127. layer.q_perm,
  128. layer.q_invperm,
  129. layer.q_scale,
  130. layer.q_scale_max,
  131. layer.q_groups,
  132. layer.q_group_map,
  133. )
  134. layer.exllama_state = 1
  135. output = ops.exl2_gemm(reshaped_x, layer.q_matrix)
  136. if bias is not None:
  137. output.add_(bias)
  138. return output.reshape(out_shape)
  139. def apply_moe_weights(self, w1: Dict[str,
  140. torch.Tensor], w2: Dict[str,
  141. torch.Tensor],
  142. x: torch.Tensor, gating_output: torch.Tensor,
  143. topk: int, renormalize: bool) -> torch.Tensor:
  144. raise NotImplementedError