# coding=utf-8 # Adapted from # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py # Copyright 2023 The vLLM team. # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-2 model compatible with HuggingFace weights.""" from typing import Iterable, List, Optional, Tuple import torch from torch import nn from transformers import GPT2Config from aphrodite.attention import Attention, AttentionMetadata from aphrodite.common.config import CacheConfig from aphrodite.common.sequence import IntermediateTensors from aphrodite.distributed import get_tensor_model_parallel_world_size from aphrodite.modeling.layers.activation import get_act_fn from aphrodite.modeling.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from aphrodite.modeling.layers.logits_processor import LogitsProcessor from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput from aphrodite.modeling.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from aphrodite.modeling.model_loader.weight_utils import default_weight_loader from aphrodite.modeling.sampling_metadata import SamplingMetadata from aphrodite.quantization.base_config import QuantizationConfig class GPT2Attention(nn.Module): def __init__( self, config: GPT2Config, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.hidden_size = config.hidden_size total_num_heads = config.num_attention_heads tensor_model_parallel_world_size = ( get_tensor_model_parallel_world_size()) assert total_num_heads % tensor_model_parallel_world_size == 0 self.num_heads = total_num_heads // tensor_model_parallel_world_size self.head_dim = self.hidden_size // total_num_heads self.scale = self.head_dim**-0.5 self.c_attn = QKVParallelLinear( self.hidden_size, self.head_dim, total_num_heads, bias=True, quant_config=quant_config, ) self.c_proj = RowParallelLinear( self.hidden_size, self.hidden_size, bias=True, quant_config=quant_config, ) self.attn = Attention(self.num_heads, self.head_dim, scale=self.scale, cache_config=cache_config, quant_config=quant_config) def forward( self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.c_attn(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output, _ = self.c_proj(attn_output) return attn_output class GPT2MLP(nn.Module): def __init__( self, intermediate_size: int, config: GPT2Config, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.hidden_size self.c_fc = ColumnParallelLinear( hidden_size, intermediate_size, bias=True, quant_config=quant_config, ) self.c_proj = RowParallelLinear( intermediate_size, hidden_size, bias=True, quant_config=quant_config, ) self.act = get_act_fn(config.activation_function, quant_config, intermediate_size) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.c_fc(hidden_states) hidden_states = self.act(hidden_states) hidden_states, _ = self.c_proj(hidden_states) return hidden_states class GPT2Block(nn.Module): def __init__( self, config: GPT2Config, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.hidden_size inner_dim = (config.n_inner if config.n_inner is not None else 4 * hidden_size) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.attn = GPT2Attention(config, cache_config, quant_config) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = GPT2MLP(inner_dim, config, quant_config) def forward( self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: residual = hidden_states hidden_states = self.ln_1(hidden_states) attn_output = self.attn( hidden_states=hidden_states, kv_cache=kv_cache, attn_metadata=attn_metadata, ) # residual connection hidden_states = attn_output + residual residual = hidden_states hidden_states = self.ln_2(hidden_states) feed_forward_hidden_states = self.mlp(hidden_states) # residual connection hidden_states = residual + feed_forward_hidden_states return hidden_states class GPT2Model(nn.Module): def __init__( self, config: GPT2Config, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config assert not config.add_cross_attention assert not config.scale_attn_by_inverse_layer_idx assert not config.reorder_and_upcast_attn self.embed_dim = config.hidden_size self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.h = nn.ModuleList([ GPT2Block(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, ) -> torch.Tensor: inputs_embeds = self.wte(input_ids) position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds for i in range(len(self.h)): layer = self.h[i] hidden_states = layer(hidden_states, kv_caches[i], attn_metadata) hidden_states = self.ln_f(hidden_states) return hidden_states class GPT2LMHeadModel(nn.Module): def __init__( self, config: GPT2Config, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config self.transformer = GPT2Model(config, cache_config, quant_config) self.lm_head = self.transformer.wte self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, attn_metadata) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in weights: if "lm_head.weight" in name: # GPT-2 ties the weights of the embedding layer and the final # linear layer. continue if ".attn.bias" in name or ".attn.masked_bias" in name: # Skip attention mask. # NOTE: "c_attn.bias" should not be skipped. continue if not name.startswith("transformer."): name = "transformer." + name param = params_dict[name] # The HF's GPT-2 implementation uses Conv1D instead of Linear. # Because of this, we need to transpose the weights. # Note(zhuohan): the logic below might break quantized models. for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: if conv1d_weight_name not in name: continue if not name.endswith(".weight"): continue loaded_weight = loaded_weight.t() weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight)