123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398 |
- # coding=utf-8
- # Adapted from
- # https://github.com/allenai/OLMo/blob/v0.2.4/olmo/model.py and
- # https://github.com/allenai/OLMo/blob/v0.2.4/hf_olmo/modeling_olmo.py
- # Copyright 2023 The PygmalionAI team.
- # Copyright 2023 The vLLM team.
- # Copyright (c) Microsoft Corporation.
- # Licensed under the MIT license.
- #
- # BSD 3-Clause License
- #
- # Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
- # All rights reserved.
- #
- # Redistribution and use in source and binary forms, with or without
- # modification, are permitted provided that the following conditions are met:
- #
- # * Redistributions of source code must retain the above copyright notice, this
- # list of conditions and the following disclaimer.
- #
- # * Redistributions in binary form must reproduce the above copyright notice,
- # this list of conditions and the following disclaimer in the documentation
- # and/or other materials provided with the distribution.
- #
- # * Neither the name of the copyright holder nor the names of its
- # contributors may be used to endorse or promote products derived from
- # this software without specific prior written permission.
- #
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
- """Inference-only OLMo model compatible with HuggingFace weights."""
- from typing import List, Optional, Tuple
- import torch
- import torch.nn.functional as F
- from torch import nn
- from aphrodite.attention import Attention, AttentionMetadata
- from aphrodite.modeling.layers.linear import (
- ColumnParallelLinear,
- LinearMethodBase,
- QKVParallelLinear,
- RowParallelLinear,
- )
- from aphrodite.modeling.layers.rotary_embedding import get_rope
- from aphrodite.modeling.layers.logits_processor import LogitsProcessor
- from aphrodite.modeling.layers.sampler import Sampler
- from aphrodite.modeling.layers.vocab_parallel_embedding import (
- VocabParallelEmbedding,
- ParallelLMHead,
- )
- from aphrodite.distributed import (
- get_tensor_model_parallel_world_size, )
- from aphrodite.modeling.sampling_metadata import SamplingMetadata
- from aphrodite.modeling.hf_downloader import (
- default_weight_loader,
- hf_model_weights_iterator,
- )
- from aphrodite.common.sequence import SamplerOutput
- from aphrodite.transformers_utils.configs.olmo import OLMoConfig
- class SwiGLU(nn.Module):
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x, gate = x.chunk(2, dim=-1)
- return F.silu(gate) * x
- @property
- def output_multiplier(self) -> float:
- return 0.5
- class OlmoAttention(nn.Module):
- """
- This is the attention block where the output is computed as
- ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
- (plus another skip connection).
- """
- def __init__(
- self,
- config: OLMoConfig,
- linear_method: Optional[LinearMethodBase] = None,
- ):
- super().__init__()
- self.config = config
- self.hidden_size = config.d_model
- assert config.d_model % config.n_heads == 0
- tensor_model_parallel_world_size = (
- get_tensor_model_parallel_world_size())
- self.total_num_heads = self.config.n_heads
- assert self.total_num_heads % tensor_model_parallel_world_size == 0
- self.num_heads = (self.total_num_heads //
- tensor_model_parallel_world_size)
- self.head_dim = self.hidden_size // self.total_num_heads
- # Layer norms.
- self.attn_norm = nn.LayerNorm(config.d_model,
- elementwise_affine=False,
- bias=False)
- # Attention input projection. Projects x -> (q, k, v)
- self.att_proj = QKVParallelLinear(
- config.d_model,
- self.head_dim,
- self.total_num_heads,
- bias=config.include_bias,
- linear_method=linear_method,
- )
- # Rotary embeddings.
- if self.config.rope:
- rope_theta = getattr(config, "rope_theta", 10000)
- max_position_embeddings = getattr(config,
- "max_position_embeddings", 8192)
- self.rotary_emb = get_rope(
- self.head_dim,
- rotary_dim=self.head_dim,
- max_position=max_position_embeddings,
- base=rope_theta,
- )
- self.scaling = self.head_dim**-0.5
- self.attn = Attention(self.num_heads,
- self.head_dim,
- scale=self.scaling)
- # Attention output projection.
- self.attn_out = RowParallelLinear(
- config.d_model,
- config.d_model,
- bias=config.include_bias,
- linear_method=linear_method,
- )
- def forward(
- self,
- positions: torch.Tensor,
- hidden_states: torch.Tensor,
- kv_cache: torch.Tensor,
- attn_metadata: AttentionMetadata,
- ) -> torch.Tensor:
- hidden_states = self.attn_norm(hidden_states)
- qkv, _ = self.att_proj(hidden_states)
- q, k, v = qkv.chunk(chunks=3, dim=-1)
- if self.config.rope:
- q, k = self.rotary_emb(positions, q, k)
- attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
- output, _ = self.attn_out(attn_output)
- return output
- class OlmoMLP(nn.Module):
- """
- This is the MLP block where the output is computed as
- ``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
- (plus another skip connection).
- """
- def __init__(
- self,
- config: OLMoConfig,
- linear_method: Optional[LinearMethodBase] = None,
- ):
- super().__init__()
- self.config = config
- self.hidden_size = (config.mlp_hidden_size if config.mlp_hidden_size
- is not None else config.mlp_ratio * config.d_model)
- # Layer norms.
- self.ff_norm = nn.LayerNorm(config.d_model,
- elementwise_affine=False,
- bias=False)
- # Feed-forward input projection.
- self.ff_proj = ColumnParallelLinear(
- config.d_model,
- self.hidden_size,
- bias=config.include_bias,
- linear_method=linear_method,
- )
- # Activation function.
- # self.act = SiluAndMul()
- # self.act.output_multiplier = 0.5
- self.act = SwiGLU()
- assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
- # Feed-forward output projection.
- self.ff_out = RowParallelLinear(
- int(self.act.output_multiplier * self.hidden_size),
- config.d_model,
- bias=config.include_bias,
- linear_method=linear_method,
- )
- def forward(
- self,
- x: torch.Tensor,
- ) -> torch.Tensor:
- # Add feed-forward projection.
- # shape: (batch_size, seq_len, d_model)
- og_x = x
- x = self.ff_norm(x)
- x, _ = self.ff_proj(x)
- x = self.act(x)
- x, _ = self.ff_out(x)
- x = og_x + x
- return x
- class OlmoBlock(nn.Module):
- """
- This is a typical transformer block where the output is computed as
- ``MLP(LN(x + Attention(LN(x))))``
- (plus another skip connection).
- """
- def __init__(
- self,
- config: OLMoConfig,
- linear_method: Optional[LinearMethodBase] = None,
- ):
- super().__init__()
- # Attention block.
- self.attn = OlmoAttention(config, linear_method)
- # MLP block.
- self.mlp = OlmoMLP(config, linear_method)
- def forward(
- self,
- positions: torch.Tensor,
- hidden_states: torch.Tensor,
- kv_cache: torch.Tensor,
- attn_metadata: AttentionMetadata,
- ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
- # Attention block.
- og_x = hidden_states
- x = self.attn(positions, hidden_states, kv_cache, attn_metadata)
- x = x + og_x
- # MLP block.
- hidden_states = self.mlp(x)
- return hidden_states
- class OlmoModel(nn.Module):
- def __init__(
- self,
- config: OLMoConfig,
- linear_method: Optional[LinearMethodBase] = None,
- ):
- super().__init__()
- self.config = config
- self.transformer = nn.ModuleDict(
- dict(
- wte=VocabParallelEmbedding(
- config.embedding_size or config.vocab_size,
- config.d_model,
- linear_method=linear_method,
- ),
- ln_f=nn.LayerNorm(config.d_model,
- elementwise_affine=False,
- bias=False),
- ff_out=ParallelLMHead(
- config.embedding_size or config.vocab_size,
- config.d_model,
- bias=config.include_bias,
- linear_method=linear_method,
- ),
- ))
- blocks = [
- OlmoBlock(config, linear_method) for i in range(config.n_layers)
- ]
- if self.config.block_group_size > 1:
- raise NotImplementedError("Block group size > 1 not supported yet")
- else:
- self.transformer.update({"blocks": nn.ModuleList(blocks)})
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- kv_caches: List[torch.Tensor],
- attn_metadata: AttentionMetadata,
- ) -> torch.Tensor:
- """
- :param input_ids: A tensor of shape `(batch_size, seq_len)`.
- """
- # Get embeddings of input.
- # shape: (batch_size, seq_len, d_model)
- x = self.transformer.wte(input_ids) # type: ignore
- # Apply blocks one-by-one.
- for block_idx, block in enumerate(self.transformer.blocks):
- # shape: (batch_size, seq_len, d_model)
- x = block(
- positions,
- x,
- kv_caches[block_idx],
- attn_metadata,
- )
- # Apply final layer norm.
- # shape: (batch_size, seq_len or 1, d_model)
- x = self.transformer.ln_f(x) # type: ignore
- return x
- class OLMoForCausalLM(nn.Module):
- """
- Extremely barebones HF model wrapper.
- """
- def __init__(
- self,
- config: OLMoConfig,
- linear_method: Optional[LinearMethodBase] = None,
- ):
- super().__init__()
- self.config = config
- self.linear_method = linear_method
- self.model = OlmoModel(config, linear_method)
- self.logits_processor = LogitsProcessor(config.vocab_size)
- self.sampler = Sampler()
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- kv_caches: List[torch.Tensor],
- attn_metadata: AttentionMetadata,
- ) -> torch.Tensor:
- hidden_states = self.model(
- input_ids=input_ids,
- positions=positions,
- kv_caches=kv_caches,
- attn_metadata=attn_metadata,
- )
- return hidden_states
- def compute_logits(self, hidden_states: torch.Tensor,
- sampling_metadata: SamplingMetadata) -> torch.Tensor:
- logits = self.logits_processor(self.model.transformer.ff_out,
- hidden_states, sampling_metadata)
- return logits
- def sample(
- self,
- logits: torch.Tensor,
- sampling_metadata: SamplingMetadata,
- ) -> Optional[SamplerOutput]:
- next_tokens = self.sampler(logits, sampling_metadata)
- return next_tokens
- def load_weights(
- self,
- model_name_or_path: str,
- cache_dir: Optional[str] = None,
- load_format: str = "auto",
- revision: Optional[str] = None,
- ):
- params_dict = dict(self.named_parameters(remove_duplicate=False))
- for name, loaded_weight in hf_model_weights_iterator(
- model_name_or_path, cache_dir, load_format, revision):
- if "wte" in name and self.config.weight_tying:
- # Copy word embedding to lm_head
- head_name = name.replace("model.transformer.wte",
- "model.transformer.ff_out")
- if head_name in params_dict:
- lm_head_param = params_dict[head_name]
- weight_loader = getattr(lm_head_param, "weight_loader",
- default_weight_loader)
- weight_loader(lm_head_param, loaded_weight)
- # attention
- if ".att" in name:
- name = name.replace(".att", ".attn.att")
- # mlp
- if ".ff" in name and "transformer.ff_out" not in name:
- name = name.replace(".ff", ".mlp.ff")
- # there is no bias in olmo
- param = params_dict[name]
- weight_loader = getattr(param, "weight_loader",
- default_weight_loader)
- weight_loader(param, loaded_weight)
|