# coding=utf-8 # Adapted from # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py # Copyright 2023 The PygmalionAI team. # Copyright 2023 The vLLM team. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Mixtral model.""" from typing import List, Optional import numpy as np import torch from torch import nn from transformers import MixtralConfig from aphrodite.attention import Attention, AttentionMetadata from aphrodite.common.config import LoRAConfig from aphrodite.modeling.layers.fused_moe import fused_topk from aphrodite.modeling.layers.layernorm import RMSNorm from aphrodite.modeling.layers.linear import ( ColumnParallelLinear, LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear, UnquantizedLinearMethod) from aphrodite.modeling.layers.logits_processor import LogitsProcessor from aphrodite.modeling.layers.rotary_embedding import get_rope from aphrodite.modeling.layers.sampler import Sampler from aphrodite.modeling.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from aphrodite.distributed import (tensor_model_parallel_all_reduce) from aphrodite.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from aphrodite.modeling.sampling_metadata import SamplingMetadata from aphrodite.modeling.hf_downloader import (default_weight_loader, hf_model_weights_iterator) from aphrodite.common.sequence import SamplerOutput class MixtralMLP(nn.Module): def __init__( self, num_experts: int, hidden_size: int, intermediate_size: int, linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() self.num_experts = num_experts self.ffn_dim = intermediate_size self.hidden_dim = hidden_size self.w1 = ReplicatedLinear(self.hidden_dim, self.ffn_dim, bias=False, linear_method=linear_method) self.w2 = ReplicatedLinear(self.ffn_dim, self.hidden_dim, bias=False, linear_method=linear_method) self.w3 = ReplicatedLinear(self.hidden_dim, self.ffn_dim, bias=False, linear_method=linear_method) # TODO: Use aphrodite's SiluAndMul self.act_fn = nn.SiLU() def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: w1_out, _ = self.w1(hidden_states) w1_out = self.act_fn(w1_out) w3_out, _ = self.w3(hidden_states) current_hidden_states = w1_out * w3_out current_hidden_states, _ = self.w2(current_hidden_states) return current_hidden_states class MixtralMoE(nn.Module): """A tensor-parallel MoE implementation for Mixtral that shards each expert across all ranks. Each expert's weights are sharded across all ranks and a fused MoE kernel is used for the forward pass, and finally we reduce the outputs across ranks. """ def __init__( self, num_experts: int, top_k: int, hidden_size: int, intermediate_size: int, tp_size: Optional[int] = None, linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.rank = get_tensor_model_parallel_rank() self.tp_size = tp_size or get_tensor_model_parallel_world_size() self.num_total_experts = num_experts self.top_k = top_k self.hidden_size = hidden_size self.intermediate_size = intermediate_size // self.tp_size self.linear_method = linear_method if self.linear_method is None: self.linear_method = UnquantizedLinearMethod() self.gate = ReplicatedLinear(self.hidden_size, self.num_total_experts, bias=False, linear_method=None) if not isinstance( self.linear_method, UnquantizedLinearMethod ) and not self.linear_method.quant_config.support_fused_moe(): if self.tp_size > self.num_total_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " f"the number of experts {self.num_total_experts}.") # Split experts equally between ranks self.expert_indicies = np.array_split( range(self.num_total_experts), self.tp_size)[self.rank].tolist() if not self.expert_indicies: raise ValueError( f"Rank {self.rank} has no experts assigned to it.") self.experts = nn.ModuleList([ MixtralMLP(self.num_total_experts, hidden_size, intermediate_size, linear_method=linear_method) if idx in self.expert_indicies else None for idx in range(self.num_total_experts) ]) else: self.ws = MergedColumnParallelLinear(hidden_size, [intermediate_size] * 2, bias=False, linear_method=linear_method, num_experts=num_experts) self.w2s = RowParallelLinear(intermediate_size, hidden_size, bias=False, linear_method=linear_method, num_experts=num_experts) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_size = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_size) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) if not isinstance( self.linear_method, UnquantizedLinearMethod ) and not self.linear_method.quant_config.support_fused_moe(): routing_weights, selected_experts = fused_topk(router_logits, self.top_k, renormalize=True) final_hidden_states = None for expert_idx in self.expert_indicies: expert_layer = self.experts[expert_idx] expert_mask = (selected_experts == expert_idx) expert_weights = (routing_weights * expert_mask).sum( dim=-1, keepdim=True) current_hidden_states = expert_layer(hidden_states).mul_( expert_weights) if final_hidden_states is None: final_hidden_states = current_hidden_states else: final_hidden_states.add_(current_hidden_states) else: final_hidden_states = self.linear_method.apply_moe_weights( self.ws.linear_weights, self.w2s.linear_weights, hidden_states, router_logits, self.top_k, renormalize=True, ) if self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) return final_hidden_states.view(num_tokens, hidden_size) class MixtralAttention(nn.Module): def __init__(self, hidden_size: int, num_heads: int, num_kv_heads: int, max_position: int = 4096 * 32, rope_theta: float = 10000, linear_method: Optional[LinearMethodBase] = None, sliding_window: Optional[int] = None) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.total_num_kv_heads = num_kv_heads if self.total_num_kv_heads >= tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. assert self.total_num_kv_heads % tp_size == 0 else: # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.head_dim = hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta self.sliding_window = sliding_window if (linear_method is not None and not linear_method.quant_config.merge_weight()): self.merge_weight = False self.q_proj = ColumnParallelLinear(hidden_size, self.total_num_heads * self.head_dim, bias=False, linear_method=linear_method) self.k_proj = ColumnParallelLinear(hidden_size, self.total_num_kv_heads * self.head_dim, bias=False, linear_method=linear_method) self.v_proj = ColumnParallelLinear(hidden_size, self.total_num_kv_heads * self.head_dim, bias=False, linear_method=linear_method) else: self.merge_weight = True self.qkv_proj = QKVParallelLinear( hidden_size, self.head_dim, self.total_num_heads, self.total_num_kv_heads, bias=False, linear_method=linear_method, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, linear_method=linear_method, ) self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, max_position=max_position, base=int(self.rope_theta), is_neox_style=True, ) self.attn = Attention( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, sliding_window=self.sliding_window, ) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: if self.merge_weight: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) else: q, _ = self.q_proj(hidden_states) k, _ = self.k_proj(hidden_states) v, _ = self.v_proj(hidden_states) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.o_proj(attn_output) return output class MixtralDecoderLayer(nn.Module): def __init__( self, config: MixtralConfig, linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 10000) self.self_attn = MixtralAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, max_position=config.max_position_embeddings, num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, sliding_window=config.sliding_window, linear_method=linear_method) self.block_sparse_moe = MixtralMoE( num_experts=config.num_local_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, linear_method=linear_method) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> torch.Tensor: # Self Attention if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm( hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, attn_metadata=attn_metadata, ) # Fully Connected hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) hidden_states = self.block_sparse_moe(hidden_states) return hidden_states, residual class MixtralModel(nn.Module): def __init__( self, config: MixtralConfig, linear_method: Optional[LinearMethodBase] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.padding_idx = config.pad_token_id lora_vocab = (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) if lora_config else 0 self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, linear_method=linear_method, org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ MixtralDecoderLayer(config, linear_method=linear_method) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) residual = None for i in range(len(self.layers)): layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states, kv_caches[i], attn_metadata, residual) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states class MixtralForCausalLM(nn.Module): packed_modules_mapping = { "qkv_proj": [ "q_proj", "k_proj", "v_proj", ], } # LoRA specific attributes supported_lora_modules = [ "qkv_proj", "o_proj", "embed_tokens", "lm_head", ] embedding_modules = { "embed_tokens": "input_embeddings", "lm_head": "output_embeddings", } embedding_padding_modules = ["lm_head"] def __init__( self, config: MixtralConfig, linear_method: Optional[LinearMethodBase] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config self.linear_method = linear_method self.model = MixtralModel(config, linear_method, lora_config=lora_config) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, linear_method=linear_method, org_num_embeddings=config.vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility if not lora_config else lora_config.lora_vocab_padding_size, ) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) self.sampler = Sampler() def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits def sample( self, logits: Optional[torch.Tensor], sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ] if (self.linear_method is not None and not self.linear_method.quant_config.merge_weight()): stacked_params_mapping = [] expert_params_mapping = [ # (param_name, weight_name, shard_id, expert_id) ("ws" if weight_name in ["w1", "w3"] else "w2s", f"experts.{expert_id}.{weight_name}", shard_id, expert_id) for expert_id in range(self.config.num_local_experts) for weight_name, shard_id in [("w1", 0), ("w3", 1), ("w2", None)] ] if self.linear_method is None or ( self.linear_method.quant_config.support_fused_moe()) else [] params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision, self.config, fall_back_to_pt=False): if "rotary_emb.inv_freq" in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: for (param_name, weight_name, shard_id, expert_id) in expert_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader if shard_id is None: weight_loader(param, loaded_weight, expert_id=expert_id) else: weight_loader(param, loaded_weight, shard_id, expert_id=expert_id) break else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue # Skip experts that are not assigned to this worker. if ("block_sparse_moe.experts." in name and name not in params_dict): continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight)