123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371 |
- # coding=utf-8
- # Adapted from
- # https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/olmo/modeling_olmo.py
- # Copyright 2024 The vLLM team.
- # Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
- #
- # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
- # and OPT implementations in this library. It has been modified from its
- # original forms to accommodate minor architectural differences compared
- # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Inference-only OLMo model compatible with HuggingFace weights."""
- from typing import Iterable, List, Optional, Tuple
- import torch
- from torch import nn
- from transformers import OlmoConfig
- from aphrodite.attention import Attention, AttentionMetadata
- from aphrodite.common.config import CacheConfig
- from aphrodite.common.sequence import IntermediateTensors
- from aphrodite.distributed import get_tensor_model_parallel_world_size
- from aphrodite.modeling.layers.activation import SiluAndMul
- from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
- QKVParallelLinear,
- RowParallelLinear)
- from aphrodite.modeling.layers.logits_processor import LogitsProcessor
- from aphrodite.modeling.layers.rotary_embedding import get_rope
- from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
- from aphrodite.modeling.layers.vocab_parallel_embedding import (
- ParallelLMHead, VocabParallelEmbedding)
- from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
- from aphrodite.modeling.sampling_metadata import SamplingMetadata
- from aphrodite.quantization.base_config import QuantizationConfig
- class OlmoAttention(nn.Module):
- """
- This is the attention block where the output is computed as
- ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
- (plus another skip connection).
- """
- def __init__(
- self,
- config: OlmoConfig,
- cache_config: Optional[CacheConfig] = None,
- quant_config: Optional[QuantizationConfig] = None,
- ):
- super().__init__()
- self.config = config
- self.hidden_size = config.hidden_size
- tensor_model_parallel_world_size = (
- get_tensor_model_parallel_world_size())
- self.total_num_heads = config.num_attention_heads
- assert self.hidden_size % self.total_num_heads == 0
- assert self.total_num_heads % tensor_model_parallel_world_size == 0
- self.num_heads = (self.total_num_heads //
- tensor_model_parallel_world_size)
- self.head_dim = self.hidden_size // self.total_num_heads
- self.max_position_embeddings = config.max_position_embeddings
- self.rope_theta = config.rope_theta
- self.clip_qkv = config.clip_qkv
- # Attention input projection. Projects x -> (q, k, v)
- self.qkv_proj = QKVParallelLinear(
- self.hidden_size,
- self.head_dim,
- self.total_num_heads,
- bias=config.attention_bias,
- quant_config=quant_config,
- )
- # Rotary embeddings.
- self.rotary_emb = get_rope(
- self.head_dim,
- rotary_dim=self.head_dim,
- max_position=self.max_position_embeddings,
- base=self.rope_theta,
- )
- self.scaling = self.head_dim**-0.5
- self.attn = Attention(self.num_heads,
- self.head_dim,
- scale=self.scaling,
- cache_config=cache_config,
- quant_config=quant_config)
- # Attention output projection.
- self.o_proj = RowParallelLinear(
- self.hidden_size,
- self.hidden_size,
- bias=config.attention_bias,
- quant_config=quant_config,
- )
- def forward(
- self,
- positions: torch.Tensor,
- hidden_states: torch.Tensor,
- kv_cache: torch.Tensor,
- attn_metadata: AttentionMetadata,
- ) -> torch.Tensor:
- qkv, _ = self.qkv_proj(hidden_states)
- if self.clip_qkv is not None:
- qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
- q, k, v = qkv.chunk(chunks=3, dim=-1)
- q, k = self.rotary_emb(positions, q, k)
- attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
- output, _ = self.o_proj(attn_output)
- return output
- class OlmoMLP(nn.Module):
- """
- This is the MLP block where the output is computed as
- ``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
- (plus another skip connection).
- """
- def __init__(
- self,
- config: OlmoConfig,
- quant_config: Optional[QuantizationConfig] = None,
- ):
- super().__init__()
- self.config = config
- self.hidden_size = config.hidden_size
- self.intermediate_size = config.intermediate_size
- # Feed-forward input projection.
- self.gate_up_proj = MergedColumnParallelLinear(
- self.hidden_size,
- [self.intermediate_size] * 2,
- bias=False,
- quant_config=quant_config,
- )
- # Activation function.
- self.act_fn = SiluAndMul()
- # Feed-forward output projection.
- self.down_proj = RowParallelLinear(
- self.intermediate_size,
- self.hidden_size,
- bias=False,
- quant_config=quant_config,
- )
- def forward(
- self,
- x: torch.Tensor,
- ) -> torch.Tensor:
- gate_up, _ = self.gate_up_proj(x)
- x = self.act_fn(gate_up)
- x, _ = self.down_proj(x)
- return x
- class OlmoDecoderLayer(nn.Module):
- """
- This is a typical transformer block where the output is
- computed as ``MLP(LN(x + Attention(LN(x))))``
- (plus another skip connection).
- """
- def __init__(self,
- config: OlmoConfig,
- cache_config: Optional[CacheConfig] = None,
- quant_config: Optional[QuantizationConfig] = None):
- super().__init__()
- # Attention block.
- self.self_attn = OlmoAttention(config, cache_config, quant_config)
- # MLP block.
- self.mlp = OlmoMLP(config, quant_config)
- # LayerNorm
- self.input_layernorm = nn.LayerNorm(config.hidden_size,
- elementwise_affine=False,
- bias=False)
- self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
- elementwise_affine=False,
- bias=False)
- def forward(
- self,
- positions: torch.Tensor,
- hidden_states: torch.Tensor,
- kv_cache: torch.Tensor,
- attn_metadata: AttentionMetadata,
- ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
- # Attention block.
- residual = hidden_states
- hidden_states = self.input_layernorm(hidden_states)
- hidden_states = self.self_attn(positions, hidden_states, kv_cache,
- attn_metadata)
- hidden_states = hidden_states + residual
- # MLP block.
- residual = hidden_states
- hidden_states = self.post_attention_layernorm(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = residual + hidden_states
- return hidden_states
- class OlmoModel(nn.Module):
- def __init__(self,
- config: OlmoConfig,
- cache_config: Optional[CacheConfig] = None,
- quant_config: Optional[QuantizationConfig] = None):
- super().__init__()
- self.config = config
- self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
- config.hidden_size)
- self.layers = nn.ModuleList([
- OlmoDecoderLayer(config, cache_config, quant_config)
- for layer_idx in range(config.num_hidden_layers)
- ])
- self.norm = nn.LayerNorm(config.hidden_size,
- elementwise_affine=False,
- bias=False)
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- kv_caches: List[torch.Tensor],
- attn_metadata: AttentionMetadata,
- ) -> torch.Tensor:
- """
- :param input_ids: A tensor of shape `(batch_size, seq_len)`.
- """
- # Get embeddings of input.
- # shape: (batch_size, seq_len, d_model)
- inputs_embeds = self.embed_tokens(input_ids)
- # embed positions
- hidden_states = inputs_embeds
- # Apply blocks one-by-one.
- for layer_idx, decoder_layer in enumerate(self.layers):
- # shape: (batch_size, seq_len, d_model)
- hidden_states = decoder_layer(
- positions,
- hidden_states,
- kv_caches[layer_idx],
- attn_metadata,
- )
- # Apply final layer norm.
- # shape: (batch_size, seq_len or 1, d_model)
- hidden_states = self.norm(hidden_states)
- return hidden_states
- class OlmoForCausalLM(nn.Module):
- """
- Extremely barebones HF model wrapper.
- """
- def __init__(self,
- config: OlmoConfig,
- cache_config: Optional[CacheConfig] = None,
- quant_config: Optional[QuantizationConfig] = None):
- super().__init__()
- self.config = config
- self.model = OlmoModel(config, cache_config, quant_config)
- if config.tie_word_embeddings:
- self.lm_head = self.model.embed_tokens
- else:
- self.unpadded_vocab_size = config.vocab_size
- self.lm_head = ParallelLMHead(
- self.unpadded_vocab_size,
- config.hidden_size,
- org_num_embeddings=config.vocab_size,
- quant_config=quant_config,
- )
- self.logits_processor = LogitsProcessor(config.vocab_size)
- self.sampler = Sampler()
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- kv_caches: List[torch.Tensor],
- attn_metadata: AttentionMetadata,
- intermediate_tensors: Optional[IntermediateTensors] = None,
- ) -> torch.Tensor:
- hidden_states = self.model(
- input_ids=input_ids,
- positions=positions,
- kv_caches=kv_caches,
- attn_metadata=attn_metadata,
- )
- return hidden_states
- def compute_logits(
- self,
- hidden_states: torch.Tensor,
- sampling_metadata: SamplingMetadata,
- ) -> Optional[torch.Tensor]:
- logits = self.logits_processor(self.lm_head, hidden_states,
- sampling_metadata)
- return logits
- def sample(
- self,
- logits: torch.Tensor,
- sampling_metadata: SamplingMetadata,
- ) -> Optional[SamplerOutput]:
- next_tokens = self.sampler(logits, sampling_metadata)
- return next_tokens
- def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
- stacked_params_mapping = [
- # (param_name, shard_name, shard_id)
- ("qkv_proj", "q_proj", "q"),
- ("qkv_proj", "k_proj", "k"),
- ("qkv_proj", "v_proj", "v"),
- ("gate_up_proj", "gate_proj", 0),
- ("gate_up_proj", "up_proj", 1),
- ]
- params_dict = dict(self.named_parameters(remove_duplicate=False))
- for name, loaded_weight in weights:
- if "rotary_emb.inv_freq" in name:
- continue
- if ("rotary_emb.cos_cached" in name
- or "rotary_emb.sin_cached" in name):
- # Models trained using ColossalAI may include these tensors in
- # the checkpoint. Skip them.
- continue
- # With tie_word_embeddings, we can skip lm_head.weight
- # The weight might appear unnecessarily in the files if the model is
- # processed with quantization, LoRA, fine-tuning, etc.
- if self.config.tie_word_embeddings and "lm_head.weight" in name:
- continue
- for (param_name, weight_name, shard_id) in stacked_params_mapping:
- if weight_name not in name:
- continue
- name = name.replace(weight_name, param_name)
- # Skip loading extra bias for GPTQ models.
- if name.endswith(".bias") and name not in params_dict:
- continue
- param = params_dict[name]
- weight_loader = param.weight_loader
- weight_loader(param, loaded_weight, shard_id)
- break
- else:
- # Skip loading extra bias for GPTQ models.
- if name.endswith(".bias") and name not in params_dict:
- continue
- param = params_dict[name]
- weight_loader = getattr(param, "weight_loader",
- default_weight_loader)
- weight_loader(param, loaded_weight)
|