123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419 |
- # coding=utf-8
- # Copyright 2023 The PygmalionAI team.
- # Copyright 2023 The vLLM team.
- # Copyright (c) Google Inc.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Inference-only Gemma model compatible with HuggingFace weights."""
- from typing import List, Optional, Tuple
- import torch
- from torch import nn
- from transformers import GemmaConfig
- from aphrodite.modeling.metadata import InputMetadata
- from aphrodite.modeling.layers.activation import GeluAndMul
- from aphrodite.modeling.layers.attention import PagedAttention
- from aphrodite.modeling.layers.layernorm import RMSNorm
- from aphrodite.modeling.layers.linear import (
- LinearMethodBase,
- MergedColumnParallelLinear,
- QKVParallelLinear,
- RowParallelLinear,
- ColumnParallelLinear,
- )
- from aphrodite.modeling.layers.rotary_embedding import get_rope
- from aphrodite.modeling.layers.sampler import Sampler, QuantSampler
- from aphrodite.modeling.layers.vocab_parallel_embedding import (
- VocabParallelEmbedding,
- ParallelLMHead,
- )
- from aphrodite.modeling.megatron.parallel_state import (
- get_tensor_model_parallel_world_size, )
- from aphrodite.modeling.sampling_metadata import SamplingMetadata
- from aphrodite.modeling.hf_downloader import (
- default_weight_loader,
- hf_model_weights_iterator,
- )
- from aphrodite.common.sequence import SamplerOutput
- KVCache = Tuple[torch.Tensor, torch.Tensor]
- class GemmaMLP(nn.Module):
- def __init__(
- self,
- hidden_size: int,
- intermediate_size: int,
- linear_method: Optional[LinearMethodBase] = None,
- ) -> None:
- super().__init__()
- if (linear_method is not None
- and not linear_method.quant_config.merge_weight()):
- self.merge_weight = False
- self.gate_proj = ColumnParallelLinear(
- hidden_size,
- intermediate_size,
- bias=False,
- linear_method=linear_method,
- )
- self.up_proj = ColumnParallelLinear(
- hidden_size,
- intermediate_size,
- bias=False,
- linear_method=linear_method,
- )
- else:
- self.merge_weight = True
- self.gate_up_proj = MergedColumnParallelLinear(
- hidden_size,
- [intermediate_size] * 2,
- bias=False,
- linear_method=linear_method,
- )
- self.down_proj = RowParallelLinear(
- intermediate_size,
- hidden_size,
- bias=False,
- linear_method=linear_method,
- )
- self.act_fn = GeluAndMul()
- def forward(self, x):
- if self.merge_weight:
- gate_up, _ = self.gate_up_proj(x)
- else:
- up, _ = self.up_proj(x)
- gate, _ = self.gate_proj(x)
- gate_up = torch.cat([gate, up], dim=-1)
- x = self.act_fn(gate_up)
- x, _ = self.down_proj(x)
- return x
- class GemmaAttention(nn.Module):
- def __init__(
- self,
- hidden_size: int,
- num_heads: int,
- num_kv_heads: int,
- head_dim: int,
- max_position_embeddings: int = 8192,
- rope_theta: float = 10000,
- linear_method: Optional[LinearMethodBase] = None,
- ) -> None:
- super().__init__()
- self.hidden_size = hidden_size
- tp_size = get_tensor_model_parallel_world_size()
- self.total_num_heads = num_heads
- assert self.total_num_heads % tp_size == 0
- self.num_heads = self.total_num_heads // tp_size
- self.total_num_kv_heads = num_kv_heads
- if self.total_num_kv_heads >= tp_size:
- # Number of KV heads is greater than TP size, so we partition
- # the KV heads across multiple tensor parallel GPUs.
- assert self.total_num_kv_heads % tp_size == 0
- else:
- # Number of KV heads is less than TP size, so we replicate
- # the KV heads across multiple tensor parallel GPUs.
- assert tp_size % self.total_num_kv_heads == 0
- self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
- self.head_dim = head_dim
- self.q_size = self.num_heads * self.head_dim
- self.kv_size = self.num_kv_heads * self.head_dim
- self.scaling = self.head_dim**-0.5
- self.rope_theta = rope_theta
- if (linear_method is not None
- and not linear_method.quant_config.merge_weight()):
- self.merge_weight = False
- self.q_proj = ColumnParallelLinear(
- hidden_size,
- self.q_size,
- bias=False,
- linear_method=linear_method,
- )
- self.k_proj = ColumnParallelLinear(
- hidden_size,
- self.kv_size,
- bias=False,
- linear_method=linear_method,
- )
- self.v_proj = ColumnParallelLinear(
- hidden_size,
- self.kv_size,
- bias=False,
- linear_method=linear_method,
- )
- else:
- self.merge_weight = True
- self.qkv_proj = QKVParallelLinear(
- hidden_size,
- self.head_dim,
- self.total_num_heads,
- self.total_num_kv_heads,
- bias=False,
- linear_method=linear_method,
- )
- self.o_proj = RowParallelLinear(
- self.total_num_heads * self.head_dim,
- hidden_size,
- bias=False,
- linear_method=linear_method,
- )
- self.rotary_emb = get_rope(
- self.head_dim,
- rotary_dim=self.head_dim,
- max_position=max_position_embeddings,
- base=self.rope_theta,
- is_neox_style=True,
- )
- self.attn = PagedAttention(
- self.num_heads,
- self.head_dim,
- self.scaling,
- num_kv_heads=self.num_kv_heads,
- )
- def forward(
- self,
- positions: torch.Tensor,
- hidden_states: torch.Tensor,
- kv_cache: KVCache,
- input_metadata: InputMetadata,
- ) -> torch.Tensor:
- if self.merge_weight:
- qkv, _ = self.qkv_proj(hidden_states)
- q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
- dim=-1)
- else:
- q, _ = self.q_proj(hidden_states)
- k, _ = self.k_proj(hidden_states)
- v, _ = self.v_proj(hidden_states)
- q, k = self.rotary_emb(positions, q, k)
- k_cache, v_cache = kv_cache
- attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
- output, _ = self.o_proj(attn_output)
- return output
- class GemmaDecoderLayer(nn.Module):
- def __init__(
- self,
- config: GemmaConfig,
- linear_method: Optional[LinearMethodBase] = None,
- ) -> None:
- super().__init__()
- self.hidden_size = config.hidden_size
- self.self_attn = GemmaAttention(
- hidden_size=self.hidden_size,
- num_heads=config.num_attention_heads,
- num_kv_heads=config.num_key_value_heads,
- head_dim=config.head_dim,
- max_position_embeddings=config.max_position_embeddings,
- rope_theta=config.rope_theta,
- linear_method=linear_method,
- )
- self.mlp = GemmaMLP(
- hidden_size=self.hidden_size,
- intermediate_size=config.intermediate_size,
- linear_method=linear_method,
- )
- self.input_layernorm = RMSNorm(config.hidden_size,
- eps=config.rms_norm_eps)
- self.post_attention_layernorm = RMSNorm(config.hidden_size,
- eps=config.rms_norm_eps)
- def forward(
- self,
- positions: torch.Tensor,
- hidden_states: torch.Tensor,
- kv_cache: KVCache,
- input_metadata: InputMetadata,
- residual: Optional[torch.Tensor],
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- # Self Attention
- if residual is None:
- residual = hidden_states
- hidden_states = self.input_layernorm(hidden_states)
- else:
- hidden_states, residual = self.input_layernorm(
- hidden_states, residual)
- hidden_states = self.self_attn(
- positions=positions,
- hidden_states=hidden_states,
- kv_cache=kv_cache,
- input_metadata=input_metadata,
- )
- # Fully Connected
- hidden_states, residual = self.post_attention_layernorm(
- hidden_states, residual)
- hidden_states = self.mlp(hidden_states)
- return hidden_states, residual
- class GemmaModel(nn.Module):
- def __init__(
- self,
- config: GemmaConfig,
- linear_method: Optional[LinearMethodBase] = None,
- ) -> None:
- super().__init__()
- self.config = config
- self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
- config.hidden_size,
- linear_method=linear_method)
- self.layers = nn.ModuleList([
- GemmaDecoderLayer(config, linear_method)
- for _ in range(config.num_hidden_layers)
- ])
- self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- kv_caches: List[KVCache],
- input_metadata: InputMetadata,
- ) -> torch.Tensor:
- hidden_states = self.embed_tokens(input_ids)
- # Normalize the embedding by sqrt(hidden_size)
- hidden_states *= self.config.hidden_size**0.5
- residual = None
- for i in range(len(self.layers)):
- layer = self.layers[i]
- hidden_states, residual = layer(
- positions,
- hidden_states,
- kv_caches[i],
- input_metadata,
- residual,
- )
- hidden_states, _ = self.norm(hidden_states, residual)
- return hidden_states
- class GemmaForCausalLM(nn.Module):
- def __init__(
- self,
- config: GemmaConfig,
- linear_method: Optional[LinearMethodBase] = None,
- ) -> None:
- super().__init__()
- self.config = config
- self.linear_method = linear_method
- self.model = GemmaModel(config, linear_method)
- self.lm_head = ParallelLMHead(config.vocab_size,
- config.hidden_size,
- linear_method=linear_method)
- self.sampler = Sampler(config.vocab_size)
- self.quant_sampler = QuantSampler(config.vocab_size)
- @torch.no_grad()
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- kv_caches: List[KVCache],
- input_metadata: InputMetadata,
- ) -> torch.Tensor:
- hidden_states = self.model(input_ids, positions, kv_caches,
- input_metadata)
- return hidden_states
- def sample(
- self,
- hidden_states: torch.Tensor,
- sampling_metadata: SamplingMetadata,
- ) -> Optional[SamplerOutput]:
- if (self.linear_method is not None
- and not self.linear_method.quant_config.merge_weight()):
- next_tokens = self.quant_sampler(self.lm_head(hidden_states),
- sampling_metadata)
- else:
- next_tokens = self.sampler(self.lm_head.weight, hidden_states,
- sampling_metadata)
- return next_tokens
- def load_weights(
- self,
- model_name_or_path: str,
- cache_dir: Optional[str] = None,
- load_format: str = "auto",
- revision: Optional[str] = None,
- ):
- stacked_params_mapping = [
- # (param_name, shard_name, shard_id)
- ("qkv_proj", "q_proj", "q"),
- ("qkv_proj", "k_proj", "k"),
- ("qkv_proj", "v_proj", "v"),
- ("gate_up_proj", "gate_proj", 0),
- ("gate_up_proj", "up_proj", 1),
- ]
- if (self.linear_method is not None
- and not self.linear_method.quant_config.merge_weight()):
- stacked_params_mapping = []
- params_dict = dict(self.named_parameters())
- loaded_params = set()
- for name, loaded_weight in hf_model_weights_iterator(
- model_name_or_path, cache_dir, load_format, revision,
- self.config):
- if "rotary_emb.inv_freq" in name:
- continue
- if "embed_tokens" in name:
- # Copy word embedding to lm_head
- head_name = name.replace("model.embed_tokens", "lm_head")
- if head_name in params_dict:
- loaded_params.add(head_name)
- lm_head_param = params_dict[head_name]
- weight_loader = getattr(lm_head_param, "weight_loader",
- default_weight_loader)
- weight_loader(lm_head_param, loaded_weight)
- for param_name, weight_name, shard_id in stacked_params_mapping:
- if weight_name not in name:
- continue
- name = name.replace(weight_name, param_name)
- # Skip loading extra bias for GPTQ models.
- if name.endswith(".bias") and name not in params_dict:
- continue
- param = params_dict[name]
- weight_loader = param.weight_loader
- weight_loader(param, loaded_weight, shard_id)
- break
- else:
- # Skip loading extra layer for lora models.
- if "lm_head" in name and name not in params_dict:
- continue
- # GemmaRMSNorm is different from Llama's in that it multiplies
- # (1 + weight) to the output, instead of just weight.
- if "norm.weight" in name:
- loaded_weight += 1.0
- param = params_dict[name]
- weight_loader = getattr(param, "weight_loader",
- default_weight_loader)
- weight_loader(param, loaded_weight)
- loaded_params.add(name)
- unloaded_params = params_dict.keys() - loaded_params
- if unloaded_params:
- raise RuntimeError(
- "Some weights are not initialized from checkpoints: "
- f"{unloaded_params}")
|