exl2.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. from typing import Any, Dict, List, Optional
  2. from contextlib import suppress
  3. import torch
  4. from aphrodite.modeling.layers.linear import (LinearMethodBase,
  5. set_weight_attrs)
  6. from aphrodite.quantization.base_config import (QuantizationConfig)
  7. HAS_QUANTS = False
  8. with suppress(ImportError):
  9. from aphrodite._quant_C import quant_ops as ops
  10. HAS_QUANTS = True
  11. def make_group_map(q_groups, num_qrows):
  12. gr = q_groups.tolist()
  13. group_map = []
  14. num_groups = len(gr) // 2
  15. for i in range(num_groups):
  16. bits = gr[i * 2]
  17. if i < num_groups - 1:
  18. qrows = gr[i * 2 + 3] - gr[i * 2 + 1]
  19. else:
  20. qrows = num_qrows - gr[i * 2 + 1]
  21. rows = qrows * 32 // bits
  22. for j in range(rows):
  23. group_map += [i]
  24. group_map += [rows - j]
  25. return torch.tensor(group_map, dtype=torch.short, device=q_groups.device)
  26. class Exl2Config(QuantizationConfig):
  27. """Config class for Exl2."""
  28. def __repr__(self) -> str:
  29. return "Exl2Config()"
  30. @classmethod
  31. def get_name(cls) -> str:
  32. return "exl2"
  33. @classmethod
  34. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  35. return [torch.half]
  36. @classmethod
  37. # Need to figure it out
  38. def get_min_capability(cls) -> int:
  39. return 60
  40. @classmethod
  41. def get_config_filenames(cls) -> List[str]:
  42. return []
  43. @classmethod
  44. def from_config(cls, config: Dict[str, Any]) -> "Exl2Config":
  45. return cls()
  46. def get_linear_method(self) -> "Exl2LinearMethod":
  47. return Exl2LinearMethod(self)
  48. def get_scaled_act_names(self) -> List[str]:
  49. return []
  50. def merge_weight(self) -> bool:
  51. return False
  52. def quant_vocab(self) -> List[bool]:
  53. return [False, True]
  54. def support_fused_moe(self) -> bool:
  55. return False
  56. def rope_style(self) -> Optional[bool]:
  57. return None
  58. class Exl2LinearMethod(LinearMethodBase):
  59. """Linear method for Exl2.
  60. Args:
  61. quant_config: The Exl2 quantization config.
  62. """
  63. def __init__(self, quant_config: Exl2Config):
  64. if not HAS_QUANTS:
  65. raise ImportError("Could not find the quantization kernels.")
  66. self.quant_config = quant_config
  67. def create_weights(self, layer: torch.nn.Module,
  68. input_size_per_partition: int,
  69. output_partition_sizes: List[int], input_size: int,
  70. output_size: int, params_dtype: torch.dtype,
  71. **extra_weight_attr):
  72. # The shape of weight is unknown until load state dict
  73. # q_groups, q_invperm, q_scale, q_scale_max, q_weight, q_groups
  74. layer.exllama_state = 0
  75. qweight = torch.nn.parameter.UninitializedParameter(
  76. requires_grad=False)
  77. set_weight_attrs(qweight, {"output_dim": 1, "ignore_warning": True})
  78. layer.register_parameter("q_weight", qweight)
  79. qscale = torch.nn.parameter.UninitializedParameter(requires_grad=False)
  80. set_weight_attrs(
  81. qscale, {
  82. "output_dim": 1,
  83. "packed_dim": 1,
  84. "pack_factor": 8,
  85. "ignore_warning": True
  86. })
  87. layer.register_parameter("q_scale", qscale)
  88. for name in ["q_groups", "q_invperm", "q_scale_max"]:
  89. fake_weight = torch.nn.parameter.UninitializedParameter(
  90. requires_grad=False)
  91. set_weight_attrs(fake_weight, {"ignore_warning": True})
  92. layer.register_parameter(name, fake_weight)
  93. def apply_weights(self,
  94. layer: torch.nn.Module,
  95. x: torch.Tensor,
  96. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  97. out_shape = x.shape[:-1] + (layer.q_weight.shape[-1], )
  98. reshaped_x = x.reshape(-1, x.shape[-1])
  99. if layer.exllama_state == 0:
  100. layer.q_scale_max /= 256
  101. layer.q_invperm = layer.q_invperm.short()
  102. if not hasattr(layer, 'q_perm'):
  103. layer.q_perm = torch.argsort(layer.q_invperm).to(torch.short)
  104. if not hasattr(layer, 'q_group_map'):
  105. layer.q_group_map = make_group_map(layer.q_groups,
  106. layer.q_weight.shape[0])
  107. layer.q_matrix = ops.exl2_make_q_matrix(
  108. layer.q_weight,
  109. layer.q_perm,
  110. layer.q_invperm,
  111. layer.q_scale,
  112. layer.q_scale_max,
  113. layer.q_groups,
  114. layer.q_group_map,
  115. )
  116. layer.exllama_state = 1
  117. output = ops.exl2_gemm(reshaped_x, layer.q_matrix)
  118. if bias is not None:
  119. output.add_(bias)
  120. return output.reshape(out_shape)
  121. def apply_moe_weights(self, w1: Dict[str,
  122. torch.Tensor], w2: Dict[str,
  123. torch.Tensor],
  124. x: torch.Tensor, gating_output: torch.Tensor,
  125. topk: int, renormalize: bool) -> torch.Tensor:
  126. raise NotImplementedError