qwen2_moe.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
  4. # Copyright 2024 The Qwen team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  7. #
  8. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  9. # and OPT implementations in this library. It has been modified from its
  10. # original forms to accommodate minor architectural differences compared
  11. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  12. #
  13. # Licensed under the Apache License, Version 2.0 (the "License");
  14. # you may not use this file except in compliance with the License.
  15. # You may obtain a copy of the License at
  16. #
  17. # http://www.apache.org/licenses/LICENSE-2.0
  18. #
  19. # Unless required by applicable law or agreed to in writing, software
  20. # distributed under the License is distributed on an "AS IS" BASIS,
  21. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  22. # See the License for the specific language governing permissions and
  23. # limitations under the License.
  24. """Inference-only Qwen2MoE model compatible with HuggingFace weights."""
  25. from typing import Any, Dict, Iterable, List, Optional, Tuple
  26. import torch
  27. import torch.nn.functional as F
  28. from torch import nn
  29. from transformers import PretrainedConfig
  30. from aphrodite.attention import Attention, AttentionMetadata
  31. from aphrodite.common.config import CacheConfig
  32. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  33. from aphrodite.common.utils import print_warning_once
  34. from aphrodite.distributed import (get_pp_group,
  35. get_tensor_model_parallel_world_size,
  36. tensor_model_parallel_all_reduce)
  37. from aphrodite.modeling.layers.activation import SiluAndMul
  38. from aphrodite.modeling.layers.fused_moe import FusedMoE
  39. from aphrodite.modeling.layers.layernorm import RMSNorm
  40. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  41. QKVParallelLinear,
  42. ReplicatedLinear,
  43. RowParallelLinear)
  44. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  45. from aphrodite.modeling.layers.rotary_embedding import get_rope
  46. from aphrodite.modeling.layers.sampler import Sampler
  47. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  48. ParallelLMHead, VocabParallelEmbedding)
  49. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  50. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  51. from aphrodite.quantization.base_config import QuantizationConfig
  52. from .utils import is_pp_missing_parameter, make_layers
  53. class Qwen2MoeMLP(nn.Module):
  54. def __init__(
  55. self,
  56. hidden_size: int,
  57. intermediate_size: int,
  58. hidden_act: str,
  59. quant_config: Optional[QuantizationConfig] = None,
  60. reduce_results: bool = True,
  61. ) -> None:
  62. super().__init__()
  63. self.gate_up_proj = MergedColumnParallelLinear(
  64. hidden_size, [intermediate_size] * 2,
  65. bias=False,
  66. quant_config=quant_config)
  67. self.down_proj = RowParallelLinear(intermediate_size,
  68. hidden_size,
  69. bias=False,
  70. quant_config=quant_config,
  71. reduce_results=reduce_results)
  72. if hidden_act != "silu":
  73. raise ValueError(f"Unsupported activation: {hidden_act}. "
  74. "Only silu is supported for now.")
  75. self.act_fn = SiluAndMul()
  76. def forward(self, x):
  77. gate_up, _ = self.gate_up_proj(x)
  78. x = self.act_fn(gate_up)
  79. x, _ = self.down_proj(x)
  80. return x
  81. class Qwen2MoeSparseMoeBlock(nn.Module):
  82. def __init__(
  83. self,
  84. config: PretrainedConfig,
  85. quant_config: Optional[QuantizationConfig] = None,
  86. ):
  87. super().__init__()
  88. self.tp_size = get_tensor_model_parallel_world_size()
  89. if self.tp_size > config.num_experts:
  90. raise ValueError(
  91. f"Tensor parallel size {self.tp_size} is greater than "
  92. f"the number of experts {config.num_experts}.")
  93. self.experts = FusedMoE(num_experts=config.num_experts,
  94. top_k=config.num_experts_per_tok,
  95. hidden_size=config.hidden_size,
  96. intermediate_size=config.moe_intermediate_size,
  97. reduce_results=False,
  98. renormalize=config.norm_topk_prob,
  99. quant_config=quant_config)
  100. self.gate = ReplicatedLinear(config.hidden_size,
  101. config.num_experts,
  102. bias=False,
  103. quant_config=None)
  104. if config.shared_expert_intermediate_size > 0:
  105. self.shared_expert = Qwen2MoeMLP(
  106. hidden_size=config.hidden_size,
  107. intermediate_size=config.shared_expert_intermediate_size,
  108. hidden_act=config.hidden_act,
  109. quant_config=quant_config,
  110. reduce_results=False,
  111. )
  112. else:
  113. self.shared_expert = None
  114. self.shared_expert_gate = torch.nn.Linear(config.hidden_size,
  115. 1,
  116. bias=False)
  117. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  118. # NOTE: hidden_states can have either 1D or 2D shape.
  119. orig_shape = hidden_states.shape
  120. hidden_dim = hidden_states.shape[-1]
  121. hidden_states = hidden_states.view(-1, hidden_dim)
  122. shared_output = None
  123. if self.shared_expert is not None:
  124. shared_output = self.shared_expert(hidden_states)
  125. if self.shared_expert_gate is not None:
  126. shared_output = F.sigmoid(
  127. self.shared_expert_gate(hidden_states)) * shared_output
  128. # router_logits: (num_tokens, n_experts)
  129. router_logits, _ = self.gate(hidden_states)
  130. final_hidden_states = self.experts(hidden_states=hidden_states,
  131. router_logits=router_logits)
  132. if shared_output is not None:
  133. final_hidden_states = final_hidden_states + shared_output
  134. if self.tp_size > 1:
  135. final_hidden_states = tensor_model_parallel_all_reduce(
  136. final_hidden_states)
  137. return final_hidden_states.view(orig_shape)
  138. class Qwen2MoeAttention(nn.Module):
  139. def __init__(
  140. self,
  141. hidden_size: int,
  142. num_heads: int,
  143. num_kv_heads: int,
  144. rope_theta: float = 10000,
  145. rope_scaling: Optional[Dict[str, Any]] = None,
  146. max_position_embeddings: int = 8192,
  147. cache_config: Optional[CacheConfig] = None,
  148. quant_config: Optional[QuantizationConfig] = None,
  149. ) -> None:
  150. super().__init__()
  151. self.hidden_size = hidden_size
  152. tp_size = get_tensor_model_parallel_world_size()
  153. self.total_num_heads = num_heads
  154. assert self.total_num_heads % tp_size == 0
  155. self.num_heads = self.total_num_heads // tp_size
  156. self.total_num_kv_heads = num_kv_heads
  157. if self.total_num_kv_heads >= tp_size:
  158. # Number of KV heads is greater than TP size, so we partition
  159. # the KV heads across multiple tensor parallel GPUs.
  160. assert self.total_num_kv_heads % tp_size == 0
  161. else:
  162. # Number of KV heads is less than TP size, so we replicate
  163. # the KV heads across multiple tensor parallel GPUs.
  164. assert tp_size % self.total_num_kv_heads == 0
  165. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  166. self.head_dim = hidden_size // self.total_num_heads
  167. self.q_size = self.num_heads * self.head_dim
  168. self.kv_size = self.num_kv_heads * self.head_dim
  169. self.scaling = self.head_dim**-0.5
  170. self.rope_theta = rope_theta
  171. self.max_position_embeddings = max_position_embeddings
  172. self.qkv_proj = QKVParallelLinear(
  173. hidden_size,
  174. self.head_dim,
  175. self.total_num_heads,
  176. self.total_num_kv_heads,
  177. bias=True,
  178. quant_config=quant_config,
  179. )
  180. self.o_proj = RowParallelLinear(
  181. self.total_num_heads * self.head_dim,
  182. hidden_size,
  183. bias=False,
  184. quant_config=quant_config,
  185. )
  186. self.rotary_emb = get_rope(
  187. self.head_dim,
  188. rotary_dim=self.head_dim,
  189. max_position=max_position_embeddings,
  190. base=rope_theta,
  191. rope_scaling=rope_scaling,
  192. )
  193. self.attn = Attention(self.num_heads,
  194. self.head_dim,
  195. self.scaling,
  196. num_kv_heads=self.num_kv_heads,
  197. cache_config=cache_config,
  198. quant_config=quant_config)
  199. def forward(
  200. self,
  201. positions: torch.Tensor,
  202. hidden_states: torch.Tensor,
  203. kv_cache: torch.Tensor,
  204. attn_metadata: AttentionMetadata,
  205. ) -> torch.Tensor:
  206. qkv, _ = self.qkv_proj(hidden_states)
  207. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  208. q, k = self.rotary_emb(positions, q, k)
  209. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  210. output, _ = self.o_proj(attn_output)
  211. return output
  212. class Qwen2MoeDecoderLayer(nn.Module):
  213. def __init__(
  214. self,
  215. config: PretrainedConfig,
  216. layer_idx: int,
  217. cache_config: Optional[CacheConfig] = None,
  218. quant_config: Optional[QuantizationConfig] = None,
  219. ) -> None:
  220. super().__init__()
  221. self.hidden_size = config.hidden_size
  222. rope_theta = getattr(config, "rope_theta", 10000)
  223. rope_scaling = getattr(config, "rope_scaling", None)
  224. max_position_embeddings = getattr(config, "max_position_embeddings",
  225. 8192)
  226. self.self_attn = Qwen2MoeAttention(
  227. hidden_size=self.hidden_size,
  228. num_heads=config.num_attention_heads,
  229. num_kv_heads=config.num_key_value_heads,
  230. rope_theta=rope_theta,
  231. rope_scaling=rope_scaling,
  232. max_position_embeddings=max_position_embeddings,
  233. cache_config=cache_config,
  234. quant_config=quant_config,
  235. )
  236. # Note: Qwen/Qwen2-57B-A14B-Instruct does not have
  237. # `mlp_only_layers` in the config.
  238. mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
  239. config.mlp_only_layers)
  240. if (layer_idx not in mlp_only_layers) and (
  241. config.num_experts > 0 and
  242. (layer_idx + 1) % config.decoder_sparse_step == 0):
  243. self.mlp = Qwen2MoeSparseMoeBlock(config=config,
  244. quant_config=quant_config)
  245. else:
  246. self.mlp = Qwen2MoeMLP(
  247. hidden_size=config.hidden_size,
  248. intermediate_size=config.intermediate_size,
  249. hidden_act=config.hidden_act,
  250. quant_config=quant_config,
  251. )
  252. self.input_layernorm = RMSNorm(config.hidden_size,
  253. eps=config.rms_norm_eps)
  254. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  255. eps=config.rms_norm_eps)
  256. def forward(
  257. self,
  258. positions: torch.Tensor,
  259. hidden_states: torch.Tensor,
  260. kv_cache: torch.Tensor,
  261. attn_metadata: AttentionMetadata,
  262. residual: Optional[torch.Tensor],
  263. ) -> torch.Tensor:
  264. # Self Attention
  265. if residual is None:
  266. residual = hidden_states
  267. hidden_states = self.input_layernorm(hidden_states)
  268. else:
  269. hidden_states, residual = self.input_layernorm(
  270. hidden_states, residual)
  271. hidden_states = self.self_attn(
  272. positions=positions,
  273. hidden_states=hidden_states,
  274. kv_cache=kv_cache,
  275. attn_metadata=attn_metadata,
  276. )
  277. # Fully Connected
  278. hidden_states, residual = self.post_attention_layernorm(
  279. hidden_states, residual)
  280. hidden_states = self.mlp(hidden_states)
  281. return hidden_states, residual
  282. class Qwen2MoeModel(nn.Module):
  283. def __init__(
  284. self,
  285. config: PretrainedConfig,
  286. cache_config: Optional[CacheConfig] = None,
  287. quant_config: Optional[QuantizationConfig] = None,
  288. prefix: str = "",
  289. ) -> None:
  290. super().__init__()
  291. self.padding_idx = config.pad_token_id
  292. self.vocab_size = config.vocab_size
  293. self.embed_tokens = VocabParallelEmbedding(
  294. config.vocab_size,
  295. config.hidden_size,
  296. )
  297. self.start_layer, self.end_layer, self.layers = make_layers(
  298. config.num_hidden_layers,
  299. lambda prefix: Qwen2MoeDecoderLayer(config=config,
  300. layer_idx=int(
  301. prefix.split(".")[-1]),
  302. cache_config=cache_config,
  303. quant_config=quant_config),
  304. prefix=f"{prefix}.layers",
  305. )
  306. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  307. def forward(
  308. self,
  309. input_ids: torch.Tensor,
  310. positions: torch.Tensor,
  311. kv_caches: List[torch.Tensor],
  312. attn_metadata: AttentionMetadata,
  313. intermediate_tensors: Optional[IntermediateTensors] = None,
  314. ) -> torch.Tensor:
  315. if get_pp_group().is_first_rank:
  316. hidden_states = self.embed_tokens(input_ids)
  317. residual = None
  318. else:
  319. assert intermediate_tensors is not None
  320. hidden_states = intermediate_tensors["hidden_states"]
  321. residual = intermediate_tensors["residual"]
  322. for i in range(self.start_layer, self.end_layer):
  323. layer = self.layers[i]
  324. hidden_states, residual = layer(positions, hidden_states,
  325. kv_caches[i - self.start_layer],
  326. attn_metadata, residual)
  327. if not get_pp_group().is_last_rank:
  328. return IntermediateTensors({
  329. "hidden_states": hidden_states,
  330. "residual": residual
  331. })
  332. hidden_states, _ = self.norm(hidden_states, residual)
  333. return hidden_states
  334. class Qwen2MoeForCausalLM(nn.Module):
  335. fall_back_to_pt_during_load = False
  336. def __init__(
  337. self,
  338. config: PretrainedConfig,
  339. cache_config: Optional[CacheConfig] = None,
  340. quant_config: Optional[QuantizationConfig] = None,
  341. ) -> None:
  342. super().__init__()
  343. self.config = config
  344. self.quant_config = quant_config
  345. self.model = Qwen2MoeModel(config, cache_config, quant_config)
  346. self.lm_head = ParallelLMHead(config.vocab_size,
  347. config.hidden_size,
  348. quant_config=quant_config)
  349. self.logits_processor = LogitsProcessor(config.vocab_size)
  350. self.sampler = Sampler()
  351. def forward(
  352. self,
  353. input_ids: torch.Tensor,
  354. positions: torch.Tensor,
  355. kv_caches: List[torch.Tensor],
  356. attn_metadata: AttentionMetadata,
  357. intermediate_tensors: Optional[IntermediateTensors] = None,
  358. ) -> torch.Tensor:
  359. hidden_states = self.model(input_ids, positions, kv_caches,
  360. attn_metadata, intermediate_tensors)
  361. return hidden_states
  362. def compute_logits(
  363. self,
  364. hidden_states: torch.Tensor,
  365. sampling_metadata: SamplingMetadata,
  366. ) -> Optional[torch.Tensor]:
  367. logits = self.logits_processor(self.lm_head, hidden_states,
  368. sampling_metadata)
  369. return logits
  370. def make_empty_intermediate_tensors(
  371. self, batch_size: int, dtype: torch.dtype,
  372. device: torch.device) -> IntermediateTensors:
  373. return IntermediateTensors({
  374. "hidden_states":
  375. torch.zeros((batch_size, self.config.hidden_size),
  376. dtype=dtype,
  377. device=device),
  378. "residual":
  379. torch.zeros((batch_size, self.config.hidden_size),
  380. dtype=dtype,
  381. device=device),
  382. })
  383. def sample(
  384. self,
  385. logits: Optional[torch.Tensor],
  386. sampling_metadata: SamplingMetadata,
  387. ) -> Optional[SamplerOutput]:
  388. next_tokens = self.sampler(logits, sampling_metadata)
  389. return next_tokens
  390. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  391. stacked_params_mapping = [
  392. # (param_name, shard_name, shard_id)
  393. ("qkv_proj", "q_proj", "q"),
  394. ("qkv_proj", "k_proj", "k"),
  395. ("qkv_proj", "v_proj", "v"),
  396. ("gate_up_proj", "gate_proj", 0),
  397. ("gate_up_proj", "up_proj", 1),
  398. ]
  399. # Params for weights, fp8 weight scales, fp8 activation scales
  400. # (param_name, weight_name, expert_id, shard_id)
  401. expert_params_mapping = FusedMoE.make_expert_params_mapping(
  402. ckpt_gate_proj_name="gate_proj",
  403. ckpt_down_proj_name="down_proj",
  404. ckpt_up_proj_name="up_proj",
  405. num_experts=self.config.num_experts)
  406. params_dict = dict(self.named_parameters())
  407. for name, loaded_weight in weights:
  408. if "rotary_emb.inv_freq" in name:
  409. continue
  410. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  411. # Skip non-stacked layers and experts (experts handled below).
  412. if weight_name not in name:
  413. continue
  414. # We have mlp.experts[0].gate_proj in the checkpoint.
  415. # Since we handle the experts below in expert_params_mapping,
  416. # we need to skip here BEFORE we update the name, otherwise
  417. # name will be updated to mlp.experts[0].gate_up_proj, which
  418. # will then be updated below in expert_params_mapping
  419. # for mlp.experts[0].gate_gate_up_proj, which breaks load.
  420. if "mlp.experts" in name:
  421. continue
  422. name = name.replace(weight_name, param_name)
  423. # Skip loading extra bias for GPTQ models.
  424. if name.endswith(".bias") and name not in params_dict:
  425. continue
  426. # Skip layers on other devices.
  427. if is_pp_missing_parameter(name, self):
  428. continue
  429. if name not in params_dict:
  430. continue
  431. param = params_dict[name]
  432. weight_loader = param.weight_loader
  433. weight_loader(param, loaded_weight, shard_id)
  434. break
  435. else:
  436. for mapping in expert_params_mapping:
  437. param_name, weight_name, expert_id, shard_id = mapping
  438. if weight_name not in name:
  439. continue
  440. name = name.replace(weight_name, param_name)
  441. # Skip layers on other devices.
  442. if is_pp_missing_parameter(name, self):
  443. continue
  444. param = params_dict[name]
  445. weight_loader = param.weight_loader
  446. weight_loader(param,
  447. loaded_weight,
  448. name,
  449. shard_id=shard_id,
  450. expert_id=expert_id)
  451. break
  452. else:
  453. # Skip loading extra bias for GPTQ models.
  454. if name.endswith(".bias") and name not in params_dict:
  455. continue
  456. # Skip layers on other devices.
  457. if is_pp_missing_parameter(name, self):
  458. continue
  459. # Remapping the name of FP8 kv-scale.
  460. if name.endswith("kv_scale"):
  461. remapped_kv_scale_name = name.replace(
  462. ".kv_scale", ".attn.kv_scale")
  463. if remapped_kv_scale_name not in params_dict:
  464. print_warning_once(
  465. "Found kv scale in the checkpoint "
  466. f"(e.g. {name}), but not found the expected "
  467. f"name in the model "
  468. f"(e.g. {remapped_kv_scale_name}). "
  469. "kv-scale is not loaded.")
  470. continue
  471. else:
  472. name = remapped_kv_scale_name
  473. param = params_dict[name]
  474. weight_loader = getattr(param, "weight_loader",
  475. default_weight_loader)
  476. weight_loader(param, loaded_weight)