123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474 |
- # coding=utf-8
- # Adapted from
- # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
- # Copyright 2023 The vLLM team.
- # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
- #
- # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
- # and OPT implementations in this library. It has been modified from its
- # original forms to accommodate minor architectural differences compared
- # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Inference-only Mixtral model."""
- from typing import Iterable, List, Optional, Tuple
- import torch
- from torch import nn
- from transformers import MixtralConfig
- from aphrodite.attention import Attention, AttentionMetadata
- from aphrodite.common.config import CacheConfig, LoRAConfig
- from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
- from aphrodite.distributed import (get_pp_group,
- get_tensor_model_parallel_world_size)
- from aphrodite.modeling.layers.fused_moe import FusedMoE
- from aphrodite.modeling.layers.layernorm import RMSNorm
- from aphrodite.modeling.layers.linear import (QKVParallelLinear,
- ReplicatedLinear,
- RowParallelLinear)
- from aphrodite.modeling.layers.logits_processor import LogitsProcessor
- from aphrodite.modeling.layers.rotary_embedding import get_rope
- from aphrodite.modeling.layers.sampler import Sampler
- from aphrodite.modeling.layers.vocab_parallel_embedding import (
- DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
- from aphrodite.modeling.model_loader.weight_utils import (
- default_weight_loader, maybe_remap_kv_scale_name)
- from aphrodite.modeling.models.utils import (is_pp_missing_parameter,
- make_layers)
- from aphrodite.modeling.sampling_metadata import SamplingMetadata
- from aphrodite.quantization.base_config import QuantizationConfig
- from .interfaces import SupportsLoRA
- class MixtralMoE(nn.Module):
- """A tensor-parallel MoE implementation for Mixtral that shards each expert
- across all ranks.
- Each expert's weights are sharded across all ranks and a fused MoE
- kernel is used for the forward pass, and finally we reduce the outputs
- across ranks.
- """
- def __init__(self,
- num_experts: int,
- top_k: int,
- hidden_size: int,
- intermediate_size: int,
- params_dtype: Optional[torch.dtype] = None,
- quant_config: Optional[QuantizationConfig] = None,
- tp_size: Optional[int] = None,
- prefix: str = ""):
- super().__init__()
- self.hidden_size = hidden_size
- # Gate always runs at half / full precision for now.
- self.gate = ReplicatedLinear(hidden_size,
- num_experts,
- bias=False,
- params_dtype=params_dtype,
- quant_config=None,
- prefix=f"{prefix}.gate")
- self.experts = FusedMoE(num_experts=num_experts,
- top_k=top_k,
- hidden_size=hidden_size,
- intermediate_size=intermediate_size,
- params_dtype=params_dtype,
- reduce_results=True,
- renormalize=True,
- quant_config=quant_config,
- tp_size=tp_size,
- prefix=f"{prefix}.experts")
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- # NOTE: hidden_states can have either 1D or 2D shape.
- orig_shape = hidden_states.shape
- hidden_states = hidden_states.view(-1, self.hidden_size)
- # router_logits: (num_tokens, n_experts)
- router_logits, _ = self.gate(hidden_states)
- final_hidden_states = self.experts(hidden_states, router_logits)
- return final_hidden_states.view(orig_shape)
- class MixtralAttention(nn.Module):
- def __init__(
- self,
- hidden_size: int,
- num_heads: int,
- num_kv_heads: int,
- max_position: int = 4096 * 32,
- rope_theta: float = 10000,
- cache_config: Optional[CacheConfig] = None,
- quant_config: Optional[QuantizationConfig] = None,
- prefix: str = "",
- ) -> None:
- super().__init__()
- self.hidden_size = hidden_size
- tp_size = get_tensor_model_parallel_world_size()
- self.total_num_heads = num_heads
- assert self.total_num_heads % tp_size == 0
- self.num_heads = self.total_num_heads // tp_size
- self.total_num_kv_heads = num_kv_heads
- if self.total_num_kv_heads >= tp_size:
- # Number of KV heads is greater than TP size, so we partition
- # the KV heads across multiple tensor parallel GPUs.
- assert self.total_num_kv_heads % tp_size == 0
- else:
- # Number of KV heads is less than TP size, so we replicate
- # the KV heads across multiple tensor parallel GPUs.
- assert tp_size % self.total_num_kv_heads == 0
- self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
- self.head_dim = hidden_size // self.total_num_heads
- self.q_size = self.num_heads * self.head_dim
- self.kv_size = self.num_kv_heads * self.head_dim
- self.scaling = self.head_dim**-0.5
- self.rope_theta = rope_theta
- self.qkv_proj = QKVParallelLinear(
- hidden_size,
- self.head_dim,
- self.total_num_heads,
- self.total_num_kv_heads,
- bias=False,
- quant_config=quant_config,
- prefix=f"{prefix}.qkv_proj",
- )
- self.o_proj = RowParallelLinear(
- self.total_num_heads * self.head_dim,
- hidden_size,
- bias=False,
- quant_config=quant_config,
- prefix=f"{prefix}.o_proj",
- )
- self.rotary_emb = get_rope(
- self.head_dim,
- rotary_dim=self.head_dim,
- max_position=max_position,
- base=int(self.rope_theta),
- is_neox_style=True,
- )
- self.attn = Attention(self.num_heads,
- self.head_dim,
- self.scaling,
- num_kv_heads=self.num_kv_heads,
- cache_config=cache_config,
- quant_config=quant_config)
- def forward(
- self,
- positions: torch.Tensor,
- hidden_states: torch.Tensor,
- kv_cache: torch.Tensor,
- attn_metadata: AttentionMetadata,
- ) -> torch.Tensor:
- qkv, _ = self.qkv_proj(hidden_states)
- q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
- q, k = self.rotary_emb(positions, q, k)
- attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
- output, _ = self.o_proj(attn_output)
- return output
- class MixtralDecoderLayer(nn.Module):
- def __init__(
- self,
- config: MixtralConfig,
- cache_config: Optional[CacheConfig] = None,
- quant_config: Optional[QuantizationConfig] = None,
- prefix: str = "",
- ) -> None:
- super().__init__()
- self.hidden_size = config.hidden_size
- # Requires transformers > 4.32.0
- rope_theta = getattr(config, "rope_theta", 10000)
- self.self_attn = MixtralAttention(
- hidden_size=self.hidden_size,
- num_heads=config.num_attention_heads,
- max_position=config.max_position_embeddings,
- num_kv_heads=config.num_key_value_heads,
- rope_theta=rope_theta,
- cache_config=cache_config,
- quant_config=quant_config,
- prefix=f"{prefix}.self_attn")
- self.block_sparse_moe = MixtralMoE(
- num_experts=config.num_local_experts,
- top_k=config.num_experts_per_tok,
- hidden_size=config.hidden_size,
- intermediate_size=config.intermediate_size,
- quant_config=quant_config,
- prefix=f"{prefix}.block_sparse_moe")
- self.input_layernorm = RMSNorm(config.hidden_size,
- eps=config.rms_norm_eps)
- self.post_attention_layernorm = RMSNorm(config.hidden_size,
- eps=config.rms_norm_eps)
- def forward(
- self,
- positions: torch.Tensor,
- hidden_states: torch.Tensor,
- kv_cache: torch.Tensor,
- attn_metadata: AttentionMetadata,
- residual: Optional[torch.Tensor],
- ) -> torch.Tensor:
- # Self Attention
- if residual is None:
- residual = hidden_states
- hidden_states = self.input_layernorm(hidden_states)
- else:
- hidden_states, residual = self.input_layernorm(
- hidden_states, residual)
- hidden_states = self.self_attn(
- positions=positions,
- hidden_states=hidden_states,
- kv_cache=kv_cache,
- attn_metadata=attn_metadata,
- )
- # Fully Connected
- hidden_states, residual = self.post_attention_layernorm(
- hidden_states, residual)
- hidden_states = self.block_sparse_moe(hidden_states)
- return hidden_states, residual
- class MixtralModel(nn.Module):
- def __init__(
- self,
- config: MixtralConfig,
- cache_config: Optional[CacheConfig] = None,
- quant_config: Optional[QuantizationConfig] = None,
- lora_config: Optional[LoRAConfig] = None,
- prefix: str = "",
- ) -> None:
- super().__init__()
- self.padding_idx = config.pad_token_id
- lora_vocab = (lora_config.lora_extra_vocab_size *
- (lora_config.max_loras or 1)) if lora_config else 0
- self.vocab_size = config.vocab_size + lora_vocab
- self.org_vocab_size = config.vocab_size
- self.embed_tokens = VocabParallelEmbedding(
- self.vocab_size,
- config.hidden_size,
- org_num_embeddings=config.vocab_size,
- )
- self.start_layer, self.end_layer, self.layers = make_layers(
- config.num_hidden_layers,
- lambda prefix: MixtralDecoderLayer(
- config, cache_config, quant_config=quant_config, prefix=prefix
- ),
- prefix=f"{prefix}.layers")
- self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- kv_caches: List[torch.Tensor],
- attn_metadata: AttentionMetadata,
- intermediate_tensors: Optional[IntermediateTensors],
- ) -> torch.Tensor:
- if get_pp_group().is_first_rank:
- hidden_states = self.embed_tokens(input_ids)
- residual = None
- else:
- assert intermediate_tensors is not None
- hidden_states = intermediate_tensors["hidden_states"]
- residual = intermediate_tensors["residual"]
- for i in range(self.start_layer, self.end_layer):
- layer = self.layers[i]
- hidden_states, residual = layer(positions, hidden_states,
- kv_caches[i - self.start_layer],
- attn_metadata, residual)
- if not get_pp_group().is_last_rank:
- return IntermediateTensors({
- "hidden_states": hidden_states,
- "residual": residual
- })
- hidden_states, _ = self.norm(hidden_states, residual)
- return hidden_states
- class MixtralForCausalLM(nn.Module, SupportsLoRA):
- fall_back_to_pt_during_load = False
- packed_modules_mapping = {
- "qkv_proj": [
- "q_proj",
- "k_proj",
- "v_proj",
- ],
- }
- # LoRA specific attributes
- supported_lora_modules = [
- "qkv_proj",
- "o_proj",
- "embed_tokens",
- "lm_head",
- ]
- embedding_modules = {
- "embed_tokens": "input_embeddings",
- "lm_head": "output_embeddings",
- }
- embedding_padding_modules = ["lm_head"]
- def __init__(
- self,
- config: MixtralConfig,
- cache_config: Optional[CacheConfig] = None,
- quant_config: Optional[QuantizationConfig] = None,
- lora_config: Optional[LoRAConfig] = None,
- ) -> None:
- super().__init__()
- self.config = config
- self.lora_config = lora_config
- self.model = MixtralModel(config,
- cache_config,
- quant_config,
- lora_config=lora_config,
- prefix="model")
- self.unpadded_vocab_size = config.vocab_size
- if lora_config:
- self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
- self.lm_head = ParallelLMHead(
- self.unpadded_vocab_size,
- config.hidden_size,
- org_num_embeddings=config.vocab_size,
- padding_size=DEFAULT_VOCAB_PADDING_SIZE
- # We need bigger padding if using lora for kernel
- # compatibility
- if not lora_config else lora_config.lora_vocab_padding_size,
- quant_config=quant_config,
- )
- self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
- config.vocab_size)
- self.sampler = Sampler()
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- kv_caches: List[torch.Tensor],
- attn_metadata: AttentionMetadata,
- intermediate_tensors: Optional[IntermediateTensors] = None,
- ) -> torch.Tensor:
- hidden_states = self.model(input_ids, positions, kv_caches,
- attn_metadata, intermediate_tensors)
- return hidden_states
- def compute_logits(
- self,
- hidden_states: torch.Tensor,
- sampling_metadata: SamplingMetadata,
- ) -> Optional[torch.Tensor]:
- logits = self.logits_processor(self.lm_head, hidden_states,
- sampling_metadata)
- return logits
- def make_empty_intermediate_tensors(
- self, batch_size: int, dtype: torch.dtype,
- device: torch.device) -> IntermediateTensors:
- return IntermediateTensors({
- "hidden_states":
- torch.zeros((batch_size, self.config.hidden_size),
- dtype=dtype,
- device=device),
- "residual":
- torch.zeros((batch_size, self.config.hidden_size),
- dtype=dtype,
- device=device),
- })
- def sample(
- self,
- logits: Optional[torch.Tensor],
- sampling_metadata: SamplingMetadata,
- ) -> Optional[SamplerOutput]:
- next_tokens = self.sampler(logits, sampling_metadata)
- return next_tokens
- def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
- stacked_params_mapping = [
- # (param_name, shard_name, shard_id)
- ("qkv_proj", "q_proj", "q"),
- ("qkv_proj", "k_proj", "k"),
- ("qkv_proj", "v_proj", "v"),
- ]
- # Params for weights, fp8 weight scales, fp8 activation scales
- # (param_name, weight_name, expert_id, shard_id)
- expert_params_mapping = FusedMoE.make_expert_params_mapping(
- ckpt_gate_proj_name="w1",
- ckpt_down_proj_name="w2",
- ckpt_up_proj_name="w3",
- num_experts=self.config.num_local_experts)
- params_dict = dict(self.named_parameters())
- for name, loaded_weight in weights:
- if "rotary_emb.inv_freq" in name:
- continue
- for (param_name, weight_name, shard_id) in stacked_params_mapping:
- if weight_name not in name:
- continue
- name = name.replace(weight_name, param_name)
- # Skip loading extra bias for GPTQ models.
- if name.endswith(".bias") and name not in params_dict:
- continue
- # Skip layers on other devices.
- if is_pp_missing_parameter(name, self):
- continue
- param = params_dict[name]
- weight_loader = param.weight_loader
- weight_loader(param, loaded_weight, shard_id)
- break
- else:
- for mapping in expert_params_mapping:
- param_name, weight_name, expert_id, shard_id = mapping
- if weight_name not in name:
- continue
- name = name.replace(weight_name, param_name)
- # Skip layers on other devices.
- if is_pp_missing_parameter(name, self):
- continue
- param = params_dict[name]
- weight_loader = param.weight_loader
- weight_loader(param,
- loaded_weight,
- name,
- shard_id=shard_id,
- expert_id=expert_id)
- break
- else:
- # Skip loading extra bias for GPTQ models.
- if name.endswith(".bias") and name not in params_dict:
- continue
- # Skip layers on other devices.
- if is_pp_missing_parameter(name, self):
- continue
- # Remapping the name of FP8 kv-scale.
- name = maybe_remap_kv_scale_name(name, params_dict)
- if name is None:
- continue
- param = params_dict[name]
- weight_loader = getattr(param, "weight_loader",
- default_weight_loader)
- weight_loader(param, loaded_weight)
|