qwen2moe.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
  4. # Copyright 2024 The Qwen team.
  5. # Copyright 2023 The PygmalionAI team.
  6. # Copyright 2023 The vLLM team.
  7. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  8. #
  9. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  10. # and OPT implementations in this library. It has been modified from its
  11. # original forms to accommodate minor architectural differences compared
  12. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  13. #
  14. # Licensed under the Apache License, Version 2.0 (the "License");
  15. # you may not use this file except in compliance with the License.
  16. # You may obtain a copy of the License at
  17. #
  18. # http://www.apache.org/licenses/LICENSE-2.0
  19. #
  20. # Unless required by applicable law or agreed to in writing, software
  21. # distributed under the License is distributed on an "AS IS" BASIS,
  22. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  23. # See the License for the specific language governing permissions and
  24. # limitations under the License.
  25. """Inference-only Qwen2MoE model compatible with HuggingFace weights."""
  26. from typing import Any, Dict, List, Optional
  27. import torch
  28. import torch.nn.functional as F
  29. from torch import nn
  30. from transformers import PretrainedConfig
  31. from aphrodite.attention import Attention, AttentionMetadata
  32. from aphrodite.modeling.layers.activation import SiluAndMul
  33. from aphrodite.modeling.layers.fused_moe import fused_moe
  34. from aphrodite.modeling.layers.layernorm import RMSNorm
  35. from aphrodite.modeling.layers.linear import (LinearMethodBase,
  36. MergedColumnParallelLinear,
  37. QKVParallelLinear,
  38. ReplicatedLinear,
  39. RowParallelLinear)
  40. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  41. from aphrodite.modeling.layers.rotary_embedding import get_rope
  42. from aphrodite.modeling.layers.sampler import Sampler
  43. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  44. ParallelLMHead, VocabParallelEmbedding)
  45. from aphrodite.distributed import (tensor_model_parallel_all_reduce)
  46. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  47. get_tensor_model_parallel_world_size)
  48. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  49. from aphrodite.modeling.hf_downloader import (default_weight_loader,
  50. hf_model_weights_iterator)
  51. from aphrodite.common.sequence import SamplerOutput
  52. class Qwen2MoeMLP(nn.Module):
  53. def __init__(
  54. self,
  55. hidden_size: int,
  56. intermediate_size: int,
  57. hidden_act: str,
  58. linear_method: Optional[LinearMethodBase] = None,
  59. reduce_results: bool = True,
  60. ) -> None:
  61. super().__init__()
  62. self.gate_up_proj = MergedColumnParallelLinear(
  63. hidden_size, [intermediate_size] * 2,
  64. bias=False,
  65. linear_method=linear_method)
  66. self.down_proj = RowParallelLinear(intermediate_size,
  67. hidden_size,
  68. bias=False,
  69. linear_method=linear_method,
  70. reduce_results=reduce_results)
  71. if hidden_act != "silu":
  72. raise ValueError(f"Unsupported activation: {hidden_act}. "
  73. "Only silu is supported for now.")
  74. self.act_fn = SiluAndMul()
  75. def forward(self, x):
  76. gate_up, _ = self.gate_up_proj(x)
  77. x = self.act_fn(gate_up)
  78. x, _ = self.down_proj(x)
  79. return x
  80. class Qwen2MoeSparseMoeBlock(nn.Module):
  81. def __init__(
  82. self,
  83. config: PretrainedConfig,
  84. linear_method: Optional[LinearMethodBase] = None,
  85. ):
  86. super().__init__()
  87. self.config = config
  88. self.rank = get_tensor_model_parallel_rank()
  89. self.tp_size = get_tensor_model_parallel_world_size()
  90. self.n_routed_experts = config.num_experts
  91. self.top_k = config.num_experts_per_tok
  92. if self.tp_size > self.n_routed_experts:
  93. raise ValueError(
  94. f"Tensor parallel size {self.tp_size} is greater than "
  95. f"the number of experts {self.n_routed_experts}.")
  96. self.experts = nn.ModuleList([
  97. Qwen2MoeMLP(hidden_size=config.hidden_size,
  98. intermediate_size=config.moe_intermediate_size,
  99. hidden_act=config.hidden_act,
  100. linear_method=linear_method,
  101. reduce_results=False)
  102. for idx in range(self.n_routed_experts)
  103. ])
  104. self.pack_params()
  105. self.gate = ReplicatedLinear(config.hidden_size,
  106. self.n_routed_experts,
  107. bias=False,
  108. linear_method=None)
  109. if config.shared_expert_intermediate_size > 0:
  110. self.shared_expert = Qwen2MoeMLP(
  111. hidden_size=config.hidden_size,
  112. intermediate_size=config.shared_expert_intermediate_size,
  113. hidden_act=config.hidden_act,
  114. linear_method=linear_method,
  115. reduce_results=False,
  116. )
  117. else:
  118. self.shared_expert = None
  119. self.shared_expert_gate = torch.nn.Linear(config.hidden_size,
  120. 1,
  121. bias=False)
  122. def pack_params(self):
  123. w1 = []
  124. w2 = []
  125. for expert in self.experts:
  126. w1.append(expert.gate_up_proj.weight)
  127. w2.append(expert.down_proj.weight)
  128. self.w1 = torch._utils._flatten_dense_tensors(w1)
  129. w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
  130. for data, param in zip(w1s, w1):
  131. param.data = data
  132. self.w1 = self.w1.view(len(w1), *w1s[0].shape)
  133. self.w2 = torch._utils._flatten_dense_tensors(w2)
  134. w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
  135. for data, param in zip(w2s, w2):
  136. param.data = data
  137. self.w2 = self.w2.view(len(w2), *w2s[0].shape)
  138. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  139. num_tokens, hidden_dim = hidden_states.shape
  140. hidden_states = hidden_states.view(-1, hidden_dim)
  141. shared_output = None
  142. if self.shared_expert is not None:
  143. shared_output = self.shared_expert(hidden_states)
  144. if self.shared_expert_gate is not None:
  145. shared_output = F.sigmoid(
  146. self.shared_expert_gate(hidden_states)) * shared_output
  147. # router_logits: (num_tokens, n_experts)
  148. router_logits, _ = self.gate(hidden_states)
  149. final_hidden_states = fused_moe(hidden_states,
  150. self.w1,
  151. self.w2,
  152. router_logits,
  153. self.top_k,
  154. renormalize=self.config.norm_topk_prob,
  155. inplace=True)
  156. if shared_output is not None:
  157. final_hidden_states = final_hidden_states + shared_output
  158. final_hidden_states = tensor_model_parallel_all_reduce(
  159. final_hidden_states)
  160. return final_hidden_states.view(num_tokens, hidden_dim)
  161. class Qwen2MoeAttention(nn.Module):
  162. def __init__(
  163. self,
  164. hidden_size: int,
  165. num_heads: int,
  166. num_kv_heads: int,
  167. rope_theta: float = 10000,
  168. rope_scaling: Optional[Dict[str, Any]] = None,
  169. max_position_embeddings: int = 8192,
  170. linear_method: Optional[LinearMethodBase] = None,
  171. ) -> None:
  172. super().__init__()
  173. self.hidden_size = hidden_size
  174. tp_size = get_tensor_model_parallel_world_size()
  175. self.total_num_heads = num_heads
  176. assert self.total_num_heads % tp_size == 0
  177. self.num_heads = self.total_num_heads // tp_size
  178. self.total_num_kv_heads = num_kv_heads
  179. if self.total_num_kv_heads >= tp_size:
  180. # Number of KV heads is greater than TP size, so we partition
  181. # the KV heads across multiple tensor parallel GPUs.
  182. assert self.total_num_kv_heads % tp_size == 0
  183. else:
  184. # Number of KV heads is less than TP size, so we replicate
  185. # the KV heads across multiple tensor parallel GPUs.
  186. assert tp_size % self.total_num_kv_heads == 0
  187. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  188. self.head_dim = hidden_size // self.total_num_heads
  189. self.q_size = self.num_heads * self.head_dim
  190. self.kv_size = self.num_kv_heads * self.head_dim
  191. self.scaling = self.head_dim**-0.5
  192. self.rope_theta = rope_theta
  193. self.max_position_embeddings = max_position_embeddings
  194. self.qkv_proj = QKVParallelLinear(
  195. hidden_size,
  196. self.head_dim,
  197. self.total_num_heads,
  198. self.total_num_kv_heads,
  199. bias=True,
  200. linear_method=linear_method,
  201. )
  202. self.o_proj = RowParallelLinear(
  203. self.total_num_heads * self.head_dim,
  204. hidden_size,
  205. bias=False,
  206. linear_method=linear_method,
  207. )
  208. self.rotary_emb = get_rope(
  209. self.head_dim,
  210. rotary_dim=self.head_dim,
  211. max_position=max_position_embeddings,
  212. base=rope_theta,
  213. rope_scaling=rope_scaling,
  214. )
  215. self.attn = Attention(self.num_heads,
  216. self.head_dim,
  217. self.scaling,
  218. num_kv_heads=self.num_kv_heads)
  219. def forward(
  220. self,
  221. positions: torch.Tensor,
  222. hidden_states: torch.Tensor,
  223. kv_cache: torch.Tensor,
  224. attn_metadata: AttentionMetadata,
  225. ) -> torch.Tensor:
  226. qkv, _ = self.qkv_proj(hidden_states)
  227. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  228. q, k = self.rotary_emb(positions, q, k)
  229. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  230. output, _ = self.o_proj(attn_output)
  231. return output
  232. class Qwen2MoeDecoderLayer(nn.Module):
  233. def __init__(
  234. self,
  235. config: PretrainedConfig,
  236. layer_idx: int,
  237. linear_method: Optional[LinearMethodBase] = None,
  238. ) -> None:
  239. super().__init__()
  240. self.hidden_size = config.hidden_size
  241. rope_theta = getattr(config, "rope_theta", 10000)
  242. rope_scaling = getattr(config, "rope_scaling", None)
  243. max_position_embeddings = getattr(config, "max_position_embeddings",
  244. 8192)
  245. self.self_attn = Qwen2MoeAttention(
  246. hidden_size=self.hidden_size,
  247. num_heads=config.num_attention_heads,
  248. num_kv_heads=config.num_key_value_heads,
  249. rope_theta=rope_theta,
  250. rope_scaling=rope_scaling,
  251. max_position_embeddings=max_position_embeddings,
  252. linear_method=linear_method,
  253. )
  254. if (config.num_experts is not None
  255. and (layer_idx + 1) % config.decoder_sparse_step == 0):
  256. self.mlp = Qwen2MoeSparseMoeBlock(config=config,
  257. linear_method=linear_method)
  258. else:
  259. self.mlp = Qwen2MoeMLP(
  260. hidden_size=config.hidden_size,
  261. intermediate_size=config.intermediate_size,
  262. hidden_act=config.hidden_act,
  263. linear_method=linear_method,
  264. )
  265. self.input_layernorm = RMSNorm(config.hidden_size,
  266. eps=config.rms_norm_eps)
  267. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  268. eps=config.rms_norm_eps)
  269. def forward(
  270. self,
  271. positions: torch.Tensor,
  272. hidden_states: torch.Tensor,
  273. kv_cache: torch.Tensor,
  274. attn_metadata: AttentionMetadata,
  275. residual: Optional[torch.Tensor],
  276. ) -> torch.Tensor:
  277. # Self Attention
  278. if residual is None:
  279. residual = hidden_states
  280. hidden_states = self.input_layernorm(hidden_states)
  281. else:
  282. hidden_states, residual = self.input_layernorm(
  283. hidden_states, residual)
  284. hidden_states = self.self_attn(
  285. positions=positions,
  286. hidden_states=hidden_states,
  287. kv_cache=kv_cache,
  288. attn_metadata=attn_metadata,
  289. )
  290. # Fully Connected
  291. hidden_states, residual = self.post_attention_layernorm(
  292. hidden_states, residual)
  293. hidden_states = self.mlp(hidden_states)
  294. return hidden_states, residual
  295. class Qwen2MoeModel(nn.Module):
  296. def __init__(
  297. self,
  298. config: PretrainedConfig,
  299. linear_method: Optional[LinearMethodBase] = None,
  300. ) -> None:
  301. super().__init__()
  302. self.padding_idx = config.pad_token_id
  303. self.vocab_size = config.vocab_size
  304. self.embed_tokens = VocabParallelEmbedding(
  305. config.vocab_size,
  306. config.hidden_size,
  307. )
  308. self.layers = nn.ModuleList([
  309. Qwen2MoeDecoderLayer(config,
  310. layer_idx,
  311. linear_method=linear_method)
  312. for layer_idx in range(config.num_hidden_layers)
  313. ])
  314. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  315. def forward(
  316. self,
  317. input_ids: torch.Tensor,
  318. positions: torch.Tensor,
  319. kv_caches: List[torch.Tensor],
  320. attn_metadata: AttentionMetadata,
  321. ) -> torch.Tensor:
  322. hidden_states = self.embed_tokens(input_ids)
  323. residual = None
  324. for i in range(len(self.layers)):
  325. layer = self.layers[i]
  326. hidden_states, residual = layer(positions, hidden_states,
  327. kv_caches[i], attn_metadata,
  328. residual)
  329. hidden_states, _ = self.norm(hidden_states, residual)
  330. return hidden_states
  331. class Qwen2MoeForCausalLM(nn.Module):
  332. def __init__(
  333. self,
  334. config: PretrainedConfig,
  335. linear_method: Optional[LinearMethodBase] = None,
  336. ) -> None:
  337. super().__init__()
  338. self.config = config
  339. self.linear_method = linear_method
  340. self.model = Qwen2MoeModel(config, linear_method)
  341. self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
  342. self.logits_processor = LogitsProcessor(config.vocab_size)
  343. self.sampler = Sampler()
  344. def forward(
  345. self,
  346. input_ids: torch.Tensor,
  347. positions: torch.Tensor,
  348. kv_caches: List[torch.Tensor],
  349. attn_metadata: AttentionMetadata,
  350. ) -> torch.Tensor:
  351. hidden_states = self.model(input_ids, positions, kv_caches,
  352. attn_metadata)
  353. return hidden_states
  354. def compute_logits(self, hidden_states: torch.Tensor,
  355. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  356. logits = self.logits_processor(self.lm_head, hidden_states,
  357. sampling_metadata)
  358. return logits
  359. def sample(
  360. self,
  361. logits: Optional[torch.Tensor],
  362. sampling_metadata: SamplingMetadata,
  363. ) -> Optional[SamplerOutput]:
  364. next_tokens = self.sampler(logits, sampling_metadata)
  365. return next_tokens
  366. def load_weights(self,
  367. model_name_or_path: str,
  368. cache_dir: Optional[str] = None,
  369. load_format: str = "auto",
  370. revision: Optional[str] = None):
  371. stacked_params_mapping = [
  372. # (param_name, shard_name, shard_id)
  373. ("qkv_proj", "q_proj", "q"),
  374. ("qkv_proj", "k_proj", "k"),
  375. ("qkv_proj", "v_proj", "v"),
  376. ("gate_up_proj", "gate_proj", 0),
  377. ("gate_up_proj", "up_proj", 1),
  378. ]
  379. params_dict = dict(self.named_parameters())
  380. for name, loaded_weight in hf_model_weights_iterator(
  381. model_name_or_path,
  382. cache_dir,
  383. load_format,
  384. revision,
  385. fall_back_to_pt=False):
  386. if "rotary_emb.inv_freq" in name:
  387. continue
  388. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  389. if weight_name not in name:
  390. continue
  391. name = name.replace(weight_name, param_name)
  392. # Skip loading extra bias for GPTQ models.
  393. if name.endswith(".bias") and name not in params_dict:
  394. continue
  395. # Skip experts that are not assigned to this worker.
  396. if (("mlp.experts." in name or "mlp.shared_expert." in name)
  397. and name not in params_dict):
  398. continue
  399. param = params_dict[name]
  400. weight_loader = param.weight_loader
  401. weight_loader(param, loaded_weight, shard_id)
  402. break
  403. else:
  404. # Skip loading extra bias for GPTQ models.
  405. if name.endswith(".bias") and name not in params_dict:
  406. continue
  407. # Skip experts that are not assigned to this worker.
  408. if (("mlp.experts." in name or "mlp.shared_expert." in name)
  409. and name not in params_dict):
  410. continue
  411. param = params_dict[name]
  412. weight_loader = getattr(param, "weight_loader",
  413. default_weight_loader)
  414. weight_loader(param, loaded_weight)