# coding=utf-8 # Adapted from # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py # Copyright 2023 The PygmalionAI team. # Copyright 2023 The vLLM team. # Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Deepseek model.""" from typing import Any, Dict, List, Optional import numpy as np import torch from torch import nn from transformers import PretrainedConfig from aphrodite.attention import Attention, AttentionMetadata from aphrodite.modeling.layers.activation import SiluAndMul from aphrodite.modeling.layers.fused_moe import fused_topk from aphrodite.modeling.layers.layernorm import RMSNorm from aphrodite.modeling.layers.linear import ( LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear, UnquantizedLinearMethod) from aphrodite.modeling.layers.logits_processor import LogitsProcessor from aphrodite.modeling.layers.rotary_embedding import get_rope from aphrodite.modeling.layers.sampler import Sampler from aphrodite.modeling.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from aphrodite.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from aphrodite.modeling.sampling_metadata import SamplingMetadata from aphrodite.modeling.hf_downloader import (default_weight_loader, hf_model_weights_iterator) from aphrodite.common.sequence import SamplerOutput class DeepseekMLP(nn.Module): def __init__( self, hidden_size: int, intermediate_size: int, hidden_act: str, linear_method: Optional[LinearMethodBase] = None, reduce_results: bool = True, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, linear_method=linear_method) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, linear_method=linear_method, reduce_results=reduce_results) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") self.act_fn = SiluAndMul() def forward(self, x): gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.down_proj(x) return x class DeepseekExpertMLP(nn.Module): def __init__( self, hidden_size: int, intermediate_size: int, hidden_act: str, linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() self.gate_proj = ReplicatedLinear(hidden_size, intermediate_size, bias=False, linear_method=linear_method) self.up_proj = ReplicatedLinear(hidden_size, intermediate_size, bias=False, linear_method=linear_method) self.down_proj = ReplicatedLinear(intermediate_size, hidden_size, bias=False, linear_method=linear_method) self.act_fn = nn.SiLU() def forward(self, hidden_states): gate_out, _ = self.gate_proj(hidden_states) gate_out = self.act_fn(gate_out) up_out, _ = self.up_proj(hidden_states) current_hidden_states = gate_out * up_out current_hidden_states, _ = self.down_proj(current_hidden_states) return current_hidden_states class DeepseekMoE(nn.Module): def __init__( self, config: PretrainedConfig, linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.config = config self.rank = get_tensor_model_parallel_rank() self.tp_size = get_tensor_model_parallel_world_size() self.n_routed_experts = config.n_routed_experts self.top_k = config.num_experts_per_tok self.linear_method = linear_method if self.linear_method is None: self.linear_method = UnquantizedLinearMethod() if not isinstance( self.linear_method, UnquantizedLinearMethod ) and not self.linear_method.quant_config.support_fused_moe(): if self.tp_size > self.n_routed_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " f"the number of experts {self.n_routed_experts}.") # Split experts equally between ranks self.expert_indicies = np.array_split(range( self.n_routed_experts), self.tp_size)[self.rank].tolist() if not self.expert_indicies: raise ValueError( f"Rank {self.rank} has no experts assigned to it.") self.experts = nn.ModuleList([ DeepseekExpertMLP( hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, hidden_act=config.hidden_act, linear_method=linear_method, ) if idx in self.expert_indicies else None for idx in range(self.n_routed_experts) ]) else: self.w1 = MergedColumnParallelLinear( config.hidden_size, [config.moe_intermediate_size] * 2, bias=False, linear_method=linear_method, num_experts=self.n_routed_experts) self.w2 = RowParallelLinear(config.moe_intermediate_size, config.hidden_size, bias=False, linear_method=linear_method, num_experts=self.n_routed_experts) self.gate = ReplicatedLinear(config.hidden_size, self.n_routed_experts, bias=False, linear_method=None) if config.n_shared_experts is not None: intermediate_size = (config.moe_intermediate_size * config.n_shared_experts) self.shared_experts = DeepseekMLP( hidden_size=config.hidden_size, intermediate_size=intermediate_size, hidden_act=config.hidden_act, linear_method=linear_method, reduce_results=False, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) if self.config.n_shared_experts is not None: shared_output = self.shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) if not isinstance( self.linear_method, UnquantizedLinearMethod ) and not self.linear_method.quant_config.support_fused_moe(): routing_weights, selected_experts = fused_topk( router_logits, self.top_k, renormalize=self.config.norm_topk_prob) final_hidden_states = None for expert_idx in self.expert_indicies: expert_layer = self.experts[expert_idx] expert_mask = (selected_experts == expert_idx) expert_weights = (routing_weights * expert_mask).sum( dim=-1, keepdim=True) current_hidden_states = expert_layer(hidden_states).mul_( expert_weights) if final_hidden_states is None: final_hidden_states = current_hidden_states else: final_hidden_states.add_(current_hidden_states) else: final_hidden_states = self.linear_method.apply_moe_weights( self.w1.linear_weights, self.w2.linear_weights, hidden_states, router_logits, self.top_k, renormalize=self.config.norm_topk_prob, ) if self.config.n_shared_experts is not None: final_hidden_states = final_hidden_states + shared_output final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) return final_hidden_states.view(num_tokens, hidden_dim) class DeepseekAttention(nn.Module): def __init__( self, hidden_size: int, num_heads: int, num_kv_heads: int, rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.total_num_kv_heads = num_kv_heads if self.total_num_kv_heads >= tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. assert self.total_num_kv_heads % tp_size == 0 else: # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.head_dim = hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings self.qkv_proj = QKVParallelLinear( hidden_size, self.head_dim, self.total_num_heads, self.total_num_kv_heads, bias=False, linear_method=linear_method, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, linear_method=linear_method, ) self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, max_position=max_position_embeddings, base=rope_theta, rope_scaling=rope_scaling, ) self.attn = Attention(self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.o_proj(attn_output) return output class DeepseekDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, layer_idx: int, linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = DeepseekAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, linear_method=linear_method, ) if (config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace and layer_idx % config.moe_layer_freq == 0): self.mlp = DeepseekMoE(config=config, linear_method=linear_method) else: self.mlp = DeepseekMLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, linear_method=linear_method, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> torch.Tensor: # Self Attention if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm( hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, attn_metadata=attn_metadata, ) # Fully Connected hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual class DeepseekModel(nn.Module): def __init__( self, config: PretrainedConfig, linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size, linear_method=linear_method) self.layers = nn.ModuleList([ DeepseekDecoderLayer(config, layer_idx, linear_method=linear_method) for layer_idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) residual = None for i in range(len(self.layers)): layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states, kv_caches[i], attn_metadata, residual) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states class DeepseekForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() self.config = config self.linear_method = linear_method self.model = DeepseekModel(config, linear_method) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits def sample( self, logits: Optional[torch.Tensor], sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ("mlp.gate_up_proj", "mlp.gate_proj", 0), ("mlp.gate_up_proj", "mlp.up_proj", 1), ("shared_experts.gate_up_proj", "shared_experts.gate_proj", 0), ("shared_experts.gate_up_proj", "shared_experts.up_proj", 1), ] expert_params_mapping = [ # (param_name, weight_name, shard_id, expert_id) ("w1" if weight_name in ["gate_proj", "up_proj"] else "w2", f"experts.{expert_id}.{weight_name}", shard_id, expert_id) for expert_id in range(self.config.n_routed_experts) for weight_name, shard_id in [("gate_proj", 0), ("up_proj", 1), ("down_proj", None)] ] if self.linear_method is None or ( self.linear_method.quant_config.support_fused_moe()) else [] params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision, self.config, fall_back_to_pt=False): if "rotary_emb.inv_freq" in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue # Skip experts that are not assigned to this worker. if (("mlp.experts." in name or "mlp.shared_experts." in name) and name not in params_dict): continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: for (param_name, weight_name, shard_id, expert_id) in expert_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader if shard_id is None: weight_loader(param, loaded_weight, expert_id=expert_id) else: weight_loader(param, loaded_weight, shard_id, expert_id=expert_id) break else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue # Skip experts that are not assigned to this worker. if (("mlp.experts." in name or "mlp.shared_experts." in name) and name not in params_dict): continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight)