# coding=utf-8 # Adapted from # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py # Copyright 2023 The PygmalionAI team. # Copyright 2023 The vLLM team. # Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-J model compatible with HuggingFace weights.""" from typing import List, Optional, Tuple import torch from torch import nn from transformers import GPTJConfig from aphrodite.modeling.metadata import InputMetadata from aphrodite.modeling.layers.activation import get_act_fn from aphrodite.modeling.layers.attention import PagedAttention from aphrodite.modeling.layers.linear import ( ColumnParallelLinear, LinearMethodBase, QKVParallelLinear, RowParallelLinear, ) from aphrodite.modeling.layers.rotary_embedding import get_rope from aphrodite.modeling.layers.sampler import Sampler, QuantSampler from aphrodite.modeling.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead, ) from aphrodite.modeling.megatron.parallel_state import ( get_tensor_model_parallel_world_size, ) from aphrodite.modeling.sampling_metadata import SamplingMetadata from aphrodite.modeling.hf_downloader import ( default_weight_loader, hf_model_weights_iterator, ) from aphrodite.common.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] class GPTJAttention(nn.Module): def __init__( self, config: GPTJConfig, linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.total_num_heads = config.num_attention_heads self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.total_num_heads if (linear_method is not None and not linear_method.quant_config.merge_weight()): self.merge_weight = False self.q_proj = ColumnParallelLinear( config.hidden_size, config.hidden_size, bias=False, linear_method=linear_method, ) self.k_proj = ColumnParallelLinear( config.hidden_size, config.hidden_size, bias=False, linear_method=linear_method, ) self.v_proj = ColumnParallelLinear( config.hidden_size, config.hidden_size, bias=False, linear_method=linear_method, ) else: self.merge_weight = True self.qkv_proj = QKVParallelLinear( config.hidden_size, self.head_dim, self.total_num_heads, bias=False, linear_method=linear_method, ) self.out_proj = RowParallelLinear( config.hidden_size, config.hidden_size, bias=False, linear_method=linear_method, ) tp_world_size = get_tensor_model_parallel_world_size() assert self.total_num_heads % tp_world_size == 0 self.num_heads = self.total_num_heads // tp_world_size scaling = self.head_size**-0.5 assert getattr(config, "rotary", True) assert config.rotary_dim % 2 == 0 rope_theta = getattr(config, "rope_theta", 10000) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.rotary_emb = get_rope( self.head_size, rotary_dim=config.rotary_dim, max_position=max_position_embeddings, base=rope_theta, is_neox_style=False, ) self.attn = PagedAttention(self.num_heads, self.head_size, scaling) def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, ) -> torch.Tensor: if self.merge_weight: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) else: q, _ = self.q_proj(hidden_states) k, _ = self.k_proj(hidden_states) v, _ = self.v_proj(hidden_states) q, k = self.rotary_emb(position_ids, q, k) k_cache, v_cache = kv_cache attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) attn_output, _ = self.out_proj(attn_output) return attn_output class GPTJMLP(nn.Module): def __init__( self, intermediate_size: int, config: GPTJConfig, linear_method: Optional[LinearMethodBase] = None, ): super().__init__() hidden_size = config.n_embd self.fc_in = ColumnParallelLinear( hidden_size, intermediate_size, linear_method=linear_method, ) self.fc_out = RowParallelLinear( intermediate_size, hidden_size, linear_method=linear_method, ) quant_config = getattr(linear_method, "quant_config", None) self.act = get_act_fn(config.activation_function, quant_config, intermediate_size) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.fc_in(hidden_states) hidden_states = self.act(hidden_states) hidden_states, _ = self.fc_out(hidden_states) return hidden_states class GPTJBlock(nn.Module): def __init__( self, config: GPTJConfig, linear_method: Optional[LinearMethodBase] = None, ): super().__init__() inner_dim = (4 * config.n_embd if config.n_inner is None else config.n_inner) self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.attn = GPTJAttention(config, linear_method) self.mlp = GPTJMLP(inner_dim, config, linear_method) def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, ) -> torch.Tensor: residual = hidden_states hidden_states = self.ln_1(hidden_states) attn_output = self.attn( position_ids=position_ids, hidden_states=hidden_states, kv_cache=kv_cache, input_metadata=input_metadata, ) mlp_output = self.mlp(hidden_states) hidden_states = attn_output + mlp_output + residual return hidden_states class GPTJModel(nn.Module): def __init__( self, config: GPTJConfig, linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.config = config self.embed_dim = config.n_embd self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim, linear_method=linear_method) self.h = nn.ModuleList( [GPTJBlock(config, linear_method) for _ in range(config.n_layer)]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, ) -> torch.Tensor: hidden_states = self.wte(input_ids) for i in range(len(self.h)): layer = self.h[i] hidden_states = layer( position_ids, hidden_states, kv_caches[i], input_metadata, ) hidden_states = self.ln_f(hidden_states) return hidden_states class GPTJForCausalLM(nn.Module): def __init__( self, config: GPTJConfig, linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.config = config self.linear_method = linear_method assert not config.tie_word_embeddings self.transformer = GPTJModel(config, linear_method) self.lm_head = ParallelLMHead( config.vocab_size, config.n_embd, bias=True, linear_method=linear_method, ) self.sampler = Sampler(config.vocab_size) self.quant_sampler = QuantSampler(config.vocab_size) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, input_metadata) return hidden_states def sample( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: if (self.linear_method is not None and not self.linear_method.quant_config.merge_weight()): next_tokens = self.quant_sampler(self.lm_head(hidden_states), sampling_metadata, self.lm_head.bias) else: next_tokens = self.sampler(self.lm_head.weight, hidden_states, sampling_metadata, self.lm_head.bias) return next_tokens def load_weights( self, model_name_or_path: str, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None, ): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] if (self.linear_method is not None and not self.linear_method.quant_config.merge_weight()): stacked_params_mapping = [] params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision, self.config): if "attn.bias" in name or "attn.masked_bias" in name: continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight)