123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419 |
- # coding=utf-8
- # Copyright 2024 The PygmalionAI team.
- # Copyright 2024 The vLLM team.
- # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
- #
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from typing import Iterable, List, Optional, Set, Tuple
- import torch
- from loguru import logger
- from torch import nn
- from transformers import Gemma2Config
- from aphrodite.attention import Attention, AttentionMetadata
- from aphrodite.common.config import CacheConfig, LoRAConfig
- from aphrodite.common.sequence import IntermediateTensors
- from aphrodite.distributed import get_tensor_model_parallel_world_size
- from aphrodite.modeling.layers.activation import GeluAndMul
- from aphrodite.modeling.layers.layernorm import GemmaRMSNorm
- from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
- QKVParallelLinear,
- RowParallelLinear)
- from aphrodite.modeling.layers.logits_processor import LogitsProcessor
- from aphrodite.modeling.layers.rotary_embedding import GemmaRotaryEmbedding
- from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
- from aphrodite.modeling.layers.vocab_parallel_embedding import (
- VocabParallelEmbedding)
- from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
- from aphrodite.modeling.sampling_metadata import SamplingMetadata
- from aphrodite.quantization.base_config import QuantizationConfig
- from .interfaces import SupportsLoRA
- class Gemma2MLP(nn.Module):
- def __init__(
- self,
- hidden_size: int,
- intermediate_size: int,
- hidden_act: str,
- hidden_activation: str,
- quant_config: Optional[QuantizationConfig] = None,
- ) -> None:
- super().__init__()
- self.gate_up_proj = MergedColumnParallelLinear(
- hidden_size, [intermediate_size] * 2,
- bias=False,
- quant_config=quant_config)
- self.down_proj = RowParallelLinear(intermediate_size,
- hidden_size,
- bias=False,
- quant_config=quant_config)
- if not (hidden_act == hidden_activation == "gelu_pytorch_tanh"):
- raise ValueError(
- "Gemma2 uses `gelu_pytorch_tanh` as the hidden activation "
- "function. Please set `hidden_act` and `hidden_activation` to "
- "`gelu_pytorch_tanh`.")
- self.act_fn = GeluAndMul(approximate="tanh")
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- gate_up, _ = self.gate_up_proj(x)
- x = self.act_fn(gate_up)
- x, _ = self.down_proj(x)
- return x
- class Gemma2Attention(nn.Module):
- def __init__(self,
- layer_idx: int,
- config: Gemma2Config,
- hidden_size: int,
- num_heads: int,
- num_kv_heads: int,
- head_dim: int,
- max_position_embeddings: int,
- rope_theta: float,
- cache_config: Optional[CacheConfig] = None,
- quant_config: Optional[QuantizationConfig] = None,
- attn_logits_soft_cap: Optional[float] = None) -> None:
- super().__init__()
- self.layer_idx = layer_idx
- self.config = config
- self.hidden_size = hidden_size
- tp_size = get_tensor_model_parallel_world_size()
- self.total_num_heads = num_heads
- assert self.total_num_heads % tp_size == 0
- self.num_heads = self.total_num_heads // tp_size
- self.total_num_kv_heads = num_kv_heads
- if self.total_num_kv_heads >= tp_size:
- # Number of KV heads is greater than TP size, so we partition
- # the KV heads across multiple tensor parallel GPUs.
- assert self.total_num_kv_heads % tp_size == 0
- else:
- # Number of KV heads is less than TP size, so we replicate
- # the KV heads across multiple tensor parallel GPUs.
- assert tp_size % self.total_num_kv_heads == 0
- self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
- self.head_dim = head_dim
- self.q_size = self.num_heads * self.head_dim
- self.kv_size = self.num_kv_heads * self.head_dim
- self.scaling = config.query_pre_attn_scalar**-0.5
- self.rope_theta = rope_theta
- self.qkv_proj = QKVParallelLinear(
- hidden_size,
- self.head_dim,
- self.total_num_heads,
- self.total_num_kv_heads,
- bias=config.attention_bias,
- quant_config=quant_config,
- )
- self.o_proj = RowParallelLinear(
- self.total_num_heads * self.head_dim,
- hidden_size,
- bias=config.attention_bias,
- quant_config=quant_config,
- )
- # TODO: Use the `get_rope` interface.
- self.rotary_emb = GemmaRotaryEmbedding(
- self.head_dim,
- self.head_dim,
- max_position_embeddings,
- base=self.rope_theta,
- is_neox_style=True,
- dtype=torch.get_default_dtype(),
- )
- # FIXME: While Gemma 2 uses sliding window attention for every
- # odd layer, Aphrodite currently ignores it and uses global attention
- # for all layers.
- use_sliding_window = (layer_idx % 2 == 1
- and config.sliding_window is not None)
- del use_sliding_window # Unused.
- self.attn = Attention(self.num_heads,
- self.head_dim,
- self.scaling,
- num_kv_heads=self.num_kv_heads,
- cache_config=cache_config,
- quant_config=quant_config,
- logits_soft_cap=attn_logits_soft_cap)
- def forward(
- self,
- positions: torch.Tensor,
- hidden_states: torch.Tensor,
- kv_cache: torch.Tensor,
- attn_metadata: AttentionMetadata,
- ) -> torch.Tensor:
- qkv, _ = self.qkv_proj(hidden_states)
- q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
- q, k = self.rotary_emb(positions, q, k)
- attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
- output, _ = self.o_proj(attn_output)
- return output
- class Gemma2DecoderLayer(nn.Module):
- def __init__(
- self,
- layer_idx: int,
- config: Gemma2Config,
- cache_config: Optional[CacheConfig] = None,
- quant_config: Optional[QuantizationConfig] = None,
- ) -> None:
- super().__init__()
- self.hidden_size = config.hidden_size
- self.self_attn = Gemma2Attention(
- layer_idx=layer_idx,
- config=config,
- hidden_size=self.hidden_size,
- num_heads=config.num_attention_heads,
- num_kv_heads=config.num_key_value_heads,
- head_dim=config.head_dim,
- max_position_embeddings=config.max_position_embeddings,
- rope_theta=config.rope_theta,
- cache_config=cache_config,
- quant_config=quant_config,
- attn_logits_soft_cap=config.attn_logit_softcapping,
- )
- self.hidden_size = config.hidden_size
- self.mlp = Gemma2MLP(
- hidden_size=self.hidden_size,
- intermediate_size=config.intermediate_size,
- hidden_act=config.hidden_act,
- hidden_activation=config.hidden_activation,
- quant_config=quant_config,
- )
- self.input_layernorm = GemmaRMSNorm(config.hidden_size,
- eps=config.rms_norm_eps)
- self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size,
- eps=config.rms_norm_eps)
- self.pre_feedforward_layernorm = GemmaRMSNorm(config.hidden_size,
- eps=config.rms_norm_eps)
- self.post_feedforward_layernorm = GemmaRMSNorm(config.hidden_size,
- eps=config.rms_norm_eps)
- def forward(
- self,
- positions: torch.Tensor,
- hidden_states: torch.Tensor,
- kv_cache: torch.Tensor,
- attn_metadata: AttentionMetadata,
- residual: Optional[torch.Tensor],
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- if residual is None:
- residual = hidden_states
- hidden_states = self.input_layernorm(hidden_states)
- else:
- hidden_states, residual = self.input_layernorm(
- hidden_states, residual)
- hidden_states = self.self_attn(
- positions=positions,
- hidden_states=hidden_states,
- kv_cache=kv_cache,
- attn_metadata=attn_metadata,
- )
- hidden_states = self.post_attention_layernorm(hidden_states)
- hidden_states, residual = self.pre_feedforward_layernorm(
- hidden_states, residual)
- hidden_states = self.mlp(hidden_states)
- hidden_states = self.post_feedforward_layernorm(hidden_states)
- return hidden_states, residual
- class Gemma2Model(nn.Module):
- def __init__(
- self,
- config: Gemma2Config,
- cache_config: Optional[CacheConfig] = None,
- quant_config: Optional[QuantizationConfig] = None,
- ) -> None:
- super().__init__()
- self.config = config
- self.embed_tokens = VocabParallelEmbedding(
- config.vocab_size,
- config.hidden_size,
- )
- self.layers = nn.ModuleList([
- Gemma2DecoderLayer(layer_idx, config, cache_config, quant_config)
- for layer_idx in range(config.num_hidden_layers)
- ])
- self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- # Normalize the embedding by sqrt(hidden_size)
- # The normalizer's data type should be downcasted to the model's
- # data type such as bfloat16, not float32.
- # See https://github.com/huggingface/transformers/pull/29402
- normalizer = self.config.hidden_size**0.5
- self.register_buffer("normalizer", torch.tensor(normalizer))
- def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
- return self.embed_tokens(input_ids)
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- kv_caches: List[torch.Tensor],
- attn_metadata: AttentionMetadata,
- intermediate_tensors: Optional[IntermediateTensors] = None,
- inputs_embeds: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- if inputs_embeds is not None:
- hidden_states = inputs_embeds
- else:
- hidden_states = self.get_input_embeddings(input_ids)
- hidden_states *= self.normalizer
- residual = None
- for i in range(len(self.layers)):
- layer = self.layers[i]
- hidden_states, residual = layer(
- positions,
- hidden_states,
- kv_caches[i],
- attn_metadata,
- residual,
- )
- hidden_states, _ = self.norm(hidden_states, residual)
- return hidden_states
- class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
- packed_modules_mapping = {
- "qkv_proj": [
- "q_proj",
- "k_proj",
- "v_proj",
- ],
- "gate_up_proj": [
- "gate_proj",
- "up_proj",
- ],
- }
- # LoRA specific attributes
- supported_lora_modules = [
- "qkv_proj",
- "o_proj",
- "gate_up_proj",
- "down_proj",
- ]
- # Gemma does not apply LoRA to the embedding layer.
- embedding_modules = {}
- embedding_padding_modules = []
- bitsandbytes_stacked_params_mapping = {
- # shard_name, weight_name, index
- "q_proj": ("qkv_proj", 0),
- "k_proj": ("qkv_proj", 1),
- "v_proj": ("qkv_proj", 2),
- "gate_proj": ("gate_up_proj", 0),
- "up_proj": ("gate_up_proj", 1),
- }
- def __init__(
- self,
- config: Gemma2Config,
- cache_config: Optional[CacheConfig] = None,
- quant_config: Optional[QuantizationConfig] = None,
- lora_config: Optional[LoRAConfig] = None,
- ) -> None:
- del lora_config # Unused.
- super().__init__()
- self.config = config
- self.quant_config = quant_config
- self.model = Gemma2Model(config, cache_config, quant_config)
- self.logits_processor = LogitsProcessor(
- config.vocab_size, soft_cap=config.final_logit_softcapping)
- self.sampler = Sampler()
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- kv_caches: List[torch.Tensor],
- attn_metadata: AttentionMetadata,
- intermediate_tensors: Optional[IntermediateTensors] = None,
- ) -> torch.Tensor:
- hidden_states = self.model(input_ids, positions, kv_caches,
- attn_metadata)
- return hidden_states
- def compute_logits(
- self,
- hidden_states: torch.Tensor,
- sampling_metadata: SamplingMetadata,
- ) -> Optional[torch.Tensor]:
- logits = self.logits_processor(self.model.embed_tokens, hidden_states,
- sampling_metadata)
- return logits
- def sample(
- self,
- logits: torch.Tensor,
- sampling_metadata: SamplingMetadata,
- ) -> Optional[SamplerOutput]:
- next_tokens = self.sampler(logits, sampling_metadata)
- return next_tokens
- def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
- stacked_params_mapping = [
- # (param_name, shard_name, shard_id)
- ("qkv_proj", "q_proj", "q"),
- ("qkv_proj", "k_proj", "k"),
- ("qkv_proj", "v_proj", "v"),
- ("gate_up_proj", "gate_proj", 0),
- ("gate_up_proj", "up_proj", 1),
- ]
- params_dict = dict(self.named_parameters())
- loaded_params: Set[str] = set()
- for name, loaded_weight in weights:
- for (param_name, shard_name, shard_id) in stacked_params_mapping:
- if shard_name not in name:
- continue
- name = name.replace(shard_name, param_name)
- # Skip loading extra bias for GPTQ models.
- if name.endswith(".bias") and name not in params_dict:
- continue
- param = params_dict[name]
- weight_loader = param.weight_loader
- weight_loader(param, loaded_weight, shard_id)
- break
- else:
- # lm_head is not used in Aphrodite as it is tied with
- # embed_token.
- # To prevent errors, skip loading lm_head.weight.
- if "lm_head.weight" in name:
- continue
- # Skip loading extra bias for GPTQ models.
- if name.endswith(".bias") and name not in params_dict:
- continue
- param = params_dict[name]
- weight_loader = getattr(param, "weight_loader",
- default_weight_loader)
- weight_loader(param, loaded_weight)
- loaded_params.add(name)
- unloaded_params = params_dict.keys() - loaded_params
- if unloaded_params:
- logger.warning(
- "Some weights are not initialized from checkpoints: "
- f"{unloaded_params}")
|