123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339 |
- # coding=utf-8
- # Adapted from
- # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
- # Copyright 2023 The vLLM team.
- # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
- #
- # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
- # and OPT implementations in this library. It has been modified from its
- # original forms to accommodate minor architectural differences compared
- # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Inference-only LLaMA model compatible with HuggingFace weights.
- The input of the model is flattened to a 1D tensor of tokens. The model uses
- InputMetadata to extract the original 2D shape of the input.
- """
- from typing import List, Optional, Tuple
- import torch
- from torch import nn
- from transformers import LlamaConfig
- from aphrodite.modeling.metadata import InputMetadata
- from aphrodite.modeling.layers.activation import SiluAndMul
- from aphrodite.modeling.layers.layernorm import RMSNorm
- from aphrodite.modeling.layers.attention import PagedAttentionWithRoPE
- from aphrodite.modeling.layers.sampler import Sampler
- from aphrodite.modeling.hf_downloader import load_tensor_parallel_weights, load_padded_tensor_parallel_vocab, hf_model_weights_iterator
- from aphrodite.modeling.megatron.parallel_state import get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size
- from aphrodite.modeling.megatron.tensor_parallel import VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear
- from aphrodite.common.sequence import SamplerOutput
- KVCache = Tuple[torch.Tensor, torch.Tensor]
- class LlamaMLP(nn.Module):
- def __init__(
- self,
- hidden_size: int,
- intermediate_size: int,
- hidden_act: str,
- ):
- super().__init__()
- self.gate_up_proj = ColumnParallelLinear(hidden_size,
- 2 * intermediate_size,
- bias=False,
- gather_output=False,
- perform_initialization=False)
- self.down_proj = RowParallelLinear(intermediate_size,
- hidden_size,
- bias=False,
- input_is_parallel=True,
- perform_initialization=False)
- if hidden_act != "silu":
- raise ValueError(f"Unsupported activation: {hidden_act}. "
- "Only silu is supported for now.")
- self.act_fn = SiluAndMul()
- def forward(self, x):
- gate_up, _ = self.gate_up_proj(x)
- x = self.act_fn(gate_up)
- x, _ = self.down_proj(x)
- return x
- class LlamaAttention(nn.Module):
- def __init__(
- self,
- hidden_size: int,
- num_heads: int,
- num_kv_heads: int,
- rope_theta: float = 10000,
- ):
- super().__init__()
- self.hidden_size = hidden_size
- tp_size = get_tensor_model_parallel_world_size()
- self.total_num_heads = num_heads
- assert self.total_num_heads % tp_size == 0
- self.num_heads = self.total_num_heads // tp_size
- self.total_num_kv_heads = num_kv_heads
- assert self.total_num_kv_heads % tp_size == 0
- self.num_kv_heads = self.total_num_kv_heads // tp_size
- self.head_dim = hidden_size // self.total_num_heads
- self.q_size = self.num_heads * self.head_dim
- self.kv_size = self.num_kv_heads * self.head_dim
- self.scaling = self.head_dim**-0.5
- self.rope_theta = rope_theta
- self.qkv_proj = ColumnParallelLinear(
- hidden_size,
- (self.total_num_heads + 2 * self.total_num_kv_heads) *
- self.head_dim,
- bias=False,
- gather_output=False,
- perform_initialization=False,
- )
- self.o_proj = RowParallelLinear(
- self.total_num_heads * self.head_dim,
- hidden_size,
- bias=False,
- input_is_parallel=True,
- perform_initialization=False,
- )
- self.attn = PagedAttentionWithRoPE(self.num_heads,
- self.head_dim,
- self.scaling,
- base=self.rope_theta,
- rotary_dim=self.head_dim,
- num_kv_heads=self.num_kv_heads)
- def forward(
- self,
- positions: torch.Tensor,
- hidden_states: torch.Tensor,
- kv_cache: KVCache,
- input_metadata: InputMetadata,
- cache_event: Optional[torch.cuda.Event],
- ) -> torch.Tensor:
- qkv, _ = self.qkv_proj(hidden_states)
- q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
- k_cache, v_cache = kv_cache
- attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
- input_metadata, cache_event)
- output, _ = self.o_proj(attn_output)
- return output
- class LlamaDecoderLayer(nn.Module):
- def __init__(self, config: LlamaConfig):
- super().__init__()
- self.hidden_size = config.hidden_size
- # Requires transformers > 4.32.0
- rope_theta = getattr(config, "rope_theta", 10000)
- self.self_attn = LlamaAttention(
- hidden_size=self.hidden_size,
- num_heads=config.num_attention_heads,
- num_kv_heads=config.num_key_value_heads,
- rope_theta=rope_theta,
- )
- self.mlp = LlamaMLP(
- hidden_size=self.hidden_size,
- intermediate_size=config.intermediate_size,
- hidden_act=config.hidden_act,
- )
- self.input_layernorm = RMSNorm(config.hidden_size,
- eps=config.rms_norm_eps)
- self.post_attention_layernorm = RMSNorm(config.hidden_size,
- eps=config.rms_norm_eps)
- def forward(
- self,
- positions: torch.Tensor,
- hidden_states: torch.Tensor,
- kv_cache: KVCache,
- input_metadata: InputMetadata,
- cache_event: Optional[torch.cuda.Event],
- ) -> torch.Tensor:
- # Self Attention
- residual = hidden_states
- hidden_states = self.input_layernorm(hidden_states)
- hidden_states = self.self_attn(
- positions=positions,
- hidden_states=hidden_states,
- kv_cache=kv_cache,
- input_metadata=input_metadata,
- cache_event=cache_event,
- )
- hidden_states = residual + hidden_states
- # Fully Connected
- residual = hidden_states
- hidden_states = self.post_attention_layernorm(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = residual + hidden_states
- return hidden_states
- class LlamaModel(nn.Module):
- def __init__(self, config: LlamaConfig):
- super().__init__()
- self.config = config
- self.padding_idx = config.pad_token_id
- self.vocab_size = config.vocab_size
- vocab_size = ((config.vocab_size + 63) // 64) * 64
- self.embed_tokens = VocabParallelEmbedding(
- vocab_size, config.hidden_size, perform_initialization=False)
- self.layers = nn.ModuleList([
- LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)
- ])
- self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- kv_caches: List[KVCache],
- input_metadata: InputMetadata,
- cache_events: Optional[List[torch.cuda.Event]],
- ) -> torch.Tensor:
- hidden_states = self.embed_tokens(input_ids)
- for i in range(len(self.layers)):
- if cache_events is None:
- cache_event = None
- else:
- cache_event = cache_events[i]
- layer = self.layers[i]
- hidden_states = layer(
- positions,
- hidden_states,
- kv_caches[i],
- input_metadata,
- cache_event,
- )
- hidden_states = self.norm(hidden_states)
- return hidden_states
- class LlamaForCausalLM(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.model = LlamaModel(config)
- vocab_size = ((config.vocab_size + 63) // 64) * 64
- self.lm_head = ColumnParallelLinear(config.hidden_size,
- vocab_size,
- bias=False,
- gather_output=False,
- perform_initialization=False)
- self.sampler = Sampler(config.vocab_size)
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- kv_caches: List[KVCache],
- input_metadata: InputMetadata,
- cache_events: Optional[List[torch.cuda.Event]],
- ) -> SamplerOutput:
- hidden_states = self.model(input_ids, positions, kv_caches,
- input_metadata, cache_events)
- next_tokens = self.sampler(self.lm_head.weight, hidden_states,
- input_metadata)
- return next_tokens
- _column_parallel_weights = [
- "qkv_proj.weight", "gate_proj.weight", "up_proj.weight"
- ]
- _row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
- def load_weights(self,
- model_name_or_path: str,
- cache_dir: Optional[str] = None,
- load_format: str = "auto",
- revision: Optional[str] = None):
- tp_size = get_tensor_model_parallel_world_size()
- tensor_model_parallel_rank = get_tensor_model_parallel_rank()
- q_proj_shard_size = (self.config.hidden_size // tp_size)
- kv_proj_shard_size = (self.config.hidden_size //
- self.config.num_attention_heads *
- self.config.num_key_value_heads // tp_size)
- attention_weight_specs = [
- # (weight_name, shard_size, offset)
- ("q_proj", q_proj_shard_size, 0),
- ("k_proj", kv_proj_shard_size, q_proj_shard_size),
- ("v_proj", kv_proj_shard_size,
- q_proj_shard_size + kv_proj_shard_size),
- ]
- state_dict = self.state_dict()
- for name, loaded_weight in hf_model_weights_iterator(
- model_name_or_path, cache_dir, load_format, revision):
- if "rotary_emb.inv_freq" in name:
- continue
- is_attention_weight = False
- for weight_name, shard_size, offset in attention_weight_specs:
- if weight_name not in name:
- continue
- param = state_dict[name.replace(weight_name, "qkv_proj")]
- loaded_weight = loaded_weight[
- shard_size * tensor_model_parallel_rank:shard_size *
- (tensor_model_parallel_rank + 1)]
- param_slice = param.data[offset:offset + shard_size]
- assert param_slice.shape == loaded_weight.shape
- param_slice.copy_(loaded_weight)
- is_attention_weight = True
- break
- if is_attention_weight:
- continue
- is_gate_up_weight = False
- for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
- if weight_name not in name:
- continue
- param = state_dict[name.replace(weight_name, "gate_up_proj")]
- shard_size = param.shape[0] // 2
- loaded_weight = loaded_weight[
- shard_size * tensor_model_parallel_rank:shard_size *
- (tensor_model_parallel_rank + 1)]
- param_slice = param.data[shard_size * stride_id:shard_size *
- (stride_id + 1)]
- assert param_slice.shape == loaded_weight.shape
- param_slice.copy_(loaded_weight)
- is_gate_up_weight = True
- break
- if is_gate_up_weight:
- continue
- param = state_dict[name]
- if "embed_tokens" in name or "lm_head" in name:
- load_padded_tensor_parallel_vocab(param, loaded_weight,
- tensor_model_parallel_rank)
- continue
- load_tensor_parallel_weights(param, loaded_weight, name,
- self._column_parallel_weights,
- self._row_parallel_weights,
- tensor_model_parallel_rank)
|