layer.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497
  1. from abc import abstractmethod
  2. from enum import Enum
  3. from typing import Callable, List, Optional, Tuple
  4. import torch
  5. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  6. get_tensor_model_parallel_world_size,
  7. tensor_model_parallel_all_reduce)
  8. from aphrodite.modeling._custom_op import CustomOp
  9. from aphrodite.modeling.utils import set_weight_attrs
  10. from aphrodite.quantization.base_config import (QuantizationConfig,
  11. QuantizeMethodBase)
  12. class FusedMoeWeightScaleSupported(Enum):
  13. TENSOR = "tensor"
  14. CHANNEL = "channel"
  15. GROUP = "group"
  16. class FusedMoEMethodBase(QuantizeMethodBase):
  17. @abstractmethod
  18. def create_weights(self, layer: torch.nn.Module, num_experts: int,
  19. hidden_size: int, intermediate_size: int,
  20. params_dtype: torch.dtype, **extra_weight_attrs):
  21. raise NotImplementedError
  22. @abstractmethod
  23. def apply(self, layer: torch.nn.Module, x: torch.Tensor,
  24. router_logits: torch.Tensor, top_k: int, renormalize: bool,
  25. use_grouped_topk: bool) -> torch.Tensor:
  26. raise NotImplementedError
  27. class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
  28. """MoE method without quantization."""
  29. def create_weights(self, layer: torch.nn.Module, num_experts: int,
  30. hidden_size: int, intermediate_size: int,
  31. params_dtype: torch.dtype, **extra_weight_attrs):
  32. # Fused gate_up_proj (column parallel)
  33. w13_weight = torch.nn.Parameter(torch.empty(num_experts,
  34. 2 * intermediate_size,
  35. hidden_size,
  36. dtype=params_dtype),
  37. requires_grad=False)
  38. layer.register_parameter("w13_weight", w13_weight)
  39. set_weight_attrs(w13_weight, extra_weight_attrs)
  40. # down_proj (row parallel)
  41. w2_weight = torch.nn.Parameter(torch.empty(num_experts,
  42. hidden_size,
  43. intermediate_size,
  44. dtype=params_dtype),
  45. requires_grad=False)
  46. layer.register_parameter("w2_weight", w2_weight)
  47. set_weight_attrs(w2_weight, extra_weight_attrs)
  48. def apply(
  49. self,
  50. layer: torch.nn.Module,
  51. x: torch.Tensor,
  52. router_logits: torch.Tensor,
  53. top_k: int,
  54. renormalize: bool,
  55. use_grouped_topk: bool,
  56. topk_group: Optional[int] = None,
  57. num_expert_group: Optional[int] = None,
  58. custom_routing_function: Optional[Callable] = None
  59. ) -> torch.Tensor:
  60. return self.forward(x=x,
  61. layer=layer,
  62. router_logits=router_logits,
  63. top_k=top_k,
  64. renormalize=renormalize,
  65. use_grouped_topk=use_grouped_topk,
  66. topk_group=topk_group,
  67. num_expert_group=num_expert_group,
  68. custom_routing_function=custom_routing_function)
  69. def forward_cuda(
  70. self,
  71. layer: torch.nn.Module,
  72. x: torch.Tensor,
  73. use_grouped_topk: bool,
  74. top_k: int,
  75. router_logits: torch.Tensor,
  76. renormalize: bool,
  77. topk_group: Optional[int] = None,
  78. num_expert_group: Optional[int] = None,
  79. custom_routing_function: Optional[Callable] = None
  80. ) -> torch.Tensor:
  81. from aphrodite.modeling.layers.fused_moe.fused_moe import fused_experts
  82. topk_weights, topk_ids = FusedMoE.select_experts(
  83. hidden_states=x,
  84. router_logits=router_logits,
  85. use_grouped_topk=use_grouped_topk,
  86. top_k=top_k,
  87. renormalize=renormalize,
  88. topk_group=topk_group,
  89. num_expert_group=num_expert_group,
  90. custom_routing_function=custom_routing_function)
  91. return fused_experts(hidden_states=x,
  92. w1=layer.w13_weight,
  93. w2=layer.w2_weight,
  94. topk_weights=topk_weights,
  95. topk_ids=topk_ids,
  96. inplace=True)
  97. def forward_cpu(self, *args, **kwargs):
  98. raise NotImplementedError(
  99. "The CPU backend currently does not support MoE.")
  100. def forward_tpu(
  101. self,
  102. layer: torch.nn.Module,
  103. x: torch.Tensor,
  104. use_grouped_topk: bool,
  105. top_k: int,
  106. router_logits: torch.Tensor,
  107. renormalize: bool,
  108. topk_group: Optional[int] = None,
  109. num_expert_group: Optional[int] = None,
  110. custom_routing_function: Optional[Callable] = None
  111. ) -> torch.Tensor:
  112. from aphrodite.modeling.layers.fused_moe.moe_pallas import fused_moe
  113. assert not use_grouped_topk
  114. assert num_expert_group is None
  115. assert topk_group is None
  116. assert custom_routing_function is None
  117. return fused_moe(hidden_states=x,
  118. w1=layer.w13_weight,
  119. w2=layer.w2_weight,
  120. topk=top_k,
  121. gating_output=router_logits,
  122. renormalize=renormalize)
  123. class FusedMoE(torch.nn.Module):
  124. """FusedMoE layer for MoE models.
  125. This layer contains both MergedColumnParallel weights (gate_up_proj /
  126. w13) and RowParallelLinear weights (down_proj/ w2).
  127. Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
  128. copy that naming convention here and handle any remapping in the
  129. load_weights function in each model implementation.
  130. Args:
  131. num_experts: Number of experts in the model
  132. top_k: Number of experts selected for each token
  133. hidden_size: Input hidden state size of the transformer
  134. intermediate_size: Intermediate size of the experts
  135. params_dtype: Data type for the parameters.
  136. reduce_results: Whether to all all_reduce on the output of the layer
  137. renomalize: Whether to renormalize the logits in the fused_moe kernel
  138. quant_config: Quantization configure.
  139. """
  140. def __init__(
  141. self,
  142. num_experts: int,
  143. top_k: int,
  144. hidden_size: int,
  145. intermediate_size: int,
  146. params_dtype: Optional[torch.dtype] = None,
  147. reduce_results: bool = False,
  148. renormalize: bool = True,
  149. use_grouped_topk: bool = False,
  150. num_expert_group: Optional[int] = None,
  151. topk_group: Optional[int] = None,
  152. quant_config: Optional[QuantizationConfig] = None,
  153. tp_size: Optional[int] = None,
  154. prefix: str = "",
  155. custom_routing_function: Optional[Callable] = None,
  156. ):
  157. super().__init__()
  158. if params_dtype is None:
  159. params_dtype = torch.get_default_dtype()
  160. self.tp_size = (tp_size if tp_size is not None else
  161. get_tensor_model_parallel_world_size())
  162. self.top_k = top_k
  163. self.num_experts = num_experts
  164. self.intermediate_size_per_partition = intermediate_size // self.tp_size
  165. self.reduce_results = reduce_results
  166. self.renormalize = renormalize
  167. self.use_grouped_topk = use_grouped_topk
  168. if self.use_grouped_topk:
  169. assert num_expert_group is not None and topk_group is not None
  170. self.num_expert_group = num_expert_group
  171. self.topk_group = topk_group
  172. self.custom_routing_function = custom_routing_function
  173. if quant_config is None:
  174. self.quant_method: Optional[QuantizeMethodBase] = (
  175. UnquantizedFusedMoEMethod())
  176. else:
  177. self.quant_method = quant_config.get_quant_method(self, prefix)
  178. assert self.quant_method is not None
  179. self.quant_method.create_weights(
  180. layer=self,
  181. num_experts=num_experts,
  182. hidden_size=hidden_size,
  183. intermediate_size=self.intermediate_size_per_partition,
  184. params_dtype=params_dtype,
  185. weight_loader=self.weight_loader)
  186. def _load_per_tensor_weight_scale(self, shard_id: str,
  187. param: torch.nn.Parameter,
  188. loaded_weight: torch.Tensor,
  189. expert_id: int):
  190. param_data = param.data
  191. # for per tensor weight quantization
  192. if shard_id in ("w1", "w3"):
  193. # We have to keep the weight scales of w1 and w3 because
  194. # we need to re-quantize w1/w3 weights after weight loading.
  195. idx = 0 if shard_id == "w1" else 1
  196. param_data[expert_id][idx] = loaded_weight
  197. # If we are in the row parallel case (down_proj)
  198. elif shard_id == "w2":
  199. param_data[expert_id] = loaded_weight
  200. def _load_model_weight_or_group_weight_scale(self, shard_dim: int,
  201. expert_data: torch.Tensor,
  202. shard_id: str,
  203. loaded_weight: torch.tensor,
  204. tp_rank: int):
  205. # Load grouped weight scales for group quantization
  206. # or model weights
  207. if shard_id == "w2":
  208. self._load_w2(shard_id=shard_id,
  209. shard_dim=shard_dim,
  210. loaded_weight=loaded_weight,
  211. expert_data=expert_data,
  212. tp_rank=tp_rank)
  213. elif shard_id in ("w1", "w3"):
  214. self._load_w13(shard_id=shard_id,
  215. shard_dim=shard_dim,
  216. loaded_weight=loaded_weight,
  217. expert_data=expert_data,
  218. tp_rank=tp_rank)
  219. def _load_per_channel_weight_scale(self, expert_data: torch.Tensor,
  220. shard_dim: int, shard_id: str,
  221. loaded_weight: torch.tensor,
  222. tp_rank: int):
  223. # for per channel weight quantization
  224. if shard_id == "w2":
  225. expert_data.copy_(loaded_weight)
  226. elif shard_id in ("w1", "w3"):
  227. self._load_w13(shard_id=shard_id,
  228. shard_dim=shard_dim,
  229. loaded_weight=loaded_weight,
  230. expert_data=expert_data,
  231. tp_rank=tp_rank)
  232. def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
  233. shard_id: str, loaded_weight: torch.tensor, tp_rank: int):
  234. # Index the loaded weight for tp sharding.
  235. # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
  236. shard_size = expert_data.shape[shard_dim] // 2
  237. loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
  238. shard_size)
  239. # Narrow parameter and load.
  240. # w1, gate_proj: Load into first logical weight of w13.
  241. if shard_id == "w1":
  242. expert_data = expert_data.narrow(shard_dim, 0, shard_size)
  243. # w3, up_proj: Load into second logical weight of w13.
  244. else:
  245. assert shard_id == "w3"
  246. expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
  247. expert_data.copy_(loaded_weight)
  248. def _load_w2(self, expert_data: torch.Tensor, shard_dim: int,
  249. shard_id: str, loaded_weight: torch.tensor, tp_rank: int):
  250. # Index the loaded weight for tp sharding.
  251. # down_proj: "RowParallel" so tp sharding on input_dim
  252. # Narrow parameter and load.
  253. shard_size = expert_data.shape[shard_dim]
  254. loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
  255. shard_size)
  256. # w2, down_proj: Load into only logical weight of w2.
  257. expert_data.copy_(loaded_weight)
  258. def _load_single_value(self, param: torch.nn.Parameter,
  259. loaded_weight: torch.Tensor, expert_id: int):
  260. param_data = param.data
  261. # Input scales can be loaded directly and should be equal.
  262. param_data[expert_id] = loaded_weight
  263. def weight_loader(self, param: torch.nn.Parameter,
  264. loaded_weight: torch.Tensor, weight_name: str,
  265. shard_id: str, expert_id: int) -> None:
  266. if shard_id not in ("w1", "w2", "w3"):
  267. raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
  268. f"got {shard_id}.")
  269. WEIGHT_SCALE_SUPPORTED = [
  270. e.value for e in FusedMoeWeightScaleSupported
  271. ]
  272. # Fetch the dim to shard the parameter/loaded weight
  273. # based on the shard id. This will be whatever
  274. # dimension intermediate_size is used.
  275. SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
  276. expert_data = param.data[expert_id]
  277. tp_rank = get_tensor_model_parallel_rank()
  278. # is_transposed: whether or not the parameter is transposed on disk
  279. # If transposed, the loaded weight will be transposed and the dim
  280. # to shard the loaded weight will be flipped.
  281. is_transposed = getattr(param, "is_transposed", False)
  282. shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
  283. if is_transposed:
  284. loaded_weight = loaded_weight.t().contiguous()
  285. shard_dim = ~shard_dim
  286. # Case weight_scales
  287. if "weight_scale" in weight_name:
  288. # load the weight scaling based on the quantization scheme
  289. # supported weight scales can be found in
  290. # FusedMoeWeightScaleSupported
  291. # TODO @dsikka: once hardened, refactor to use vLLM Parameters
  292. # specific to each case
  293. quant_method = getattr(param, "quant_method", None)
  294. if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
  295. self._load_per_channel_weight_scale(
  296. shard_id=shard_id,
  297. shard_dim=shard_dim,
  298. loaded_weight=loaded_weight,
  299. expert_data=expert_data,
  300. tp_rank=tp_rank)
  301. elif quant_method == FusedMoeWeightScaleSupported.GROUP.value:
  302. self._load_model_weight_or_group_weight_scale(
  303. shard_id=shard_id,
  304. shard_dim=shard_dim,
  305. loaded_weight=loaded_weight,
  306. expert_data=expert_data,
  307. tp_rank=tp_rank)
  308. elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
  309. self._load_per_tensor_weight_scale(shard_id=shard_id,
  310. param=param,
  311. loaded_weight=loaded_weight,
  312. expert_id=expert_id)
  313. else:
  314. raise ValueError(
  315. f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}")
  316. return
  317. if "weight_shape" in weight_name:
  318. self._load_single_value(param=param,
  319. loaded_weight=loaded_weight,
  320. expert_id=expert_id)
  321. return
  322. # Case input scale
  323. if "input_scale" in weight_name:
  324. # Note: input_scale loading is only supported for fp8
  325. if param.data[expert_id] != 1 and (param.data[expert_id] -
  326. loaded_weight).abs() > 1e-5:
  327. raise ValueError(
  328. "input_scales of w1 and w3 of a layer "
  329. f"must be equal. But got {param.data[expert_id]} "
  330. f"vs. {loaded_weight}")
  331. self._load_single_value(param=param,
  332. loaded_weight=loaded_weight,
  333. expert_id=expert_id)
  334. return
  335. # Case model weights
  336. if "weight" in weight_name:
  337. self._load_model_weight_or_group_weight_scale(
  338. shard_id=shard_id,
  339. shard_dim=shard_dim,
  340. loaded_weight=loaded_weight,
  341. expert_data=expert_data,
  342. tp_rank=tp_rank)
  343. return
  344. @staticmethod
  345. def select_experts(hidden_states: torch.Tensor,
  346. router_logits: torch.Tensor,
  347. top_k: int,
  348. use_grouped_topk: bool,
  349. renormalize: bool,
  350. topk_group: Optional[int] = None,
  351. num_expert_group: Optional[int] = None,
  352. custom_routing_function: Optional[Callable] = None):
  353. from aphrodite.modeling.layers.fused_moe.fused_moe import (
  354. fused_topk, grouped_topk)
  355. # DeekSeekv2 uses grouped_top_k
  356. if use_grouped_topk:
  357. assert topk_group is not None
  358. assert num_expert_group is not None
  359. topk_weights, topk_ids = grouped_topk(
  360. hidden_states=hidden_states,
  361. gating_output=router_logits,
  362. topk=top_k,
  363. renormalize=renormalize,
  364. num_expert_group=num_expert_group,
  365. topk_group=topk_group)
  366. elif custom_routing_function is None:
  367. topk_weights, topk_ids = fused_topk(hidden_states=hidden_states,
  368. gating_output=router_logits,
  369. topk=top_k,
  370. renormalize=renormalize)
  371. else:
  372. topk_weights, topk_ids = custom_routing_function(
  373. hidden_states=hidden_states,
  374. gating_output=router_logits,
  375. topk=top_k,
  376. renormalize=renormalize)
  377. return topk_weights, topk_ids
  378. def forward(self, hidden_states: torch.Tensor,
  379. router_logits: torch.Tensor):
  380. assert self.quant_method is not None
  381. # Matrix multiply.
  382. final_hidden_states = self.quant_method.apply(
  383. layer=self,
  384. x=hidden_states,
  385. router_logits=router_logits,
  386. top_k=self.top_k,
  387. renormalize=self.renormalize,
  388. use_grouped_topk=self.use_grouped_topk,
  389. topk_group=self.topk_group,
  390. num_expert_group=self.num_expert_group,
  391. custom_routing_function=self.custom_routing_function)
  392. if self.reduce_results and self.tp_size > 1:
  393. final_hidden_states = tensor_model_parallel_all_reduce(
  394. final_hidden_states)
  395. return final_hidden_states
  396. @classmethod
  397. def make_expert_params_mapping(
  398. cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str,
  399. ckpt_up_proj_name: str,
  400. num_experts: int) -> List[Tuple[str, str, int, str]]:
  401. return [
  402. # (param_name, weight_name, expert_id, shard_id)
  403. ("experts.w13_" if weight_name
  404. in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_",
  405. f"experts.{expert_id}.{weight_name}.", expert_id, shard_id)
  406. for expert_id in range(num_experts) for shard_id, weight_name in [
  407. ("w1", ckpt_gate_proj_name),
  408. ("w2", ckpt_down_proj_name),
  409. ("w3", ckpt_up_proj_name),
  410. ]
  411. ]
  412. def _load_fp8_scale(self, param: torch.nn.Parameter,
  413. loaded_weight: torch.Tensor, weight_name: str,
  414. shard_id: str, expert_id: int) -> None:
  415. param_data = param.data
  416. # Input scales can be loaded directly and should be equal.
  417. if "input_scale" in weight_name:
  418. if param_data[expert_id] != 1 and (param_data[expert_id] -
  419. loaded_weight).abs() > 1e-5:
  420. raise ValueError(
  421. "input_scales of w1 and w3 of a layer "
  422. f"must be equal. But got {param_data[expert_id]} "
  423. f"vs. {loaded_weight}")
  424. param_data[expert_id] = loaded_weight
  425. # Weight scales
  426. elif "weight_scale" in weight_name:
  427. # If we are in merged column case (gate_up_proj)
  428. if shard_id in ("w1", "w3"):
  429. # We have to keep the weight scales of w1 and w3 because
  430. # we need to re-quantize w1/w3 weights after weight loading.
  431. idx = 0 if shard_id == "w1" else 1
  432. param_data[expert_id][idx] = loaded_weight
  433. # If we are in the row parallel case (down_proj)
  434. else:
  435. param_data[expert_id] = loaded_weight