phimoe.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  6. #
  7. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  8. # and OPT implementations in this library. It has been modified from its
  9. # original forms to accommodate minor architectural differences compared
  10. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  11. #
  12. # Licensed under the Apache License, Version 2.0 (the "License");
  13. # you may not use this file except in compliance with the License.
  14. # You may obtain a copy of the License at
  15. #
  16. # http://www.apache.org/licenses/LICENSE-2.0
  17. #
  18. # Unless required by applicable law or agreed to in writing, software
  19. # distributed under the License is distributed on an "AS IS" BASIS,
  20. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # See the License for the specific language governing permissions and
  22. # limitations under the License.
  23. """Inference-only PhiMoE model."""
  24. from typing import Iterable, List, Optional, Tuple
  25. import torch
  26. from torch import nn
  27. from transformers.configuration_utils import PretrainedConfig
  28. from aphrodite.attention import Attention, AttentionMetadata
  29. from aphrodite.common.config import CacheConfig, LoRAConfig
  30. from aphrodite.common.sequence import IntermediateTensors
  31. from aphrodite.distributed import get_tensor_model_parallel_world_size
  32. from aphrodite.modeling.layers.fused_moe import FusedMoE
  33. from aphrodite.modeling.layers.linear import (QKVParallelLinear,
  34. ReplicatedLinear,
  35. RowParallelLinear)
  36. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  37. from aphrodite.modeling.layers.rotary_embedding import get_rope
  38. from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
  39. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  40. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
  41. from aphrodite.modeling.model_loader.weight_utils import (
  42. default_weight_loader, maybe_remap_kv_scale_name)
  43. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  44. from aphrodite.quantization.base_config import QuantizationConfig
  45. from .interfaces import SupportsLoRA
  46. class PhiMoEConfig(PretrainedConfig):
  47. model_type = "phimoe"
  48. keys_to_ignore_at_inference = ["past_key_values"]
  49. def __init__(
  50. self,
  51. vocab_size=32000,
  52. hidden_size=4096,
  53. intermediate_size=14336,
  54. num_hidden_layers=32,
  55. num_attention_heads=32,
  56. num_key_value_heads=8,
  57. hidden_act="silu",
  58. max_position_embeddings=4096 * 32,
  59. initializer_range=0.02,
  60. rms_norm_eps=1e-5,
  61. use_cache=True,
  62. pad_token_id=None,
  63. bos_token_id=1,
  64. eos_token_id=2,
  65. tie_word_embeddings=False,
  66. rope_theta=1e6,
  67. sliding_window=None,
  68. attention_dropout=0.0,
  69. num_experts_per_tok=2,
  70. num_local_experts=16,
  71. output_router_logits=False,
  72. router_aux_loss_coef=0.001,
  73. router_jitter_noise=0.0,
  74. attention_bias=False,
  75. lm_head_bias=False,
  76. **kwargs,
  77. ):
  78. self.vocab_size = vocab_size
  79. self.max_position_embeddings = max_position_embeddings
  80. self.hidden_size = hidden_size
  81. self.intermediate_size = intermediate_size
  82. self.num_hidden_layers = num_hidden_layers
  83. self.num_attention_heads = num_attention_heads
  84. self.sliding_window = sliding_window
  85. self.attention_bias = attention_bias
  86. self.lm_head_bias = lm_head_bias
  87. # for backward compatibility
  88. if num_key_value_heads is None:
  89. num_key_value_heads = num_attention_heads
  90. self.num_key_value_heads = num_key_value_heads
  91. self.hidden_act = hidden_act
  92. self.initializer_range = initializer_range
  93. self.rms_norm_eps = rms_norm_eps
  94. self.use_cache = use_cache
  95. self.rope_theta = rope_theta
  96. self.attention_dropout = attention_dropout
  97. self.num_experts_per_tok = num_experts_per_tok
  98. self.num_local_experts = num_local_experts
  99. self.output_router_logits = output_router_logits
  100. self.router_aux_loss_coef = router_aux_loss_coef
  101. self.router_jitter_noise = router_jitter_noise
  102. super().__init__(
  103. pad_token_id=pad_token_id,
  104. bos_token_id=bos_token_id,
  105. eos_token_id=eos_token_id,
  106. tie_word_embeddings=tie_word_embeddings,
  107. **kwargs,
  108. )
  109. class mp(torch.autograd.Function):
  110. @staticmethod
  111. def forward(
  112. ctx,
  113. scores: torch.Tensor,
  114. multiplier: torch.Tensor,
  115. selected_experts: torch.Tensor,
  116. masked_gates: torch.Tensor,
  117. mask_for_one: torch.Tensor,
  118. ):
  119. ctx.save_for_backward(multiplier, selected_experts, masked_gates)
  120. return multiplier * mask_for_one
  121. @staticmethod
  122. def backward(
  123. ctx,
  124. grad_at_output: torch.Tensor,
  125. ):
  126. multiplier, selected_experts, masked_gates = ctx.saved_tensors
  127. grad_at_output = grad_at_output * multiplier
  128. grad_at_scores_expaned = masked_gates * grad_at_output.mul(-1)
  129. grad_at_scores_expaned.scatter_add_(
  130. dim=-1,
  131. index=selected_experts,
  132. src=grad_at_output,
  133. )
  134. return (
  135. grad_at_scores_expaned,
  136. None,
  137. None,
  138. None,
  139. None,
  140. )
  141. def sparsemixer(scores, jitter_eps=0.01):
  142. ################ first expert ################
  143. with torch.no_grad():
  144. # compute mask for sparsity
  145. mask_logits_threshold, max_ind = scores.max(dim=-1, keepdim=True)
  146. factor = scores.abs().clamp(min=mask_logits_threshold)
  147. mask_logits_threshold = (
  148. (mask_logits_threshold - scores) / factor) > (2 * jitter_eps)
  149. # apply mask
  150. masked_gates = scores.masked_fill(mask_logits_threshold, float("-inf"))
  151. selected_experts = max_ind
  152. # compute scores for gradients
  153. masked_gates = torch.softmax(masked_gates, dim=-1)
  154. multiplier_o = masked_gates.gather(dim=-1, index=selected_experts)
  155. multiplier = multiplier_o
  156. # masked out first expert
  157. masked_scores = torch.scatter(
  158. scores,
  159. -1,
  160. selected_experts,
  161. float("-inf"),
  162. )
  163. with torch.no_grad():
  164. # compute mask for sparsity
  165. mask_logits_threshold, max_ind = masked_scores.max(dim=-1,
  166. keepdim=True)
  167. factor = scores.abs().clamp(min=mask_logits_threshold)
  168. mask_logits_threshold = (
  169. (mask_logits_threshold - scores) / factor) > (2 * jitter_eps)
  170. # apply mask
  171. masked_gates_top2 = masked_scores.masked_fill(mask_logits_threshold,
  172. float("-inf"))
  173. selected_experts_top2 = max_ind
  174. # compute scores for gradients
  175. masked_gates_top2 = torch.softmax(masked_gates_top2, dim=-1)
  176. multiplier_top2 = masked_gates_top2.gather(dim=-1,
  177. index=selected_experts_top2)
  178. multiplier = torch.concat((multiplier, multiplier_top2), dim=-1)
  179. selected_experts = torch.concat((selected_experts, selected_experts_top2),
  180. dim=-1)
  181. return (
  182. multiplier,
  183. selected_experts,
  184. )
  185. def phimoe_routing_function(
  186. hidden_states: torch.Tensor,
  187. gating_output: torch.Tensor,
  188. topk: int,
  189. renormalize: bool,
  190. ):
  191. assert hidden_states.shape[0] == gating_output.shape[0], (
  192. "Number of tokens mismatch")
  193. assert topk == 2, "Only top-2 routing is supported"
  194. assert renormalize is False, "Renormalization is not supported"
  195. topk_weights, topk_ids = sparsemixer(gating_output)
  196. return topk_weights, topk_ids
  197. class PhiMoE(nn.Module):
  198. """A tensor-parallel MoE implementation for PhiMoE that shards each expert
  199. across all ranks.
  200. Each expert's weights are sharded across all ranks and a fused MoE
  201. kernel is used for the forward pass, and finally we reduce the outputs
  202. across ranks.
  203. """
  204. def __init__(
  205. self,
  206. num_experts: int,
  207. top_k: int,
  208. hidden_size: int,
  209. intermediate_size: int,
  210. params_dtype: Optional[torch.dtype] = None,
  211. quant_config: Optional[QuantizationConfig] = None,
  212. tp_size: Optional[int] = None,
  213. ):
  214. super().__init__()
  215. self.hidden_size = hidden_size
  216. # Gate always runs at half / full precision for now.
  217. self.gate = ReplicatedLinear(
  218. hidden_size,
  219. num_experts,
  220. bias=False,
  221. params_dtype=params_dtype,
  222. quant_config=None,
  223. )
  224. self.experts = FusedMoE(
  225. num_experts=num_experts,
  226. top_k=top_k,
  227. hidden_size=hidden_size,
  228. intermediate_size=intermediate_size,
  229. params_dtype=params_dtype,
  230. reduce_results=True,
  231. renormalize=False,
  232. quant_config=quant_config,
  233. tp_size=tp_size,
  234. custom_routing_function=phimoe_routing_function)
  235. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  236. # NOTE: hidden_states can have either 1D or 2D shape.
  237. orig_shape = hidden_states.shape
  238. hidden_states = hidden_states.view(-1, self.hidden_size)
  239. # router_logits: (num_tokens, n_experts)
  240. router_logits, _ = self.gate(hidden_states)
  241. final_hidden_states = self.experts(hidden_states, router_logits)
  242. return final_hidden_states.view(orig_shape)
  243. class PhiMoEAttention(nn.Module):
  244. def __init__(
  245. self,
  246. hidden_size: int,
  247. num_heads: int,
  248. num_kv_heads: int,
  249. max_position: int = 4096 * 32,
  250. rope_theta: float = 10000,
  251. cache_config: Optional[CacheConfig] = None,
  252. quant_config: Optional[QuantizationConfig] = None,
  253. rope_scaling: Optional[dict] = None,
  254. ) -> None:
  255. super().__init__()
  256. self.hidden_size = hidden_size
  257. tp_size = get_tensor_model_parallel_world_size()
  258. self.total_num_heads = num_heads
  259. assert self.total_num_heads % tp_size == 0
  260. self.num_heads = self.total_num_heads // tp_size
  261. self.total_num_kv_heads = num_kv_heads
  262. if self.total_num_kv_heads >= tp_size:
  263. # Number of KV heads is greater than TP size, so we partition
  264. # the KV heads across multiple tensor parallel GPUs.
  265. assert self.total_num_kv_heads % tp_size == 0
  266. else:
  267. # Number of KV heads is less than TP size, so we replicate
  268. # the KV heads across multiple tensor parallel GPUs.
  269. assert tp_size % self.total_num_kv_heads == 0
  270. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  271. self.head_dim = hidden_size // self.total_num_heads
  272. self.q_size = self.num_heads * self.head_dim
  273. self.kv_size = self.num_kv_heads * self.head_dim
  274. self.scaling = self.head_dim**-0.5
  275. self.rope_theta = rope_theta
  276. self.rope_scaling = rope_scaling
  277. self.qkv_proj = QKVParallelLinear(
  278. hidden_size,
  279. self.head_dim,
  280. self.total_num_heads,
  281. self.total_num_kv_heads,
  282. bias=True,
  283. quant_config=None,
  284. )
  285. self.o_proj = RowParallelLinear(
  286. self.total_num_heads * self.head_dim,
  287. hidden_size,
  288. bias=True,
  289. quant_config=None,
  290. )
  291. self.rotary_emb = get_rope(
  292. self.head_dim,
  293. rotary_dim=self.head_dim,
  294. max_position=max_position,
  295. base=int(self.rope_theta),
  296. is_neox_style=True,
  297. rope_scaling=self.rope_scaling,
  298. )
  299. self.attn = Attention(
  300. self.num_heads,
  301. self.head_dim,
  302. self.scaling,
  303. num_kv_heads=self.num_kv_heads,
  304. cache_config=cache_config,
  305. quant_config=quant_config,
  306. )
  307. def forward(
  308. self,
  309. positions: torch.Tensor,
  310. hidden_states: torch.Tensor,
  311. kv_cache: torch.Tensor,
  312. attn_metadata: AttentionMetadata,
  313. ) -> torch.Tensor:
  314. qkv, _ = self.qkv_proj(hidden_states)
  315. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  316. q, k = self.rotary_emb(positions, q, k)
  317. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  318. output, _ = self.o_proj(attn_output)
  319. return output
  320. class PhiMoEDecoderLayer(nn.Module):
  321. def __init__(
  322. self,
  323. config: PhiMoEConfig,
  324. cache_config: Optional[CacheConfig] = None,
  325. quant_config: Optional[QuantizationConfig] = None,
  326. ) -> None:
  327. super().__init__()
  328. self.hidden_size = config.hidden_size
  329. # Requires transformers > 4.32.0
  330. rope_theta = getattr(config, "rope_theta", 10000)
  331. self.self_attn = PhiMoEAttention(
  332. hidden_size=self.hidden_size,
  333. num_heads=config.num_attention_heads,
  334. max_position=config.max_position_embeddings,
  335. num_kv_heads=config.num_key_value_heads,
  336. rope_theta=rope_theta,
  337. cache_config=cache_config,
  338. quant_config=quant_config,
  339. rope_scaling=config.rope_scaling,
  340. )
  341. self.block_sparse_moe = PhiMoE(
  342. num_experts=config.num_local_experts,
  343. top_k=config.num_experts_per_tok,
  344. hidden_size=config.hidden_size,
  345. intermediate_size=config.intermediate_size,
  346. quant_config=quant_config,
  347. )
  348. self.input_layernorm = nn.LayerNorm(config.hidden_size,
  349. eps=config.rms_norm_eps,
  350. elementwise_affine=True)
  351. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
  352. eps=config.rms_norm_eps,
  353. elementwise_affine=True)
  354. def forward(
  355. self,
  356. positions: torch.Tensor,
  357. hidden_states: torch.Tensor,
  358. kv_cache: torch.Tensor,
  359. attn_metadata: AttentionMetadata,
  360. residual: Optional[torch.Tensor],
  361. ) -> torch.Tensor:
  362. residual = hidden_states
  363. # Self Attention
  364. hidden_states = self.input_layernorm(hidden_states)
  365. hidden_states = self.self_attn(
  366. positions=positions,
  367. hidden_states=hidden_states,
  368. kv_cache=kv_cache,
  369. attn_metadata=attn_metadata,
  370. )
  371. hidden_states = hidden_states + residual
  372. # Fully Connected
  373. residual = hidden_states
  374. hidden_states = self.post_attention_layernorm(hidden_states)
  375. hidden_states = self.block_sparse_moe(hidden_states)
  376. hidden_states = hidden_states + residual
  377. return hidden_states, residual
  378. class PhiMoEModel(nn.Module):
  379. def __init__(
  380. self,
  381. config: PhiMoEConfig,
  382. cache_config: Optional[CacheConfig] = None,
  383. quant_config: Optional[QuantizationConfig] = None,
  384. lora_config: Optional[LoRAConfig] = None,
  385. ) -> None:
  386. super().__init__()
  387. self.padding_idx = config.pad_token_id
  388. lora_vocab = ((lora_config.lora_extra_vocab_size *
  389. (lora_config.max_loras or 1)) if lora_config else 0)
  390. self.vocab_size = config.vocab_size + lora_vocab
  391. self.org_vocab_size = config.vocab_size
  392. self.embed_tokens = VocabParallelEmbedding(
  393. self.vocab_size,
  394. config.hidden_size,
  395. org_num_embeddings=config.vocab_size,
  396. )
  397. self.layers = nn.ModuleList([
  398. PhiMoEDecoderLayer(config, cache_config, quant_config=quant_config)
  399. for _ in range(config.num_hidden_layers)
  400. ])
  401. self.norm = nn.LayerNorm(config.hidden_size,
  402. eps=config.rms_norm_eps,
  403. elementwise_affine=True)
  404. def forward(
  405. self,
  406. input_ids: torch.Tensor,
  407. positions: torch.Tensor,
  408. kv_caches: List[torch.Tensor],
  409. attn_metadata: AttentionMetadata,
  410. ) -> torch.Tensor:
  411. hidden_states = self.embed_tokens(input_ids)
  412. residual = None
  413. for i in range(len(self.layers)):
  414. layer = self.layers[i]
  415. hidden_states, residual = layer(positions, hidden_states,
  416. kv_caches[i], attn_metadata,
  417. residual)
  418. hidden_states = self.norm(hidden_states)
  419. return hidden_states
  420. class PhiMoEForCausalLM(nn.Module, SupportsLoRA):
  421. fall_back_to_pt_during_load = False
  422. packed_modules_mapping = {
  423. "qkv_proj": [
  424. "q_proj",
  425. "k_proj",
  426. "v_proj",
  427. ],
  428. }
  429. # LoRA specific attributes
  430. supported_lora_modules = [
  431. "qkv_proj",
  432. "o_proj",
  433. "embed_tokens",
  434. "lm_head",
  435. "w1",
  436. "w2",
  437. "w3",
  438. "gate",
  439. ]
  440. embedding_modules = {
  441. "embed_tokens": "input_embeddings",
  442. "lm_head": "output_embeddings",
  443. }
  444. embedding_padding_modules = ["lm_head"]
  445. def __init__(
  446. self,
  447. config: PhiMoEConfig,
  448. cache_config: Optional[CacheConfig] = None,
  449. quant_config: Optional[QuantizationConfig] = None,
  450. lora_config: Optional[LoRAConfig] = None,
  451. ) -> None:
  452. super().__init__()
  453. self.config = config
  454. self.lora_config = lora_config
  455. self.model = PhiMoEModel(config,
  456. cache_config,
  457. quant_config,
  458. lora_config=lora_config)
  459. self.unpadded_vocab_size = config.vocab_size
  460. if lora_config:
  461. self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
  462. self.lm_head = ParallelLMHead(
  463. self.unpadded_vocab_size,
  464. config.hidden_size,
  465. org_num_embeddings=config.vocab_size,
  466. padding_size=(
  467. DEFAULT_VOCAB_PADDING_SIZE
  468. # We need bigger padding if using lora for kernel
  469. # compatibility
  470. if not lora_config else lora_config.lora_vocab_padding_size),
  471. quant_config=None,
  472. bias=True,
  473. )
  474. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  475. config.vocab_size)
  476. self.sampler = Sampler()
  477. def forward(
  478. self,
  479. input_ids: torch.Tensor,
  480. positions: torch.Tensor,
  481. kv_caches: List[torch.Tensor],
  482. attn_metadata: AttentionMetadata,
  483. intermediate_tensors: Optional[IntermediateTensors] = None,
  484. ) -> torch.Tensor:
  485. hidden_states = self.model(input_ids, positions, kv_caches,
  486. attn_metadata)
  487. return hidden_states
  488. def compute_logits(self, hidden_states: torch.Tensor,
  489. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  490. logits = self.logits_processor(self.lm_head, hidden_states,
  491. sampling_metadata)
  492. return logits
  493. def sample(
  494. self,
  495. logits: Optional[torch.Tensor],
  496. sampling_metadata: SamplingMetadata,
  497. ) -> Optional[SamplerOutput]:
  498. next_tokens = self.sampler(logits, sampling_metadata)
  499. return next_tokens
  500. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  501. stacked_params_mapping = [
  502. # (param_name, shard_name, shard_id)
  503. ("qkv_proj", "q_proj", "q"),
  504. ("qkv_proj", "k_proj", "k"),
  505. ("qkv_proj", "v_proj", "v"),
  506. ]
  507. expert_params_mapping = FusedMoE.make_expert_params_mapping(
  508. ckpt_gate_proj_name="w1",
  509. ckpt_down_proj_name="w2",
  510. ckpt_up_proj_name="w3",
  511. num_experts=self.config.num_local_experts)
  512. params_dict = dict(self.named_parameters())
  513. for name, loaded_weight in weights:
  514. if "rotary_emb.inv_freq" in name:
  515. continue
  516. for param_name, weight_name, shard_id in stacked_params_mapping:
  517. if weight_name not in name:
  518. continue
  519. name = name.replace(weight_name, param_name)
  520. # Skip loading extra bias for GPTQ models.
  521. if name.endswith(".bias") and name not in params_dict:
  522. continue
  523. param = params_dict[name]
  524. weight_loader = param.weight_loader
  525. weight_loader(param, loaded_weight, shard_id)
  526. break
  527. else:
  528. for mapping in expert_params_mapping:
  529. param_name, weight_name, expert_id, shard_id = mapping
  530. if weight_name not in name:
  531. continue
  532. name = name.replace(weight_name, param_name)
  533. param = params_dict[name]
  534. weight_loader = param.weight_loader
  535. weight_loader(
  536. param,
  537. loaded_weight,
  538. name,
  539. shard_id=shard_id,
  540. expert_id=expert_id,
  541. )
  542. break
  543. else:
  544. # Skip loading extra bias for GPTQ models.
  545. if name.endswith(".bias") and name not in params_dict:
  546. continue
  547. # Remapping the name of FP8 kv-scale.
  548. name = maybe_remap_kv_scale_name(name, params_dict)
  549. if name is None:
  550. continue
  551. param = params_dict[name]
  552. weight_loader = getattr(param, "weight_loader",
  553. default_weight_loader)
  554. weight_loader(param, loaded_weight)