exl2.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. from aphrodite import _custom_ops as ops
  4. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  5. from aphrodite.modeling.utils import set_weight_attrs
  6. from aphrodite.quantization.base_config import QuantizationConfig
  7. def make_group_map(q_groups, num_qrows):
  8. gr = q_groups.tolist()
  9. group_map = []
  10. num_groups = len(gr) // 2
  11. for i in range(num_groups):
  12. bits = gr[i * 2]
  13. if i < num_groups - 1:
  14. qrows = gr[i * 2 + 3] - gr[i * 2 + 1]
  15. else:
  16. qrows = num_qrows - gr[i * 2 + 1]
  17. rows = qrows * 32 // bits
  18. for j in range(rows):
  19. group_map += [i]
  20. group_map += [rows - j]
  21. return torch.tensor(group_map, dtype=torch.short, device=q_groups.device)
  22. class Exl2Config(QuantizationConfig):
  23. """Config class for Exl2."""
  24. def __repr__(self) -> str:
  25. return "Exl2Config()"
  26. @classmethod
  27. def get_name(cls) -> str:
  28. return "exl2"
  29. @classmethod
  30. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  31. return [torch.half]
  32. @classmethod
  33. # Need to figure it out
  34. def get_min_capability(cls) -> int:
  35. return 60
  36. @classmethod
  37. def get_config_filenames(cls) -> List[str]:
  38. return []
  39. @classmethod
  40. def from_config(cls, config: Dict[str, Any]) -> "Exl2Config":
  41. return cls()
  42. def get_quant_method(self, layer: torch.nn.Module,
  43. prefix: str) -> Optional["Exl2LinearMethod"]:
  44. if isinstance(layer, LinearBase):
  45. return Exl2LinearMethod(self)
  46. return None
  47. def get_scaled_act_names(self) -> List[str]:
  48. return []
  49. def merge_weight(self) -> bool:
  50. return False
  51. def quant_vocab(self) -> List[bool]:
  52. return [False, True]
  53. def support_fused_moe(self) -> bool:
  54. return False
  55. def rope_style(self) -> Optional[bool]:
  56. return None
  57. class Exl2LinearMethod(LinearMethodBase):
  58. """Linear method for Exl2.
  59. Args:
  60. quant_config: The Exl2 quantization config.
  61. """
  62. def __init__(self, quant_config: Exl2Config):
  63. self.quant_config = quant_config
  64. def create_weights(self, layer: torch.nn.Module,
  65. input_size_per_partition: int,
  66. output_partition_sizes: List[int], input_size: int,
  67. output_size: int, params_dtype: torch.dtype,
  68. **extra_weight_attr):
  69. # The shape of weight is unknown until load state dict
  70. # q_groups, q_invperm, q_scale, q_scale_max, q_weight, q_groups
  71. layer.exllama_state = 0
  72. qweight = torch.nn.parameter.UninitializedParameter(
  73. requires_grad=False)
  74. set_weight_attrs(qweight, {"output_dim": 1, "ignore_warning": True})
  75. layer.register_parameter("q_weight", qweight)
  76. qscale = torch.nn.parameter.UninitializedParameter(requires_grad=False)
  77. set_weight_attrs(
  78. qscale, {
  79. "output_dim": 1,
  80. "packed_dim": 1,
  81. "pack_factor": 8,
  82. "ignore_warning": True
  83. })
  84. layer.register_parameter("q_scale", qscale)
  85. for name in ["q_groups", "q_invperm", "q_scale_max"]:
  86. fake_weight = torch.nn.parameter.UninitializedParameter(
  87. requires_grad=False)
  88. set_weight_attrs(fake_weight, {"ignore_warning": True})
  89. layer.register_parameter(name, fake_weight)
  90. def apply(self,
  91. layer: torch.nn.Module,
  92. x: torch.Tensor,
  93. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  94. out_shape = x.shape[:-1] + (layer.q_weight.shape[-1], )
  95. reshaped_x = x.reshape(-1, x.shape[-1])
  96. if layer.exllama_state == 0:
  97. layer.q_scale_max /= 256
  98. layer.q_invperm = layer.q_invperm.short()
  99. if not hasattr(layer, 'q_perm'):
  100. layer.q_perm = torch.argsort(layer.q_invperm).to(torch.short)
  101. if not hasattr(layer, 'q_group_map'):
  102. layer.q_group_map = make_group_map(layer.q_groups,
  103. layer.q_weight.shape[0])
  104. layer.q_matrix = ops.exl2_make_q_matrix(
  105. layer.q_weight,
  106. layer.q_perm,
  107. layer.q_invperm,
  108. layer.q_scale,
  109. layer.q_scale_max,
  110. layer.q_groups,
  111. layer.q_group_map,
  112. )
  113. layer.exllama_state = 1
  114. output = ops.exl2_gemm(reshaped_x, layer.q_matrix)
  115. if bias is not None:
  116. output.add_(bias)
  117. return output.reshape(out_shape)
  118. def apply_moe_weights(self, w1: Dict[str,
  119. torch.Tensor], w2: Dict[str,
  120. torch.Tensor],
  121. x: torch.Tensor, gating_output: torch.Tensor,
  122. topk: int, renormalize: bool) -> torch.Tensor:
  123. raise NotImplementedError