qwen2_moe.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
  4. # Copyright 2024 The Qwen team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  7. #
  8. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  9. # and OPT implementations in this library. It has been modified from its
  10. # original forms to accommodate minor architectural differences compared
  11. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  12. #
  13. # Licensed under the Apache License, Version 2.0 (the "License");
  14. # you may not use this file except in compliance with the License.
  15. # You may obtain a copy of the License at
  16. #
  17. # http://www.apache.org/licenses/LICENSE-2.0
  18. #
  19. # Unless required by applicable law or agreed to in writing, software
  20. # distributed under the License is distributed on an "AS IS" BASIS,
  21. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  22. # See the License for the specific language governing permissions and
  23. # limitations under the License.
  24. """Inference-only Qwen2MoE model compatible with HuggingFace weights."""
  25. from typing import Any, Dict, Iterable, List, Optional, Tuple
  26. import torch
  27. import torch.nn.functional as F
  28. from torch import nn
  29. from transformers import PretrainedConfig
  30. from aphrodite.attention import Attention, AttentionMetadata
  31. from aphrodite.common.sequence import SamplerOutput
  32. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  33. get_tensor_model_parallel_world_size,
  34. tensor_model_parallel_all_reduce)
  35. from aphrodite.modeling.layers.activation import SiluAndMul
  36. from aphrodite.modeling.layers.fused_moe import fused_moe
  37. from aphrodite.modeling.layers.layernorm import RMSNorm
  38. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  39. QKVParallelLinear,
  40. ReplicatedLinear,
  41. RowParallelLinear)
  42. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  43. from aphrodite.modeling.layers.rotary_embedding import get_rope
  44. from aphrodite.modeling.layers.sampler import Sampler
  45. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  46. ParallelLMHead, VocabParallelEmbedding)
  47. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  48. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  49. from aphrodite.quantization.base_config import QuantizationConfig
  50. class Qwen2MoeMLP(nn.Module):
  51. def __init__(
  52. self,
  53. hidden_size: int,
  54. intermediate_size: int,
  55. hidden_act: str,
  56. quant_config: Optional[QuantizationConfig] = None,
  57. reduce_results: bool = True,
  58. ) -> None:
  59. super().__init__()
  60. self.gate_up_proj = MergedColumnParallelLinear(
  61. hidden_size, [intermediate_size] * 2,
  62. bias=False,
  63. quant_config=quant_config)
  64. self.down_proj = RowParallelLinear(intermediate_size,
  65. hidden_size,
  66. bias=False,
  67. quant_config=quant_config,
  68. reduce_results=reduce_results)
  69. if hidden_act != "silu":
  70. raise ValueError(f"Unsupported activation: {hidden_act}. "
  71. "Only silu is supported for now.")
  72. self.act_fn = SiluAndMul()
  73. def forward(self, x):
  74. gate_up, _ = self.gate_up_proj(x)
  75. x = self.act_fn(gate_up)
  76. x, _ = self.down_proj(x)
  77. return x
  78. class Qwen2MoeSparseMoeBlock(nn.Module):
  79. def __init__(
  80. self,
  81. config: PretrainedConfig,
  82. quant_config: Optional[QuantizationConfig] = None,
  83. ):
  84. super().__init__()
  85. self.config = config
  86. self.rank = get_tensor_model_parallel_rank()
  87. self.tp_size = get_tensor_model_parallel_world_size()
  88. self.n_routed_experts = config.num_experts
  89. self.top_k = config.num_experts_per_tok
  90. if self.tp_size > self.n_routed_experts:
  91. raise ValueError(
  92. f"Tensor parallel size {self.tp_size} is greater than "
  93. f"the number of experts {self.n_routed_experts}.")
  94. self.experts = nn.ModuleList([
  95. Qwen2MoeMLP(hidden_size=config.hidden_size,
  96. intermediate_size=config.moe_intermediate_size,
  97. hidden_act=config.hidden_act,
  98. quant_config=quant_config,
  99. reduce_results=False)
  100. for idx in range(self.n_routed_experts)
  101. ])
  102. self.pack_params()
  103. self.gate = ReplicatedLinear(config.hidden_size,
  104. self.n_routed_experts,
  105. bias=False,
  106. quant_config=None)
  107. if config.shared_expert_intermediate_size > 0:
  108. self.shared_expert = Qwen2MoeMLP(
  109. hidden_size=config.hidden_size,
  110. intermediate_size=config.shared_expert_intermediate_size,
  111. hidden_act=config.hidden_act,
  112. quant_config=quant_config,
  113. reduce_results=False,
  114. )
  115. else:
  116. self.shared_expert = None
  117. self.shared_expert_gate = torch.nn.Linear(config.hidden_size,
  118. 1,
  119. bias=False)
  120. def pack_params(self):
  121. w1 = []
  122. w2 = []
  123. for expert in self.experts:
  124. w1.append(expert.gate_up_proj.weight)
  125. w2.append(expert.down_proj.weight)
  126. self.w1 = torch._utils._flatten_dense_tensors(w1)
  127. w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
  128. for data, param in zip(w1s, w1):
  129. param.data = data
  130. self.w1 = self.w1.view(len(w1), *w1s[0].shape)
  131. self.w2 = torch._utils._flatten_dense_tensors(w2)
  132. w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
  133. for data, param in zip(w2s, w2):
  134. param.data = data
  135. self.w2 = self.w2.view(len(w2), *w2s[0].shape)
  136. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  137. num_tokens, hidden_dim = hidden_states.shape
  138. hidden_states = hidden_states.view(-1, hidden_dim)
  139. shared_output = None
  140. if self.shared_expert is not None:
  141. shared_output = self.shared_expert(hidden_states)
  142. if self.shared_expert_gate is not None:
  143. shared_output = F.sigmoid(
  144. self.shared_expert_gate(hidden_states)) * shared_output
  145. # router_logits: (num_tokens, n_experts)
  146. router_logits, _ = self.gate(hidden_states)
  147. final_hidden_states = fused_moe(hidden_states,
  148. self.w1,
  149. self.w2,
  150. router_logits,
  151. self.top_k,
  152. renormalize=self.config.norm_topk_prob,
  153. inplace=True)
  154. if shared_output is not None:
  155. final_hidden_states = final_hidden_states + shared_output
  156. final_hidden_states = tensor_model_parallel_all_reduce(
  157. final_hidden_states)
  158. return final_hidden_states.view(num_tokens, hidden_dim)
  159. class Qwen2MoeAttention(nn.Module):
  160. def __init__(
  161. self,
  162. hidden_size: int,
  163. num_heads: int,
  164. num_kv_heads: int,
  165. rope_theta: float = 10000,
  166. rope_scaling: Optional[Dict[str, Any]] = None,
  167. max_position_embeddings: int = 8192,
  168. quant_config: Optional[QuantizationConfig] = None,
  169. ) -> None:
  170. super().__init__()
  171. self.hidden_size = hidden_size
  172. tp_size = get_tensor_model_parallel_world_size()
  173. self.total_num_heads = num_heads
  174. assert self.total_num_heads % tp_size == 0
  175. self.num_heads = self.total_num_heads // tp_size
  176. self.total_num_kv_heads = num_kv_heads
  177. if self.total_num_kv_heads >= tp_size:
  178. # Number of KV heads is greater than TP size, so we partition
  179. # the KV heads across multiple tensor parallel GPUs.
  180. assert self.total_num_kv_heads % tp_size == 0
  181. else:
  182. # Number of KV heads is less than TP size, so we replicate
  183. # the KV heads across multiple tensor parallel GPUs.
  184. assert tp_size % self.total_num_kv_heads == 0
  185. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  186. self.head_dim = hidden_size // self.total_num_heads
  187. self.q_size = self.num_heads * self.head_dim
  188. self.kv_size = self.num_kv_heads * self.head_dim
  189. self.scaling = self.head_dim**-0.5
  190. self.rope_theta = rope_theta
  191. self.max_position_embeddings = max_position_embeddings
  192. self.qkv_proj = QKVParallelLinear(
  193. hidden_size,
  194. self.head_dim,
  195. self.total_num_heads,
  196. self.total_num_kv_heads,
  197. bias=True,
  198. quant_config=quant_config,
  199. )
  200. self.o_proj = RowParallelLinear(
  201. self.total_num_heads * self.head_dim,
  202. hidden_size,
  203. bias=False,
  204. quant_config=quant_config,
  205. )
  206. self.rotary_emb = get_rope(
  207. self.head_dim,
  208. rotary_dim=self.head_dim,
  209. max_position=max_position_embeddings,
  210. base=rope_theta,
  211. rope_scaling=rope_scaling,
  212. )
  213. self.attn = Attention(self.num_heads,
  214. self.head_dim,
  215. self.scaling,
  216. num_kv_heads=self.num_kv_heads)
  217. def forward(
  218. self,
  219. positions: torch.Tensor,
  220. hidden_states: torch.Tensor,
  221. kv_cache: torch.Tensor,
  222. attn_metadata: AttentionMetadata,
  223. ) -> torch.Tensor:
  224. qkv, _ = self.qkv_proj(hidden_states)
  225. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  226. q, k = self.rotary_emb(positions, q, k)
  227. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  228. output, _ = self.o_proj(attn_output)
  229. return output
  230. class Qwen2MoeDecoderLayer(nn.Module):
  231. def __init__(
  232. self,
  233. config: PretrainedConfig,
  234. layer_idx: int,
  235. quant_config: Optional[QuantizationConfig] = None,
  236. ) -> None:
  237. super().__init__()
  238. self.hidden_size = config.hidden_size
  239. rope_theta = getattr(config, "rope_theta", 10000)
  240. rope_scaling = getattr(config, "rope_scaling", None)
  241. max_position_embeddings = getattr(config, "max_position_embeddings",
  242. 8192)
  243. self.self_attn = Qwen2MoeAttention(
  244. hidden_size=self.hidden_size,
  245. num_heads=config.num_attention_heads,
  246. num_kv_heads=config.num_key_value_heads,
  247. rope_theta=rope_theta,
  248. rope_scaling=rope_scaling,
  249. max_position_embeddings=max_position_embeddings,
  250. quant_config=quant_config,
  251. )
  252. if (config.num_experts is not None
  253. and (layer_idx + 1) % config.decoder_sparse_step == 0):
  254. self.mlp = Qwen2MoeSparseMoeBlock(config=config,
  255. quant_config=quant_config)
  256. else:
  257. self.mlp = Qwen2MoeMLP(
  258. hidden_size=config.hidden_size,
  259. intermediate_size=config.intermediate_size,
  260. hidden_act=config.hidden_act,
  261. quant_config=quant_config,
  262. )
  263. self.input_layernorm = RMSNorm(config.hidden_size,
  264. eps=config.rms_norm_eps)
  265. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  266. eps=config.rms_norm_eps)
  267. def forward(
  268. self,
  269. positions: torch.Tensor,
  270. hidden_states: torch.Tensor,
  271. kv_cache: torch.Tensor,
  272. attn_metadata: AttentionMetadata,
  273. residual: Optional[torch.Tensor],
  274. ) -> torch.Tensor:
  275. # Self Attention
  276. if residual is None:
  277. residual = hidden_states
  278. hidden_states = self.input_layernorm(hidden_states)
  279. else:
  280. hidden_states, residual = self.input_layernorm(
  281. hidden_states, residual)
  282. hidden_states = self.self_attn(
  283. positions=positions,
  284. hidden_states=hidden_states,
  285. kv_cache=kv_cache,
  286. attn_metadata=attn_metadata,
  287. )
  288. # Fully Connected
  289. hidden_states, residual = self.post_attention_layernorm(
  290. hidden_states, residual)
  291. hidden_states = self.mlp(hidden_states)
  292. return hidden_states, residual
  293. class Qwen2MoeModel(nn.Module):
  294. def __init__(
  295. self,
  296. config: PretrainedConfig,
  297. quant_config: Optional[QuantizationConfig] = None,
  298. ) -> None:
  299. super().__init__()
  300. self.padding_idx = config.pad_token_id
  301. self.vocab_size = config.vocab_size
  302. self.embed_tokens = VocabParallelEmbedding(
  303. config.vocab_size,
  304. config.hidden_size,
  305. )
  306. self.layers = nn.ModuleList([
  307. Qwen2MoeDecoderLayer(config, layer_idx, quant_config=quant_config)
  308. for layer_idx in range(config.num_hidden_layers)
  309. ])
  310. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  311. def forward(
  312. self,
  313. input_ids: torch.Tensor,
  314. positions: torch.Tensor,
  315. kv_caches: List[torch.Tensor],
  316. attn_metadata: AttentionMetadata,
  317. ) -> torch.Tensor:
  318. hidden_states = self.embed_tokens(input_ids)
  319. residual = None
  320. for i in range(len(self.layers)):
  321. layer = self.layers[i]
  322. hidden_states, residual = layer(positions, hidden_states,
  323. kv_caches[i], attn_metadata,
  324. residual)
  325. hidden_states, _ = self.norm(hidden_states, residual)
  326. return hidden_states
  327. class Qwen2MoeForCausalLM(nn.Module):
  328. fall_back_to_pt_during_load = False
  329. def __init__(
  330. self,
  331. config: PretrainedConfig,
  332. quant_config: Optional[QuantizationConfig] = None,
  333. ) -> None:
  334. super().__init__()
  335. self.config = config
  336. self.quant_config = quant_config
  337. self.model = Qwen2MoeModel(config, quant_config)
  338. self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
  339. self.logits_processor = LogitsProcessor(config.vocab_size)
  340. self.sampler = Sampler()
  341. def forward(
  342. self,
  343. input_ids: torch.Tensor,
  344. positions: torch.Tensor,
  345. kv_caches: List[torch.Tensor],
  346. attn_metadata: AttentionMetadata,
  347. ) -> torch.Tensor:
  348. hidden_states = self.model(input_ids, positions, kv_caches,
  349. attn_metadata)
  350. return hidden_states
  351. def compute_logits(self, hidden_states: torch.Tensor,
  352. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  353. logits = self.logits_processor(self.lm_head.weight, hidden_states,
  354. sampling_metadata)
  355. return logits
  356. def sample(
  357. self,
  358. logits: Optional[torch.Tensor],
  359. sampling_metadata: SamplingMetadata,
  360. ) -> Optional[SamplerOutput]:
  361. next_tokens = self.sampler(logits, sampling_metadata)
  362. return next_tokens
  363. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  364. stacked_params_mapping = [
  365. # (param_name, shard_name, shard_id)
  366. ("qkv_proj", "q_proj", "q"),
  367. ("qkv_proj", "k_proj", "k"),
  368. ("qkv_proj", "v_proj", "v"),
  369. ("gate_up_proj", "gate_proj", 0),
  370. ("gate_up_proj", "up_proj", 1),
  371. ]
  372. params_dict = dict(self.named_parameters())
  373. for name, loaded_weight in weights:
  374. if "rotary_emb.inv_freq" in name:
  375. continue
  376. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  377. if weight_name not in name:
  378. continue
  379. name = name.replace(weight_name, param_name)
  380. # Skip loading extra bias for GPTQ models.
  381. if name.endswith(".bias") and name not in params_dict:
  382. continue
  383. # Skip experts that are not assigned to this worker.
  384. if (("mlp.experts." in name or "mlp.shared_expert." in name)
  385. and name not in params_dict):
  386. continue
  387. param = params_dict[name]
  388. weight_loader = param.weight_loader
  389. weight_loader(param, loaded_weight, shard_id)
  390. break
  391. else:
  392. # Skip loading extra bias for GPTQ models.
  393. if name.endswith(".bias") and name not in params_dict:
  394. continue
  395. # Skip experts that are not assigned to this worker.
  396. if (("mlp.experts." in name or "mlp.shared_expert." in name)
  397. and name not in params_dict):
  398. continue
  399. param = params_dict[name]
  400. weight_loader = getattr(param, "weight_loader",
  401. default_weight_loader)
  402. weight_loader(param, loaded_weight)