1
0

layer.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. from abc import abstractmethod
  2. from typing import List, Optional, Tuple
  3. import torch
  4. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  5. get_tensor_model_parallel_world_size,
  6. tensor_model_parallel_all_reduce)
  7. from aphrodite.modeling._custom_op import CustomOp
  8. from aphrodite.modeling.utils import set_weight_attrs
  9. from aphrodite.quantization.base_config import (QuantizationConfig,
  10. QuantizeMethodBase)
  11. class FusedMoEMethodBase(QuantizeMethodBase):
  12. @abstractmethod
  13. def create_weights(self, layer: torch.nn.Module, num_experts: int,
  14. hidden_size: int, intermediate_size: int,
  15. params_dtype: torch.dtype, **extra_weight_attrs):
  16. raise NotImplementedError
  17. @abstractmethod
  18. def apply(self,
  19. layer: torch.nn.Module,
  20. x: torch.Tensor,
  21. router_logits: torch.Tensor,
  22. top_k: int,
  23. renormalize: bool = True,
  24. use_grouped_topk: bool = False,
  25. num_expert_group: Optional[int] = None,
  26. topk_group: Optional[int] = None) -> torch.Tensor:
  27. raise NotImplementedError
  28. class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
  29. """MoE method without quantization."""
  30. def create_weights(self, layer: torch.nn.Module, num_experts: int,
  31. hidden_size: int, intermediate_size: int,
  32. params_dtype: torch.dtype, **extra_weight_attrs):
  33. # Fused gate_up_proj (column parallel)
  34. w13_weight = torch.nn.Parameter(torch.empty(num_experts,
  35. 2 * intermediate_size,
  36. hidden_size,
  37. dtype=params_dtype),
  38. requires_grad=False)
  39. layer.register_parameter("w13_weight", w13_weight)
  40. set_weight_attrs(w13_weight, extra_weight_attrs)
  41. # down_proj (row parallel)
  42. w2_weight = torch.nn.Parameter(torch.empty(num_experts,
  43. hidden_size,
  44. intermediate_size,
  45. dtype=params_dtype),
  46. requires_grad=False)
  47. layer.register_parameter("w2_weight", w2_weight)
  48. set_weight_attrs(w2_weight, extra_weight_attrs)
  49. def apply(
  50. self,
  51. layer: torch.nn.Module,
  52. x: torch.Tensor,
  53. router_logits: torch.Tensor,
  54. top_k: int,
  55. renormalize: bool = True,
  56. use_grouped_topk: bool = False,
  57. num_expert_group: Optional[int] = None,
  58. topk_group: Optional[int] = None,
  59. ) -> torch.Tensor:
  60. return self.forward(x, layer.w13_weight, layer.w2_weight,
  61. router_logits, top_k, renormalize,
  62. use_grouped_topk, num_expert_group, topk_group)
  63. def forward_cuda(
  64. self,
  65. x: torch.Tensor,
  66. w1: torch.Tensor,
  67. w2: torch.Tensor,
  68. router_logits: torch.Tensor,
  69. top_k: int,
  70. renormalize: bool,
  71. use_grouped_topk: bool,
  72. num_expert_group: Optional[int],
  73. topk_group: Optional[int],
  74. ) -> torch.Tensor:
  75. from aphrodite.modeling.layers.fused_moe.fused_moe import fused_moe
  76. return fused_moe(x,
  77. w1,
  78. w2,
  79. router_logits,
  80. top_k,
  81. renormalize=renormalize,
  82. inplace=True,
  83. use_grouped_topk=use_grouped_topk,
  84. num_expert_group=num_expert_group,
  85. topk_group=topk_group)
  86. def forward_cpu(self, *args, **kwargs):
  87. raise NotImplementedError(
  88. "The CPU backend currently does not support MoE.")
  89. def forward_tpu(
  90. self,
  91. x: torch.Tensor,
  92. w1: torch.Tensor,
  93. w2: torch.Tensor,
  94. router_logits: torch.Tensor,
  95. top_k: int,
  96. renormalize: bool,
  97. use_grouped_topk: bool,
  98. num_expert_group: Optional[int],
  99. topk_group: Optional[int],
  100. ) -> torch.Tensor:
  101. from aphrodite.modeling.layers.fused_moe.moe_pallas import fused_moe
  102. assert not use_grouped_topk
  103. assert num_expert_group is None
  104. assert topk_group is None
  105. return fused_moe(x, w1, w2, router_logits, top_k, renormalize)
  106. class FusedMoE(torch.nn.Module):
  107. """FusedMoE layer for MoE models.
  108. This layer contains both MergedColumnParallel weights (gate_up_proj /
  109. w13) and RowParallelLinear weights (down_proj/ w2).
  110. Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
  111. copy that naming convention here and handle any remapping in the
  112. load_weights function in each model implementation.
  113. Args:
  114. num_experts: Number of experts in the model
  115. top_k: Number of experts selected for each token
  116. hidden_size: Input hidden state size of the transformer
  117. intermediate_size: Intermediate size of the experts
  118. params_dtype: Data type for the parameters.
  119. reduce_results: Whether to all all_reduce on the output of the layer
  120. renomalize: Whether to renormalize the logits in the fused_moe kernel
  121. quant_config: Quantization configure.
  122. """
  123. def __init__(
  124. self,
  125. num_experts: int,
  126. top_k: int,
  127. hidden_size: int,
  128. intermediate_size: int,
  129. params_dtype: Optional[torch.dtype] = None,
  130. reduce_results: bool = False,
  131. renormalize: bool = True,
  132. use_grouped_topk: bool = False,
  133. num_expert_group: Optional[int] = None,
  134. topk_group: Optional[int] = None,
  135. quant_config: Optional[QuantizationConfig] = None,
  136. tp_size: Optional[int] = None,
  137. prefix: str = "",
  138. ):
  139. super().__init__()
  140. if params_dtype is None:
  141. params_dtype = torch.get_default_dtype()
  142. self.tp_size = (tp_size if tp_size is not None else
  143. get_tensor_model_parallel_world_size())
  144. self.top_k = top_k
  145. self.num_experts = num_experts
  146. self.intermediate_size_per_partition = intermediate_size // self.tp_size
  147. self.reduce_results = reduce_results
  148. self.renormalize = renormalize
  149. self.use_grouped_topk = use_grouped_topk
  150. if self.use_grouped_topk:
  151. assert num_expert_group is not None and topk_group is not None
  152. self.num_expert_group = num_expert_group
  153. self.topk_group = topk_group
  154. if quant_config is None:
  155. self.quant_method: Optional[QuantizeMethodBase] = (
  156. UnquantizedFusedMoEMethod())
  157. else:
  158. self.quant_method = quant_config.get_quant_method(self, prefix)
  159. assert self.quant_method is not None
  160. self.quant_method.create_weights(
  161. layer=self,
  162. num_experts=num_experts,
  163. hidden_size=hidden_size,
  164. intermediate_size=self.intermediate_size_per_partition,
  165. params_dtype=params_dtype,
  166. weight_loader=self.weight_loader)
  167. def weight_loader(self, param: torch.nn.Parameter,
  168. loaded_weight: torch.Tensor, weight_name: str,
  169. shard_id: int, expert_id: int):
  170. param_data = param.data
  171. # Input scales can be loaded directly and should be equal.
  172. if "input_scale" in weight_name:
  173. if param_data[expert_id] != 1 and (param_data[expert_id] -
  174. loaded_weight).abs() > 1e-5:
  175. raise ValueError(
  176. "input_scales of w1 and w3 of a layer "
  177. f"must be equal. But got {param_data[expert_id]} "
  178. f"vs. {loaded_weight}")
  179. param_data[expert_id] = loaded_weight
  180. # Weight scales
  181. elif "weight_scale" in weight_name:
  182. # If we are in merged column case (gate_up_proj)
  183. # shard_id 0 == gate_proj / w1
  184. # shard_id 2 == up_proj / w3
  185. if shard_id == 0 or shard_id == 2:
  186. # We have to keep the weight scales of w1 and w3 because
  187. # we need to re-quantize w1/w3 weights after weight loading.
  188. idx = 0 if shard_id == 0 else 1
  189. param_data[expert_id][idx] = loaded_weight
  190. # If we are in the row parallel case (down_proj)
  191. # shard_id 1 == down_proj / w2
  192. else:
  193. param_data[expert_id] = loaded_weight
  194. # Weights
  195. else:
  196. tp_rank = get_tensor_model_parallel_rank()
  197. shard_size = self.intermediate_size_per_partition
  198. shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
  199. # w1, gate_proj case: Load into first shard of w13.
  200. if shard_id == 0:
  201. param_data[expert_id,
  202. 0:shard_size, :] = loaded_weight[shard, :]
  203. # w3, up_proj case: Load into second shard of w13.
  204. elif shard_id == 2:
  205. param_data[expert_id, shard_size:2 *
  206. shard_size, :] = loaded_weight[shard, :]
  207. # w2, down_proj case: Load into only shard of w2.
  208. elif shard_id == 1:
  209. param_data[expert_id, :, :] = loaded_weight[:, shard]
  210. else:
  211. raise ValueError(
  212. f"Shard id must be in [0,1,2] but got {shard_id}")
  213. def forward(self, hidden_states: torch.Tensor,
  214. router_logits: torch.Tensor):
  215. assert self.quant_method is not None
  216. # Matrix multiply.
  217. final_hidden_states = self.quant_method.apply(
  218. self,
  219. x=hidden_states,
  220. router_logits=router_logits,
  221. top_k=self.top_k,
  222. renormalize=self.renormalize,
  223. use_grouped_topk=self.use_grouped_topk,
  224. num_expert_group=self.num_expert_group,
  225. topk_group=self.topk_group)
  226. if self.reduce_results and self.tp_size > 1:
  227. final_hidden_states = tensor_model_parallel_all_reduce(
  228. final_hidden_states)
  229. return final_hidden_states
  230. @classmethod
  231. def make_expert_params_mapping(
  232. cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str,
  233. ckpt_up_proj_name: str,
  234. num_experts: int) -> List[Tuple[str, str, int, int]]:
  235. gate_up = [ckpt_gate_proj_name, ckpt_up_proj_name]
  236. gate_down_up = [
  237. ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name
  238. ]
  239. return [
  240. # These are the weight scales for the experts
  241. # (param_name, weight_name, expert_id, shard_id)
  242. ("experts.w13_scale"
  243. if weight_name in gate_up else "experts.w2_scale",
  244. f"experts.{expert_id}.{weight_name}.weight_scale", expert_id,
  245. shard_id) for expert_id in range(num_experts)
  246. for shard_id, weight_name in enumerate(gate_down_up)
  247. ] + [
  248. # These are the weights for the experts
  249. # (param_name, weight_name, expert_id, shard_id)
  250. ("experts.w13_weight"
  251. if weight_name in gate_up else "experts.w2_weight",
  252. f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id)
  253. for expert_id in range(num_experts)
  254. for shard_id, weight_name in enumerate(gate_down_up)
  255. ] + [
  256. # These are the weight scales for the experts
  257. # (param_name, weight_name, expert_id, shard_id)
  258. ("experts.a13_scale"
  259. if weight_name in gate_up else "experts.a2_scale",
  260. f"experts.{expert_id}.{weight_name}.input_scale", expert_id,
  261. shard_id) for expert_id in range(num_experts)
  262. for shard_id, weight_name in enumerate(gate_down_up)
  263. ]