1
0

qwen2_moe.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
  4. # Copyright 2024 The Qwen team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  7. #
  8. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  9. # and OPT implementations in this library. It has been modified from its
  10. # original forms to accommodate minor architectural differences compared
  11. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  12. #
  13. # Licensed under the Apache License, Version 2.0 (the "License");
  14. # you may not use this file except in compliance with the License.
  15. # You may obtain a copy of the License at
  16. #
  17. # http://www.apache.org/licenses/LICENSE-2.0
  18. #
  19. # Unless required by applicable law or agreed to in writing, software
  20. # distributed under the License is distributed on an "AS IS" BASIS,
  21. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  22. # See the License for the specific language governing permissions and
  23. # limitations under the License.
  24. """Inference-only Qwen2MoE model compatible with HuggingFace weights."""
  25. from typing import Any, Dict, Iterable, List, Optional, Tuple
  26. import torch
  27. import torch.nn.functional as F
  28. from torch import nn
  29. from transformers import PretrainedConfig
  30. from aphrodite.attention import Attention, AttentionMetadata
  31. from aphrodite.common.config import CacheConfig
  32. from aphrodite.common.sequence import SamplerOutput
  33. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  34. get_tensor_model_parallel_world_size,
  35. tensor_model_parallel_all_reduce)
  36. from aphrodite.modeling.layers.activation import SiluAndMul
  37. from aphrodite.modeling.layers.fused_moe import fused_moe
  38. from aphrodite.modeling.layers.layernorm import RMSNorm
  39. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  40. QKVParallelLinear,
  41. ReplicatedLinear,
  42. RowParallelLinear)
  43. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  44. from aphrodite.modeling.layers.rotary_embedding import get_rope
  45. from aphrodite.modeling.layers.sampler import Sampler
  46. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  47. ParallelLMHead, VocabParallelEmbedding)
  48. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  49. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  50. from aphrodite.quantization.base_config import QuantizationConfig
  51. class Qwen2MoeMLP(nn.Module):
  52. def __init__(
  53. self,
  54. hidden_size: int,
  55. intermediate_size: int,
  56. hidden_act: str,
  57. quant_config: Optional[QuantizationConfig] = None,
  58. reduce_results: bool = True,
  59. ) -> None:
  60. super().__init__()
  61. self.gate_up_proj = MergedColumnParallelLinear(
  62. hidden_size, [intermediate_size] * 2,
  63. bias=False,
  64. quant_config=quant_config)
  65. self.down_proj = RowParallelLinear(intermediate_size,
  66. hidden_size,
  67. bias=False,
  68. quant_config=quant_config,
  69. reduce_results=reduce_results)
  70. if hidden_act != "silu":
  71. raise ValueError(f"Unsupported activation: {hidden_act}. "
  72. "Only silu is supported for now.")
  73. self.act_fn = SiluAndMul()
  74. def forward(self, x):
  75. gate_up, _ = self.gate_up_proj(x)
  76. x = self.act_fn(gate_up)
  77. x, _ = self.down_proj(x)
  78. return x
  79. class Qwen2MoeSparseMoeBlock(nn.Module):
  80. def __init__(
  81. self,
  82. config: PretrainedConfig,
  83. quant_config: Optional[QuantizationConfig] = None,
  84. ):
  85. super().__init__()
  86. self.config = config
  87. self.rank = get_tensor_model_parallel_rank()
  88. self.tp_size = get_tensor_model_parallel_world_size()
  89. self.n_routed_experts = config.num_experts
  90. self.top_k = config.num_experts_per_tok
  91. if self.tp_size > self.n_routed_experts:
  92. raise ValueError(
  93. f"Tensor parallel size {self.tp_size} is greater than "
  94. f"the number of experts {self.n_routed_experts}.")
  95. self.experts = nn.ModuleList([
  96. Qwen2MoeMLP(hidden_size=config.hidden_size,
  97. intermediate_size=config.moe_intermediate_size,
  98. hidden_act=config.hidden_act,
  99. quant_config=quant_config,
  100. reduce_results=False)
  101. for idx in range(self.n_routed_experts)
  102. ])
  103. self.pack_params()
  104. self.gate = ReplicatedLinear(config.hidden_size,
  105. self.n_routed_experts,
  106. bias=False,
  107. quant_config=None)
  108. if config.shared_expert_intermediate_size > 0:
  109. self.shared_expert = Qwen2MoeMLP(
  110. hidden_size=config.hidden_size,
  111. intermediate_size=config.shared_expert_intermediate_size,
  112. hidden_act=config.hidden_act,
  113. quant_config=quant_config,
  114. reduce_results=False,
  115. )
  116. else:
  117. self.shared_expert = None
  118. self.shared_expert_gate = torch.nn.Linear(config.hidden_size,
  119. 1,
  120. bias=False)
  121. def pack_params(self):
  122. w1 = []
  123. w2 = []
  124. for expert in self.experts:
  125. w1.append(expert.gate_up_proj.weight)
  126. w2.append(expert.down_proj.weight)
  127. self.w1 = torch._utils._flatten_dense_tensors(w1)
  128. w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
  129. for data, param in zip(w1s, w1):
  130. param.data = data
  131. self.w1 = self.w1.view(len(w1), *w1s[0].shape)
  132. self.w2 = torch._utils._flatten_dense_tensors(w2)
  133. w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
  134. for data, param in zip(w2s, w2):
  135. param.data = data
  136. self.w2 = self.w2.view(len(w2), *w2s[0].shape)
  137. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  138. num_tokens, hidden_dim = hidden_states.shape
  139. hidden_states = hidden_states.view(-1, hidden_dim)
  140. shared_output = None
  141. if self.shared_expert is not None:
  142. shared_output = self.shared_expert(hidden_states)
  143. if self.shared_expert_gate is not None:
  144. shared_output = F.sigmoid(
  145. self.shared_expert_gate(hidden_states)) * shared_output
  146. # router_logits: (num_tokens, n_experts)
  147. router_logits, _ = self.gate(hidden_states)
  148. final_hidden_states = fused_moe(hidden_states,
  149. self.w1,
  150. self.w2,
  151. router_logits,
  152. self.top_k,
  153. renormalize=self.config.norm_topk_prob,
  154. inplace=True)
  155. if shared_output is not None:
  156. final_hidden_states = final_hidden_states + shared_output
  157. final_hidden_states = tensor_model_parallel_all_reduce(
  158. final_hidden_states)
  159. return final_hidden_states.view(num_tokens, hidden_dim)
  160. class Qwen2MoeAttention(nn.Module):
  161. def __init__(
  162. self,
  163. hidden_size: int,
  164. num_heads: int,
  165. num_kv_heads: int,
  166. rope_theta: float = 10000,
  167. rope_scaling: Optional[Dict[str, Any]] = None,
  168. max_position_embeddings: int = 8192,
  169. cache_config: Optional[CacheConfig] = None,
  170. quant_config: Optional[QuantizationConfig] = None,
  171. ) -> None:
  172. super().__init__()
  173. self.hidden_size = hidden_size
  174. tp_size = get_tensor_model_parallel_world_size()
  175. self.total_num_heads = num_heads
  176. assert self.total_num_heads % tp_size == 0
  177. self.num_heads = self.total_num_heads // tp_size
  178. self.total_num_kv_heads = num_kv_heads
  179. if self.total_num_kv_heads >= tp_size:
  180. # Number of KV heads is greater than TP size, so we partition
  181. # the KV heads across multiple tensor parallel GPUs.
  182. assert self.total_num_kv_heads % tp_size == 0
  183. else:
  184. # Number of KV heads is less than TP size, so we replicate
  185. # the KV heads across multiple tensor parallel GPUs.
  186. assert tp_size % self.total_num_kv_heads == 0
  187. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  188. self.head_dim = hidden_size // self.total_num_heads
  189. self.q_size = self.num_heads * self.head_dim
  190. self.kv_size = self.num_kv_heads * self.head_dim
  191. self.scaling = self.head_dim**-0.5
  192. self.rope_theta = rope_theta
  193. self.max_position_embeddings = max_position_embeddings
  194. self.qkv_proj = QKVParallelLinear(
  195. hidden_size,
  196. self.head_dim,
  197. self.total_num_heads,
  198. self.total_num_kv_heads,
  199. bias=True,
  200. quant_config=quant_config,
  201. )
  202. self.o_proj = RowParallelLinear(
  203. self.total_num_heads * self.head_dim,
  204. hidden_size,
  205. bias=False,
  206. quant_config=quant_config,
  207. )
  208. self.rotary_emb = get_rope(
  209. self.head_dim,
  210. rotary_dim=self.head_dim,
  211. max_position=max_position_embeddings,
  212. base=rope_theta,
  213. rope_scaling=rope_scaling,
  214. )
  215. self.attn = Attention(self.num_heads,
  216. self.head_dim,
  217. self.scaling,
  218. num_kv_heads=self.num_kv_heads,
  219. cache_config=cache_config,
  220. quant_config=quant_config)
  221. def forward(
  222. self,
  223. positions: torch.Tensor,
  224. hidden_states: torch.Tensor,
  225. kv_cache: torch.Tensor,
  226. attn_metadata: AttentionMetadata,
  227. ) -> torch.Tensor:
  228. qkv, _ = self.qkv_proj(hidden_states)
  229. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  230. q, k = self.rotary_emb(positions, q, k)
  231. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  232. output, _ = self.o_proj(attn_output)
  233. return output
  234. class Qwen2MoeDecoderLayer(nn.Module):
  235. def __init__(
  236. self,
  237. config: PretrainedConfig,
  238. layer_idx: int,
  239. cache_config: Optional[CacheConfig] = None,
  240. quant_config: Optional[QuantizationConfig] = None,
  241. ) -> None:
  242. super().__init__()
  243. self.hidden_size = config.hidden_size
  244. rope_theta = getattr(config, "rope_theta", 10000)
  245. rope_scaling = getattr(config, "rope_scaling", None)
  246. max_position_embeddings = getattr(config, "max_position_embeddings",
  247. 8192)
  248. self.self_attn = Qwen2MoeAttention(
  249. hidden_size=self.hidden_size,
  250. num_heads=config.num_attention_heads,
  251. num_kv_heads=config.num_key_value_heads,
  252. rope_theta=rope_theta,
  253. rope_scaling=rope_scaling,
  254. max_position_embeddings=max_position_embeddings,
  255. cache_config=cache_config,
  256. quant_config=quant_config,
  257. )
  258. if (layer_idx not in config.mlp_only_layers) and (
  259. config.num_experts > 0 and
  260. (layer_idx + 1) % config.decoder_sparse_step == 0):
  261. self.mlp = Qwen2MoeSparseMoeBlock(config=config,
  262. quant_config=quant_config)
  263. else:
  264. self.mlp = Qwen2MoeMLP(
  265. hidden_size=config.hidden_size,
  266. intermediate_size=config.intermediate_size,
  267. hidden_act=config.hidden_act,
  268. quant_config=quant_config,
  269. )
  270. self.input_layernorm = RMSNorm(config.hidden_size,
  271. eps=config.rms_norm_eps)
  272. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  273. eps=config.rms_norm_eps)
  274. def forward(
  275. self,
  276. positions: torch.Tensor,
  277. hidden_states: torch.Tensor,
  278. kv_cache: torch.Tensor,
  279. attn_metadata: AttentionMetadata,
  280. residual: Optional[torch.Tensor],
  281. ) -> torch.Tensor:
  282. # Self Attention
  283. if residual is None:
  284. residual = hidden_states
  285. hidden_states = self.input_layernorm(hidden_states)
  286. else:
  287. hidden_states, residual = self.input_layernorm(
  288. hidden_states, residual)
  289. hidden_states = self.self_attn(
  290. positions=positions,
  291. hidden_states=hidden_states,
  292. kv_cache=kv_cache,
  293. attn_metadata=attn_metadata,
  294. )
  295. # Fully Connected
  296. hidden_states, residual = self.post_attention_layernorm(
  297. hidden_states, residual)
  298. hidden_states = self.mlp(hidden_states)
  299. return hidden_states, residual
  300. class Qwen2MoeModel(nn.Module):
  301. def __init__(
  302. self,
  303. config: PretrainedConfig,
  304. cache_config: Optional[CacheConfig] = None,
  305. quant_config: Optional[QuantizationConfig] = None,
  306. ) -> None:
  307. super().__init__()
  308. self.padding_idx = config.pad_token_id
  309. self.vocab_size = config.vocab_size
  310. self.embed_tokens = VocabParallelEmbedding(
  311. config.vocab_size,
  312. config.hidden_size,
  313. )
  314. self.layers = nn.ModuleList([
  315. Qwen2MoeDecoderLayer(config,
  316. layer_idx,
  317. cache_config,
  318. quant_config=quant_config)
  319. for layer_idx in range(config.num_hidden_layers)
  320. ])
  321. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  322. def forward(
  323. self,
  324. input_ids: torch.Tensor,
  325. positions: torch.Tensor,
  326. kv_caches: List[torch.Tensor],
  327. attn_metadata: AttentionMetadata,
  328. ) -> torch.Tensor:
  329. hidden_states = self.embed_tokens(input_ids)
  330. residual = None
  331. for i in range(len(self.layers)):
  332. layer = self.layers[i]
  333. hidden_states, residual = layer(positions, hidden_states,
  334. kv_caches[i], attn_metadata,
  335. residual)
  336. hidden_states, _ = self.norm(hidden_states, residual)
  337. return hidden_states
  338. class Qwen2MoeForCausalLM(nn.Module):
  339. fall_back_to_pt_during_load = False
  340. def __init__(
  341. self,
  342. config: PretrainedConfig,
  343. cache_config: Optional[CacheConfig] = None,
  344. quant_config: Optional[QuantizationConfig] = None,
  345. ) -> None:
  346. super().__init__()
  347. self.config = config
  348. self.quant_config = quant_config
  349. self.model = Qwen2MoeModel(config, cache_config, quant_config)
  350. self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
  351. self.logits_processor = LogitsProcessor(config.vocab_size)
  352. self.sampler = Sampler()
  353. def forward(
  354. self,
  355. input_ids: torch.Tensor,
  356. positions: torch.Tensor,
  357. kv_caches: List[torch.Tensor],
  358. attn_metadata: AttentionMetadata,
  359. ) -> torch.Tensor:
  360. hidden_states = self.model(input_ids, positions, kv_caches,
  361. attn_metadata)
  362. return hidden_states
  363. def compute_logits(self, hidden_states: torch.Tensor,
  364. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  365. logits = self.logits_processor(self.lm_head.weight, hidden_states,
  366. sampling_metadata)
  367. return logits
  368. def sample(
  369. self,
  370. logits: Optional[torch.Tensor],
  371. sampling_metadata: SamplingMetadata,
  372. ) -> Optional[SamplerOutput]:
  373. next_tokens = self.sampler(logits, sampling_metadata)
  374. return next_tokens
  375. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  376. stacked_params_mapping = [
  377. # (param_name, shard_name, shard_id)
  378. ("qkv_proj", "q_proj", "q"),
  379. ("qkv_proj", "k_proj", "k"),
  380. ("qkv_proj", "v_proj", "v"),
  381. ("gate_up_proj", "gate_proj", 0),
  382. ("gate_up_proj", "up_proj", 1),
  383. ]
  384. params_dict = dict(self.named_parameters())
  385. for name, loaded_weight in weights:
  386. if "rotary_emb.inv_freq" in name:
  387. continue
  388. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  389. if weight_name not in name:
  390. continue
  391. name = name.replace(weight_name, param_name)
  392. # Skip loading extra bias for GPTQ models.
  393. if name.endswith(".bias") and name not in params_dict:
  394. continue
  395. # Skip experts that are not assigned to this worker.
  396. if (("mlp.experts." in name or "mlp.shared_expert." in name)
  397. and name not in params_dict):
  398. continue
  399. if name not in params_dict:
  400. continue
  401. param = params_dict[name]
  402. weight_loader = param.weight_loader
  403. weight_loader(param, loaded_weight, shard_id)
  404. break
  405. else:
  406. # Skip loading extra bias for GPTQ models.
  407. if name.endswith(".bias") and name not in params_dict:
  408. continue
  409. # Skip experts that are not assigned to this worker.
  410. if (("mlp.experts." in name or "mlp.shared_expert." in name)
  411. and name not in params_dict):
  412. continue
  413. if name not in params_dict:
  414. continue
  415. param = params_dict[name]
  416. weight_loader = getattr(param, "weight_loader",
  417. default_weight_loader)
  418. weight_loader(param, loaded_weight)