experts_int8.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. from aphrodite.distributed import get_tensor_model_parallel_rank, get_tp_group
  4. from aphrodite.modeling.layers.fused_moe import FusedMoE, FusedMoEMethodBase
  5. from aphrodite.modeling.layers.linear import (LinearBase,
  6. UnquantizedLinearMethod)
  7. from aphrodite.modeling.utils import set_weight_attrs
  8. from aphrodite.quantization.base_config import (QuantizationConfig,
  9. QuantizeMethodBase)
  10. class ExpertsInt8Config(QuantizationConfig):
  11. """Config class for Int8 experts quantization."""
  12. def __init__(self) -> None:
  13. pass
  14. @classmethod
  15. def get_name(cls) -> str:
  16. return "experts_int8"
  17. @classmethod
  18. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  19. return [torch.bfloat16, torch.half]
  20. @classmethod
  21. def get_min_capability(cls) -> int:
  22. return 80
  23. @classmethod
  24. def get_config_filenames(cls) -> List[str]:
  25. return []
  26. @classmethod
  27. def from_config(cls, config: Dict[str, Any]) -> "ExpertsInt8Config":
  28. return cls()
  29. def get_quant_method(self, layer: torch.nn.Module,
  30. prefix: str) -> Optional["QuantizeMethodBase"]:
  31. if isinstance(layer, LinearBase):
  32. return UnquantizedLinearMethod()
  33. elif isinstance(layer, FusedMoE):
  34. return ExpertsInt8MoEMethod(self)
  35. return None
  36. def get_scaled_act_names(self) -> List[str]:
  37. return []
  38. class ExpertsInt8MoEMethod(FusedMoEMethodBase):
  39. def __init__(self, quant_config: ExpertsInt8Config):
  40. self.quant_config = quant_config
  41. def create_weights(self, layer: torch.nn.Module, num_experts: int,
  42. hidden_size: int, intermediate_size: int,
  43. params_dtype: torch.dtype, **extra_weight_attrs):
  44. int8_dtype = torch.int8
  45. assert 'weight_loader' in extra_weight_attrs
  46. weight_loader = extra_weight_attrs['weight_loader']
  47. wrapped_weight_loader = ExpertsInt8MoEMethod.quantizing_weight_loader(
  48. layer, weight_loader)
  49. extra_weight_attrs['weight_loader'] = wrapped_weight_loader
  50. # Fused gate_up_proj (column parallel)
  51. w13_weight = torch.nn.Parameter(torch.empty(num_experts,
  52. 2 * intermediate_size,
  53. hidden_size,
  54. dtype=int8_dtype),
  55. requires_grad=False)
  56. layer.register_parameter("w13_weight", w13_weight)
  57. set_weight_attrs(w13_weight, extra_weight_attrs)
  58. # down_proj (row parallel)
  59. w2_weight = torch.nn.Parameter(torch.empty(num_experts,
  60. hidden_size,
  61. intermediate_size,
  62. dtype=int8_dtype),
  63. requires_grad=False)
  64. layer.register_parameter("w2_weight", w2_weight)
  65. set_weight_attrs(w2_weight, extra_weight_attrs)
  66. w13_scale = torch.nn.Parameter(torch.zeros(num_experts,
  67. 2 * intermediate_size,
  68. dtype=torch.float32),
  69. requires_grad=False)
  70. layer.register_parameter("w13_scale", w13_scale)
  71. w2_scale = torch.nn.Parameter(torch.zeros(num_experts,
  72. hidden_size,
  73. dtype=torch.float32),
  74. requires_grad=False)
  75. layer.register_parameter("w2_scale", w2_scale)
  76. def apply(self,
  77. layer: torch.nn.Module,
  78. x: torch.Tensor,
  79. router_logits: torch.Tensor,
  80. top_k: int,
  81. renormalize: bool = True,
  82. use_grouped_topk: bool = False,
  83. num_expert_group: Optional[int] = None,
  84. topk_group: Optional[int] = None) -> torch.Tensor:
  85. from aphrodite.modeling.layers.fused_moe import fused_experts
  86. topk_weights, topk_ids = FusedMoE.select_experts(
  87. hidden_states=x,
  88. router_logits=router_logits,
  89. use_grouped_topk=use_grouped_topk,
  90. top_k=top_k,
  91. renormalize=renormalize,
  92. topk_group=topk_group,
  93. num_expert_group=num_expert_group)
  94. return fused_experts(x,
  95. layer.w13_weight,
  96. layer.w2_weight,
  97. topk_weights=topk_weights,
  98. topk_ids=topk_ids,
  99. inplace=True,
  100. use_int8_w8a16=True,
  101. w1_scale=layer.w13_scale,
  102. w2_scale=layer.w2_scale)
  103. @staticmethod
  104. def quantizing_weight_loader(layer, weight_loader):
  105. def quantize_and_call_weight_loader(param: torch.nn.Parameter,
  106. loaded_weight: torch.Tensor,
  107. weight_name: str, shard_id: int,
  108. expert_id: int):
  109. tp_rank = get_tensor_model_parallel_rank()
  110. shard_size = layer.intermediate_size_per_partition
  111. shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
  112. device = get_tp_group().device
  113. loaded_weight = loaded_weight.to(device)
  114. # w1, gate_proj case: Load into first shard of w13.
  115. if shard_id == "w1":
  116. scales = quantize_in_place_and_get_scales(
  117. loaded_weight[shard, :])
  118. layer.w13_scale.data[expert_id, 0:shard_size].copy_(scales[:,
  119. 0])
  120. # w3, up_proj case: Load into second shard of w13.
  121. elif shard_id == "w3":
  122. scales = quantize_in_place_and_get_scales(
  123. loaded_weight[shard, :])
  124. layer.w13_scale.data[expert_id, shard_size:2 *
  125. shard_size].copy_(scales[:, 0])
  126. # w2, down_proj case: Load into only shard of w2.
  127. elif shard_id == "w2":
  128. scales = quantize_in_place_and_get_scales(loaded_weight[:,
  129. shard])
  130. layer.w2_scale.data[expert_id, :].copy_(scales[:, 0])
  131. else:
  132. raise ValueError(
  133. f"Shard id must be in [0,1,2] but got {shard_id}")
  134. weight_loader(param, loaded_weight, weight_name, shard_id,
  135. expert_id)
  136. return quantize_and_call_weight_loader
  137. def quantize_in_place_and_get_scales(weight: torch.Tensor) -> torch.Tensor:
  138. vmax = torch.iinfo(torch.int8).max
  139. scales = (torch.max(torch.abs(weight), dim=1, keepdim=True)[0] / vmax)
  140. weight.div_(scales)
  141. weight.round_()
  142. weight.clamp_(-vmax, vmax)
  143. return scales