|
@@ -1,65 +1,47 @@
|
|
# coding=utf-8
|
|
# coding=utf-8
|
|
# Adapted from
|
|
# Adapted from
|
|
-# https://github.com/allenai/OLMo/blob/v0.2.4/olmo/model.py and
|
|
|
|
-# https://github.com/allenai/OLMo/blob/v0.2.4/hf_olmo/modeling_olmo.py
|
|
|
|
-# Copyright 2023 The PygmalionAI team.
|
|
|
|
-# Copyright 2023 The vLLM team.
|
|
|
|
-# Copyright (c) Microsoft Corporation.
|
|
|
|
-# Licensed under the MIT license.
|
|
|
|
|
|
+# https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/olmo/modeling_olmo.py
|
|
|
|
+# Copyright 2024 The vLLM team.
|
|
|
|
+# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
#
|
|
-# BSD 3-Clause License
|
|
|
|
|
|
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
|
|
|
+# and OPT implementations in this library. It has been modified from its
|
|
|
|
+# original forms to accommodate minor architectural differences compared
|
|
|
|
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
|
#
|
|
#
|
|
-# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
|
|
|
|
-# All rights reserved.
|
|
|
|
|
|
+# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
+# you may not use this file except in compliance with the License.
|
|
|
|
+# You may obtain a copy of the License at
|
|
#
|
|
#
|
|
-# Redistribution and use in source and binary forms, with or without
|
|
|
|
-# modification, are permitted provided that the following conditions are met:
|
|
|
|
|
|
+# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
#
|
|
-# * Redistributions of source code must retain the above copyright notice, this
|
|
|
|
-# list of conditions and the following disclaimer.
|
|
|
|
-#
|
|
|
|
-# * Redistributions in binary form must reproduce the above copyright notice,
|
|
|
|
-# this list of conditions and the following disclaimer in the documentation
|
|
|
|
-# and/or other materials provided with the distribution.
|
|
|
|
-#
|
|
|
|
-# * Neither the name of the copyright holder nor the names of its
|
|
|
|
-# contributors may be used to endorse or promote products derived from
|
|
|
|
-# this software without specific prior written permission.
|
|
|
|
-#
|
|
|
|
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
|
|
-# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
|
|
-# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
|
|
-# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
|
|
-# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
|
|
-# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
|
|
-# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
|
|
-# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
|
|
-# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
|
|
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
|
|
+# Unless required by applicable law or agreed to in writing, software
|
|
|
|
+# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
+# See the License for the specific language governing permissions and
|
|
|
|
+# limitations under the License.
|
|
"""Inference-only OLMo model compatible with HuggingFace weights."""
|
|
"""Inference-only OLMo model compatible with HuggingFace weights."""
|
|
from typing import Iterable, List, Optional, Tuple
|
|
from typing import Iterable, List, Optional, Tuple
|
|
|
|
|
|
import torch
|
|
import torch
|
|
-# this model must need this dependency
|
|
|
|
-from hf_olmo import OLMoConfig
|
|
|
|
from torch import nn
|
|
from torch import nn
|
|
|
|
+from transformers import OlmoConfig
|
|
|
|
|
|
from aphrodite.attention import Attention, AttentionMetadata
|
|
from aphrodite.attention import Attention, AttentionMetadata
|
|
from aphrodite.common.sequence import SamplerOutput
|
|
from aphrodite.common.sequence import SamplerOutput
|
|
from aphrodite.distributed import get_tensor_model_parallel_world_size
|
|
from aphrodite.distributed import get_tensor_model_parallel_world_size
|
|
from aphrodite.modeling.layers.activation import SiluAndMul
|
|
from aphrodite.modeling.layers.activation import SiluAndMul
|
|
-from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
|
|
|
|
- LinearMethodBase,
|
|
|
|
- MergedColumnParallelLinear,
|
|
|
|
|
|
+from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
|
|
QKVParallelLinear,
|
|
QKVParallelLinear,
|
|
RowParallelLinear)
|
|
RowParallelLinear)
|
|
from aphrodite.modeling.layers.logits_processor import LogitsProcessor
|
|
from aphrodite.modeling.layers.logits_processor import LogitsProcessor
|
|
from aphrodite.modeling.layers.rotary_embedding import get_rope
|
|
from aphrodite.modeling.layers.rotary_embedding import get_rope
|
|
from aphrodite.modeling.layers.sampler import Sampler
|
|
from aphrodite.modeling.layers.sampler import Sampler
|
|
-from aphrodite.modeling.layers.vocab_parallel_embedding import \
|
|
|
|
- VocabParallelEmbedding
|
|
|
|
|
|
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
|
|
|
|
+ ParallelLMHead, VocabParallelEmbedding)
|
|
from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
|
|
from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
|
|
from aphrodite.modeling.sampling_metadata import SamplingMetadata
|
|
from aphrodite.modeling.sampling_metadata import SamplingMetadata
|
|
|
|
+from aphrodite.quantization.base_config import QuantizationConfig
|
|
|
|
|
|
|
|
|
|
class OlmoAttention(nn.Module):
|
|
class OlmoAttention(nn.Module):
|
|
@@ -71,56 +53,53 @@ class OlmoAttention(nn.Module):
|
|
|
|
|
|
def __init__(
|
|
def __init__(
|
|
self,
|
|
self,
|
|
- config: OLMoConfig,
|
|
|
|
- linear_method: Optional[LinearMethodBase] = None,
|
|
|
|
|
|
+ config: OlmoConfig,
|
|
|
|
+ quant_config: Optional[QuantizationConfig] = None,
|
|
):
|
|
):
|
|
super().__init__()
|
|
super().__init__()
|
|
self.config = config
|
|
self.config = config
|
|
- self.hidden_size = config.d_model
|
|
|
|
- assert config.d_model % config.n_heads == 0
|
|
|
|
|
|
+ self.hidden_size = config.hidden_size
|
|
tensor_model_parallel_world_size = (
|
|
tensor_model_parallel_world_size = (
|
|
get_tensor_model_parallel_world_size())
|
|
get_tensor_model_parallel_world_size())
|
|
- self.total_num_heads = self.config.n_heads
|
|
|
|
|
|
+ self.total_num_heads = config.num_attention_heads
|
|
|
|
+
|
|
|
|
+ assert self.hidden_size % self.total_num_heads == 0
|
|
assert self.total_num_heads % tensor_model_parallel_world_size == 0
|
|
assert self.total_num_heads % tensor_model_parallel_world_size == 0
|
|
|
|
+
|
|
self.num_heads = (self.total_num_heads //
|
|
self.num_heads = (self.total_num_heads //
|
|
tensor_model_parallel_world_size)
|
|
tensor_model_parallel_world_size)
|
|
self.head_dim = self.hidden_size // self.total_num_heads
|
|
self.head_dim = self.hidden_size // self.total_num_heads
|
|
|
|
+ self.max_position_embeddings = config.max_position_embeddings
|
|
|
|
+ self.rope_theta = config.rope_theta
|
|
|
|
+ self.clip_qkv = config.clip_qkv
|
|
|
|
|
|
- # Layer norms.
|
|
|
|
- self.attn_norm = nn.LayerNorm(config.d_model,
|
|
|
|
- elementwise_affine=False,
|
|
|
|
- bias=False)
|
|
|
|
# Attention input projection. Projects x -> (q, k, v)
|
|
# Attention input projection. Projects x -> (q, k, v)
|
|
- self.att_proj = QKVParallelLinear(
|
|
|
|
- config.d_model,
|
|
|
|
|
|
+ self.qkv_proj = QKVParallelLinear(
|
|
|
|
+ self.hidden_size,
|
|
self.head_dim,
|
|
self.head_dim,
|
|
self.total_num_heads,
|
|
self.total_num_heads,
|
|
- bias=config.include_bias,
|
|
|
|
- linear_method=linear_method,
|
|
|
|
|
|
+ bias=config.attention_bias,
|
|
|
|
+ quant_config=quant_config,
|
|
)
|
|
)
|
|
|
|
|
|
# Rotary embeddings.
|
|
# Rotary embeddings.
|
|
- if self.config.rope:
|
|
|
|
- rope_theta = getattr(config, "rope_theta", 10000)
|
|
|
|
- max_position_embeddings = getattr(config,
|
|
|
|
- "max_position_embeddings", 8192)
|
|
|
|
- self.rotary_emb = get_rope(
|
|
|
|
- self.head_dim,
|
|
|
|
- rotary_dim=self.head_dim,
|
|
|
|
- max_position=max_position_embeddings,
|
|
|
|
- base=rope_theta,
|
|
|
|
- )
|
|
|
|
|
|
+ self.rotary_emb = get_rope(
|
|
|
|
+ self.head_dim,
|
|
|
|
+ rotary_dim=self.head_dim,
|
|
|
|
+ max_position=self.max_position_embeddings,
|
|
|
|
+ base=self.rope_theta,
|
|
|
|
+ )
|
|
self.scaling = self.head_dim**-0.5
|
|
self.scaling = self.head_dim**-0.5
|
|
self.attn = Attention(self.num_heads,
|
|
self.attn = Attention(self.num_heads,
|
|
self.head_dim,
|
|
self.head_dim,
|
|
scale=self.scaling)
|
|
scale=self.scaling)
|
|
|
|
|
|
# Attention output projection.
|
|
# Attention output projection.
|
|
- self.attn_out = RowParallelLinear(
|
|
|
|
- config.d_model,
|
|
|
|
- config.d_model,
|
|
|
|
- bias=config.include_bias,
|
|
|
|
- linear_method=linear_method,
|
|
|
|
|
|
+ self.o_proj = RowParallelLinear(
|
|
|
|
+ self.hidden_size,
|
|
|
|
+ self.hidden_size,
|
|
|
|
+ bias=config.attention_bias,
|
|
|
|
+ quant_config=quant_config,
|
|
)
|
|
)
|
|
|
|
|
|
def forward(
|
|
def forward(
|
|
@@ -130,13 +109,13 @@ class OlmoAttention(nn.Module):
|
|
kv_cache: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
attn_metadata: AttentionMetadata,
|
|
attn_metadata: AttentionMetadata,
|
|
) -> torch.Tensor:
|
|
) -> torch.Tensor:
|
|
- hidden_states = self.attn_norm(hidden_states)
|
|
|
|
- qkv, _ = self.att_proj(hidden_states)
|
|
|
|
|
|
+ qkv, _ = self.qkv_proj(hidden_states)
|
|
|
|
+ if self.clip_qkv is not None:
|
|
|
|
+ qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
|
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
|
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
|
- if self.config.rope:
|
|
|
|
- q, k = self.rotary_emb(positions, q, k)
|
|
|
|
|
|
+ q, k = self.rotary_emb(positions, q, k)
|
|
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
|
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
|
- output, _ = self.attn_out(attn_output)
|
|
|
|
|
|
+ output, _ = self.o_proj(attn_output)
|
|
return output
|
|
return output
|
|
|
|
|
|
|
|
|
|
@@ -149,57 +128,44 @@ class OlmoMLP(nn.Module):
|
|
|
|
|
|
def __init__(
|
|
def __init__(
|
|
self,
|
|
self,
|
|
- config: OLMoConfig,
|
|
|
|
- linear_method: Optional[LinearMethodBase] = None,
|
|
|
|
|
|
+ config: OlmoConfig,
|
|
|
|
+ quant_config: Optional[QuantizationConfig] = None,
|
|
):
|
|
):
|
|
super().__init__()
|
|
super().__init__()
|
|
self.config = config
|
|
self.config = config
|
|
- self.hidden_size = (config.mlp_hidden_size if config.mlp_hidden_size
|
|
|
|
- is not None else config.mlp_ratio * config.d_model)
|
|
|
|
-
|
|
|
|
- # Layer norms.
|
|
|
|
- self.ff_norm = nn.LayerNorm(config.d_model,
|
|
|
|
- elementwise_affine=False,
|
|
|
|
- bias=False)
|
|
|
|
|
|
+ self.hidden_size = config.hidden_size
|
|
|
|
+ self.intermediate_size = config.intermediate_size
|
|
|
|
|
|
# Feed-forward input projection.
|
|
# Feed-forward input projection.
|
|
- self.ff_proj = MergedColumnParallelLinear(
|
|
|
|
- config.d_model,
|
|
|
|
- [self.hidden_size // 2] * 2,
|
|
|
|
- bias=config.include_bias,
|
|
|
|
- linear_method=linear_method,
|
|
|
|
|
|
+ self.gate_up_proj = MergedColumnParallelLinear(
|
|
|
|
+ self.hidden_size,
|
|
|
|
+ [self.intermediate_size] * 2,
|
|
|
|
+ bias=False,
|
|
|
|
+ quant_config=quant_config,
|
|
)
|
|
)
|
|
|
|
|
|
# Activation function.
|
|
# Activation function.
|
|
- self.act = SiluAndMul()
|
|
|
|
- self.act.output_multiplier = 0.5
|
|
|
|
- assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
|
|
|
|
|
|
+ self.act_fn = SiluAndMul()
|
|
|
|
|
|
# Feed-forward output projection.
|
|
# Feed-forward output projection.
|
|
- self.ff_out = RowParallelLinear(
|
|
|
|
- int(self.act.output_multiplier * self.hidden_size),
|
|
|
|
- config.d_model,
|
|
|
|
- bias=config.include_bias,
|
|
|
|
- linear_method=linear_method,
|
|
|
|
|
|
+ self.down_proj = RowParallelLinear(
|
|
|
|
+ self.intermediate_size,
|
|
|
|
+ self.hidden_size,
|
|
|
|
+ bias=False,
|
|
|
|
+ quant_config=quant_config,
|
|
)
|
|
)
|
|
|
|
|
|
def forward(
|
|
def forward(
|
|
self,
|
|
self,
|
|
x: torch.Tensor,
|
|
x: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
) -> torch.Tensor:
|
|
- # Add feed-forward projection.
|
|
|
|
- # shape: (batch_size, seq_len, d_model)
|
|
|
|
- og_x = x
|
|
|
|
- x = self.ff_norm(x)
|
|
|
|
- x, _ = self.ff_proj(x)
|
|
|
|
- x = self.act(x)
|
|
|
|
- x, _ = self.ff_out(x)
|
|
|
|
- x = og_x + x
|
|
|
|
-
|
|
|
|
|
|
+ gate_up, _ = self.gate_up_proj(x)
|
|
|
|
+ x = self.act_fn(gate_up)
|
|
|
|
+ x, _ = self.down_proj(x)
|
|
return x
|
|
return x
|
|
|
|
|
|
|
|
|
|
-class OlmoBlock(nn.Module):
|
|
|
|
|
|
+class OlmoDecoderLayer(nn.Module):
|
|
"""
|
|
"""
|
|
This is a typical transformer block where the output is
|
|
This is a typical transformer block where the output is
|
|
computed as ``MLP(LN(x + Attention(LN(x))))``
|
|
computed as ``MLP(LN(x + Attention(LN(x))))``
|
|
@@ -207,14 +173,22 @@ class OlmoBlock(nn.Module):
|
|
"""
|
|
"""
|
|
|
|
|
|
def __init__(self,
|
|
def __init__(self,
|
|
- config: OLMoConfig,
|
|
|
|
- linear_method: Optional[LinearMethodBase] = None):
|
|
|
|
|
|
+ config: OlmoConfig,
|
|
|
|
+ quant_config: Optional[QuantizationConfig] = None):
|
|
super().__init__()
|
|
super().__init__()
|
|
# Attention block.
|
|
# Attention block.
|
|
- self.attn = OlmoAttention(config, linear_method)
|
|
|
|
|
|
+ self.self_attn = OlmoAttention(config, quant_config)
|
|
|
|
|
|
# MLP block.
|
|
# MLP block.
|
|
- self.mlp = OlmoMLP(config, linear_method)
|
|
|
|
|
|
+ self.mlp = OlmoMLP(config, quant_config)
|
|
|
|
+
|
|
|
|
+ # LayerNorm
|
|
|
|
+ self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
|
|
|
+ elementwise_affine=False,
|
|
|
|
+ bias=False)
|
|
|
|
+ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
|
|
|
|
+ elementwise_affine=False,
|
|
|
|
+ bias=False)
|
|
|
|
|
|
def forward(
|
|
def forward(
|
|
self,
|
|
self,
|
|
@@ -224,52 +198,37 @@ class OlmoBlock(nn.Module):
|
|
attn_metadata: AttentionMetadata,
|
|
attn_metadata: AttentionMetadata,
|
|
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
|
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
|
# Attention block.
|
|
# Attention block.
|
|
- og_x = hidden_states
|
|
|
|
- x = self.attn(positions, hidden_states, kv_cache, attn_metadata)
|
|
|
|
- x = x + og_x
|
|
|
|
|
|
+ residual = hidden_states
|
|
|
|
+ hidden_states = self.input_layernorm(hidden_states)
|
|
|
|
+ hidden_states = self.self_attn(positions, hidden_states, kv_cache,
|
|
|
|
+ attn_metadata)
|
|
|
|
+ hidden_states = hidden_states + residual
|
|
|
|
|
|
# MLP block.
|
|
# MLP block.
|
|
- hidden_states = self.mlp(x)
|
|
|
|
|
|
+ residual = hidden_states
|
|
|
|
+ hidden_states = self.post_attention_layernorm(hidden_states)
|
|
|
|
+ hidden_states = self.mlp(hidden_states)
|
|
|
|
+ hidden_states = residual + hidden_states
|
|
return hidden_states
|
|
return hidden_states
|
|
|
|
|
|
|
|
|
|
class OlmoModel(nn.Module):
|
|
class OlmoModel(nn.Module):
|
|
|
|
|
|
def __init__(self,
|
|
def __init__(self,
|
|
- config: OLMoConfig,
|
|
|
|
- linear_method: Optional[LinearMethodBase] = None):
|
|
|
|
|
|
+ config: OlmoConfig,
|
|
|
|
+ quant_config: Optional[QuantizationConfig] = None):
|
|
super().__init__()
|
|
super().__init__()
|
|
self.config = config
|
|
self.config = config
|
|
|
|
|
|
- self.transformer = nn.ModuleDict(
|
|
|
|
- dict(
|
|
|
|
- wte=VocabParallelEmbedding(
|
|
|
|
- config.embedding_size or config.vocab_size,
|
|
|
|
- config.d_model,
|
|
|
|
- ),
|
|
|
|
- ln_f=nn.LayerNorm(config.d_model,
|
|
|
|
- elementwise_affine=False,
|
|
|
|
- bias=False),
|
|
|
|
- ))
|
|
|
|
-
|
|
|
|
- blocks = [
|
|
|
|
- OlmoBlock(config, linear_method) for i in range(config.n_layers)
|
|
|
|
- ]
|
|
|
|
- if self.config.block_group_size > 1:
|
|
|
|
- raise NotImplementedError("Block group size > 1 not supported yet")
|
|
|
|
- else:
|
|
|
|
- self.transformer.update({"blocks": nn.ModuleList(blocks)})
|
|
|
|
-
|
|
|
|
- if not config.weight_tying:
|
|
|
|
- self.transformer.update({
|
|
|
|
- "ff_out":
|
|
|
|
- ColumnParallelLinear(
|
|
|
|
- config.d_model,
|
|
|
|
- config.embedding_size or config.vocab_size,
|
|
|
|
- bias=config.include_bias,
|
|
|
|
- linear_method=linear_method,
|
|
|
|
- )
|
|
|
|
- })
|
|
|
|
|
|
+ self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
|
|
|
+ config.hidden_size)
|
|
|
|
+ self.layers = nn.ModuleList([
|
|
|
|
+ OlmoDecoderLayer(config, quant_config)
|
|
|
|
+ for layer_idx in range(config.num_hidden_layers)
|
|
|
|
+ ])
|
|
|
|
+ self.norm = nn.LayerNorm(config.hidden_size,
|
|
|
|
+ elementwise_affine=False,
|
|
|
|
+ bias=False)
|
|
|
|
|
|
def forward(
|
|
def forward(
|
|
self,
|
|
self,
|
|
@@ -283,39 +242,48 @@ class OlmoModel(nn.Module):
|
|
"""
|
|
"""
|
|
# Get embeddings of input.
|
|
# Get embeddings of input.
|
|
# shape: (batch_size, seq_len, d_model)
|
|
# shape: (batch_size, seq_len, d_model)
|
|
- x = self.transformer.wte(input_ids) # type: ignore
|
|
|
|
|
|
+ inputs_embeds = self.embed_tokens(input_ids)
|
|
|
|
+
|
|
|
|
+ # embed positions
|
|
|
|
+ hidden_states = inputs_embeds
|
|
|
|
|
|
# Apply blocks one-by-one.
|
|
# Apply blocks one-by-one.
|
|
- for block_idx, block in enumerate(self.transformer.blocks):
|
|
|
|
|
|
+ for layer_idx, decoder_layer in enumerate(self.layers):
|
|
# shape: (batch_size, seq_len, d_model)
|
|
# shape: (batch_size, seq_len, d_model)
|
|
- x = block(
|
|
|
|
|
|
+ hidden_states = decoder_layer(
|
|
positions,
|
|
positions,
|
|
- x,
|
|
|
|
- kv_caches[block_idx],
|
|
|
|
|
|
+ hidden_states,
|
|
|
|
+ kv_caches[layer_idx],
|
|
attn_metadata,
|
|
attn_metadata,
|
|
)
|
|
)
|
|
|
|
|
|
# Apply final layer norm.
|
|
# Apply final layer norm.
|
|
# shape: (batch_size, seq_len or 1, d_model)
|
|
# shape: (batch_size, seq_len or 1, d_model)
|
|
- x = self.transformer.ln_f(x) # type: ignore
|
|
|
|
- return x
|
|
|
|
|
|
+ hidden_states = self.norm(hidden_states)
|
|
|
|
+ return hidden_states
|
|
|
|
|
|
|
|
|
|
-class OLMoForCausalLM(nn.Module):
|
|
|
|
|
|
+class OlmoForCausalLM(nn.Module):
|
|
"""
|
|
"""
|
|
Extremely barebones HF model wrapper.
|
|
Extremely barebones HF model wrapper.
|
|
"""
|
|
"""
|
|
|
|
|
|
def __init__(self,
|
|
def __init__(self,
|
|
- config: OLMoConfig,
|
|
|
|
- linear_method: Optional[LinearMethodBase] = None):
|
|
|
|
|
|
+ config: OlmoConfig,
|
|
|
|
+ quant_config: Optional[QuantizationConfig] = None):
|
|
super().__init__()
|
|
super().__init__()
|
|
self.config = config
|
|
self.config = config
|
|
- self.linear_method = linear_method
|
|
|
|
- self.model = OlmoModel(config, linear_method)
|
|
|
|
- self.lm_head_weight = (self.model.transformer.wte.weight
|
|
|
|
- if config.weight_tying else
|
|
|
|
- self.model.transformer.ff_out.weight)
|
|
|
|
|
|
+ self.model = OlmoModel(config, quant_config)
|
|
|
|
+ if config.tie_word_embeddings:
|
|
|
|
+ self.lm_head_weight = self.model.embed_tokens.weight
|
|
|
|
+ else:
|
|
|
|
+ self.unpadded_vocab_size = config.vocab_size
|
|
|
|
+ self.lm_head = ParallelLMHead(
|
|
|
|
+ self.unpadded_vocab_size,
|
|
|
|
+ config.hidden_size,
|
|
|
|
+ org_num_embeddings=config.vocab_size,
|
|
|
|
+ )
|
|
|
|
+ self.lm_head_weight = self.lm_head.weight
|
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
|
self.sampler = Sampler()
|
|
self.sampler = Sampler()
|
|
|
|
|
|
@@ -349,20 +317,39 @@ class OLMoForCausalLM(nn.Module):
|
|
return next_tokens
|
|
return next_tokens
|
|
|
|
|
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
|
|
+ stacked_params_mapping = [
|
|
|
|
+ # (param_name, shard_name, shard_id)
|
|
|
|
+ ("qkv_proj", "q_proj", "q"),
|
|
|
|
+ ("qkv_proj", "k_proj", "k"),
|
|
|
|
+ ("qkv_proj", "v_proj", "v"),
|
|
|
|
+ ("gate_up_proj", "gate_proj", 0),
|
|
|
|
+ ("gate_up_proj", "up_proj", 1),
|
|
|
|
+ ]
|
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
|
for name, loaded_weight in weights:
|
|
for name, loaded_weight in weights:
|
|
- # attention
|
|
|
|
- if ".att" in name:
|
|
|
|
- name = name.replace(".att", ".attn.att")
|
|
|
|
- # mlp
|
|
|
|
- if ".ff_proj" in name:
|
|
|
|
- name = name.replace(".ff_proj", ".mlp.ff_proj")
|
|
|
|
- # Reverse the weight for the MergeColumnParallelLinear
|
|
|
|
- loaded_weight = torch.concat(loaded_weight.chunk(2)[::-1])
|
|
|
|
- if ".ff_out" in name and "transformer.ff_out" not in name:
|
|
|
|
- name = name.replace(".ff_out", ".mlp.ff_out")
|
|
|
|
- # there is no bias in olmo
|
|
|
|
- param = params_dict[name]
|
|
|
|
- weight_loader = getattr(param, "weight_loader",
|
|
|
|
- default_weight_loader)
|
|
|
|
- weight_loader(param, loaded_weight)
|
|
|
|
|
|
+ if "rotary_emb.inv_freq" in name:
|
|
|
|
+ continue
|
|
|
|
+ if ("rotary_emb.cos_cached" in name
|
|
|
|
+ or "rotary_emb.sin_cached" in name):
|
|
|
|
+ # Models trained using ColossalAI may include these tensors in
|
|
|
|
+ # the checkpoint. Skip them.
|
|
|
|
+ continue
|
|
|
|
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
|
|
|
+ if weight_name not in name:
|
|
|
|
+ continue
|
|
|
|
+ name = name.replace(weight_name, param_name)
|
|
|
|
+ # Skip loading extra bias for GPTQ models.
|
|
|
|
+ if name.endswith(".bias") and name not in params_dict:
|
|
|
|
+ continue
|
|
|
|
+ param = params_dict[name]
|
|
|
|
+ weight_loader = param.weight_loader
|
|
|
|
+ weight_loader(param, loaded_weight, shard_id)
|
|
|
|
+ break
|
|
|
|
+ else:
|
|
|
|
+ # Skip loading extra bias for GPTQ models.
|
|
|
|
+ if name.endswith(".bias") and name not in params_dict:
|
|
|
|
+ continue
|
|
|
|
+ param = params_dict[name]
|
|
|
|
+ weight_loader = getattr(param, "weight_loader",
|
|
|
|
+ default_weight_loader)
|
|
|
|
+ weight_loader(param, loaded_weight)
|