qwen2_moe.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
  4. # Copyright 2024 The Qwen team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  7. #
  8. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  9. # and OPT implementations in this library. It has been modified from its
  10. # original forms to accommodate minor architectural differences compared
  11. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  12. #
  13. # Licensed under the Apache License, Version 2.0 (the "License");
  14. # you may not use this file except in compliance with the License.
  15. # You may obtain a copy of the License at
  16. #
  17. # http://www.apache.org/licenses/LICENSE-2.0
  18. #
  19. # Unless required by applicable law or agreed to in writing, software
  20. # distributed under the License is distributed on an "AS IS" BASIS,
  21. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  22. # See the License for the specific language governing permissions and
  23. # limitations under the License.
  24. """Inference-only Qwen2MoE model compatible with HuggingFace weights."""
  25. from typing import Any, Dict, Iterable, List, Optional, Tuple
  26. import torch
  27. import torch.nn.functional as F
  28. from torch import nn
  29. from transformers import PretrainedConfig
  30. from aphrodite.attention import Attention, AttentionMetadata
  31. from aphrodite.common.config import CacheConfig
  32. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  33. from aphrodite.distributed import (get_tensor_model_parallel_world_size,
  34. tensor_model_parallel_all_reduce)
  35. from aphrodite.modeling.layers.activation import SiluAndMul
  36. from aphrodite.modeling.layers.fused_moe import FusedMoE
  37. from aphrodite.modeling.layers.layernorm import RMSNorm
  38. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  39. QKVParallelLinear,
  40. ReplicatedLinear,
  41. RowParallelLinear)
  42. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  43. from aphrodite.modeling.layers.rotary_embedding import get_rope
  44. from aphrodite.modeling.layers.sampler import Sampler
  45. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  46. ParallelLMHead, VocabParallelEmbedding)
  47. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  48. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  49. from aphrodite.quantization.base_config import QuantizationConfig
  50. class Qwen2MoeMLP(nn.Module):
  51. def __init__(
  52. self,
  53. hidden_size: int,
  54. intermediate_size: int,
  55. hidden_act: str,
  56. quant_config: Optional[QuantizationConfig] = None,
  57. reduce_results: bool = True,
  58. ) -> None:
  59. super().__init__()
  60. self.gate_up_proj = MergedColumnParallelLinear(
  61. hidden_size, [intermediate_size] * 2,
  62. bias=False,
  63. quant_config=quant_config)
  64. self.down_proj = RowParallelLinear(intermediate_size,
  65. hidden_size,
  66. bias=False,
  67. quant_config=quant_config,
  68. reduce_results=reduce_results)
  69. if hidden_act != "silu":
  70. raise ValueError(f"Unsupported activation: {hidden_act}. "
  71. "Only silu is supported for now.")
  72. self.act_fn = SiluAndMul()
  73. def forward(self, x):
  74. gate_up, _ = self.gate_up_proj(x)
  75. x = self.act_fn(gate_up)
  76. x, _ = self.down_proj(x)
  77. return x
  78. class Qwen2MoeSparseMoeBlock(nn.Module):
  79. def __init__(
  80. self,
  81. config: PretrainedConfig,
  82. quant_config: Optional[QuantizationConfig] = None,
  83. ):
  84. super().__init__()
  85. self.tp_size = get_tensor_model_parallel_world_size()
  86. if self.tp_size > config.num_experts:
  87. raise ValueError(
  88. f"Tensor parallel size {self.tp_size} is greater than "
  89. f"the number of experts {config.num_experts}.")
  90. self.experts = FusedMoE(num_experts=config.num_experts,
  91. top_k=config.num_experts_per_tok,
  92. hidden_size=config.hidden_size,
  93. intermediate_size=config.moe_intermediate_size,
  94. reduce_results=False,
  95. renormalize=config.norm_topk_prob,
  96. quant_config=quant_config)
  97. self.gate = ReplicatedLinear(config.hidden_size,
  98. config.num_experts,
  99. bias=False,
  100. quant_config=None)
  101. if config.shared_expert_intermediate_size > 0:
  102. self.shared_expert = Qwen2MoeMLP(
  103. hidden_size=config.hidden_size,
  104. intermediate_size=config.shared_expert_intermediate_size,
  105. hidden_act=config.hidden_act,
  106. quant_config=quant_config,
  107. reduce_results=False,
  108. )
  109. else:
  110. self.shared_expert = None
  111. self.shared_expert_gate = torch.nn.Linear(config.hidden_size,
  112. 1,
  113. bias=False)
  114. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  115. num_tokens, hidden_dim = hidden_states.shape
  116. hidden_states = hidden_states.view(-1, hidden_dim)
  117. shared_output = None
  118. if self.shared_expert is not None:
  119. shared_output = self.shared_expert(hidden_states)
  120. if self.shared_expert_gate is not None:
  121. shared_output = F.sigmoid(
  122. self.shared_expert_gate(hidden_states)) * shared_output
  123. # router_logits: (num_tokens, n_experts)
  124. router_logits, _ = self.gate(hidden_states)
  125. final_hidden_states = self.experts(hidden_states=hidden_states,
  126. router_logits=router_logits)
  127. if shared_output is not None:
  128. final_hidden_states = final_hidden_states + shared_output
  129. if self.tp_size > 1:
  130. final_hidden_states = tensor_model_parallel_all_reduce(
  131. final_hidden_states)
  132. return final_hidden_states.view(num_tokens, hidden_dim)
  133. class Qwen2MoeAttention(nn.Module):
  134. def __init__(
  135. self,
  136. hidden_size: int,
  137. num_heads: int,
  138. num_kv_heads: int,
  139. rope_theta: float = 10000,
  140. rope_scaling: Optional[Dict[str, Any]] = None,
  141. max_position_embeddings: int = 8192,
  142. cache_config: Optional[CacheConfig] = None,
  143. quant_config: Optional[QuantizationConfig] = None,
  144. ) -> None:
  145. super().__init__()
  146. self.hidden_size = hidden_size
  147. tp_size = get_tensor_model_parallel_world_size()
  148. self.total_num_heads = num_heads
  149. assert self.total_num_heads % tp_size == 0
  150. self.num_heads = self.total_num_heads // tp_size
  151. self.total_num_kv_heads = num_kv_heads
  152. if self.total_num_kv_heads >= tp_size:
  153. # Number of KV heads is greater than TP size, so we partition
  154. # the KV heads across multiple tensor parallel GPUs.
  155. assert self.total_num_kv_heads % tp_size == 0
  156. else:
  157. # Number of KV heads is less than TP size, so we replicate
  158. # the KV heads across multiple tensor parallel GPUs.
  159. assert tp_size % self.total_num_kv_heads == 0
  160. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  161. self.head_dim = hidden_size // self.total_num_heads
  162. self.q_size = self.num_heads * self.head_dim
  163. self.kv_size = self.num_kv_heads * self.head_dim
  164. self.scaling = self.head_dim**-0.5
  165. self.rope_theta = rope_theta
  166. self.max_position_embeddings = max_position_embeddings
  167. self.qkv_proj = QKVParallelLinear(
  168. hidden_size,
  169. self.head_dim,
  170. self.total_num_heads,
  171. self.total_num_kv_heads,
  172. bias=True,
  173. quant_config=quant_config,
  174. )
  175. self.o_proj = RowParallelLinear(
  176. self.total_num_heads * self.head_dim,
  177. hidden_size,
  178. bias=False,
  179. quant_config=quant_config,
  180. )
  181. self.rotary_emb = get_rope(
  182. self.head_dim,
  183. rotary_dim=self.head_dim,
  184. max_position=max_position_embeddings,
  185. base=rope_theta,
  186. rope_scaling=rope_scaling,
  187. )
  188. self.attn = Attention(self.num_heads,
  189. self.head_dim,
  190. self.scaling,
  191. num_kv_heads=self.num_kv_heads,
  192. cache_config=cache_config,
  193. quant_config=quant_config)
  194. def forward(
  195. self,
  196. positions: torch.Tensor,
  197. hidden_states: torch.Tensor,
  198. kv_cache: torch.Tensor,
  199. attn_metadata: AttentionMetadata,
  200. ) -> torch.Tensor:
  201. qkv, _ = self.qkv_proj(hidden_states)
  202. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  203. q, k = self.rotary_emb(positions, q, k)
  204. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  205. output, _ = self.o_proj(attn_output)
  206. return output
  207. class Qwen2MoeDecoderLayer(nn.Module):
  208. def __init__(
  209. self,
  210. config: PretrainedConfig,
  211. layer_idx: int,
  212. cache_config: Optional[CacheConfig] = None,
  213. quant_config: Optional[QuantizationConfig] = None,
  214. ) -> None:
  215. super().__init__()
  216. self.hidden_size = config.hidden_size
  217. rope_theta = getattr(config, "rope_theta", 10000)
  218. rope_scaling = getattr(config, "rope_scaling", None)
  219. max_position_embeddings = getattr(config, "max_position_embeddings",
  220. 8192)
  221. self.self_attn = Qwen2MoeAttention(
  222. hidden_size=self.hidden_size,
  223. num_heads=config.num_attention_heads,
  224. num_kv_heads=config.num_key_value_heads,
  225. rope_theta=rope_theta,
  226. rope_scaling=rope_scaling,
  227. max_position_embeddings=max_position_embeddings,
  228. cache_config=cache_config,
  229. quant_config=quant_config,
  230. )
  231. # Note: Qwen/Qwen2-57B-A14B-Instruct does not have
  232. # `mlp_only_layers` in the config.
  233. mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
  234. config.mlp_only_layers)
  235. if (layer_idx not in mlp_only_layers) and (
  236. config.num_experts > 0 and
  237. (layer_idx + 1) % config.decoder_sparse_step == 0):
  238. self.mlp = Qwen2MoeSparseMoeBlock(config=config,
  239. quant_config=quant_config)
  240. else:
  241. self.mlp = Qwen2MoeMLP(
  242. hidden_size=config.hidden_size,
  243. intermediate_size=config.intermediate_size,
  244. hidden_act=config.hidden_act,
  245. quant_config=quant_config,
  246. )
  247. self.input_layernorm = RMSNorm(config.hidden_size,
  248. eps=config.rms_norm_eps)
  249. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  250. eps=config.rms_norm_eps)
  251. def forward(
  252. self,
  253. positions: torch.Tensor,
  254. hidden_states: torch.Tensor,
  255. kv_cache: torch.Tensor,
  256. attn_metadata: AttentionMetadata,
  257. residual: Optional[torch.Tensor],
  258. ) -> torch.Tensor:
  259. # Self Attention
  260. if residual is None:
  261. residual = hidden_states
  262. hidden_states = self.input_layernorm(hidden_states)
  263. else:
  264. hidden_states, residual = self.input_layernorm(
  265. hidden_states, residual)
  266. hidden_states = self.self_attn(
  267. positions=positions,
  268. hidden_states=hidden_states,
  269. kv_cache=kv_cache,
  270. attn_metadata=attn_metadata,
  271. )
  272. # Fully Connected
  273. hidden_states, residual = self.post_attention_layernorm(
  274. hidden_states, residual)
  275. hidden_states = self.mlp(hidden_states)
  276. return hidden_states, residual
  277. class Qwen2MoeModel(nn.Module):
  278. def __init__(
  279. self,
  280. config: PretrainedConfig,
  281. cache_config: Optional[CacheConfig] = None,
  282. quant_config: Optional[QuantizationConfig] = None,
  283. ) -> None:
  284. super().__init__()
  285. self.padding_idx = config.pad_token_id
  286. self.vocab_size = config.vocab_size
  287. self.embed_tokens = VocabParallelEmbedding(
  288. config.vocab_size,
  289. config.hidden_size,
  290. )
  291. self.layers = nn.ModuleList([
  292. Qwen2MoeDecoderLayer(config,
  293. layer_idx,
  294. cache_config,
  295. quant_config=quant_config)
  296. for layer_idx in range(config.num_hidden_layers)
  297. ])
  298. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  299. def forward(
  300. self,
  301. input_ids: torch.Tensor,
  302. positions: torch.Tensor,
  303. kv_caches: List[torch.Tensor],
  304. attn_metadata: AttentionMetadata,
  305. ) -> torch.Tensor:
  306. hidden_states = self.embed_tokens(input_ids)
  307. residual = None
  308. for i in range(len(self.layers)):
  309. layer = self.layers[i]
  310. hidden_states, residual = layer(positions, hidden_states,
  311. kv_caches[i], attn_metadata,
  312. residual)
  313. hidden_states, _ = self.norm(hidden_states, residual)
  314. return hidden_states
  315. class Qwen2MoeForCausalLM(nn.Module):
  316. fall_back_to_pt_during_load = False
  317. def __init__(
  318. self,
  319. config: PretrainedConfig,
  320. cache_config: Optional[CacheConfig] = None,
  321. quant_config: Optional[QuantizationConfig] = None,
  322. ) -> None:
  323. super().__init__()
  324. self.config = config
  325. self.quant_config = quant_config
  326. self.model = Qwen2MoeModel(config, cache_config, quant_config)
  327. self.lm_head = ParallelLMHead(config.vocab_size,
  328. config.hidden_size,
  329. quant_config=quant_config)
  330. self.logits_processor = LogitsProcessor(config.vocab_size)
  331. self.sampler = Sampler()
  332. def forward(
  333. self,
  334. input_ids: torch.Tensor,
  335. positions: torch.Tensor,
  336. kv_caches: List[torch.Tensor],
  337. attn_metadata: AttentionMetadata,
  338. intermediate_tensors: Optional[IntermediateTensors] = None,
  339. ) -> torch.Tensor:
  340. hidden_states = self.model(input_ids, positions, kv_caches,
  341. attn_metadata)
  342. return hidden_states
  343. def compute_logits(self, hidden_states: torch.Tensor,
  344. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  345. logits = self.logits_processor(self.lm_head, hidden_states,
  346. sampling_metadata)
  347. return logits
  348. def sample(
  349. self,
  350. logits: Optional[torch.Tensor],
  351. sampling_metadata: SamplingMetadata,
  352. ) -> Optional[SamplerOutput]:
  353. next_tokens = self.sampler(logits, sampling_metadata)
  354. return next_tokens
  355. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  356. stacked_params_mapping = [
  357. # (param_name, shard_name, shard_id)
  358. ("qkv_proj", "q_proj", "q"),
  359. ("qkv_proj", "k_proj", "k"),
  360. ("qkv_proj", "v_proj", "v"),
  361. ("gate_up_proj", "gate_proj", 0),
  362. ("gate_up_proj", "up_proj", 1),
  363. ]
  364. expert_params_mapping = [
  365. # These are the weights for the experts
  366. # (param_name, weight_name, expert_id, shard_id)
  367. ("experts.w13_weight" if weight_name in ["gate_proj", "up_proj"]
  368. else "experts.w2_weight",
  369. f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id)
  370. for expert_id in range(self.config.num_experts) for shard_id,
  371. weight_name in enumerate(["gate_proj", "down_proj", "up_proj"])
  372. ]
  373. params_dict = dict(self.named_parameters())
  374. for name, loaded_weight in weights:
  375. if "rotary_emb.inv_freq" in name:
  376. continue
  377. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  378. # Skip non-stacked layers and experts (experts handled below).
  379. if weight_name not in name:
  380. continue
  381. # We have mlp.experts[0].gate_proj in the checkpoint.
  382. # Since we handle the experts below in expert_params_mapping,
  383. # we need to skip here BEFORE we update the name, otherwise
  384. # name will be updated to mlp.experts[0].gate_up_proj, which
  385. # will then be updated below in expert_params_mapping
  386. # for mlp.experts[0].gate_gate_up_proj, which breaks load.
  387. if "mlp.experts" in name:
  388. continue
  389. name = name.replace(weight_name, param_name)
  390. # Skip loading extra bias for GPTQ models.
  391. if name.endswith(".bias") and name not in params_dict:
  392. continue
  393. if name not in params_dict:
  394. continue
  395. param = params_dict[name]
  396. weight_loader = param.weight_loader
  397. weight_loader(param, loaded_weight, shard_id)
  398. break
  399. else:
  400. for mapping in expert_params_mapping:
  401. param_name, weight_name, expert_id, shard_id = mapping
  402. if weight_name not in name:
  403. continue
  404. name = name.replace(weight_name, param_name)
  405. param = params_dict[name]
  406. weight_loader = param.weight_loader
  407. weight_loader(param,
  408. loaded_weight,
  409. weight_name,
  410. shard_id=shard_id,
  411. expert_id=expert_id)
  412. break
  413. else:
  414. # Skip loading extra bias for GPTQ models.
  415. if name.endswith(".bias") and name not in params_dict:
  416. continue
  417. if name not in params_dict:
  418. continue
  419. param = params_dict[name]
  420. weight_loader = getattr(param, "weight_loader",
  421. default_weight_loader)
  422. weight_loader(param, loaded_weight)