# coding=utf-8 # Adapted from # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py # Copyright 2023 The PygmalionAI team. # Copyright 2023 The vLLM team. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Inference-only MiniCPM model compatible with HuggingFace weights.""" import math from typing import Any, Dict, Iterable, List, Optional, Tuple import torch from torch import nn from transformers import PretrainedConfig from aphrodite.attention import Attention, AttentionMetadata from aphrodite.common.config import CacheConfig, LoRAConfig from aphrodite.common.sequence import IntermediateTensors, SamplerOutput from aphrodite.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from aphrodite.modeling.layers.activation import SiluAndMul from aphrodite.modeling.layers.fused_moe import fused_moe from aphrodite.modeling.layers.layernorm import RMSNorm from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear) from aphrodite.modeling.layers.logits_processor import LogitsProcessor from aphrodite.modeling.layers.rotary_embedding import get_rope from aphrodite.modeling.layers.sampler import Sampler from aphrodite.modeling.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from aphrodite.modeling.model_loader.weight_utils import default_weight_loader from aphrodite.modeling.models.interfaces import SupportsLoRA from aphrodite.modeling.sampling_metadata import SamplingMetadata from aphrodite.modeling.utils import set_weight_attrs from aphrodite.quantization.base_config import QuantizationConfig class MiniCPMMoE(nn.Module): """A tensor-parallel MoE implementation that shards each expert across all ranks. Each expert's weights are sharded across all ranks and a fused MoE kernel is used for the forward pass, and finally we reduce the outputs across ranks. """ def __init__( self, num_experts: int, top_k: int, hidden_size: int, intermediate_size: int, params_dtype: Optional[torch.dtype] = None, tp_size: Optional[int] = None, ): super().__init__() self.tp_size = tp_size or get_tensor_model_parallel_world_size() self.num_total_experts = num_experts self.top_k = top_k self.hidden_size = hidden_size self.intermediate_size = intermediate_size // self.tp_size if params_dtype is None: params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype self.gate = ReplicatedLinear(self.hidden_size, self.num_total_experts, bias=False, params_dtype=self.params_dtype, quant_config=None) self.ws = nn.Parameter( torch.empty(self.num_total_experts, 2 * self.intermediate_size, self.hidden_size, device="cuda", dtype=self.params_dtype)) self.w2s = nn.Parameter( torch.empty(self.num_total_experts, self.hidden_size, self.intermediate_size, device="cuda", dtype=self.params_dtype)) set_weight_attrs(self.ws, { "weight_loader": self.weight_loader, }) set_weight_attrs(self.w2s, { "weight_loader": self.weight_loader, }) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, expert_id: int): tp_rank = get_tensor_model_parallel_rank() param_data = param.data shard_size = self.intermediate_size shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) if weight_name.endswith("w1.weight"): param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] if weight_name.endswith("w3.weight"): param_data[expert_id, shard_size:2 * shard_size, :] = loaded_weight[shard, :] if weight_name.endswith("w2.weight"): param_data[expert_id, :, :] = loaded_weight[:, shard] def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_size = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_size) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) final_hidden_states = fused_moe(hidden_states, self.ws, self.w2s, router_logits, self.top_k, renormalize=True, inplace=True) if self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) return final_hidden_states.view(num_tokens, hidden_size) class MiniCPMMLP(nn.Module): def __init__( self, hidden_size: int, intermediate_size: int, hidden_act: str, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, quant_config=quant_config) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") self.act_fn = SiluAndMul() def forward(self, x): gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.down_proj(x) return x class MiniCPMAttention(nn.Module): def __init__( self, hidden_size: int, num_heads: int, num_kv_heads: int, rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.total_num_kv_heads = num_kv_heads if self.total_num_kv_heads >= tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. assert self.total_num_kv_heads % tp_size == 0 else: # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.head_dim = hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings self.qkv_proj = QKVParallelLinear( hidden_size, self.head_dim, self.total_num_heads, self.total_num_kv_heads, bias=False, quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, quant_config=quant_config, ) self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, max_position=max_position_embeddings, base=rope_theta, rope_scaling=rope_scaling, ) # set rope as fp32 instead of bf16 self.rotary_emb.cos_sin_cache = self.rotary_emb._compute_cos_sin_cache( ) self.attn = Attention(self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) orig_dtype = q.dtype q, k = q.float(), k.float() q, k = self.rotary_emb(positions, q, k) q, k = q.to(orig_dtype), k.to(orig_dtype) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.o_proj(attn_output) return output class MiniCPMDecoderLayer(nn.Module): def __init__( self, config, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = MiniCPMAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, cache_config=cache_config, quant_config=quant_config, ) self.num_experts = getattr(self.config, "num_experts", 0) if self.num_experts == 0: self.mlp = MiniCPMMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, ) else: self.mlp = MiniCPMMoE(num_experts=config.num_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, attn_metadata=attn_metadata, ) hidden_states = residual + hidden_states * \ (self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)) # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states * \ (self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)) return hidden_states, None class MiniCPMModel(nn.Module): def __init__( self, config, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config self.padding_idx = config.pad_token_id lora_vocab = (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) if lora_config else 0 self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ MiniCPMDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: embedding = self.embed_tokens(input_ids) return embedding * self.config.scale_emb def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.get_input_embeddings(input_ids) residual = None for i in range(len(self.layers)): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, kv_caches[i], attn_metadata, residual, ) hidden_states = self.norm(hidden_states) return hidden_states class MiniCPMForCausalLM(nn.Module, SupportsLoRA): packed_modules_mapping = { "qkv_proj": [ "q_proj", "k_proj", "v_proj", ], "gate_up_proj": [ "gate_proj", "up_proj", ], } # LoRA specific attributes supported_lora_modules = [ "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens", "lm_head", ] embedding_modules = { "embed_tokens": "input_embeddings", "lm_head": "output_embeddings", } embedding_padding_modules = ["lm_head"] def __init__( self, config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config self.lora_config = lora_config self.num_experts = getattr(self.config, "num_experts", 0) self.quant_config = quant_config self.model = MiniCPMModel(config, cache_config, quant_config, lora_config=lora_config) unpadded_vocab_size = config.vocab_size if lora_config: unpadded_vocab_size += lora_config.lora_extra_vocab_size if not self.config.tie_word_embeddings: self.lm_head = ParallelLMHead( unpadded_vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility if not lora_config else lora_config.lora_vocab_padding_size, quant_config=quant_config, ) self.scale_width = self.config.hidden_size / self.config.dim_model_base self.logits_processor = LogitsProcessor(unpadded_vocab_size, config.vocab_size) self.sampler = Sampler() def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: hidden_states = hidden_states / self.scale_width if self.config.tie_word_embeddings: lm_head = self.model.embed_tokens else: lm_head = self.lm_head logits = self.logits_processor(lm_head, hidden_states, sampling_metadata) return logits def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] expert_params_mapping = [ # (param_name, weight_name, expert_id) ("ws" if weight_name in ["w1", "w3"] else "w2s", f"experts.{expert_id}.{weight_name}.weight", expert_id) for expert_id in range(self.num_experts) for weight_name in ["w1", "w2", "w3"] ] params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue if ("rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue # With tie_word_embeddings, we can skip lm_head.weight # The weight might appear unnecessarily in the files if the model is # processed with quantization, LoRA, fine-tuning, etc. if self.config.tie_word_embeddings and "lm_head.weight" in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: for param_name, weight_name, expert_id in expert_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, weight_name, expert_id=expert_id) break else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight)