layer.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. from abc import abstractmethod
  2. from typing import List, Optional, Tuple
  3. import torch
  4. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  5. get_tensor_model_parallel_world_size,
  6. tensor_model_parallel_all_reduce)
  7. from aphrodite.modeling._custom_op import CustomOp
  8. from aphrodite.modeling.utils import set_weight_attrs
  9. from aphrodite.quantization.base_config import (QuantizationConfig,
  10. QuantizeMethodBase)
  11. class FusedMoEMethodBase(QuantizeMethodBase):
  12. @abstractmethod
  13. def create_weights(self, layer: torch.nn.Module, num_experts: int,
  14. hidden_size: int, intermediate_size: int,
  15. params_dtype: torch.dtype, **extra_weight_attrs):
  16. raise NotImplementedError
  17. @abstractmethod
  18. def apply(self, layer: torch.nn.Module, x: torch.Tensor,
  19. router_logits: torch.Tensor, top_k: int, renormalize: bool,
  20. use_grouped_topk: bool) -> torch.Tensor:
  21. raise NotImplementedError
  22. class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
  23. """MoE method without quantization."""
  24. def create_weights(self, layer: torch.nn.Module, num_experts: int,
  25. hidden_size: int, intermediate_size: int,
  26. params_dtype: torch.dtype, **extra_weight_attrs):
  27. # Fused gate_up_proj (column parallel)
  28. w13_weight = torch.nn.Parameter(torch.empty(num_experts,
  29. 2 * intermediate_size,
  30. hidden_size,
  31. dtype=params_dtype),
  32. requires_grad=False)
  33. layer.register_parameter("w13_weight", w13_weight)
  34. set_weight_attrs(w13_weight, extra_weight_attrs)
  35. # down_proj (row parallel)
  36. w2_weight = torch.nn.Parameter(torch.empty(num_experts,
  37. hidden_size,
  38. intermediate_size,
  39. dtype=params_dtype),
  40. requires_grad=False)
  41. layer.register_parameter("w2_weight", w2_weight)
  42. set_weight_attrs(w2_weight, extra_weight_attrs)
  43. def apply(self,
  44. layer: torch.nn.Module,
  45. x: torch.Tensor,
  46. router_logits: torch.Tensor,
  47. top_k: int,
  48. renormalize: bool,
  49. use_grouped_topk: bool,
  50. topk_group: Optional[int] = None,
  51. num_expert_group: Optional[int] = None) -> torch.Tensor:
  52. return self.forward(x=x,
  53. layer=layer,
  54. router_logits=router_logits,
  55. top_k=top_k,
  56. renormalize=renormalize,
  57. use_grouped_topk=use_grouped_topk,
  58. topk_group=topk_group,
  59. num_expert_group=num_expert_group)
  60. def forward_cuda(self,
  61. layer: torch.nn.Module,
  62. x: torch.Tensor,
  63. use_grouped_topk: bool,
  64. top_k: int,
  65. router_logits: torch.Tensor,
  66. renormalize: bool,
  67. topk_group: Optional[int] = None,
  68. num_expert_group: Optional[int] = None) -> torch.Tensor:
  69. from aphrodite.modeling.layers.fused_moe.fused_moe import fused_experts
  70. topk_weights, topk_ids = FusedMoE.select_experts(
  71. hidden_states=x,
  72. router_logits=router_logits,
  73. use_grouped_topk=use_grouped_topk,
  74. top_k=top_k,
  75. renormalize=renormalize,
  76. topk_group=topk_group,
  77. num_expert_group=num_expert_group)
  78. return fused_experts(hidden_states=x,
  79. w1=layer.w13_weight,
  80. w2=layer.w2_weight,
  81. topk_weights=topk_weights,
  82. topk_ids=topk_ids,
  83. inplace=True)
  84. def forward_cpu(self, *args, **kwargs):
  85. raise NotImplementedError(
  86. "The CPU backend currently does not support MoE.")
  87. def forward_tpu(self,
  88. layer: torch.nn.Module,
  89. x: torch.Tensor,
  90. use_grouped_topk: bool,
  91. top_k: int,
  92. router_logits: torch.Tensor,
  93. renormalize: bool,
  94. topk_group: Optional[int] = None,
  95. num_expert_group: Optional[int] = None) -> torch.Tensor:
  96. from aphrodite.modeling.layers.fused_moe.moe_pallas import fused_moe
  97. assert not use_grouped_topk
  98. assert num_expert_group is None
  99. assert topk_group is None
  100. return fused_moe(hidden_states=x,
  101. w1=layer.w13_weight,
  102. w2=layer.w2_weight,
  103. topk=top_k,
  104. gating_output=router_logits,
  105. renormalize=renormalize)
  106. class FusedMoE(torch.nn.Module):
  107. """FusedMoE layer for MoE models.
  108. This layer contains both MergedColumnParallel weights (gate_up_proj /
  109. w13) and RowParallelLinear weights (down_proj/ w2).
  110. Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
  111. copy that naming convention here and handle any remapping in the
  112. load_weights function in each model implementation.
  113. Args:
  114. num_experts: Number of experts in the model
  115. top_k: Number of experts selected for each token
  116. hidden_size: Input hidden state size of the transformer
  117. intermediate_size: Intermediate size of the experts
  118. params_dtype: Data type for the parameters.
  119. reduce_results: Whether to all all_reduce on the output of the layer
  120. renomalize: Whether to renormalize the logits in the fused_moe kernel
  121. quant_config: Quantization configure.
  122. """
  123. def __init__(
  124. self,
  125. num_experts: int,
  126. top_k: int,
  127. hidden_size: int,
  128. intermediate_size: int,
  129. params_dtype: Optional[torch.dtype] = None,
  130. reduce_results: bool = False,
  131. renormalize: bool = True,
  132. use_grouped_topk: bool = False,
  133. num_expert_group: Optional[int] = None,
  134. topk_group: Optional[int] = None,
  135. quant_config: Optional[QuantizationConfig] = None,
  136. tp_size: Optional[int] = None,
  137. prefix: str = "",
  138. ):
  139. super().__init__()
  140. if params_dtype is None:
  141. params_dtype = torch.get_default_dtype()
  142. self.tp_size = (tp_size if tp_size is not None else
  143. get_tensor_model_parallel_world_size())
  144. self.top_k = top_k
  145. self.num_experts = num_experts
  146. self.intermediate_size_per_partition = intermediate_size // self.tp_size
  147. self.reduce_results = reduce_results
  148. self.renormalize = renormalize
  149. self.use_grouped_topk = use_grouped_topk
  150. if self.use_grouped_topk:
  151. assert num_expert_group is not None and topk_group is not None
  152. self.num_expert_group = num_expert_group
  153. self.topk_group = topk_group
  154. if quant_config is None:
  155. self.quant_method: Optional[QuantizeMethodBase] = (
  156. UnquantizedFusedMoEMethod())
  157. else:
  158. self.quant_method = quant_config.get_quant_method(self, prefix)
  159. assert self.quant_method is not None
  160. self.quant_method.create_weights(
  161. layer=self,
  162. num_experts=num_experts,
  163. hidden_size=hidden_size,
  164. intermediate_size=self.intermediate_size_per_partition,
  165. params_dtype=params_dtype,
  166. weight_loader=self.weight_loader)
  167. def weight_loader(self, param: torch.nn.Parameter,
  168. loaded_weight: torch.Tensor, weight_name: str,
  169. shard_id: str, expert_id: int) -> None:
  170. if shard_id not in ("w1", "w2", "w3"):
  171. raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
  172. f"got {shard_id}.")
  173. # Special case for fp8 scales.
  174. if getattr(param, "is_fp8_scale", False):
  175. self._load_fp8_scale(param.data, loaded_weight, weight_name,
  176. shard_id, expert_id)
  177. return
  178. expert_data = param.data[expert_id]
  179. tp_rank = get_tensor_model_parallel_rank()
  180. # If transposed, weight is saved as [input_dim, output_dim]
  181. # Otherwise, weight is saved as [output_dim, input_dim]
  182. # Default is not transposed/input dim is dim 1
  183. input_dim = getattr(param, "input_dim", 1)
  184. output_dim = getattr(param, "output_dim", 0)
  185. # Index the loaded weight for tp sharding.
  186. # down_proj: "RowParallel" so tp sharding on input_dim
  187. if shard_id == "w2":
  188. shard_dim = input_dim
  189. shard_size = expert_data.shape[shard_dim]
  190. # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
  191. elif shard_id in ("w1", "w3"):
  192. shard_dim = output_dim
  193. shard_size = expert_data.shape[output_dim] // 2
  194. offset = shard_size * tp_rank
  195. loaded_weight = loaded_weight.narrow(shard_dim, offset, shard_size)
  196. # Narrow parameter and load.
  197. # w1, gate_proj: Load into first logical weight of w13.
  198. if shard_id == "w1":
  199. expert_data = expert_data.narrow(shard_dim, 0, shard_size)
  200. expert_data.copy_(loaded_weight)
  201. # w3, up_proj: Load into second logical weight of w13.
  202. elif shard_id == "w3":
  203. expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
  204. expert_data.copy_(loaded_weight)
  205. # w2, down_proj: Load into only logical weight of w2.
  206. elif shard_id == "w2":
  207. expert_data.copy_(loaded_weight)
  208. else:
  209. raise ValueError(
  210. f"Expected shard_id w1,w2 or w3 but got {shard_id}")
  211. @staticmethod
  212. def select_experts(hidden_states: torch.Tensor,
  213. router_logits: torch.Tensor,
  214. top_k: int,
  215. use_grouped_topk: bool,
  216. renormalize: bool,
  217. topk_group: Optional[int] = None,
  218. num_expert_group: Optional[int] = None):
  219. from aphrodite.modeling.layers.fused_moe.fused_moe import (
  220. fused_topk, grouped_topk)
  221. # DeekSeekv2 uses grouped_top_k
  222. if use_grouped_topk:
  223. assert topk_group is not None
  224. assert num_expert_group is not None
  225. topk_weights, topk_ids = grouped_topk(
  226. hidden_states=hidden_states,
  227. gating_output=router_logits,
  228. topk=top_k,
  229. renormalize=renormalize,
  230. num_expert_group=num_expert_group,
  231. topk_group=topk_group)
  232. else:
  233. topk_weights, topk_ids = fused_topk(hidden_states=hidden_states,
  234. gating_output=router_logits,
  235. topk=top_k,
  236. renormalize=renormalize)
  237. return topk_weights, topk_ids
  238. def forward(self, hidden_states: torch.Tensor,
  239. router_logits: torch.Tensor):
  240. assert self.quant_method is not None
  241. # Matrix multiply.
  242. final_hidden_states = self.quant_method.apply(
  243. layer=self,
  244. x=hidden_states,
  245. router_logits=router_logits,
  246. top_k=self.top_k,
  247. renormalize=self.renormalize,
  248. use_grouped_topk=self.use_grouped_topk,
  249. topk_group=self.topk_group,
  250. num_expert_group=self.num_expert_group)
  251. if self.reduce_results and self.tp_size > 1:
  252. final_hidden_states = tensor_model_parallel_all_reduce(
  253. final_hidden_states)
  254. return final_hidden_states
  255. @classmethod
  256. def make_expert_params_mapping(
  257. cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str,
  258. ckpt_up_proj_name: str,
  259. num_experts: int) -> List[Tuple[str, str, int, str]]:
  260. return [
  261. # (param_name, weight_name, expert_id, shard_id)
  262. ("experts.w13_" if weight_name
  263. in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_",
  264. f"experts.{expert_id}.{weight_name}.", expert_id, shard_id)
  265. for expert_id in range(num_experts) for shard_id, weight_name in [
  266. ("w1", ckpt_gate_proj_name),
  267. ("w2", ckpt_down_proj_name),
  268. ("w3", ckpt_up_proj_name),
  269. ]
  270. ]
  271. def _load_fp8_scale(self, param: torch.nn.Parameter,
  272. loaded_weight: torch.Tensor, weight_name: str,
  273. shard_id: str, expert_id: int) -> None:
  274. param_data = param.data
  275. # Input scales can be loaded directly and should be equal.
  276. if "input_scale" in weight_name:
  277. if param_data[expert_id] != 1 and (param_data[expert_id] -
  278. loaded_weight).abs() > 1e-5:
  279. raise ValueError(
  280. "input_scales of w1 and w3 of a layer "
  281. f"must be equal. But got {param_data[expert_id]} "
  282. f"vs. {loaded_weight}")
  283. param_data[expert_id] = loaded_weight
  284. # Weight scales
  285. elif "weight_scale" in weight_name:
  286. # If we are in merged column case (gate_up_proj)
  287. if shard_id in ("w1", "w3"):
  288. # We have to keep the weight scales of w1 and w3 because
  289. # we need to re-quantize w1/w3 weights after weight loading.
  290. idx = 0 if shard_id == "w1" else 1
  291. param_data[expert_id][idx] = loaded_weight
  292. # If we are in the row parallel case (down_proj)
  293. else:
  294. param_data[expert_id] = loaded_weight