123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415 |
- # coding=utf-8
- # Adapted from
- # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
- # Copyright 2023 The PygmalionAI team.
- # Copyright 2023 The vLLM team.
- # Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
- # reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Inference-only OPT model compatible with HuggingFace weights."""
- from typing import List, Optional
- import torch
- from torch import nn
- from transformers import OPTConfig
- from aphrodite.attention import Attention, AttentionMetadata
- from aphrodite.modeling.layers.activation import get_act_fn
- from aphrodite.modeling.layers.linear import (
- ColumnParallelLinear,
- LinearMethodBase,
- QKVParallelLinear,
- ReplicatedLinear,
- RowParallelLinear,
- )
- from aphrodite.modeling.layers.logits_processor import LogitsProcessor
- from aphrodite.modeling.layers.sampler import Sampler
- from aphrodite.modeling.layers.vocab_parallel_embedding import (
- VocabParallelEmbedding,
- ParallelLMHead,
- )
- from aphrodite.distributed import (
- get_tensor_model_parallel_world_size, )
- from aphrodite.modeling.sampling_metadata import SamplingMetadata
- from aphrodite.modeling.hf_downloader import (
- default_weight_loader,
- hf_model_weights_iterator,
- )
- from aphrodite.common.sequence import SamplerOutput
- class OPTLearnedPositionalEmbedding(nn.Embedding):
- def __init__(self, num_embeddings: int, embedding_dim: int):
- # OPT is set up so that if padding_idx is specified then offset the
- # embedding ids by 2 and adjust num_embeddings appropriately. Other
- # models don't have this hack
- self.offset = 2
- super().__init__(num_embeddings + self.offset, embedding_dim)
- def forward(self, positions: torch.Tensor):
- return super().forward(positions + self.offset)
- class OPTAttention(nn.Module):
- def __init__(
- self,
- embed_dim: int,
- num_heads: int,
- bias: bool = True,
- linear_method: Optional[LinearMethodBase] = None,
- ) -> None:
- super().__init__()
- self.embed_dim = embed_dim
- tensor_model_parallel_world_size = (
- get_tensor_model_parallel_world_size())
- total_num_heads = num_heads
- assert num_heads % tensor_model_parallel_world_size == 0
- self.num_heads = total_num_heads // tensor_model_parallel_world_size
- self.head_dim = embed_dim // total_num_heads
- self.scaling = self.head_dim**-0.5
- if (linear_method is not None
- and not linear_method.quant_config.merge_weight()):
- self.merge_weight = False
- self.q_proj = ColumnParallelLinear(embed_dim,
- embed_dim,
- bias=bias,
- linear_method=linear_method)
- self.k_proj = ColumnParallelLinear(embed_dim,
- embed_dim,
- bias=bias,
- linear_method=linear_method)
- self.v_proj = ColumnParallelLinear(embed_dim,
- embed_dim,
- bias=bias,
- linear_method=linear_method)
- else:
- self.merge_weight = True
- self.qkv_proj = QKVParallelLinear(
- embed_dim,
- self.head_dim,
- total_num_heads,
- bias=bias,
- linear_method=linear_method,
- )
- self.out_proj = RowParallelLinear(
- embed_dim,
- embed_dim,
- bias=bias,
- linear_method=linear_method,
- )
- self.attn = Attention(self.num_heads,
- self.head_dim,
- scale=self.scaling)
- def forward(
- self,
- hidden_states: torch.Tensor,
- kv_cache: torch.Tensor,
- attn_metadata: AttentionMetadata,
- ) -> torch.Tensor:
- if self.merge_weight:
- qkv, _ = self.qkv_proj(hidden_states)
- q, k, v = qkv.chunk(chunks=3, dim=-1)
- else:
- q, _ = self.q_proj(hidden_states)
- k, _ = self.k_proj(hidden_states)
- v, _ = self.v_proj(hidden_states)
- attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
- output, _ = self.out_proj(attn_output)
- return output
- class OPTDecoderLayer(nn.Module):
- def __init__(
- self,
- config: OPTConfig,
- linear_method: Optional[LinearMethodBase] = None,
- ):
- super().__init__()
- self.config = config
- self.embed_dim = config.hidden_size
- self.self_attn = OPTAttention(
- embed_dim=self.embed_dim,
- num_heads=config.num_attention_heads,
- bias=config.enable_bias,
- linear_method=linear_method,
- )
- self.do_layer_norm_before = config.do_layer_norm_before
- self.self_attn_layer_norm = nn.LayerNorm(
- self.embed_dim,
- elementwise_affine=config.layer_norm_elementwise_affine,
- )
- self.fc1 = ColumnParallelLinear(
- self.embed_dim,
- config.ffn_dim,
- bias=config.enable_bias,
- linear_method=linear_method,
- )
- quant_config = getattr(linear_method, "quant_config", None)
- self.activation_fn = get_act_fn(config.activation_function,
- quant_config, config.ffn_dim)
- self.fc2 = RowParallelLinear(
- config.ffn_dim,
- self.embed_dim,
- bias=config.enable_bias,
- linear_method=linear_method,
- )
- self.final_layer_norm = nn.LayerNorm(
- self.embed_dim,
- elementwise_affine=config.layer_norm_elementwise_affine,
- )
- def forward(
- self,
- hidden_states: torch.Tensor,
- kv_cache: torch.Tensor,
- attn_metadata: AttentionMetadata,
- ) -> torch.Tensor:
- # Self Attention
- residual = hidden_states
- # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
- if self.do_layer_norm_before:
- hidden_states = self.self_attn_layer_norm(hidden_states)
- hidden_states = self.self_attn(
- hidden_states=hidden_states,
- kv_cache=kv_cache,
- attn_metadata=attn_metadata,
- )
- hidden_states = residual + hidden_states
- # 350m applies layer norm AFTER attention
- if not self.do_layer_norm_before:
- hidden_states = self.self_attn_layer_norm(hidden_states)
- # Fully Connected
- residual = hidden_states
- # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
- if self.do_layer_norm_before:
- hidden_states = self.final_layer_norm(hidden_states)
- hidden_states, _ = self.fc1(hidden_states)
- hidden_states = self.activation_fn(hidden_states)
- hidden_states, _ = self.fc2(hidden_states)
- hidden_states = residual + hidden_states
- # 350m applies layer norm AFTER attention
- if not self.do_layer_norm_before:
- hidden_states = self.final_layer_norm(hidden_states)
- return hidden_states
- class OPTDecoder(nn.Module):
- def __init__(
- self,
- config: OPTConfig,
- linear_method: Optional[LinearMethodBase] = None,
- ):
- super().__init__()
- self.config = config
- self.padding_idx = config.pad_token_id
- self.max_target_positions = config.max_position_embeddings
- self.vocab_size = config.vocab_size
- self.embed_tokens = VocabParallelEmbedding(
- config.vocab_size,
- config.word_embed_proj_dim,
- linear_method=linear_method,
- )
- # Positional embeddings are replicated (not sharded).
- self.embed_positions = OPTLearnedPositionalEmbedding(
- config.max_position_embeddings, config.hidden_size)
- # Project out & in will be replicated if they exist.
- if config.word_embed_proj_dim != config.hidden_size:
- self.project_out = ReplicatedLinear(
- config.hidden_size,
- config.word_embed_proj_dim,
- bias=False,
- linear_method=linear_method,
- )
- else:
- self.project_out = None
- if config.word_embed_proj_dim != config.hidden_size:
- self.project_in = ReplicatedLinear(
- config.word_embed_proj_dim,
- config.hidden_size,
- bias=False,
- linear_method=linear_method,
- )
- else:
- self.project_in = None
- # Note that the only purpose of `config._remove_final_layer_norm` is to
- # keep backward compatibility with checkpoints that have been fine-tuned
- # before transformers v4.20.1
- # see https://github.com/facebookresearch/metaseq/pull/164
- if config.do_layer_norm_before and not config._remove_final_layer_norm:
- self.final_layer_norm = nn.LayerNorm(
- config.hidden_size,
- elementwise_affine=config.layer_norm_elementwise_affine,
- )
- else:
- self.final_layer_norm = None
- self.layers = nn.ModuleList([
- OPTDecoderLayer(config, linear_method)
- for _ in range(config.num_hidden_layers)
- ])
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- kv_caches: List[torch.Tensor],
- attn_metadata: AttentionMetadata,
- ) -> torch.Tensor:
- inputs_embeds = self.embed_tokens(input_ids)
- pos_embeds = self.embed_positions(positions)
- if self.project_in is not None:
- inputs_embeds, _ = self.project_in(inputs_embeds)
- hidden_states = inputs_embeds + pos_embeds
- for i in range(len(self.layers)):
- layer = self.layers[i]
- hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
- if self.final_layer_norm is not None:
- hidden_states = self.final_layer_norm(hidden_states)
- if self.project_out is not None:
- hidden_states, _ = self.project_out(hidden_states)
- return hidden_states
- class OPTModel(nn.Module):
- def __init__(
- self,
- config: OPTConfig,
- linear_method: Optional[LinearMethodBase] = None,
- ):
- super().__init__()
- self.decoder = OPTDecoder(config, linear_method)
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- kv_caches: List[torch.Tensor],
- attn_metadata: AttentionMetadata,
- ) -> torch.Tensor:
- return self.decoder(input_ids, positions, kv_caches, attn_metadata)
- class OPTForCausalLM(nn.Module):
- def __init__(
- self,
- config,
- linear_method: Optional[LinearMethodBase] = None,
- ):
- super().__init__()
- self.config = config
- self.linear_method = linear_method
- self.model = OPTModel(config, linear_method)
- self.lm_head = ParallelLMHead(config.vocab_size,
- config.hidden_size,
- linear_method=linear_method)
- self.logits_processor = LogitsProcessor(config.vocab_size)
- self.sampler = Sampler()
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- kv_caches: List[torch.Tensor],
- attn_metadata: AttentionMetadata,
- ) -> torch.Tensor:
- hidden_states = self.model(input_ids, positions, kv_caches,
- attn_metadata)
- return hidden_states
- def compute_logits(self, hidden_states: torch.Tensor,
- sampling_metadata: SamplingMetadata) -> torch.Tensor:
- logits = self.logits_processor(self.lm_head, hidden_states,
- sampling_metadata)
- return logits
- def sample(
- self,
- logits: torch.Tensor,
- sampling_metadata: SamplingMetadata,
- ) -> Optional[SamplerOutput]:
- next_tokens = self.sampler(logits, sampling_metadata)
- return next_tokens
- def load_weights(
- self,
- model_name_or_path: str,
- cache_dir: Optional[str] = None,
- load_format: str = "auto",
- revision: Optional[str] = None,
- ):
- stacked_params_mapping = [
- # (param_name, shard_name, shard_id)
- ("qkv_proj", "q_proj", "q"),
- ("qkv_proj", "k_proj", "k"),
- ("qkv_proj", "v_proj", "v"),
- ]
- if (self.linear_method is not None
- and not self.linear_method.quant_config.merge_weight()):
- stacked_params_mapping = []
- params_dict = dict(self.named_parameters(remove_duplicate=False))
- for name, loaded_weight in hf_model_weights_iterator(
- model_name_or_path, cache_dir, load_format, revision,
- self.config):
- if "lm_head" in name and name not in params_dict:
- continue
- if "embed_tokens" in name:
- # Copy word embedding to lm_head
- if name.startswith("decoder."):
- name = "model." + name
- head_name = name.replace("model.decoder.embed_tokens",
- "lm_head")
- if head_name in params_dict:
- lm_head_param = params_dict[head_name]
- weight_loader = getattr(lm_head_param, "weight_loader",
- default_weight_loader)
- weight_loader(lm_head_param, loaded_weight)
- if name.startswith("decoder."):
- name = "model." + name
- for param_name, weight_name, shard_id in stacked_params_mapping:
- if weight_name not in name:
- continue
- name = name.replace(weight_name, param_name)
- # Skip loading extra bias for GPTQ models.
- if name.endswith(".bias") and name not in params_dict:
- continue
- param = params_dict[name]
- weight_loader = param.weight_loader
- weight_loader(param, loaded_weight, shard_id)
- break
- else:
- # Skip loading extra bias for GPTQ models.
- if name.endswith(".bias") and name not in params_dict:
- continue
- param = params_dict[name]
- weight_loader = getattr(param, "weight_loader",
- default_weight_loader)
- weight_loader(param, loaded_weight)
|