# coding=utf-8 # Adapted from # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py # Copyright 2023 The vLLM team. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Inference-only PhiMoE model.""" from typing import Iterable, List, Optional, Tuple import torch from torch import nn from transformers.configuration_utils import PretrainedConfig from aphrodite.attention import Attention, AttentionMetadata from aphrodite.common.config import CacheConfig, LoRAConfig from aphrodite.common.sequence import IntermediateTensors from aphrodite.distributed import get_tensor_model_parallel_world_size from aphrodite.modeling.layers.fused_moe import FusedMoE from aphrodite.modeling.layers.linear import (QKVParallelLinear, ReplicatedLinear, RowParallelLinear) from aphrodite.modeling.layers.logits_processor import LogitsProcessor from aphrodite.modeling.layers.rotary_embedding import get_rope from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput from aphrodite.modeling.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from aphrodite.modeling.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from aphrodite.modeling.sampling_metadata import SamplingMetadata from aphrodite.quantization.base_config import QuantizationConfig from .interfaces import SupportsLoRA class PhiMoEConfig(PretrainedConfig): model_type = "phimoe" keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, vocab_size=32000, hidden_size=4096, intermediate_size=14336, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=8, hidden_act="silu", max_position_embeddings=4096 * 32, initializer_range=0.02, rms_norm_eps=1e-5, use_cache=True, pad_token_id=None, bos_token_id=1, eos_token_id=2, tie_word_embeddings=False, rope_theta=1e6, sliding_window=None, attention_dropout=0.0, num_experts_per_tok=2, num_local_experts=16, output_router_logits=False, router_aux_loss_coef=0.001, router_jitter_noise=0.0, attention_bias=False, lm_head_bias=False, **kwargs, ): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.sliding_window = sliding_window self.attention_bias = attention_bias self.lm_head_bias = lm_head_bias # for backward compatibility if num_key_value_heads is None: num_key_value_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache self.rope_theta = rope_theta self.attention_dropout = attention_dropout self.num_experts_per_tok = num_experts_per_tok self.num_local_experts = num_local_experts self.output_router_logits = output_router_logits self.router_aux_loss_coef = router_aux_loss_coef self.router_jitter_noise = router_jitter_noise super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs, ) class mp(torch.autograd.Function): @staticmethod def forward( ctx, scores: torch.Tensor, multiplier: torch.Tensor, selected_experts: torch.Tensor, masked_gates: torch.Tensor, mask_for_one: torch.Tensor, ): ctx.save_for_backward(multiplier, selected_experts, masked_gates) return multiplier * mask_for_one @staticmethod def backward( ctx, grad_at_output: torch.Tensor, ): multiplier, selected_experts, masked_gates = ctx.saved_tensors grad_at_output = grad_at_output * multiplier grad_at_scores_expaned = masked_gates * grad_at_output.mul(-1) grad_at_scores_expaned.scatter_add_( dim=-1, index=selected_experts, src=grad_at_output, ) return ( grad_at_scores_expaned, None, None, None, None, ) def sparsemixer(scores, jitter_eps=0.01): ################ first expert ################ with torch.no_grad(): # compute mask for sparsity mask_logits_threshold, max_ind = scores.max(dim=-1, keepdim=True) factor = scores.abs().clamp(min=mask_logits_threshold) mask_logits_threshold = ( (mask_logits_threshold - scores) / factor) > (2 * jitter_eps) # apply mask masked_gates = scores.masked_fill(mask_logits_threshold, float("-inf")) selected_experts = max_ind # compute scores for gradients masked_gates = torch.softmax(masked_gates, dim=-1) multiplier_o = masked_gates.gather(dim=-1, index=selected_experts) multiplier = multiplier_o # masked out first expert masked_scores = torch.scatter( scores, -1, selected_experts, float("-inf"), ) with torch.no_grad(): # compute mask for sparsity mask_logits_threshold, max_ind = masked_scores.max(dim=-1, keepdim=True) factor = scores.abs().clamp(min=mask_logits_threshold) mask_logits_threshold = ( (mask_logits_threshold - scores) / factor) > (2 * jitter_eps) # apply mask masked_gates_top2 = masked_scores.masked_fill(mask_logits_threshold, float("-inf")) selected_experts_top2 = max_ind # compute scores for gradients masked_gates_top2 = torch.softmax(masked_gates_top2, dim=-1) multiplier_top2 = masked_gates_top2.gather(dim=-1, index=selected_experts_top2) multiplier = torch.concat((multiplier, multiplier_top2), dim=-1) selected_experts = torch.concat((selected_experts, selected_experts_top2), dim=-1) return ( multiplier, selected_experts, ) def phimoe_routing_function( hidden_states: torch.Tensor, gating_output: torch.Tensor, topk: int, renormalize: bool, ): assert hidden_states.shape[0] == gating_output.shape[0], ( "Number of tokens mismatch") assert topk == 2, "Only top-2 routing is supported" assert renormalize is False, "Renormalization is not supported" topk_weights, topk_ids = sparsemixer(gating_output) return topk_weights, topk_ids class PhiMoE(nn.Module): """A tensor-parallel MoE implementation for PhiMoE that shards each expert across all ranks. Each expert's weights are sharded across all ranks and a fused MoE kernel is used for the forward pass, and finally we reduce the outputs across ranks. """ def __init__( self, num_experts: int, top_k: int, hidden_size: int, intermediate_size: int, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, ): super().__init__() self.hidden_size = hidden_size # Gate always runs at half / full precision for now. self.gate = ReplicatedLinear( hidden_size, num_experts, bias=False, params_dtype=params_dtype, quant_config=None, ) self.experts = FusedMoE( num_experts=num_experts, top_k=top_k, hidden_size=hidden_size, intermediate_size=intermediate_size, params_dtype=params_dtype, reduce_results=True, renormalize=False, quant_config=quant_config, tp_size=tp_size, custom_routing_function=phimoe_routing_function) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. orig_shape = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_size) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states, router_logits) return final_hidden_states.view(orig_shape) class PhiMoEAttention(nn.Module): def __init__( self, hidden_size: int, num_heads: int, num_kv_heads: int, max_position: int = 4096 * 32, rope_theta: float = 10000, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, rope_scaling: Optional[dict] = None, ) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.total_num_kv_heads = num_kv_heads if self.total_num_kv_heads >= tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. assert self.total_num_kv_heads % tp_size == 0 else: # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.head_dim = hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta self.rope_scaling = rope_scaling self.qkv_proj = QKVParallelLinear( hidden_size, self.head_dim, self.total_num_heads, self.total_num_kv_heads, bias=True, quant_config=None, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=True, quant_config=None, ) self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, max_position=max_position, base=int(self.rope_theta), is_neox_style=True, rope_scaling=self.rope_scaling, ) self.attn = Attention( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, ) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.o_proj(attn_output) return output class PhiMoEDecoderLayer(nn.Module): def __init__( self, config: PhiMoEConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 10000) self.self_attn = PhiMoEAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, max_position=config.max_position_embeddings, num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, cache_config=cache_config, quant_config=quant_config, rope_scaling=config.rope_scaling, ) self.block_sparse_moe = PhiMoE( num_experts=config.num_local_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, quant_config=quant_config, ) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> torch.Tensor: residual = hidden_states # Self Attention hidden_states = self.input_layernorm(hidden_states) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, attn_metadata=attn_metadata, ) hidden_states = hidden_states + residual # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.block_sparse_moe(hidden_states) hidden_states = hidden_states + residual return hidden_states, residual class PhiMoEModel(nn.Module): def __init__( self, config: PhiMoEConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.padding_idx = config.pad_token_id lora_vocab = ((lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) if lora_config else 0) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ PhiMoEDecoderLayer(config, cache_config, quant_config=quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) residual = None for i in range(len(self.layers)): layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states, kv_caches[i], attn_metadata, residual) hidden_states = self.norm(hidden_states) return hidden_states class PhiMoEForCausalLM(nn.Module, SupportsLoRA): fall_back_to_pt_during_load = False packed_modules_mapping = { "qkv_proj": [ "q_proj", "k_proj", "v_proj", ], } # LoRA specific attributes supported_lora_modules = [ "qkv_proj", "o_proj", "embed_tokens", "lm_head", "w1", "w2", "w3", "gate", ] embedding_modules = { "embed_tokens": "input_embeddings", "lm_head": "output_embeddings", } embedding_padding_modules = ["lm_head"] def __init__( self, config: PhiMoEConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config self.lora_config = lora_config self.model = PhiMoEModel(config, cache_config, quant_config, lora_config=lora_config) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, padding_size=( DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility if not lora_config else lora_config.lora_vocab_padding_size), quant_config=None, bias=True, ) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) self.sampler = Sampler() def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits def sample( self, logits: Optional[torch.Tensor], sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ] expert_params_mapping = FusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="w1", ckpt_down_proj_name="w2", ckpt_up_proj_name="w3", num_experts=self.config.num_local_experts) params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue name = name.replace(weight_name, param_name) param = params_dict[name] weight_loader = param.weight_loader weight_loader( param, loaded_weight, name, shard_id=shard_id, expert_id=expert_id, ) break else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue # Remapping the name of FP8 kv-scale. name = maybe_remap_kv_scale_name(name, params_dict) if name is None: continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight)