exl2.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. """Modeling utilities for Exllamav2 Quantization"""
  2. from typing import Any, Dict, List, Optional
  3. import torch
  4. from aphrodite._C import ops
  5. from aphrodite.modeling.layers.linear import (LinearMethodBase,
  6. set_weight_attrs)
  7. from aphrodite.modeling.layers.quantization.base_config import (
  8. QuantizationConfig)
  9. def make_group_map(q_groups, num_qrows):
  10. gr = q_groups.tolist()
  11. group_map = []
  12. num_groups = len(gr) // 2
  13. for i in range(num_groups):
  14. bits = gr[i * 2]
  15. if i < num_groups - 1:
  16. qrows = gr[i * 2 + 3] - gr[i * 2 + 1]
  17. else:
  18. qrows = num_qrows - gr[i * 2 + 1]
  19. rows = qrows * 32 // bits
  20. for j in range(rows):
  21. group_map += [i]
  22. group_map += [rows - j]
  23. return torch.tensor(group_map, dtype=torch.short, device=q_groups.device)
  24. class Exl2Config(QuantizationConfig):
  25. """Config class for Exl2.
  26. Reference: https://github.com/turboderp/exllamav2
  27. """
  28. def __repr__(self) -> str:
  29. return "Exl2Config()"
  30. @classmethod
  31. def get_name(cls) -> str:
  32. return "exl2"
  33. @classmethod
  34. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  35. return [torch.half]
  36. @classmethod
  37. # Need to figure it out
  38. def get_min_capability(cls) -> int:
  39. return 60
  40. @classmethod
  41. def get_config_filenames(cls) -> List[str]:
  42. return []
  43. @classmethod
  44. def from_config(cls, config: Dict[str, Any]) -> "Exl2Config":
  45. return cls()
  46. def get_linear_method(self) -> "Exl2LinearMethod":
  47. return Exl2LinearMethod(self)
  48. def get_scaled_act_names(self) -> List[str]:
  49. return []
  50. def merge_weight(self) -> bool:
  51. return False
  52. def quant_vocab(self) -> Optional[bool]:
  53. return (False, True)
  54. def rope_style(self) -> Optional[bool]:
  55. return None
  56. class Exl2LinearMethod(LinearMethodBase):
  57. """Linear method for Exl2.
  58. Args:
  59. quant_config: The Exl2 quantization config.
  60. """
  61. def __init__(self, quant_config: Exl2Config):
  62. self.quant_config = quant_config
  63. def create_weights(self, input_size_per_partition: int,
  64. output_partition_sizes: List[int], input_size: int,
  65. output_size: int,
  66. params_dtype: torch.dtype) -> Dict[str, Any]:
  67. output_size_per_partition = sum(output_partition_sizes)
  68. if (input_size != input_size_per_partition
  69. or output_size != output_size_per_partition):
  70. raise ValueError(
  71. "Currently exl2 doesn't support tensor parallel yet")
  72. # The shape of weight is unknown until load state dict
  73. # q_groups, q_invperm, q_scale, q_scale_max, q_weight, q_groups
  74. state_dict = {"exllama_state": 0}
  75. qweight = torch.nn.parameter.UninitializedParameter(
  76. requires_grad=False)
  77. set_weight_attrs(qweight, {
  78. "input_dim": 0,
  79. "output_dim": 1,
  80. })
  81. state_dict["q_weight"] = qweight
  82. for name in ["q_groups", "q_invperm", "q_scale", "q_scale_max"]:
  83. fake_weight = torch.nn.parameter.UninitializedParameter(
  84. requires_grad=False)
  85. set_weight_attrs(fake_weight, {"ignore_warning": True})
  86. state_dict[name] = fake_weight
  87. return state_dict
  88. def apply_weights(self,
  89. weights: Dict[str, Any],
  90. x: torch.Tensor,
  91. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  92. out_shape = x.shape[:-1] + (weights["q_weight"].shape[-1], )
  93. reshaped_x = x.reshape(-1, x.shape[-1])
  94. if weights["exllama_state"] == 0:
  95. weights["q_scale_max"] /= 256
  96. weights["q_invperm"] = weights["q_invperm"].short()
  97. weights["q_perm"] = torch.argsort(weights["q_invperm"]).to(
  98. torch.short)
  99. if "q_group_map" not in weights:
  100. weights["q_group_map"] = make_group_map(
  101. weights["q_groups"], weights["q_weight"].shape[0])
  102. weights["q_matrix"] = ops.exl2_make_q_matrix(
  103. weights["q_weight"],
  104. weights["q_perm"],
  105. weights["q_invperm"],
  106. weights["q_scale"],
  107. weights["q_scale_max"],
  108. weights["q_groups"],
  109. weights["q_group_map"],
  110. )
  111. weights["exllama_state"] = 1
  112. output = ops.exl2_gemm(reshaped_x, weights["q_matrix"])
  113. if bias is not None:
  114. output = output + bias
  115. return output.reshape(out_shape)