Browse Source

feat: refactor modeling logic and support more models (#274)

* feat: refactor llama models and add new models

* add geglu kernels

* add geglu activation layer

* formatting
AlpinDale 1 year ago
parent
commit
e31c6f0b45
38 changed files with 5141 additions and 363 deletions
  1. 1 0
      .pylintrc
  2. 21 0
      aphrodite/modeling/layers/activation.py
  3. 25 0
      aphrodite/modeling/layers/layernorm.py
  4. 36 10
      aphrodite/modeling/models/__init__.py
  5. 412 0
      aphrodite/modeling/models/baichuan.py
  6. 340 0
      aphrodite/modeling/models/bloom.py
  7. 378 0
      aphrodite/modeling/models/chatglm.py
  8. 6 2
      aphrodite/modeling/models/decilm.py
  9. 23 72
      aphrodite/modeling/models/deepseek.py
  10. 446 0
      aphrodite/modeling/models/falcon.py
  11. 65 62
      aphrodite/modeling/models/gemma.py
  12. 286 0
      aphrodite/modeling/models/gpt2.py
  13. 292 0
      aphrodite/modeling/models/gpt_bigcode.py
  14. 8 13
      aphrodite/modeling/models/gpt_j.py
  15. 6 10
      aphrodite/modeling/models/gpt_neox.py
  16. 352 0
      aphrodite/modeling/models/internlm2.py
  17. 45 14
      aphrodite/modeling/models/llama.py
  18. 62 18
      aphrodite/modeling/models/mixtral.py
  19. 7 6
      aphrodite/modeling/models/mixtral_quant.py
  20. 307 0
      aphrodite/modeling/models/mpt.py
  21. 377 0
      aphrodite/modeling/models/olmo.py
  22. 0 1
      aphrodite/modeling/models/phi.py
  23. 315 0
      aphrodite/modeling/models/qwen.py
  24. 35 52
      aphrodite/modeling/models/qwen2.py
  25. 352 0
      aphrodite/modeling/models/stablelm.py
  26. 37 24
      aphrodite/transformers_utils/config.py
  27. 13 2
      aphrodite/transformers_utils/configs/__init__.py
  28. 62 0
      aphrodite/transformers_utils/configs/baichuan.py
  29. 68 0
      aphrodite/transformers_utils/configs/chatglm.py
  30. 88 0
      aphrodite/transformers_utils/configs/falcon.py
  31. 233 0
      aphrodite/transformers_utils/configs/mpt.py
  32. 72 0
      aphrodite/transformers_utils/configs/olmo.py
  33. 49 52
      aphrodite/transformers_utils/tokenizer.py
  34. 5 0
      aphrodite/transformers_utils/tokenizers/__init__.py
  35. 261 0
      aphrodite/transformers_utils/tokenizers/baichuan.py
  36. 48 25
      kernels/activation_kernels.cu
  37. 4 0
      kernels/ops.h
  38. 4 0
      kernels/pybind.cpp

+ 1 - 0
.pylintrc

@@ -61,6 +61,7 @@ disable=abstract-method,
         c-extension-no-member,
         c-extension-no-member,
         consider-using-enumerate,
         consider-using-enumerate,
         cmp-builtin,
         cmp-builtin,
+        inconsistent-quotes,
         cmp-method,
         cmp-method,
         coerce-builtin,
         coerce-builtin,
         coerce-method,
         coerce-method,

+ 21 - 0
aphrodite/modeling/layers/activation.py

@@ -37,6 +37,27 @@ class SiluAndMul(nn.Module):
         return out
         return out
 
 
 
 
+class GeluAndMul(nn.Module):
+    """An activation function for GeGLU.
+    The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
+    Shapes:
+        x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
+        return: (batch_size, seq_len, d) or (num_tokens, d)
+    """
+
+    def _forward(self, x: torch.Tensor) -> torch.Tensor:
+        """PyTorch-native implementation equivalent to forward()."""
+        d = x.shape[-1] // 2
+        return F.gelu(x[..., :d]) * x[..., d:]
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        d = x.shape[-1] // 2
+        output_shape = (x.shape[:-1] + (d, ))
+        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
+        ops.gelu_and_mul(out, x)
+        return out
+
+
 class NewGELU(nn.Module):
 class NewGELU(nn.Module):
 
 
     def _forward(self, x: torch.Tensor) -> torch.Tensor:
     def _forward(self, x: torch.Tensor) -> torch.Tensor:

+ 25 - 0
aphrodite/modeling/layers/layernorm.py

@@ -7,6 +7,31 @@ import torch.nn as nn
 from aphrodite._C import ops
 from aphrodite._C import ops
 
 
 
 
+class LayerNorm(nn.LayerNorm):
+
+    def __init__(
+        self,
+        hidden_size: int,
+        eps: float = 1e-6,
+    ) -> None:
+        super().__init__(hidden_size, eps=eps)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        residual: Optional[torch.Tensor] = None,
+    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+        """normalization."""
+        if residual is not None:
+            x = x + residual
+            residual = x
+        x = super().forward(x)
+        if residual is None:
+            return x
+        else:
+            return x, residual
+
+
 class RMSNorm(nn.Module):
 class RMSNorm(nn.Module):
     """Root mean square normalization.
     """Root mean square normalization.
 
 

+ 36 - 10
aphrodite/modeling/models/__init__.py

@@ -8,30 +8,53 @@ from aphrodite.common.utils import is_hip
 
 
 logger = init_logger(__name__)
 logger = init_logger(__name__)
 
 
-# Architecture -> (module, class)
+# Architecture -> (module, class).
 _MODELS = {
 _MODELS = {
+    "AquilaModel": ("llama", "LlamaForCausalLM"),
+    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
+    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),  # baichuan-7b
+    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),  # baichuan-13b
+    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
+    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
+    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
     "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
     "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
     "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
     "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
+    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
+    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
+    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
+    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
     "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
     "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
     "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
     "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
+    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
+    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
     "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
     "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
+    # For decapoda-research/llama-*
     "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
     "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
-    "MistralForCausalLM": ("mistral", "MistralForCausalLM"),
+    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
     "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
     "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
     "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
     "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
-    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
-    "YiForCausalLM": ("yi", "YiForCausalLM"),
+    # transformers's mpt class has lower case
+    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
+    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
+    "OLMoForCausalLM": ("olmo", "OLMoForCausalLM"),
     "OPTForCausalLM": ("opt", "OPTForCausalLM"),
     "OPTForCausalLM": ("opt", "OPTForCausalLM"),
+    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
+    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
+    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
+    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
+    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
 }
 }
 
 
-# Models not supported by ROCm
+# Models not supported by ROCm.
 _ROCM_UNSUPPORTED_MODELS = []
 _ROCM_UNSUPPORTED_MODELS = []
 
 
 # Models partially supported by ROCm.
 # Models partially supported by ROCm.
-# Architecture -> Reason
+# Architecture -> Reason.
 _ROCM_PARTIALLY_SUPPORTED_MODELS = {
 _ROCM_PARTIALLY_SUPPORTED_MODELS = {
+    "Qwen2ForCausalLM":
+    "Sliding window attention is not yet supported in ROCm's flash attention",
     "MistralForCausalLM":
     "MistralForCausalLM":
-    "Sliding window attention is not yet supported in ROCM's flash attention.",
+    "Sliding window attention is not yet supported in ROCm's flash attention",
     "MixtralForCausalLM":
     "MixtralForCausalLM":
     "Sliding window attention is not yet supported in ROCm's flash attention",
     "Sliding window attention is not yet supported in ROCm's flash attention",
 }
 }
@@ -45,8 +68,9 @@ class ModelRegistry:
             return None
             return None
         if is_hip():
         if is_hip():
             if model_arch in _ROCM_UNSUPPORTED_MODELS:
             if model_arch in _ROCM_UNSUPPORTED_MODELS:
-                raise ValueError(f"Model architecture {model_arch} is not "
-                                 "supported in ROCm for now.")
+                raise ValueError(
+                    f"Model architecture {model_arch} is not supported by "
+                    "ROCm for now.")
             if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
             if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
                 logger.warning(
                 logger.warning(
                     f"Model architecture {model_arch} is partially supported "
                     f"Model architecture {model_arch} is partially supported "
@@ -62,4 +86,6 @@ class ModelRegistry:
         return list(_MODELS.keys())
         return list(_MODELS.keys())
 
 
 
 
-__all__ = ["ModelRegistry"]
+__all__ = [
+    "ModelRegistry",
+]

+ 412 - 0
aphrodite/modeling/models/baichuan.py

@@ -0,0 +1,412 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only BaiChuan model compatible with HuggingFace weights."""
+import math
+from typing import List, Optional, Tuple
+
+import torch
+from torch import nn
+
+from aphrodite.modeling.metadata import InputMetadata
+from aphrodite.modeling.layers.activation import SiluAndMul
+from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.layernorm import RMSNorm
+from aphrodite.modeling.layers.linear import (LinearMethodBase,
+                                              MergedColumnParallelLinear,
+                                              QKVParallelLinear,
+                                              RowParallelLinear,
+                                              ColumnParallelLinear)
+from aphrodite.modeling.layers.rotary_embedding import get_rope
+from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    VocabParallelEmbedding, ParallelLMHead)
+from aphrodite.modeling.megatron.parallel_state import (
+    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.modeling.hf_downloader import (default_weight_loader,
+                                              hf_model_weights_iterator)
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.transformers_utils.configs.baichuan import BaiChuanConfig
+
+KVCache = Tuple[torch.Tensor, torch.Tensor]
+
+
+def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
+    closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
+    base = torch.tensor(
+        2**(-(2**-(math.log2(closest_power_of_2) - 3))),
+        dtype=torch.float32,
+    )
+    powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
+    slopes = torch.pow(base, powers)
+
+    if closest_power_of_2 != total_num_heads:
+        extra_base = torch.tensor(
+            2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
+            dtype=torch.float32,
+        )
+        num_remaining_heads = min(closest_power_of_2,
+                                  total_num_heads - closest_power_of_2)
+        extra_powers = torch.arange(start=1,
+                                    end=1 + 2 * num_remaining_heads,
+                                    step=2,
+                                    dtype=torch.int32)
+        slopes = torch.cat(
+            [slopes, torch.pow(extra_base, extra_powers)], dim=0)
+    return slopes
+
+
+class BaiChuanMLP(nn.Module):
+
+    def __init__(
+        self,
+        hidden_size: int,
+        intermediate_size: int,
+        hidden_act: str,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        if linear_method is not None and not linear_method.quant_config.merge_weight(
+        ):
+            self.merge_weight = False
+            self.gate_proj = ColumnParallelLinear(hidden_size,
+                                                  intermediate_size,
+                                                  bias=False,
+                                                  linear_method=linear_method)
+            self.up_proj = ColumnParallelLinear(hidden_size,
+                                                intermediate_size,
+                                                bias=False,
+                                                linear_method=linear_method)
+        else:
+            self.merge_weight = True
+            self.gate_up_proj = MergedColumnParallelLinear(
+                hidden_size, [intermediate_size] * 2,
+                bias=False,
+                linear_method=linear_method)
+        self.down_proj = RowParallelLinear(intermediate_size,
+                                           hidden_size,
+                                           bias=False,
+                                           linear_method=linear_method)
+        if hidden_act != "silu":
+            raise ValueError(f"Unsupported activation: {hidden_act}. "
+                             "Only silu is supported for now.")
+        self.act_fn = SiluAndMul()
+
+    def forward(self, x):
+        if self.merge_weight:
+            gate_up, _ = self.gate_up_proj(x)
+        else:
+            up, _ = self.up_proj(x)
+            gate, _ = self.gate_proj(x)
+            gate_up = torch.cat([gate, up], dim=-1)
+        x = self.act_fn(gate_up)
+        x, _ = self.down_proj(x)
+        return x
+
+
+class BaiChuanAttention(nn.Module):
+    """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+    def __init__(
+        self,
+        hidden_size: int,
+        num_heads: int,
+        position_embedding: str,
+        rope_theta: float = 10000,
+        max_position_embeddings: int = 8192,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.hidden_size = hidden_size
+        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
+        )
+        self.total_num_heads = num_heads
+        assert self.total_num_heads % tensor_model_parallel_world_size == 0
+        self.num_heads = (self.total_num_heads //
+                          tensor_model_parallel_world_size)
+        self.head_dim = hidden_size // self.total_num_heads
+        self.postion_embedding = position_embedding
+        self.rope_theta = rope_theta
+        self.max_position_embeddings = max_position_embeddings
+
+        # pylint: disable=invalid-name
+        self.W_pack = QKVParallelLinear(
+            hidden_size,
+            self.head_dim,
+            self.total_num_heads,
+            self.total_num_heads,
+            bias=False,
+            linear_method=linear_method,
+        )
+        self.o_proj = RowParallelLinear(
+            self.total_num_heads * self.head_dim,
+            hidden_size,
+            bias=False,
+            linear_method=linear_method,
+        )
+        # Create the alibi slopes and slice them.
+        if self.postion_embedding == "ALIBI":
+            tp_rank = get_tensor_model_parallel_rank()
+            head_start = tp_rank * self.num_heads
+            head_end = (tp_rank + 1) * self.num_heads
+            alibi_slopes = _get_alibi_slopes(self.total_num_heads)
+            alibi_slopes = alibi_slopes[head_start:head_end].tolist()
+
+            scaling = self.head_dim**-0.5
+            self.attn = PagedAttention(self.num_heads,
+                                       self.head_dim,
+                                       scaling,
+                                       alibi_slopes=alibi_slopes)
+        else:
+            is_neox_style = True if linear_method is None or linear_method.quant_config.rope_style(
+            ) is None else linear_method.quant_config.rope_style()
+            self.rotary_emb = get_rope(
+                self.head_dim,
+                rotary_dim=self.head_dim,
+                max_position=self.max_position_embeddings,
+                base=self.rope_theta,
+                is_neox_style=is_neox_style,
+            )
+            self.scaling = self.head_dim**-0.5
+            self.attn = PagedAttention(self.num_heads, self.head_dim,
+                                       self.scaling)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        qkv, _ = self.W_pack(hidden_states)
+        q, k, v = qkv.chunk(chunks=3, dim=-1)
+        if self.postion_embedding != "ALIBI":
+            q, k = self.rotary_emb(positions, q, k)
+        k_cache, v_cache = kv_cache
+        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
+        output, _ = self.o_proj(attn_output)
+        return output
+
+
+class BaiChuanDecoderLayer(nn.Module):
+
+    def __init__(self,
+                 config: BaiChuanConfig,
+                 position_embedding: str,
+                 linear_method: Optional[LinearMethodBase] = None):
+        super().__init__()
+        self.hidden_size = config.hidden_size
+        rope_theta = getattr(config, "rope_theta", 10000)
+        max_position_embeddings = getattr(config, "max_position_embeddings",
+                                          8192)
+        self.self_attn = BaiChuanAttention(
+            hidden_size=self.hidden_size,
+            num_heads=config.num_attention_heads,
+            position_embedding=position_embedding,
+            rope_theta=rope_theta,
+            max_position_embeddings=max_position_embeddings,
+            linear_method=linear_method,
+        )
+        self.mlp = BaiChuanMLP(
+            hidden_size=self.hidden_size,
+            intermediate_size=config.intermediate_size,
+            hidden_act=config.hidden_act,
+            linear_method=linear_method,
+        )
+        self.input_layernorm = RMSNorm(config.hidden_size,
+                                       eps=config.rms_norm_eps)
+        self.post_attention_layernorm = RMSNorm(config.hidden_size,
+                                                eps=config.rms_norm_eps)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+        residual: Optional[torch.Tensor],
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        # Self Attention
+        if residual is None:
+            residual = hidden_states
+            hidden_states = self.input_layernorm(hidden_states)
+        else:
+            hidden_states, residual = self.input_layernorm(
+                hidden_states, residual)
+        hidden_states = self.self_attn(
+            positions=positions,
+            hidden_states=hidden_states,
+            kv_cache=kv_cache,
+            input_metadata=input_metadata,
+        )
+
+        # Fully Connected
+        hidden_states, residual = self.post_attention_layernorm(
+            hidden_states, residual)
+        hidden_states = self.mlp(hidden_states)
+        return hidden_states, residual
+
+
+class BaiChuanModel(nn.Module):
+
+    def __init__(self,
+                 config: BaiChuanConfig,
+                 position_embedding: str,
+                 linear_method: Optional[LinearMethodBase] = None):
+        super().__init__()
+        self.config = config
+        self.padding_idx = config.pad_token_id
+        self.vocab_size = config.vocab_size
+
+        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
+                                                   config.hidden_size,
+                                                   linear_method=linear_method)
+        self.layers = nn.ModuleList([
+            BaiChuanDecoderLayer(config, position_embedding, linear_method)
+            for _ in range(config.num_hidden_layers)
+        ])
+        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.embed_tokens(input_ids)
+        residual = None
+        for i in range(len(self.layers)):
+            layer = self.layers[i]
+            hidden_states, residual = layer(
+                positions,
+                hidden_states,
+                kv_caches[i],
+                input_metadata,
+                residual,
+            )
+        hidden_states, _ = self.norm(hidden_states, residual)
+        return hidden_states
+
+
+class BaiChuanBaseForCausalLM(nn.Module):
+
+    def __init__(self,
+                 config,
+                 position_embedding: str,
+                 linear_method: Optional[LinearMethodBase] = None):
+        super().__init__()
+        self.config = config
+        self.linear_method = linear_method
+        self.model = BaiChuanModel(config, position_embedding, linear_method)
+        self.lm_head = ParallelLMHead(config.vocab_size,
+                                      config.hidden_size,
+                                      linear_method=linear_method)
+        self.sampler = Sampler(config.vocab_size)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.model(input_ids, positions, kv_caches,
+                                   input_metadata)
+        return hidden_states
+
+    def sample(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(self.lm_head(hidden_states),
+                                   sampling_metadata)
+        return next_tokens
+
+    def load_weights(self,
+                     model_name_or_path: str,
+                     cache_dir: Optional[str] = None,
+                     load_format: str = "auto",
+                     revision: Optional[str] = None):
+        stacked_params_mapping = [
+            # (param_name, shard_name, shard_id)
+            ("gate_up_proj", "gate_proj", 0),
+            ("gate_up_proj", "up_proj", 1),
+        ]
+        if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
+        ):
+            stacked_params_mapping = []
+        params_dict = dict(self.named_parameters())
+        for name, loaded_weight in hf_model_weights_iterator(
+                model_name_or_path, cache_dir, load_format, revision,
+                self.config):
+            if "rotary_emb.inv_freq" in name:
+                continue
+            if name == "lm_head.weight":
+                # Unlike Baichuan, Baichuan2 normalizes the head weights. Refer to:
+                # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
+                # Distinguish between Baichuan and Baichuan2 by checking the
+                # vocab size.
+                is_baichuan2 = self.config.vocab_size == 125696
+                if is_baichuan2:
+                    loaded_weight = torch.nn.functional.normalize(
+                        loaded_weight)
+
+            for (param_name, weight_name, shard_id) in stacked_params_mapping:
+                if weight_name not in name:
+                    continue
+                name = name.replace(weight_name, param_name)
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+                param = params_dict[name]
+                weight_loader = param.weight_loader
+                weight_loader(param, loaded_weight, shard_id)
+                break
+            else:
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+                param = params_dict[name]
+                weight_loader = getattr(param, "weight_loader",
+                                        default_weight_loader)
+                weight_loader(param, loaded_weight)
+
+
+class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
+    """Baichuan 13B and Baichuan2 7B/13B."""
+
+    def __init__(self,
+                 config,
+                 linear_method: Optional[LinearMethodBase] = None):
+        if config.hidden_size == 4096:  # baichuan2 7b
+            super().__init__(config, "ROPE", linear_method)
+        else:  # baichuan 13b, baichuan2 13b
+            super().__init__(config, "ALIBI", linear_method)
+
+
+class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
+    """Baichuan 7B."""
+
+    def __init__(self,
+                 config,
+                 linear_method: Optional[LinearMethodBase] = None):
+        super().__init__(config, "ROPE", linear_method)

+ 340 - 0
aphrodite/modeling/models/bloom.py

@@ -0,0 +1,340 @@
+# coding=utf-8
+# Adapted from
+# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py
+# Copyright 2023 The PygmalionAI team.
+# Copyright 2023 The CacheFlow team.
+# Copyright 2022 HuggingFace Inc. team and BigScience workshop.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only BLOOM model compatible with HuggingFace weights."""
+import math
+from typing import List, Optional, Tuple
+
+import torch
+from torch import nn
+from transformers import BloomConfig
+
+from aphrodite.modeling.metadata import InputMetadata
+from aphrodite.modeling.layers.activation import get_act_fn
+from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
+                                              LinearMethodBase,
+                                              QKVParallelLinear,
+                                              RowParallelLinear)
+from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    VocabParallelEmbedding, ParallelLMHead)
+from aphrodite.modeling.megatron.parallel_state import (
+    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.modeling.hf_downloader import (default_weight_loader,
+                                              hf_model_weights_iterator)
+from aphrodite.common.sequence import SamplerOutput
+
+KVCache = Tuple[torch.Tensor, torch.Tensor]
+
+
+def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
+    closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
+    base = torch.tensor(
+        2**(-(2**-(math.log2(closest_power_of_2) - 3))),
+        dtype=torch.float32,
+    )
+    powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
+    slopes = torch.pow(base, powers)
+
+    if closest_power_of_2 != total_num_heads:
+        extra_base = torch.tensor(
+            2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
+            dtype=torch.float32,
+        )
+        num_remaining_heads = min(closest_power_of_2,
+                                  total_num_heads - closest_power_of_2)
+        extra_powers = torch.arange(start=1,
+                                    end=1 + 2 * num_remaining_heads,
+                                    step=2,
+                                    dtype=torch.int32)
+        slopes = torch.cat(
+            [slopes, torch.pow(extra_base, extra_powers)], dim=0)
+    return slopes
+
+
+class BloomAttention(nn.Module):
+
+    def __init__(
+        self,
+        config: BloomConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.hidden_size = config.hidden_size
+        self.total_num_heads = config.n_head
+        self.head_dim = self.hidden_size // self.total_num_heads
+        assert self.head_dim * self.total_num_heads == self.hidden_size
+
+        tp_world_size = get_tensor_model_parallel_world_size()
+        assert self.total_num_heads % tp_world_size == 0
+        self.num_heads = self.total_num_heads // tp_world_size
+
+        self.query_key_value = QKVParallelLinear(
+            self.hidden_size,
+            self.head_dim,
+            self.total_num_heads,
+            bias=True,
+            linear_method=linear_method,
+        )
+        self.dense = RowParallelLinear(
+            self.hidden_size,
+            self.hidden_size,
+            bias=True,
+            linear_method=linear_method,
+        )
+
+        # Create the alibi slopes and slice them.
+        tp_rank = get_tensor_model_parallel_rank()
+        head_start = tp_rank * self.num_heads
+        head_end = (tp_rank + 1) * self.num_heads
+        alibi_slopes = _get_alibi_slopes(self.total_num_heads)
+        alibi_slopes = alibi_slopes[head_start:head_end].tolist()
+
+        scaling = self.head_dim**-0.5
+        self.attn = PagedAttention(self.num_heads,
+                                   self.head_dim,
+                                   scaling,
+                                   alibi_slopes=alibi_slopes)
+
+    def forward(
+        self,
+        position_ids: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        del position_ids  # Unused.
+        qkv, _ = self.query_key_value(hidden_states)
+        q, k, v = qkv.chunk(chunks=3, dim=-1)
+        k_cache, v_cache = kv_cache
+        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
+        output, _ = self.dense(attn_output)
+        return output
+
+
+class BloomMLP(nn.Module):
+
+    def __init__(
+        self,
+        config: BloomConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        hidden_size = config.hidden_size
+        self.dense_h_to_4h = ColumnParallelLinear(
+            hidden_size,
+            4 * hidden_size,
+            linear_method=linear_method,
+        )
+        quant_config = getattr(linear_method, "quant_config", None)
+        self.gelu_impl = get_act_fn("gelu", quant_config, 4 * hidden_size)
+        self.dense_4h_to_h = RowParallelLinear(
+            4 * hidden_size,
+            hidden_size,
+            linear_method=linear_method,
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x, _ = self.dense_h_to_4h(x)
+        x = self.gelu_impl(x)
+        x, _ = self.dense_4h_to_h(x)
+        return x
+
+
+class BloomBlock(nn.Module):
+
+    def __init__(
+        self,
+        config: BloomConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        hidden_size = config.hidden_size
+
+        self.input_layernorm = nn.LayerNorm(hidden_size,
+                                            eps=config.layer_norm_epsilon)
+        self.self_attention = BloomAttention(config, linear_method)
+        self.post_attention_layernorm = nn.LayerNorm(
+            hidden_size, eps=config.layer_norm_epsilon)
+        self.mlp = BloomMLP(config, linear_method)
+        self.apply_residual_connection_post_layernorm = (
+            config.apply_residual_connection_post_layernorm)
+
+    def forward(
+        self,
+        position_ids: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        # Layer norm at the beginning of the transformer layer.
+        layernorm_output = self.input_layernorm(hidden_states)
+
+        # Layer norm post the self attention.
+        if self.apply_residual_connection_post_layernorm:
+            residual = layernorm_output
+        else:
+            residual = hidden_states
+
+        # Self attention.
+        attention_output = self.self_attention(
+            position_ids=position_ids,
+            hidden_states=layernorm_output,
+            kv_cache=kv_cache,
+            input_metadata=input_metadata,
+        )
+        attention_output = attention_output + residual
+        layernorm_output = self.post_attention_layernorm(attention_output)
+
+        # Get residual
+        if self.apply_residual_connection_post_layernorm:
+            residual = layernorm_output
+        else:
+            residual = attention_output
+
+        # MLP.
+        output = self.mlp(layernorm_output) + residual
+        return output
+
+
+class BloomModel(nn.Module):
+
+    def __init__(
+        self,
+        config: BloomConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.embed_dim = config.hidden_size
+
+        # Embedding + LN Embedding
+        self.word_embeddings = VocabParallelEmbedding(
+            config.vocab_size, self.embed_dim, linear_method=linear_method)
+        self.word_embeddings_layernorm = nn.LayerNorm(
+            self.embed_dim, eps=config.layer_norm_epsilon)
+
+        # Transformer blocks
+        self.h = nn.ModuleList([
+            BloomBlock(config, linear_method)
+            for _ in range(config.num_hidden_layers)
+        ])
+
+        # Final Layer Norm
+        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        position_ids: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.word_embeddings(input_ids)
+        hidden_states = self.word_embeddings_layernorm(hidden_states)
+        for i in range(len(self.h)):
+            layer = self.h[i]
+            hidden_states = layer(
+                position_ids,
+                hidden_states,
+                kv_caches[i],
+                input_metadata,
+            )
+        hidden_states = self.ln_f(hidden_states)
+        return hidden_states
+
+
+class BloomForCausalLM(nn.Module):
+
+    def __init__(
+        self,
+        config: BloomConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.config = config
+        self.linear_method = linear_method
+        self.transformer = BloomModel(config, linear_method)
+        # self.lm_head_weight = self.transformer.word_embeddings.weight
+        self.lm_head = ParallelLMHead(config.vocab_size,
+                                      config.hidden_size,
+                                      linear_method=linear_method)
+        self.sampler = Sampler(config.vocab_size)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.transformer(input_ids, positions, kv_caches,
+                                         input_metadata)
+        return hidden_states
+
+    def sample(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(self.lm_head(hidden_states),
+                                   sampling_metadata)
+        return next_tokens
+
+    def load_weights(self,
+                     model_name_or_path: str,
+                     cache_dir: Optional[str] = None,
+                     load_format: str = "auto",
+                     revision: Optional[str] = None):
+        params_dict = dict(self.named_parameters(remove_duplicate=False))
+        for name, loaded_weight in hf_model_weights_iterator(
+                model_name_or_path, cache_dir, load_format, revision,
+                self.config):
+            if name == "lm_head.weight":
+                continue
+            if not name.startswith("transformer."):
+                name = "transformer." + name
+            param = params_dict[name]
+
+            if "word_embeddings.weight" in name:
+                # Copy word embedding to lm_head
+                lm_head_param = params_dict["lm_head.weight"]
+                weight_loader = getattr(lm_head_param, "weight_loader",
+                                        default_weight_loader)
+                weight_loader(lm_head_param, loaded_weight)
+
+            if "query_key_value" in name:
+                # NOTE: BLOOM's fused QKV's output_dim has the shape of
+                # (num_heads * 3 * head_size), while the
+                # required shape is (3 * num_heads * head_size).
+                # Thus, we need weight conversion.
+                output_dim = getattr(param, "output_dim", None)
+                num_heads = self.config.num_attention_heads
+                if output_dim is not None:
+                    loaded_weight_shape = loaded_weight.shape
+                    loaded_weight = loaded_weight.view(
+                        loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
+                        loaded_weight_shape[output_dim + 1:])
+                    loaded_weight = loaded_weight.transpose(
+                        output_dim, output_dim + 1)
+                    loaded_weight = loaded_weight.reshape(loaded_weight_shape)
+
+            weight_loader = getattr(param, "weight_loader",
+                                    default_weight_loader)
+            weight_loader(param, loaded_weight)

+ 378 - 0
aphrodite/modeling/models/chatglm.py

@@ -0,0 +1,378 @@
+# coding=utf-8
+# Adapted from
+# https://github.com/THUDM/ChatGLM2-6B
+"""Inference-only ChatGLM model compatible with THUDM weights."""
+from typing import List, Optional, Tuple
+
+import torch
+from torch import nn
+from torch.nn import LayerNorm
+
+from aphrodite.modeling.metadata import InputMetadata
+from aphrodite.modeling.layers.activation import SiluAndMul
+from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.layernorm import RMSNorm
+from aphrodite.modeling.layers.linear import (LinearMethodBase,
+                                              MergedColumnParallelLinear,
+                                              QKVParallelLinear,
+                                              RowParallelLinear)
+from aphrodite.modeling.layers.rotary_embedding import get_rope
+from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    VocabParallelEmbedding, ParallelLMHead)
+from aphrodite.modeling.megatron.parallel_state import (
+    get_tensor_model_parallel_world_size)
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.modeling.hf_downloader import (default_weight_loader,
+                                              hf_model_weights_iterator)
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.transformers_utils.configs import ChatGLMConfig
+
+KVCache = Tuple[torch.Tensor, torch.Tensor]
+
+
+class GLMAttention(nn.Module):
+
+    def __init__(
+        self,
+        config,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.hidden_size = config.hidden_size
+        tp_size = get_tensor_model_parallel_world_size()
+        self.total_num_heads = config.num_attention_heads
+        assert self.total_num_heads % tp_size == 0
+        self.num_heads = self.total_num_heads // tp_size
+        self.multi_query_attention = config.multi_query_attention
+        self.total_num_kv_heads = (config.multi_query_group_num
+                                   if config.multi_query_attention else
+                                   config.num_attention_heads)
+        if self.total_num_kv_heads >= tp_size:
+            # Number of KV heads is greater than TP size, so we partition
+            # the KV heads across multiple tensor parallel GPUs.
+            assert self.total_num_kv_heads % tp_size == 0
+        else:
+            # Number of KV heads is less than TP size, so we replicate
+            # the KV heads across multiple tensor parallel GPUs.
+            assert tp_size % self.total_num_kv_heads == 0
+        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
+        self.head_dim = config.hidden_size // self.total_num_heads
+        self.q_size = self.num_heads * self.head_dim
+        self.kv_size = self.num_kv_heads * self.head_dim
+        self.scaling = self.head_dim**-0.5
+
+        self.query_key_value = QKVParallelLinear(
+            self.hidden_size,
+            self.head_dim,
+            self.total_num_heads,
+            self.total_num_kv_heads,
+            bias=config.add_bias_linear or config.add_qkv_bias,
+            linear_method=linear_method,
+        )
+        self.dense = RowParallelLinear(
+            self.total_num_heads * self.head_dim,
+            config.hidden_size,
+            bias=config.add_bias_linear,
+            linear_method=linear_method,
+        )
+
+        # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
+        rope_ratio = getattr(config, "rope_ratio", 1.0)
+        max_positions = getattr(config, "seq_length", 8192)
+        self.rotary_emb = get_rope(
+            self.head_dim,
+            rotary_dim=self.head_dim // 2,
+            max_position=max_positions,
+            base=10000 * rope_ratio,
+            is_neox_style=False,
+        )
+        self.attn = PagedAttention(
+            self.num_heads,
+            self.head_dim,
+            self.scaling,
+            num_kv_heads=self.num_kv_heads,
+        )
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        position_ids: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        qkv, _ = self.query_key_value(hidden_states)
+        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+        q, k = self.rotary_emb(position_ids, q, k)
+        key_cache, value_cache = kv_cache
+        context_layer = self.attn(
+            q,
+            k,
+            v,
+            key_cache,
+            value_cache,
+            input_metadata,
+        )
+        attn_output, _ = self.dense(context_layer)
+        return attn_output
+
+
+class GLMMLP(nn.Module):
+    """MLP.
+
+    MLP will take the input with h hidden state, project it to 4*h
+    hidden dimension, perform nonlinear transformation, and project the
+    state back into h hidden dimension.
+    """
+
+    def __init__(
+        self,
+        config,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+
+        self.add_bias = config.add_bias_linear
+
+        # Project to 4h.
+        self.dense_h_to_4h = MergedColumnParallelLinear(
+            config.hidden_size,
+            [config.ffn_hidden_size] * 2,
+            bias=config.add_bias_linear,
+            linear_method=linear_method,
+        )
+
+        self.activation_func = SiluAndMul()
+
+        # Project back to h.
+        self.dense_4h_to_h = RowParallelLinear(
+            config.ffn_hidden_size,
+            config.hidden_size,
+            bias=config.add_bias_linear,
+            linear_method=linear_method,
+        )
+
+    def forward(self, hidden_states):
+        # [s, b, 4hp]
+        intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
+        intermediate_parallel = self.activation_func(intermediate_parallel)
+        # [s, b, h]
+        output, _ = self.dense_4h_to_h(intermediate_parallel)
+        return output
+
+
+class GLMBlock(nn.Module):
+    """A single transformer layer.
+
+    Transformer layer takes input with size [s, b, h] and returns an
+    output of the same size.
+    """
+
+    def __init__(
+        self,
+        config,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.apply_residual_connection_post_layernorm = (
+            config.apply_residual_connection_post_layernorm)
+
+        self.fp32_residual_connection = config.fp32_residual_connection
+
+        layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
+        # Layernorm on the input data.
+        self.input_layernorm = layer_norm_func(config.hidden_size,
+                                               eps=config.layernorm_epsilon)
+
+        # Self attention.
+        self.self_attention = GLMAttention(config, linear_method)
+        self.hidden_dropout = config.hidden_dropout
+
+        # Layernorm on the attention output
+        self.post_attention_layernorm = layer_norm_func(
+            config.hidden_size, eps=config.layernorm_epsilon)
+
+        # MLP
+        self.mlp = GLMMLP(config, linear_method)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        position_ids: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        # hidden_states: [num_tokens, h]
+        # Layer norm at the beginning of the transformer layer.
+        layernorm_output = self.input_layernorm(hidden_states)
+        # Self attention.
+        attention_output = self.self_attention(
+            hidden_states=layernorm_output,
+            position_ids=position_ids,
+            kv_cache=kv_cache,
+            input_metadata=input_metadata,
+        )
+
+        # Residual connection.
+        if self.apply_residual_connection_post_layernorm:
+            residual = layernorm_output
+        else:
+            residual = hidden_states
+
+        layernorm_input = residual + attention_output
+
+        # Layer norm post the self attention.
+        layernorm_output = self.post_attention_layernorm(layernorm_input)
+
+        # Second residual connection.
+        if self.apply_residual_connection_post_layernorm:
+            residual = layernorm_output
+        else:
+            residual = layernorm_input
+
+        output = self.mlp(layernorm_output) + residual
+
+        return output
+
+
+class GLMTransformer(nn.Module):
+    """Transformer class."""
+
+    def __init__(
+        self,
+        config,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.post_layer_norm = config.post_layer_norm
+
+        # Number of layers.
+        self.num_layers = config.num_layers
+
+        # Transformer layers.
+        self.layers = nn.ModuleList(
+            [GLMBlock(config, linear_method) for i in range(self.num_layers)])
+
+        if self.post_layer_norm:
+            layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
+            # Final layer norm before output.
+            self.final_layernorm = layer_norm_func(
+                config.hidden_size, eps=config.layernorm_epsilon)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        position_ids: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        for i in range(self.num_layers):
+            layer = self.layers[i]
+            hidden_states = layer(
+                hidden_states=hidden_states,
+                position_ids=position_ids,
+                kv_cache=kv_caches[i],
+                input_metadata=input_metadata,
+            )
+        # Final layer norm.
+        if self.post_layer_norm:
+            hidden_states = self.final_layernorm(hidden_states)
+
+        return hidden_states
+
+
+class ChatGLMModel(nn.Module):
+
+    def __init__(
+        self,
+        config,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+
+        self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
+                                                config.hidden_size,
+                                                linear_method=linear_method)
+
+        self.num_layers = config.num_layers
+        self.multi_query_group_num = config.multi_query_group_num
+        self.kv_channels = config.kv_channels
+        self.encoder = GLMTransformer(config, linear_method)
+
+        self.output_layer = ParallelLMHead(config.padded_vocab_size,
+                                           config.hidden_size,
+                                           linear_method=linear_method)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        position_ids: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        inputs_embeds = self.embedding(input_ids)
+
+        # Run encoder.
+        hidden_states = self.encoder(
+            hidden_states=inputs_embeds,
+            position_ids=position_ids,
+            kv_caches=kv_caches,
+            input_metadata=input_metadata,
+        )
+        return hidden_states
+
+
+class ChatGLMForCausalLM(nn.Module):
+
+    def __init__(
+        self,
+        config: ChatGLMConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.config: ChatGLMConfig = config
+        self.linear_method = linear_method
+        self.transformer = ChatGLMModel(config, linear_method)
+        # self.lm_head_weight = self.transformer.output_layer.weight
+        self.sampler = Sampler(config.padded_vocab_size)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.transformer(input_ids, positions, kv_caches,
+                                         input_metadata)
+        return hidden_states
+
+    def sample(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(
+            self.transformer.output_layer(hidden_states), sampling_metadata)
+        return next_tokens
+
+    def load_weights(self,
+                     model_name_or_path: str,
+                     cache_dir: Optional[str] = None,
+                     load_format: str = "auto",
+                     revision: Optional[str] = None):
+        params_dict = dict(self.named_parameters(remove_duplicate=False))
+        for name, loaded_weight in hf_model_weights_iterator(
+                model_name_or_path, cache_dir, load_format, revision,
+                self.config):
+            if "rotary_pos_emb.inv_freq" in name:
+                continue
+            if "word_embeddings" in name:
+                name = name.replace(".word_embeddings", "")
+            # Skip loading extra bias for GPTQ models.
+            if name.endswith(".bias") and name not in params_dict:
+                continue
+            param = params_dict[name]
+            weight_loader = getattr(param, "weight_loader",
+                                    default_weight_loader)
+            weight_loader(param, loaded_weight)

+ 6 - 2
aphrodite/modeling/models/decilm.py

@@ -29,6 +29,7 @@ from typing import Optional
 import torch
 import torch
 from transformers import PretrainedConfig
 from transformers import PretrainedConfig
 
 
+from aphrodite.common.config import LoRAConfig
 from aphrodite.modeling.layers.linear import LinearMethodBase
 from aphrodite.modeling.layers.linear import LinearMethodBase
 from aphrodite.modeling.models.llama import LlamaForCausalLM
 from aphrodite.modeling.models.llama import LlamaForCausalLM
 from aphrodite.modeling.hf_downloader import (default_weight_loader,
 from aphrodite.modeling.hf_downloader import (default_weight_loader,
@@ -41,7 +42,7 @@ class DeciLMForCausalLM(LlamaForCausalLM):
     Based on the llama executor.
     Based on the llama executor.
 
 
     The main difference is that DeciLM uses Variable Grouped Query Attention.
     The main difference is that DeciLM uses Variable Grouped Query Attention.
-    The constant number of GQA heads in the decoder is overriden with a value
+    The constant number of GQA heads in the decoder is overridden with a value
     per layer.
     per layer.
 
 
     Usually, in the HuggingFace implementation, instead of
     Usually, in the HuggingFace implementation, instead of
@@ -57,10 +58,13 @@ class DeciLMForCausalLM(LlamaForCausalLM):
         self,
         self,
         config: Optional[PretrainedConfig] = None,
         config: Optional[PretrainedConfig] = None,
         linear_method: Optional[LinearMethodBase] = None,
         linear_method: Optional[LinearMethodBase] = None,
+        lora_config: Optional[LoRAConfig] = None,
     ) -> None:
     ) -> None:
         config.num_key_value_heads = max(config.num_key_value_heads_per_layer)
         config.num_key_value_heads = max(config.num_key_value_heads_per_layer)
         delattr(config, "num_key_value_heads_per_layer")
         delattr(config, "num_key_value_heads_per_layer")
-        super().__init__(config=config, linear_method=linear_method)
+        super().__init__(config=config,
+                         linear_method=linear_method,
+                         lora_config=lora_config)
 
 
     def load_weights(self,
     def load_weights(self,
                      model_name_or_path: str,
                      model_name_or_path: str,

+ 23 - 72
aphrodite/modeling/models/deepseek.py

@@ -28,14 +28,16 @@ import torch
 from torch import nn
 from torch import nn
 from transformers import PretrainedConfig
 from transformers import PretrainedConfig
 
 
-from aphrodite.modeling.metadata import InputMetadata
+from aphrodite.modeling.megatron import InputMetadata
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.attention import PagedAttention
 from aphrodite.modeling.layers.attention import PagedAttention
 from aphrodite.modeling.layers.triton_kernel.fused_moe import fused_moe
 from aphrodite.modeling.layers.triton_kernel.fused_moe import fused_moe
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.layernorm import RMSNorm
-from aphrodite.modeling.layers.linear import (
-    LinearMethodBase, MergedColumnParallelLinear, ReplicatedLinear,
-    QKVParallelLinear, RowParallelLinear, ColumnParallelLinear)
+from aphrodite.modeling.layers.linear import (LinearMethodBase,
+                                              MergedColumnParallelLinear,
+                                              ReplicatedLinear,
+                                              QKVParallelLinear,
+                                              RowParallelLinear)
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
@@ -63,23 +65,10 @@ class DeepseekMLP(nn.Module):
         reduce_results: bool = True,
         reduce_results: bool = True,
     ) -> None:
     ) -> None:
         super().__init__()
         super().__init__()
-        if linear_method is not None and not linear_method.quant_config.merge_weight(
-        ):
-            self.merge_weight = False
-            self.gate_proj = ColumnParallelLinear(hidden_size,
-                                                  intermediate_size,
-                                                  bias=False,
-                                                  linear_method=linear_method)
-            self.up_proj = ColumnParallelLinear(hidden_size,
-                                                intermediate_size,
-                                                bias=False,
-                                                linear_method=linear_method)
-        else:
-            self.merge_weight = True
-            self.gate_up_proj = MergedColumnParallelLinear(
-                hidden_size, [intermediate_size] * 2,
-                bias=False,
-                linear_method=linear_method)
+        self.gate_up_proj = MergedColumnParallelLinear(
+            hidden_size, [intermediate_size] * 2,
+            bias=False,
+            linear_method=linear_method)
         self.down_proj = RowParallelLinear(intermediate_size,
         self.down_proj = RowParallelLinear(intermediate_size,
                                            hidden_size,
                                            hidden_size,
                                            bias=False,
                                            bias=False,
@@ -91,12 +80,7 @@ class DeepseekMLP(nn.Module):
         self.act_fn = SiluAndMul()
         self.act_fn = SiluAndMul()
 
 
     def forward(self, x):
     def forward(self, x):
-        if self.merge_weight:
-            gate_up, _ = self.gate_up_proj(x)
-        else:
-            up, _ = self.up_proj(x)
-            gate, _ = self.gate_proj(x)
-            gate_up = torch.cat([gate, up], dim=-1)
+        gate_up, _ = self.gate_up_proj(x)
         x = self.act_fn(gate_up)
         x = self.act_fn(gate_up)
         x, _ = self.down_proj(x)
         x, _ = self.down_proj(x)
         return x
         return x
@@ -171,7 +155,6 @@ class DeepseekMoE(nn.Module):
             shared_output = self.shared_experts(hidden_states)
             shared_output = self.shared_experts(hidden_states)
         # router_logits: (batch * sequence_length, n_experts)
         # router_logits: (batch * sequence_length, n_experts)
         router_logits, _ = self.gate(hidden_states)
         router_logits, _ = self.gate(hidden_states)
-
         final_hidden_states = fused_moe(hidden_states,
         final_hidden_states = fused_moe(hidden_states,
                                         self.w1,
                                         self.w1,
                                         self.w2,
                                         self.w2,
@@ -224,31 +207,14 @@ class DeepseekAttention(nn.Module):
         self.rope_theta = rope_theta
         self.rope_theta = rope_theta
         self.max_position_embeddings = max_position_embeddings
         self.max_position_embeddings = max_position_embeddings
 
 
-        if linear_method is not None and not linear_method.quant_config.merge_weight(
-        ):
-            self.merge_weight = False
-            self.q_proj = ColumnParallelLinear(hidden_size,
-                                               self.q_size,
-                                               bias=False,
-                                               linear_method=linear_method)
-            self.k_proj = ColumnParallelLinear(hidden_size,
-                                               self.kv_size,
-                                               bias=False,
-                                               linear_method=linear_method)
-            self.v_proj = ColumnParallelLinear(hidden_size,
-                                               self.kv_size,
-                                               bias=False,
-                                               linear_method=linear_method)
-        else:
-            self.merge_weight = True
-            self.qkv_proj = QKVParallelLinear(
-                hidden_size,
-                self.head_dim,
-                self.total_num_heads,
-                self.total_num_kv_heads,
-                bias=False,
-                linear_method=linear_method,
-            )
+        self.qkv_proj = QKVParallelLinear(
+            hidden_size,
+            self.head_dim,
+            self.total_num_heads,
+            self.total_num_kv_heads,
+            bias=False,
+            linear_method=linear_method,
+        )
 
 
         self.o_proj = RowParallelLinear(
         self.o_proj = RowParallelLinear(
             self.total_num_heads * self.head_dim,
             self.total_num_heads * self.head_dim,
@@ -257,15 +223,12 @@ class DeepseekAttention(nn.Module):
             linear_method=linear_method,
             linear_method=linear_method,
         )
         )
 
 
-        is_neox_style = True if linear_method is None or linear_method.quant_config.rope_style(
-        ) is None else linear_method.quant_config.rope_style()
         self.rotary_emb = get_rope(
         self.rotary_emb = get_rope(
             self.head_dim,
             self.head_dim,
             rotary_dim=self.head_dim,
             rotary_dim=self.head_dim,
             max_position=max_position_embeddings,
             max_position=max_position_embeddings,
             base=rope_theta,
             base=rope_theta,
             rope_scaling=rope_scaling,
             rope_scaling=rope_scaling,
-            is_neox_style=is_neox_style,
         )
         )
         self.attn = PagedAttention(self.num_heads,
         self.attn = PagedAttention(self.num_heads,
                                    self.head_dim,
                                    self.head_dim,
@@ -279,14 +242,8 @@ class DeepseekAttention(nn.Module):
         kv_cache: KVCache,
         kv_cache: KVCache,
         input_metadata: InputMetadata,
         input_metadata: InputMetadata,
     ) -> torch.Tensor:
     ) -> torch.Tensor:
-        if self.merge_weight:
-            qkv, _ = self.qkv_proj(hidden_states)
-            q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
-                                dim=-1)
-        else:
-            q, _ = self.q_proj(hidden_states)
-            k, _ = self.k_proj(hidden_states)
-            v, _ = self.v_proj(hidden_states)
+        qkv, _ = self.qkv_proj(hidden_states)
+        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
         q, k = self.rotary_emb(positions, q, k)
         q, k = self.rotary_emb(positions, q, k)
         k_cache, v_cache = kv_cache
         k_cache, v_cache = kv_cache
         attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
         attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
@@ -375,7 +332,6 @@ class DeepseekModel(nn.Module):
         self.embed_tokens = VocabParallelEmbedding(
         self.embed_tokens = VocabParallelEmbedding(
             config.vocab_size,
             config.vocab_size,
             config.hidden_size,
             config.hidden_size,
-            linear_method=linear_method,
         )
         )
         self.layers = nn.ModuleList([
         self.layers = nn.ModuleList([
             DeepseekDecoderLayer(config,
             DeepseekDecoderLayer(config,
@@ -414,9 +370,7 @@ class DeepseekForCausalLM(nn.Module):
         self.config = config
         self.config = config
         self.linear_method = linear_method
         self.linear_method = linear_method
         self.model = DeepseekModel(config, linear_method)
         self.model = DeepseekModel(config, linear_method)
-        self.lm_head = ParallelLMHead(config.vocab_size,
-                                      config.hidden_size,
-                                      linear_method=linear_method)
+        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
         self.sampler = Sampler(config.vocab_size)
         self.sampler = Sampler(config.vocab_size)
 
 
     def forward(
     def forward(
@@ -452,16 +406,13 @@ class DeepseekForCausalLM(nn.Module):
             ("gate_up_proj", "gate_proj", 0),
             ("gate_up_proj", "gate_proj", 0),
             ("gate_up_proj", "up_proj", 1),
             ("gate_up_proj", "up_proj", 1),
         ]
         ]
-        if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
-        ):
-            stacked_params_mapping = []
+
         params_dict = dict(self.named_parameters())
         params_dict = dict(self.named_parameters())
         for name, loaded_weight in hf_model_weights_iterator(
         for name, loaded_weight in hf_model_weights_iterator(
                 model_name_or_path,
                 model_name_or_path,
                 cache_dir,
                 cache_dir,
                 load_format,
                 load_format,
                 revision,
                 revision,
-                self.config,
                 fall_back_to_pt=False):
                 fall_back_to_pt=False):
             if "rotary_emb.inv_freq" in name:
             if "rotary_emb.inv_freq" in name:
                 continue
                 continue

+ 446 - 0
aphrodite/modeling/models/falcon.py

@@ -0,0 +1,446 @@
+# coding=utf-8
+# Adapted from
+# https://github.com/huggingface/transformers/blob/a5cc30d72ae2dc19af534e4b35c986cc28db1275/src/transformers/models/falcon/modeling_falcon.py
+# Copyright 2023 The PygmalionAI team.
+# Copyright 2023 The vLLM team.
+# Copyright 2023 the Falcon authors and HuggingFace Inc. team.  All rights
+# reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Falcon model."""
+
+import math
+from typing import List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+from torch.nn import LayerNorm
+from transformers import FalconConfig as HF_FalconConfig
+
+from aphrodite.modeling.metadata import InputMetadata
+from aphrodite.modeling.layers.activation import get_act_fn
+from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
+                                              LinearMethodBase,
+                                              QKVParallelLinear,
+                                              RowParallelLinear)
+from aphrodite.modeling.layers.rotary_embedding import get_rope
+from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    VocabParallelEmbedding, ParallelLMHead)
+from aphrodite.modeling.megatron.communication_op import (
+    tensor_model_parallel_all_reduce)
+from aphrodite.modeling.megatron.parallel_state import (
+    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.modeling.hf_downloader import (default_weight_loader,
+                                              hf_model_weights_iterator)
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.transformers_utils.configs import RWConfig
+
+KVCache = Tuple[torch.Tensor, torch.Tensor]
+FalconConfig = Union[HF_FalconConfig, RWConfig]
+
+
+def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
+    closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
+    base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))),
+                        dtype=torch.float32)
+    powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
+    slopes = torch.pow(base, powers)
+
+    if closest_power_of_2 != total_num_heads:
+        extra_base = torch.tensor(
+            2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
+            dtype=torch.float32)
+        num_remaining_heads = min(closest_power_of_2,
+                                  total_num_heads - closest_power_of_2)
+        extra_powers = torch.arange(1,
+                                    1 + 2 * num_remaining_heads,
+                                    2,
+                                    dtype=torch.int32)
+        slopes = torch.cat(
+            [slopes, torch.pow(extra_base, extra_powers)], dim=0)
+
+    return slopes
+
+
+class FalconAttention(nn.Module):
+
+    def __init__(
+        self,
+        config: FalconConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+
+        self.hidden_size = config.hidden_size
+        tp_size = get_tensor_model_parallel_world_size()
+
+        self.total_num_heads = config.num_attention_heads
+        assert self.total_num_heads % tp_size == 0
+        self.num_heads = self.total_num_heads // tp_size
+        self.head_dim = self.hidden_size // self.total_num_heads
+        assert self.head_dim * self.total_num_heads == self.hidden_size
+
+        self.new_decoder_architecture = config.new_decoder_architecture
+        self.multi_query = config.multi_query
+
+        if self.new_decoder_architecture:
+            self.total_num_kv_heads = config.num_kv_heads
+        elif self.multi_query:
+            self.total_num_kv_heads = 1
+        else:
+            self.total_num_kv_heads = self.total_num_heads
+        if self.total_num_kv_heads >= tp_size:
+            # Number of KV heads is greater than TP size, so we partition
+            # the KV heads across multiple tensor parallel GPUs.
+            assert self.total_num_kv_heads % tp_size == 0
+        else:
+            # Number of KV heads is less than TP size, so we replicate
+            # the KV heads across multiple tensor parallel GPUs.
+            assert tp_size % self.total_num_kv_heads == 0
+        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
+
+        self.query_key_value = QKVParallelLinear(
+            self.hidden_size,
+            self.head_dim,
+            self.total_num_heads,
+            self.total_num_kv_heads,
+            bias=config.bias,
+            skip_bias_add=True,
+            linear_method=linear_method,
+        )
+        self.q_size = self.num_heads * self.head_dim
+        self.kv_size = self.num_kv_heads * self.head_dim
+
+        # Layer-wise attention scaling
+        self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
+        self.reduce_row_parallel_results = not (config.new_decoder_architecture
+                                                or config.parallel_attn)
+        self.dense = RowParallelLinear(
+            self.hidden_size,
+            self.hidden_size,
+            bias=config.bias,
+            skip_bias_add=True,
+            linear_method=linear_method,
+            reduce_results=self.reduce_row_parallel_results)
+
+        self.use_rotary = config.rotary
+        self.use_alibi = config.alibi
+        assert not (self.use_rotary and self.use_alibi), (
+            "Rotary and alibi are mutually exclusive.")
+
+        if self.use_rotary:
+            rope_theta = getattr(config, "rope_theta", 10000)
+            max_position_embeddings = getattr(config,
+                                              "max_position_embeddings", 8192)
+            self.rotary_emb = get_rope(
+                self.head_dim,
+                rotary_dim=self.head_dim,
+                max_position=max_position_embeddings,
+                base=rope_theta,
+            )
+            self.attn = PagedAttention(self.num_heads,
+                                       self.head_dim,
+                                       self.inv_norm_factor,
+                                       num_kv_heads=self.num_kv_heads)
+        elif self.use_alibi:
+            tp_rank = get_tensor_model_parallel_rank()
+            head_start = tp_rank * self.num_heads
+            head_end = (tp_rank + 1) * self.num_heads
+            alibi_slopes = (_get_alibi_slopes(self.total_num_heads) *
+                            self.inv_norm_factor)
+            alibi_slopes = alibi_slopes[head_start:head_end].tolist()
+            self.attn = PagedAttention(self.num_heads,
+                                       self.head_dim,
+                                       self.inv_norm_factor,
+                                       num_kv_heads=self.num_kv_heads,
+                                       alibi_slopes=alibi_slopes)
+        else:
+            self.attn = PagedAttention(self.num_heads,
+                                       self.head_dim,
+                                       scale=self.inv_norm_factor,
+                                       num_kv_heads=self.num_kv_heads)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        qkv, bias = self.query_key_value(hidden_states)
+        if bias is not None:
+            qkv += bias
+        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+        if self.use_rotary:
+            q, k = self.rotary_emb(positions, q, k)
+        k_cache, v_cache = kv_cache
+        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
+        attn_output, bias = self.dense(attn_output)
+        return attn_output, bias
+
+
+class FalconMLP(nn.Module):
+
+    def __init__(
+        self,
+        config: FalconConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        hidden_size = config.hidden_size
+
+        self.dense_h_to_4h = ColumnParallelLinear(hidden_size,
+                                                  4 * hidden_size,
+                                                  bias=config.bias,
+                                                  skip_bias_add=True,
+                                                  linear_method=linear_method)
+        quant_config = getattr(linear_method, "quant_config", None)
+        self.act = get_act_fn("gelu", quant_config, 4 * hidden_size)
+        self.reduce_row_parallel_results = not (config.new_decoder_architecture
+                                                or config.parallel_attn)
+        self.dense_4h_to_h = RowParallelLinear(
+            4 * hidden_size,
+            hidden_size,
+            bias=config.bias,
+            skip_bias_add=True,
+            reduce_results=self.reduce_row_parallel_results,
+            linear_method=linear_method)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        # NOTE(zhuohan): Following huggingface, we do not fuse bias add here.
+        x, bias = self.dense_h_to_4h(x)
+        if bias is not None:
+            x += bias
+        x = self.act(x)
+        x, bias = self.dense_4h_to_h(x)
+        return x, bias
+
+
+class FalconDecoderLayer(nn.Module):
+
+    def __init__(
+        self,
+        config: FalconConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        hidden_size = config.hidden_size
+        self.num_heads = config.num_attention_heads
+        self.self_attention = FalconAttention(config, linear_method)
+        self.mlp = FalconMLP(config, linear_method)
+        self.config = config
+
+        if config.new_decoder_architecture:
+            # The layer norm before self-attention
+            self.ln_attn = LayerNorm(hidden_size,
+                                     eps=config.layer_norm_epsilon)
+            # The layer norm before the MLP
+            self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+        else:
+            self.input_layernorm = LayerNorm(hidden_size,
+                                             eps=config.layer_norm_epsilon)
+            if not config.parallel_attn:
+                self.post_attention_layernorm = LayerNorm(
+                    hidden_size, eps=config.layer_norm_epsilon)
+
+        self.reduce_row_parallel_results = not (config.new_decoder_architecture
+                                                or config.parallel_attn)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        residual = hidden_states
+
+        if self.config.new_decoder_architecture:
+            attention_layernorm_out = self.ln_attn(hidden_states)
+            mlp_layernorm_out = self.ln_mlp(hidden_states)
+        else:
+            attention_layernorm_out = self.input_layernorm(hidden_states)
+
+        # Self attention.
+        attention_output, attention_bias = self.self_attention(
+            positions=positions,
+            hidden_states=attention_layernorm_out,
+            kv_cache=kv_cache,
+            input_metadata=input_metadata,
+        )
+        if self.reduce_row_parallel_results and attention_bias is not None:
+            attention_output += attention_bias
+
+        if not self.config.new_decoder_architecture:
+            if self.config.parallel_attn:
+                mlp_layernorm_out = attention_layernorm_out
+            else:
+                residual += attention_output
+                mlp_layernorm_out = self.post_attention_layernorm(residual)
+
+        # MLP.
+        mlp_output, mlp_bias = self.mlp(mlp_layernorm_out)
+        if self.reduce_row_parallel_results and mlp_bias is not None:
+            mlp_output += mlp_bias
+
+        if not self.reduce_row_parallel_results:
+            # When MLP and Attention layers are parallel, we can use
+            # only one all-reduce operator to reduce the results from
+            # both MLP and Attention layers.
+            mlp_output += attention_output
+            mlp_output = tensor_model_parallel_all_reduce(mlp_output)
+            if attention_bias is not None:
+                mlp_output += attention_bias
+            if mlp_bias is not None:
+                mlp_output += mlp_bias
+
+        output = mlp_output + residual
+        return output
+
+
+class FalconModel(nn.Module):
+
+    def __init__(
+        self,
+        config: FalconConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.config = config
+        self.embed_dim = config.hidden_size
+        self.num_heads = config.num_attention_heads
+        self.use_alibi = config.alibi
+
+        # Embedding + LN Embedding
+        self.word_embeddings = VocabParallelEmbedding(
+            config.vocab_size, self.embed_dim, linear_method=linear_method)
+
+        # Transformer blocks
+        self.h = nn.ModuleList([
+            FalconDecoderLayer(config, linear_method)
+            for _ in range(config.num_hidden_layers)
+        ])
+
+        # Final Layer Norm
+        self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
+
+    def forward(
+        self,
+        input_ids: torch.LongTensor,
+        positions: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.word_embeddings(input_ids)
+        for i in range(len(self.h)):
+            layer = self.h[i]
+            hidden_states = layer(
+                positions,
+                hidden_states,
+                kv_caches[i],
+                input_metadata,
+            )
+        hidden_states = self.ln_f(hidden_states)
+        return hidden_states
+
+
+class FalconForCausalLM(nn.Module):
+
+    def __init__(
+        self,
+        config: FalconConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.config = config
+        self.linear_method = linear_method
+        self.transformer = FalconModel(config, linear_method)
+        self.lm_head = ParallelLMHead(config.vocab_size,
+                                      config.hidden_size,
+                                      linear_method=linear_method)
+        self.sampler = Sampler(config.vocab_size)
+
+    def forward(
+        self,
+        input_ids: torch.LongTensor,
+        positions: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.transformer(
+            input_ids,
+            positions,
+            kv_caches,
+            input_metadata,
+        )
+        return hidden_states
+
+    def sample(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(self.lm_head(hidden_states),
+                                   sampling_metadata)
+        return next_tokens
+
+    def load_weights(self,
+                     model_name_or_path: str,
+                     cache_dir: Optional[str] = None,
+                     load_format: str = "auto",
+                     revision: Optional[str] = None):
+        total_num_heads = self.config.num_attention_heads
+        if self.config.new_decoder_architecture:
+            total_num_kv_heads = self.config.num_kv_heads
+        elif self.config.multi_query:
+            total_num_kv_heads = 1
+        else:
+            total_num_kv_heads = total_num_heads
+        num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
+        params_dict = dict(self.named_parameters())
+        for name, loaded_weight in hf_model_weights_iterator(
+                model_name_or_path, cache_dir, load_format, revision,
+                self.config):
+            # Skip loading extra bias for GPTQ models.
+            if name.endswith(".bias") and name not in params_dict:
+                continue
+            param = params_dict[name]
+            if "query_key_value" in name:
+                output_dim = getattr(param, "output_dim", None)
+                loaded_weight_shape = loaded_weight.shape
+                if output_dim is not None:
+                    loaded_weight = loaded_weight.view(
+                        loaded_weight_shape[:output_dim] +
+                        (total_num_kv_heads, num_query_heads_per_kv_head + 2,
+                         -1) + loaded_weight_shape[output_dim + 1:])
+                    wq = loaded_weight.narrow(
+                        output_dim + 1, 0,
+                        num_query_heads_per_kv_head).reshape(
+                            *loaded_weight_shape[:output_dim], -1,
+                            *loaded_weight_shape[output_dim + 1:])
+                    wk = loaded_weight.narrow(
+                        output_dim + 1, num_query_heads_per_kv_head,
+                        1).reshape(*loaded_weight_shape[:output_dim], -1,
+                                   *loaded_weight_shape[output_dim + 1:])
+                    wv = loaded_weight.narrow(
+                        output_dim + 1, num_query_heads_per_kv_head + 1,
+                        1).reshape(*loaded_weight_shape[:output_dim], -1,
+                                   *loaded_weight_shape[output_dim + 1:])
+                    loaded_weight = torch.cat([wq, wk, wv], dim=output_dim)
+
+            weight_loader = getattr(param, "weight_loader",
+                                    default_weight_loader)
+            weight_loader(param, loaded_weight)

+ 65 - 62
aphrodite/modeling/models/yi.py → aphrodite/modeling/models/gemma.py

@@ -1,14 +1,7 @@
 # coding=utf-8
 # coding=utf-8
-# Adapted from
-# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
 # Copyright 2023 The PygmalionAI team.
 # Copyright 2023 The PygmalionAI team.
 # Copyright 2023 The vLLM team.
 # Copyright 2023 The vLLM team.
-# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
-#
-# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
-# and OPT implementations in this library. It has been modified from its
-# original forms to accommodate minor architectural differences compared
-# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+# Copyright (c) Google Inc.
 #
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # you may not use this file except in compliance with the License.
@@ -21,15 +14,15 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
-"""Inference-only Yi model compatible with HuggingFace weights."""
-from typing import Any, Dict, List, Optional, Tuple
+"""Inference-only Gemma model compatible with HuggingFace weights."""
+from typing import List, Optional, Tuple
 
 
 import torch
 import torch
 from torch import nn
 from torch import nn
-from aphrodite.transformers_utils.configs.yi import YiConfig
+from transformers import GemmaConfig
 
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.metadata import InputMetadata
-from aphrodite.modeling.layers.activation import SiluAndMul
+from aphrodite.modeling.layers.activation import GeluAndMul
 from aphrodite.modeling.layers.attention import PagedAttention
 from aphrodite.modeling.layers.attention import PagedAttention
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (LinearMethodBase,
 from aphrodite.modeling.layers.linear import (LinearMethodBase,
@@ -51,13 +44,12 @@ from aphrodite.common.sequence import SamplerOutput
 KVCache = Tuple[torch.Tensor, torch.Tensor]
 KVCache = Tuple[torch.Tensor, torch.Tensor]
 
 
 
 
-class YiMLP(nn.Module):
+class GemmaMLP(nn.Module):
 
 
     def __init__(
     def __init__(
         self,
         self,
         hidden_size: int,
         hidden_size: int,
         intermediate_size: int,
         intermediate_size: int,
-        hidden_act: str,
         linear_method: Optional[LinearMethodBase] = None,
         linear_method: Optional[LinearMethodBase] = None,
     ) -> None:
     ) -> None:
         super().__init__()
         super().__init__()
@@ -82,10 +74,7 @@ class YiMLP(nn.Module):
                                            hidden_size,
                                            hidden_size,
                                            bias=False,
                                            bias=False,
                                            linear_method=linear_method)
                                            linear_method=linear_method)
-        if hidden_act != "silu":
-            raise ValueError(f"Unsupported activation: {hidden_act}. "
-                             "Only silu is supported for now.")
-        self.act_fn = SiluAndMul()
+        self.act_fn = GeluAndMul()
 
 
     def forward(self, x):
     def forward(self, x):
         if self.merge_weight:
         if self.merge_weight:
@@ -99,18 +88,16 @@ class YiMLP(nn.Module):
         return x
         return x
 
 
 
 
-class YiAttention(nn.Module):
+class GemmaAttention(nn.Module):
 
 
-    def __init__(
-        self,
-        hidden_size: int,
-        num_heads: int,
-        num_kv_heads: int,
-        rope_theta: float = 10000,
-        rope_scaling: Optional[Dict[str, Any]] = None,
-        max_position_embeddings: int = 8192,
-        linear_method: Optional[LinearMethodBase] = None,
-    ) -> None:
+    def __init__(self,
+                 hidden_size: int,
+                 num_heads: int,
+                 num_kv_heads: int,
+                 head_dim: int,
+                 max_position_embeddings: int = 8192,
+                 rope_theta: float = 10000,
+                 linear_method: Optional[LinearMethodBase] = None) -> None:
         super().__init__()
         super().__init__()
         self.hidden_size = hidden_size
         self.hidden_size = hidden_size
         tp_size = get_tensor_model_parallel_world_size()
         tp_size = get_tensor_model_parallel_world_size()
@@ -127,12 +114,11 @@ class YiAttention(nn.Module):
             # the KV heads across multiple tensor parallel GPUs.
             # the KV heads across multiple tensor parallel GPUs.
             assert tp_size % self.total_num_kv_heads == 0
             assert tp_size % self.total_num_kv_heads == 0
         self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
         self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
-        self.head_dim = hidden_size // self.total_num_heads
+        self.head_dim = head_dim
         self.q_size = self.num_heads * self.head_dim
         self.q_size = self.num_heads * self.head_dim
         self.kv_size = self.num_kv_heads * self.head_dim
         self.kv_size = self.num_kv_heads * self.head_dim
         self.scaling = self.head_dim**-0.5
         self.scaling = self.head_dim**-0.5
         self.rope_theta = rope_theta
         self.rope_theta = rope_theta
-        self.max_position_embeddings = max_position_embeddings
 
 
         if linear_method is not None and not linear_method.quant_config.merge_weight(
         if linear_method is not None and not linear_method.quant_config.merge_weight(
         ):
         ):
@@ -165,15 +151,12 @@ class YiAttention(nn.Module):
             bias=False,
             bias=False,
             linear_method=linear_method,
             linear_method=linear_method,
         )
         )
-        is_neox_style = True if linear_method is None or linear_method.quant_config.rope_style(
-        ) is None else linear_method.quant_config.rope_style()
         self.rotary_emb = get_rope(
         self.rotary_emb = get_rope(
             self.head_dim,
             self.head_dim,
             rotary_dim=self.head_dim,
             rotary_dim=self.head_dim,
             max_position=max_position_embeddings,
             max_position=max_position_embeddings,
             base=self.rope_theta,
             base=self.rope_theta,
-            rope_scaling=rope_scaling,
-            is_neox_style=is_neox_style,
+            is_neox_style=True,
         )
         )
         self.attn = PagedAttention(self.num_heads,
         self.attn = PagedAttention(self.num_heads,
                                    self.head_dim,
                                    self.head_dim,
@@ -202,36 +185,33 @@ class YiAttention(nn.Module):
         return output
         return output
 
 
 
 
-class YiDecoderLayer(nn.Module):
+class GemmaDecoderLayer(nn.Module):
 
 
     def __init__(
     def __init__(
         self,
         self,
-        config: YiConfig,
+        config: GemmaConfig,
         linear_method: Optional[LinearMethodBase] = None,
         linear_method: Optional[LinearMethodBase] = None,
     ) -> None:
     ) -> None:
         super().__init__()
         super().__init__()
         self.hidden_size = config.hidden_size
         self.hidden_size = config.hidden_size
-        rope_theta = getattr(config, "rope_theta", 10000)
-        rope_scaling = getattr(config, "rope_scaling", None)
-        max_position_embeddings = getattr(config, "max_position_embeddings",
-                                          8192)
-        self.self_attn = YiAttention(
+        self.self_attn = GemmaAttention(
             hidden_size=self.hidden_size,
             hidden_size=self.hidden_size,
             num_heads=config.num_attention_heads,
             num_heads=config.num_attention_heads,
             num_kv_heads=config.num_key_value_heads,
             num_kv_heads=config.num_key_value_heads,
-            rope_theta=rope_theta,
-            rope_scaling=rope_scaling,
-            max_position_embeddings=max_position_embeddings,
+            head_dim=config.head_dim,
+            max_position_embeddings=config.max_position_embeddings,
+            rope_theta=config.rope_theta,
             linear_method=linear_method,
             linear_method=linear_method,
         )
         )
-        self.mlp = YiMLP(
+        self.mlp = GemmaMLP(
             hidden_size=self.hidden_size,
             hidden_size=self.hidden_size,
             intermediate_size=config.intermediate_size,
             intermediate_size=config.intermediate_size,
-            hidden_act=config.hidden_act,
             linear_method=linear_method,
             linear_method=linear_method,
         )
         )
-        self.ln1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
-        self.ln2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+        self.input_layernorm = RMSNorm(config.hidden_size,
+                                       eps=config.rms_norm_eps)
+        self.post_attention_layernorm = RMSNorm(config.hidden_size,
+                                                eps=config.rms_norm_eps)
 
 
     def forward(
     def forward(
         self,
         self,
@@ -244,9 +224,10 @@ class YiDecoderLayer(nn.Module):
         # Self Attention
         # Self Attention
         if residual is None:
         if residual is None:
             residual = hidden_states
             residual = hidden_states
-            hidden_states = self.ln1(hidden_states)
+            hidden_states = self.input_layernorm(hidden_states)
         else:
         else:
-            hidden_states, residual = self.ln1(hidden_states, residual)
+            hidden_states, residual = self.input_layernorm(
+                hidden_states, residual)
         hidden_states = self.self_attn(
         hidden_states = self.self_attn(
             positions=positions,
             positions=positions,
             hidden_states=hidden_states,
             hidden_states=hidden_states,
@@ -255,27 +236,27 @@ class YiDecoderLayer(nn.Module):
         )
         )
 
 
         # Fully Connected
         # Fully Connected
-        hidden_states, residual = self.ln2(hidden_states, residual)
+        hidden_states, residual = self.post_attention_layernorm(
+            hidden_states, residual)
         hidden_states = self.mlp(hidden_states)
         hidden_states = self.mlp(hidden_states)
         return hidden_states, residual
         return hidden_states, residual
 
 
 
 
-class YiModel(nn.Module):
+class GemmaModel(nn.Module):
 
 
     def __init__(
     def __init__(
         self,
         self,
-        config: YiConfig,
+        config: GemmaConfig,
         linear_method: Optional[LinearMethodBase] = None,
         linear_method: Optional[LinearMethodBase] = None,
     ) -> None:
     ) -> None:
         super().__init__()
         super().__init__()
         self.config = config
         self.config = config
-        self.padding_idx = config.pad_token_id
-        self.vocab_size = config.vocab_size
+
         self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
         self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                    config.hidden_size,
                                                    config.hidden_size,
                                                    linear_method=linear_method)
                                                    linear_method=linear_method)
         self.layers = nn.ModuleList([
         self.layers = nn.ModuleList([
-            YiDecoderLayer(config, linear_method)
+            GemmaDecoderLayer(config, linear_method)
             for _ in range(config.num_hidden_layers)
             for _ in range(config.num_hidden_layers)
         ])
         ])
         self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
         self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -288,6 +269,9 @@ class YiModel(nn.Module):
         input_metadata: InputMetadata,
         input_metadata: InputMetadata,
     ) -> torch.Tensor:
     ) -> torch.Tensor:
         hidden_states = self.embed_tokens(input_ids)
         hidden_states = self.embed_tokens(input_ids)
+        # Normalize the embedding by sqrt(hidden_size)
+        hidden_states *= self.config.hidden_size**0.5
+
         residual = None
         residual = None
         for i in range(len(self.layers)):
         for i in range(len(self.layers)):
             layer = self.layers[i]
             layer = self.layers[i]
@@ -302,22 +286,23 @@ class YiModel(nn.Module):
         return hidden_states
         return hidden_states
 
 
 
 
-class YiForCausalLM(nn.Module):
+class GemmaForCausalLM(nn.Module):
 
 
     def __init__(
     def __init__(
         self,
         self,
-        config: YiConfig,
+        config: GemmaConfig,
         linear_method: Optional[LinearMethodBase] = None,
         linear_method: Optional[LinearMethodBase] = None,
     ) -> None:
     ) -> None:
         super().__init__()
         super().__init__()
         self.config = config
         self.config = config
         self.linear_method = linear_method
         self.linear_method = linear_method
-        self.model = YiModel(config, linear_method)
+        self.model = GemmaModel(config, linear_method)
         self.lm_head = ParallelLMHead(config.vocab_size,
         self.lm_head = ParallelLMHead(config.vocab_size,
                                       config.hidden_size,
                                       config.hidden_size,
                                       linear_method=linear_method)
                                       linear_method=linear_method)
         self.sampler = Sampler(config.vocab_size)
         self.sampler = Sampler(config.vocab_size)
 
 
+    @torch.no_grad()
     def forward(
     def forward(
         self,
         self,
         input_ids: torch.Tensor,
         input_ids: torch.Tensor,
@@ -355,11 +340,19 @@ class YiForCausalLM(nn.Module):
         ):
         ):
             stacked_params_mapping = []
             stacked_params_mapping = []
         params_dict = dict(self.named_parameters())
         params_dict = dict(self.named_parameters())
+        loaded_params = set()
         for name, loaded_weight in hf_model_weights_iterator(
         for name, loaded_weight in hf_model_weights_iterator(
                 model_name_or_path, cache_dir, load_format, revision,
                 model_name_or_path, cache_dir, load_format, revision,
                 self.config):
                 self.config):
             if "rotary_emb.inv_freq" in name:
             if "rotary_emb.inv_freq" in name:
                 continue
                 continue
+            if "embed_tokens.weight" in name:
+                # Copy word embedding to lm_head
+                loaded_params.add("lm_head.weight")
+                lm_head_param = params_dict["lm_head.weight"]
+                weight_loader = getattr(lm_head_param, "weight_loader",
+                                        default_weight_loader)
+                weight_loader(lm_head_param, loaded_weight)
             for (param_name, weight_name, shard_id) in stacked_params_mapping:
             for (param_name, weight_name, shard_id) in stacked_params_mapping:
                 if weight_name not in name:
                 if weight_name not in name:
                     continue
                     continue
@@ -372,10 +365,20 @@ class YiForCausalLM(nn.Module):
                 weight_loader(param, loaded_weight, shard_id)
                 weight_loader(param, loaded_weight, shard_id)
                 break
                 break
             else:
             else:
-                # Skip loading extra bias for GPTQ models.
-                if name.endswith(".bias") and name not in params_dict:
+                # Skip loading extra layer for lora models.
+                if "lm_head" in name:
                     continue
                     continue
+                # GemmaRMSNorm is different from Llama's in that it multiplies
+                # (1 + weight) to the output, instead of just weight.
+                if "norm.weight" in name:
+                    loaded_weight += 1.0
                 param = params_dict[name]
                 param = params_dict[name]
                 weight_loader = getattr(param, "weight_loader",
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        unloaded_params = params_dict.keys() - loaded_params
+        if unloaded_params:
+            raise RuntimeError(
+                "Some weights are not initialized from checkpoints: "
+                f"{unloaded_params}")

+ 286 - 0
aphrodite/modeling/models/gpt2.py

@@ -0,0 +1,286 @@
+# coding=utf-8
+# Adapted from
+# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
+# Copyright 2023 The PygmalionAI team.
+# Copyright 2023 The vLLM team.
+# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only GPT-2 model compatible with HuggingFace weights."""
+from typing import List, Optional, Tuple
+
+import torch
+from torch import nn
+from transformers import GPT2Config
+
+from aphrodite.modeling.metadata import InputMetadata
+from aphrodite.modeling.layers.activation import get_act_fn
+from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
+                                              LinearMethodBase,
+                                              QKVParallelLinear,
+                                              RowParallelLinear)
+from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    VocabParallelEmbedding, ParallelLMHead)
+from aphrodite.modeling.megatron.parallel_state import (
+    get_tensor_model_parallel_world_size)
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.modeling.hf_downloader import (default_weight_loader,
+                                              hf_model_weights_iterator)
+from aphrodite.common.sequence import SamplerOutput
+
+KVCache = Tuple[torch.Tensor, torch.Tensor]
+
+
+class GPT2Attention(nn.Module):
+
+    def __init__(
+        self,
+        config: GPT2Config,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.hidden_size = config.hidden_size
+        total_num_heads = config.num_attention_heads
+        tensor_model_parallel_world_size = (
+            get_tensor_model_parallel_world_size())
+        assert total_num_heads % tensor_model_parallel_world_size == 0
+        self.num_heads = total_num_heads // tensor_model_parallel_world_size
+        self.head_dim = self.hidden_size // total_num_heads
+        self.scale = self.head_dim**-0.5
+
+        self.c_attn = QKVParallelLinear(
+            self.hidden_size,
+            self.head_dim,
+            total_num_heads,
+            bias=True,
+            linear_method=linear_method,
+        )
+        self.c_proj = RowParallelLinear(
+            self.hidden_size,
+            self.hidden_size,
+            bias=True,
+            linear_method=linear_method,
+        )
+        self.attn = PagedAttention(self.num_heads,
+                                   self.head_dim,
+                                   scale=self.scale)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        qkv, _ = self.c_attn(hidden_states)
+        q, k, v = qkv.chunk(chunks=3, dim=-1)
+        key_cache, value_cache = kv_cache
+        attn_output = self.attn(q, k, v, key_cache, value_cache,
+                                input_metadata)
+        attn_output, _ = self.c_proj(attn_output)
+        return attn_output
+
+
+class GPT2MLP(nn.Module):
+
+    def __init__(
+        self,
+        intermediate_size: int,
+        config: GPT2Config,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        hidden_size = config.hidden_size
+        self.c_fc = ColumnParallelLinear(
+            hidden_size,
+            intermediate_size,
+            bias=True,
+            linear_method=linear_method,
+        )
+        self.c_proj = RowParallelLinear(
+            intermediate_size,
+            hidden_size,
+            bias=True,
+            linear_method=linear_method,
+        )
+        quant_config = getattr(linear_method, "quant_config", None)
+        self.act = get_act_fn(config.activation_function, quant_config,
+                              intermediate_size)
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        hidden_states, _ = self.c_fc(hidden_states)
+        hidden_states = self.act(hidden_states)
+        hidden_states, _ = self.c_proj(hidden_states)
+        return hidden_states
+
+
+class GPT2Block(nn.Module):
+
+    def __init__(
+        self,
+        config: GPT2Config,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        hidden_size = config.hidden_size
+        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
+                     hidden_size)
+
+        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+        self.attn = GPT2Attention(config, linear_method)
+        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+        self.mlp = GPT2MLP(inner_dim, config, linear_method)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        residual = hidden_states
+        hidden_states = self.ln_1(hidden_states)
+        attn_output = self.attn(
+            hidden_states=hidden_states,
+            kv_cache=kv_cache,
+            input_metadata=input_metadata,
+        )
+        # residual connection
+        hidden_states = attn_output + residual
+
+        residual = hidden_states
+        hidden_states = self.ln_2(hidden_states)
+        feed_forward_hidden_states = self.mlp(hidden_states)
+        # residual connection
+        hidden_states = residual + feed_forward_hidden_states
+        return hidden_states
+
+
+class GPT2Model(nn.Module):
+
+    def __init__(
+        self,
+        config: GPT2Config,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.config = config
+        assert not config.add_cross_attention
+        assert not config.scale_attn_by_inverse_layer_idx
+        assert not config.reorder_and_upcast_attn
+        self.embed_dim = config.hidden_size
+        self.wte = VocabParallelEmbedding(config.vocab_size,
+                                          self.embed_dim,
+                                          linear_method=linear_method)
+        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
+        self.h = nn.ModuleList([
+            GPT2Block(config, linear_method)
+            for _ in range(config.num_hidden_layers)
+        ])
+        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        position_ids: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        inputs_embeds = self.wte(input_ids)
+        position_embeds = self.wpe(position_ids)
+        hidden_states = inputs_embeds + position_embeds
+
+        for i in range(len(self.h)):
+            layer = self.h[i]
+            hidden_states = layer(hidden_states, kv_caches[i], input_metadata)
+
+        hidden_states = self.ln_f(hidden_states)
+        return hidden_states
+
+
+class GPT2LMHeadModel(nn.Module):
+
+    def __init__(
+        self,
+        config: GPT2Config,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.config = config
+        self.linear_method = linear_method
+        self.transformer = GPT2Model(config, linear_method)
+        # self.lm_head_weight = self.transformer.wte.weight
+        self.lm_head = ParallelLMHead(config.vocab_size,
+                                      config.hidden_size,
+                                      linear_method=linear_method)
+        self.sampler = Sampler(config.vocab_size)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.transformer(input_ids, positions, kv_caches,
+                                         input_metadata)
+        return hidden_states
+
+    def sample(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(self.lm_head(hidden_states),
+                                   sampling_metadata)
+        return next_tokens
+
+    def load_weights(self,
+                     model_name_or_path: str,
+                     cache_dir: Optional[str] = None,
+                     load_format: str = "auto",
+                     revision: Optional[str] = None):
+        params_dict = dict(self.named_parameters(remove_duplicate=False))
+        for name, loaded_weight in hf_model_weights_iterator(
+                model_name_or_path, cache_dir, load_format, revision,
+                self.config):
+            if "lm_head.weight" in name:
+                # GPT-2 ties the weights of the embedding layer and the final
+                # linear layer.
+                continue
+            if "wte.weight" in name:
+                # Copy word embedding to lm_head
+                lm_head_param = params_dict["lm_head.weight"]
+                weight_loader = getattr(lm_head_param, "weight_loader",
+                                        default_weight_loader)
+                weight_loader(lm_head_param, loaded_weight)
+            if ".attn.bias" in name or ".attn.masked_bias" in name:
+                # Skip attention mask.
+                # NOTE: "c_attn.bias" should not be skipped.
+                continue
+            if not name.startswith("transformer."):
+                name = "transformer." + name
+            param = params_dict[name]
+            # The HF's GPT-2 implementation uses Conv1D instead of Linear.
+            # Because of this, we need to transpose the weights.
+            # Note(zhuohan): the logic below might break quantized models.
+            for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
+                if conv1d_weight_name not in name:
+                    continue
+                if not name.endswith(".weight"):
+                    continue
+                loaded_weight = loaded_weight.t()
+            weight_loader = getattr(param, "weight_loader",
+                                    default_weight_loader)
+            weight_loader(param, loaded_weight)

+ 292 - 0
aphrodite/modeling/models/gpt_bigcode.py

@@ -0,0 +1,292 @@
+# coding=utf-8
+# Adapted from
+# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
+# Copyright 2023 The PygmalionAI team.
+# Copyright 2023 The vLLM team.
+# Copyright 2023 CTranslate2, and Michael Feil
+# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
+from typing import List, Optional, Tuple
+
+import torch
+from torch import nn
+from transformers import GPTBigCodeConfig
+
+from aphrodite.modeling.metadata import InputMetadata
+from aphrodite.modeling.layers.activation import get_act_fn
+from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
+                                              LinearMethodBase,
+                                              QKVParallelLinear,
+                                              RowParallelLinear)
+from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    VocabParallelEmbedding, ParallelLMHead)
+from aphrodite.modeling.megatron.parallel_state import (
+    get_tensor_model_parallel_world_size)
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.modeling.hf_downloader import (default_weight_loader,
+                                              hf_model_weights_iterator)
+from aphrodite.common.sequence import SamplerOutput
+
+KVCache = Tuple[torch.Tensor, torch.Tensor]
+
+
+class GPTBigCodeAttention(nn.Module):
+
+    def __init__(
+        self,
+        config: GPTBigCodeConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.hidden_size = config.hidden_size
+        total_num_heads = config.num_attention_heads
+        self.tensor_model_parallel_world_size = (
+            get_tensor_model_parallel_world_size())
+        assert total_num_heads % self.tensor_model_parallel_world_size == 0
+        self.num_heads = (total_num_heads //
+                          self.tensor_model_parallel_world_size)
+        self.head_dim = self.hidden_size // total_num_heads
+        self.scale = self.head_dim**-0.5
+
+        self.multi_query = config.multi_query
+        if self.multi_query:
+            total_num_kv_heads = 1
+            self.num_kv_heads = 1
+        else:
+            total_num_kv_heads = total_num_heads
+            self.num_kv_heads = self.num_heads
+        self.kv_dim = self.head_dim * self.num_kv_heads
+        self.c_attn = QKVParallelLinear(
+            self.hidden_size,
+            self.head_dim,
+            total_num_heads,
+            total_num_kv_heads,
+            bias=True,
+            linear_method=linear_method,
+        )
+
+        self.c_proj = RowParallelLinear(
+            self.hidden_size,
+            self.hidden_size,
+            bias=True,
+            linear_method=linear_method,
+        )
+        self.attn = PagedAttention(self.num_heads,
+                                   self.head_dim,
+                                   scale=self.scale,
+                                   num_kv_heads=self.num_kv_heads)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        qkv, _ = self.c_attn(hidden_states)
+        q, k, v = qkv.split(
+            [
+                self.hidden_size // self.tensor_model_parallel_world_size,
+                self.kv_dim, self.kv_dim
+            ],
+            dim=-1,
+        )
+        key_cache, value_cache = kv_cache
+        attn_output = self.attn(q, k, v, key_cache, value_cache,
+                                input_metadata)
+        attn_output, _ = self.c_proj(attn_output)
+        return attn_output
+
+
+class GPTBigMLP(nn.Module):
+
+    def __init__(
+        self,
+        intermediate_size: int,
+        config: GPTBigCodeConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        hidden_size = config.hidden_size
+        self.c_fc = ColumnParallelLinear(
+            hidden_size,
+            intermediate_size,
+            bias=True,
+            linear_method=linear_method,
+        )
+        self.c_proj = RowParallelLinear(
+            intermediate_size,
+            hidden_size,
+            bias=True,
+            linear_method=linear_method,
+        )
+        quant_config = getattr(linear_method, "quant_config", None)
+        self.act = get_act_fn(config.activation_function, quant_config,
+                              intermediate_size)
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        hidden_states, _ = self.c_fc(hidden_states)
+        hidden_states = self.act(hidden_states)
+        hidden_states, _ = self.c_proj(hidden_states)
+        return hidden_states
+
+
+class GPTBigCodeBlock(nn.Module):
+
+    def __init__(
+        self,
+        config: GPTBigCodeConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        hidden_size = config.hidden_size
+        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
+                     hidden_size)
+
+        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+        self.attn = GPTBigCodeAttention(config, linear_method)
+        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+        self.mlp = GPTBigMLP(inner_dim, config, linear_method)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        residual = hidden_states
+        hidden_states = self.ln_1(hidden_states)
+        attn_output = self.attn(
+            hidden_states=hidden_states,
+            kv_cache=kv_cache,
+            input_metadata=input_metadata,
+        )
+        # residual connection
+        hidden_states = attn_output + residual
+
+        residual = hidden_states
+        hidden_states = self.ln_2(hidden_states)
+        feed_forward_hidden_states = self.mlp(hidden_states)
+        # residual connection
+        hidden_states = residual + feed_forward_hidden_states
+        return hidden_states
+
+
+class GPTBigCodeModel(nn.Module):
+
+    def __init__(
+        self,
+        config: GPTBigCodeConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.config = config
+        assert not config.add_cross_attention
+
+        self.embed_dim = config.hidden_size
+
+        self.wte = VocabParallelEmbedding(config.vocab_size,
+                                          self.embed_dim,
+                                          linear_method=linear_method)
+        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
+        self.h = nn.ModuleList([
+            GPTBigCodeBlock(config, linear_method)
+            for _ in range(config.num_hidden_layers)
+        ])
+        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        position_ids: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        inputs_embeds = self.wte(input_ids)
+        position_embeds = self.wpe(position_ids)
+        hidden_states = inputs_embeds + position_embeds
+
+        for i in range(len(self.h)):
+            layer = self.h[i]
+            hidden_states = layer(hidden_states, kv_caches[i], input_metadata)
+
+        hidden_states = self.ln_f(hidden_states)
+        return hidden_states
+
+
+class GPTBigCodeForCausalLM(nn.Module):
+
+    def __init__(
+        self,
+        config: GPTBigCodeConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.config = config
+        self.linear_method = linear_method
+        self.transformer = GPTBigCodeModel(config, linear_method)
+        # self.lm_head_weight = self.transformer.wte.weight
+        self.lm_head = ParallelLMHead(config.vocab_size,
+                                      config.hidden_size,
+                                      linear_method=linear_method)
+        self.sampler = Sampler(config.vocab_size)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.transformer(input_ids, positions, kv_caches,
+                                         input_metadata)
+        return hidden_states
+
+    def sample(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(self.lm_head(hidden_states),
+                                   sampling_metadata)
+        return next_tokens
+
+    def load_weights(self,
+                     model_name_or_path: str,
+                     cache_dir: Optional[str] = None,
+                     load_format: str = "auto",
+                     revision: Optional[str] = None):
+        params_dict = dict(self.named_parameters(remove_duplicate=False))
+        for name, loaded_weight in hf_model_weights_iterator(
+                model_name_or_path, cache_dir, load_format, revision,
+                self.config):
+            if "lm_head.weight" in name:
+                continue
+            if "wte.weight" in name:
+                # Copy word embedding to lm_head
+                lm_head_param = params_dict["lm_head.weight"]
+                weight_loader = getattr(lm_head_param, "weight_loader",
+                                        default_weight_loader)
+                weight_loader(lm_head_param, loaded_weight)
+            if ".attn.bias" in name:
+                # Skip attention mask.
+                # NOTE: "c_attn.bias" should not be skipped.
+                continue
+            param = params_dict[name]
+            weight_loader = getattr(param, "weight_loader",
+                                    default_weight_loader)
+            weight_loader(param, loaded_weight)

+ 8 - 13
aphrodite/modeling/models/gpt_j.py

@@ -166,8 +166,7 @@ class GPTJBlock(nn.Module):
         linear_method: Optional[LinearMethodBase] = None,
         linear_method: Optional[LinearMethodBase] = None,
     ):
     ):
         super().__init__()
         super().__init__()
-        inner_dim = (4 * config.n_embd
-                     if config.n_inner is None else config.n_inner)
+        inner_dim = 4 * config.n_embd if config.n_inner is None else config.n_inner
         self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
         self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
         self.attn = GPTJAttention(config, linear_method)
         self.attn = GPTJAttention(config, linear_method)
         self.mlp = GPTJMLP(inner_dim, config, linear_method)
         self.mlp = GPTJMLP(inner_dim, config, linear_method)
@@ -202,11 +201,9 @@ class GPTJModel(nn.Module):
         super().__init__()
         super().__init__()
         self.config = config
         self.config = config
         self.embed_dim = config.n_embd
         self.embed_dim = config.n_embd
-        self.wte = VocabParallelEmbedding(
-            config.vocab_size,
-            self.embed_dim,
-            linear_method=linear_method,
-        )
+        self.wte = VocabParallelEmbedding(config.vocab_size,
+                                          self.embed_dim,
+                                          linear_method=linear_method)
         self.h = nn.ModuleList(
         self.h = nn.ModuleList(
             [GPTJBlock(config, linear_method) for _ in range(config.n_layer)])
             [GPTJBlock(config, linear_method) for _ in range(config.n_layer)])
         self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
         self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
@@ -243,12 +240,10 @@ class GPTJForCausalLM(nn.Module):
         self.linear_method = linear_method
         self.linear_method = linear_method
         assert not config.tie_word_embeddings
         assert not config.tie_word_embeddings
         self.transformer = GPTJModel(config, linear_method)
         self.transformer = GPTJModel(config, linear_method)
-        self.lm_head = ParallelLMHead(
-            config.vocab_size,
-            config.n_embd,
-            bias=True,
-            linear_method=linear_method,
-        )
+        self.lm_head = ParallelLMHead(config.vocab_size,
+                                      config.n_embd,
+                                      bias=True,
+                                      linear_method=linear_method)
         self.sampler = Sampler(config.vocab_size)
         self.sampler = Sampler(config.vocab_size)
 
 
     def forward(
     def forward(

+ 6 - 10
aphrodite/modeling/models/gpt_neox.py

@@ -196,11 +196,9 @@ class GPTNeoXModel(nn.Module):
         super().__init__()
         super().__init__()
         self.config = config
         self.config = config
 
 
-        self.embed_in = VocabParallelEmbedding(
-            config.vocab_size,
-            config.hidden_size,
-            linear_method=linear_method,
-        )
+        self.embed_in = VocabParallelEmbedding(config.vocab_size,
+                                               config.hidden_size,
+                                               linear_method=linear_method)
         self.layers = nn.ModuleList([
         self.layers = nn.ModuleList([
             GPTNeoXLayer(config, linear_method)
             GPTNeoXLayer(config, linear_method)
             for _ in range(config.num_hidden_layers)
             for _ in range(config.num_hidden_layers)
@@ -239,11 +237,9 @@ class GPTNeoXForCausalLM(nn.Module):
         self.config = config
         self.config = config
         self.linear_method = linear_method
         self.linear_method = linear_method
         self.gpt_neox = GPTNeoXModel(config, linear_method)
         self.gpt_neox = GPTNeoXModel(config, linear_method)
-        self.embed_out = ParallelLMHead(
-            config.vocab_size,
-            config.hidden_size,
-            linear_method=linear_method,
-        )
+        self.embed_out = ParallelLMHead(config.vocab_size,
+                                        config.hidden_size,
+                                        linear_method=linear_method)
         self.sampler = Sampler(config.vocab_size)
         self.sampler = Sampler(config.vocab_size)
 
 
     def forward(
     def forward(

+ 352 - 0
aphrodite/modeling/models/internlm2.py

@@ -0,0 +1,352 @@
+# -*- coding: utf-8 -*-
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+from torch import nn
+from transformers import PretrainedConfig
+
+from aphrodite.modeling.metadata import InputMetadata
+from aphrodite.modeling.layers.activation import SiluAndMul
+from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.layernorm import RMSNorm
+from aphrodite.modeling.layers.linear import (LinearMethodBase,
+                                              ColumnParallelLinear,
+                                              MergedColumnParallelLinear,
+                                              QKVParallelLinear,
+                                              RowParallelLinear)
+from aphrodite.modeling.layers.rotary_embedding import get_rope
+from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    VocabParallelEmbedding, ParallelLMHead)
+from aphrodite.modeling.megatron.parallel_state import (
+    get_tensor_model_parallel_world_size)
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.modeling.hf_downloader import (default_weight_loader,
+                                              hf_model_weights_iterator)
+from aphrodite.common.sequence import SamplerOutput
+
+KVCache = Tuple[torch.Tensor, torch.Tensor]
+
+
+class InternLM2MLP(nn.Module):
+
+    def __init__(
+        self,
+        hidden_size: int,
+        intermediate_size: int,
+        hidden_act: str,
+        linear_method: Optional[LinearMethodBase] = None,
+    ) -> None:
+        super().__init__()
+        if linear_method is not None and not linear_method.quant_config.merge_weight(
+        ):
+            self.merge_weight = False
+            self.w1 = ColumnParallelLinear(hidden_size,
+                                           intermediate_size,
+                                           bias=False,
+                                           linear_method=linear_method)
+            self.w3 = ColumnParallelLinear(hidden_size,
+                                           intermediate_size,
+                                           bias=False,
+                                           linear_method=linear_method)
+        else:
+            self.merge_weight = True
+            self.gate_up_proj = MergedColumnParallelLinear(
+                hidden_size, [intermediate_size] * 2,
+                bias=False,
+                linear_method=linear_method)
+        self.w2 = RowParallelLinear(intermediate_size,
+                                    hidden_size,
+                                    bias=False,
+                                    linear_method=linear_method)
+        if hidden_act != "silu":
+            raise ValueError(f"Unsupported activation: {hidden_act}. "
+                             "Only silu is supported for now.")
+        self.act_fn = SiluAndMul()
+
+    def forward(self, x):
+        if self.merge_weight:
+            gate_up, _ = self.gate_up_proj(x)
+        else:
+            up, _ = self.up_proj(x)
+            gate, _ = self.gate_proj(x)
+            gate_up = torch.cat([gate, up], dim=-1)
+        x = self.act_fn(gate_up)
+        x, _ = self.w2(x)
+        return x
+
+
+class InternLM2Attention(nn.Module):
+
+    def __init__(
+        self,
+        hidden_size: int,
+        num_heads: int,
+        num_kv_heads: int,
+        rope_theta: float = 10000,
+        rope_scaling: Optional[Dict[str, Any]] = None,
+        max_position_embeddings: int = 8192,
+        linear_method: Optional[LinearMethodBase] = None,
+    ) -> None:
+        super().__init__()
+        self.hidden_size = hidden_size
+        tp_size = get_tensor_model_parallel_world_size()
+        self.total_num_heads = num_heads
+        assert self.total_num_heads % tp_size == 0
+        self.num_heads = self.total_num_heads // tp_size
+        self.total_num_kv_heads = num_kv_heads
+        if self.total_num_kv_heads >= tp_size:
+            # Number of KV heads is greater than TP size, so we partition
+            # the KV heads across multiple tensor parallel GPUs.
+            assert self.total_num_kv_heads % tp_size == 0
+        else:
+            # Number of KV heads is less than TP size, so we replicate
+            # the KV heads across multiple tensor parallel GPUs.
+            assert tp_size % self.total_num_kv_heads == 0
+        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
+        self.head_dim = hidden_size // self.total_num_heads
+        self.q_size = self.num_heads * self.head_dim
+        self.kv_size = self.num_kv_heads * self.head_dim
+        self.scaling = self.head_dim**-0.5
+        self.rope_theta = rope_theta
+        self.max_position_embeddings = max_position_embeddings
+
+        self.wqkv = QKVParallelLinear(
+            hidden_size,
+            self.head_dim,
+            self.total_num_heads,
+            self.total_num_kv_heads,
+            bias=False,
+            linear_method=linear_method,
+        )
+        self.wo = RowParallelLinear(
+            self.total_num_heads * self.head_dim,
+            hidden_size,
+            bias=False,
+            linear_method=linear_method,
+        )
+
+        self.rotary_emb = get_rope(
+            self.head_dim,
+            rotary_dim=self.head_dim,
+            max_position=max_position_embeddings,
+            base=rope_theta,
+            rope_scaling=rope_scaling,
+        )
+        self.attn = PagedAttention(self.num_heads,
+                                   self.head_dim,
+                                   self.scaling,
+                                   num_kv_heads=self.num_kv_heads)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        qkv, _ = self.wqkv(hidden_states)
+        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+        q, k = self.rotary_emb(positions, q, k)
+        k_cache, v_cache = kv_cache
+        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
+        output, _ = self.wo(attn_output)
+        return output
+
+
+class InternLMDecoderLayer(nn.Module):
+
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ) -> None:
+        super().__init__()
+        self.hidden_size = config.hidden_size
+        rope_theta = getattr(config, "rope_theta", 10000)
+        rope_scaling = getattr(config, "rope_scaling", None)
+        max_position_embeddings = getattr(config, "max_position_embeddings",
+                                          8192)
+        self.attention = InternLM2Attention(
+            hidden_size=self.hidden_size,
+            num_heads=config.num_attention_heads,
+            num_kv_heads=config.num_key_value_heads,
+            rope_theta=rope_theta,
+            rope_scaling=rope_scaling,
+            max_position_embeddings=max_position_embeddings,
+            linear_method=linear_method,
+        )
+        self.feed_forward = InternLM2MLP(
+            hidden_size=self.hidden_size,
+            intermediate_size=config.intermediate_size,
+            hidden_act=config.hidden_act,
+            linear_method=linear_method,
+        )
+        self.attention_norm = RMSNorm(config.hidden_size,
+                                      eps=config.rms_norm_eps)
+        self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+        residual: Optional[torch.Tensor],
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        # Self Attention
+        if residual is None:
+            residual = hidden_states
+            hidden_states = self.attention_norm(hidden_states)
+        else:
+            hidden_states, residual = self.attention_norm(
+                hidden_states, residual)
+        hidden_states = self.attention(
+            positions=positions,
+            hidden_states=hidden_states,
+            kv_cache=kv_cache,
+            input_metadata=input_metadata,
+        )
+
+        # Fully Connected
+        hidden_states, residual = self.ffn_norm(hidden_states, residual)
+        hidden_states = self.feed_forward(hidden_states)
+        return hidden_states, residual
+
+
+class InternLM2Model(nn.Module):
+
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ) -> None:
+        super().__init__()
+        self.config = config
+        self.padding_idx = config.pad_token_id
+        self.vocab_size = config.vocab_size
+        self.tok_embeddings = VocabParallelEmbedding(
+            config.vocab_size,
+            config.hidden_size,
+            linear_method=linear_method,
+        )
+        self.layers = nn.ModuleList([
+            InternLMDecoderLayer(config, linear_method)
+            for _ in range(config.num_hidden_layers)
+        ])
+        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.tok_embeddings(input_ids)
+        residual = None
+        for i in range(len(self.layers)):
+            layer = self.layers[i]
+            hidden_states, residual = layer(
+                positions,
+                hidden_states,
+                kv_caches[i],
+                input_metadata,
+                residual,
+            )
+        hidden_states, _ = self.norm(hidden_states, residual)
+        return hidden_states
+
+
+class InternLM2ForCausalLM(nn.Module):
+
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ) -> None:
+        super().__init__()
+        self.config = config
+        self.linear_method = linear_method
+        self.model = InternLM2Model(config, linear_method)
+        self.output = ParallelLMHead(
+            config.vocab_size,
+            config.hidden_size,
+            linear_method=linear_method,
+        )
+        self.sampler = Sampler(config.vocab_size)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.model(input_ids, positions, kv_caches,
+                                   input_metadata)
+        return hidden_states
+
+    def sample(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(self.output(hidden_states),
+                                   sampling_metadata)
+        return next_tokens
+
+    def load_weights(self,
+                     model_name_or_path: str,
+                     cache_dir: Optional[str] = None,
+                     load_format: str = "auto",
+                     revision: Optional[str] = None):
+        stacked_params_mapping = [
+            # (param_name, shard_name, shard_id)
+            ("gate_up_proj", "w1", 0),
+            ("gate_up_proj", "w3", 1),
+        ]
+        if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
+        ):
+            stacked_params_mapping = []
+        params_dict = dict(self.named_parameters())
+        for name, loaded_weight in hf_model_weights_iterator(
+                model_name_or_path, cache_dir, load_format, revision):
+            if "rotary_emb.inv_freq" in name:
+                continue
+            for (param_name, weight_name, shard_id) in stacked_params_mapping:
+                if weight_name not in name:
+                    continue
+                name = name.replace(weight_name, param_name)
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+                param = params_dict[name]
+                weight_loader = param.weight_loader
+                weight_loader(param, loaded_weight, shard_id)
+                break
+            else:
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+                param = params_dict[name]
+                if "wqkv" in name:
+                    config = self.config
+                    kv_groups = config.num_attention_heads // config.num_key_value_heads
+                    head_dim = config.hidden_size // config.num_attention_heads
+                    loaded_weight = loaded_weight.view(-1, 2 + kv_groups,
+                                                       head_dim,
+                                                       loaded_weight.shape[-1])
+                    wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1],
+                                             dim=1)
+                    wq = wq.reshape(-1, wq.shape[-1])
+                    wk = wk.reshape(-1, wk.shape[-1])
+                    wv = wv.reshape(-1, wv.shape[-1])
+                    weight_loader = param.weight_loader
+                    weight_loader(param, wq, 'q')
+                    weight_loader(param, wk, 'k')
+                    weight_loader(param, wv, 'v')
+                else:
+                    weight_loader = getattr(param, "weight_loader",
+                                            default_weight_loader)
+                    weight_loader(param, loaded_weight)

+ 45 - 14
aphrodite/modeling/models/llama.py

@@ -1,7 +1,6 @@
 # coding=utf-8
 # coding=utf-8
 # Adapted from
 # Adapted from
 # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
 # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
-# Copyright 2023 The PygmalionAI team.
 # Copyright 2023 The vLLM team.
 # Copyright 2023 The vLLM team.
 # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
 # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
 #
 #
@@ -111,6 +110,8 @@ class LlamaAttention(nn.Module):
         rope_scaling: Optional[Dict[str, Any]] = None,
         rope_scaling: Optional[Dict[str, Any]] = None,
         max_position_embeddings: int = 8192,
         max_position_embeddings: int = 8192,
         linear_method: Optional[LinearMethodBase] = None,
         linear_method: Optional[LinearMethodBase] = None,
+        bias: bool = False,
+        sliding_window: Optional[int] = None,
     ) -> None:
     ) -> None:
         super().__init__()
         super().__init__()
         self.hidden_size = hidden_size
         self.hidden_size = hidden_size
@@ -140,15 +141,15 @@ class LlamaAttention(nn.Module):
             self.merge_weight = False
             self.merge_weight = False
             self.q_proj = ColumnParallelLinear(hidden_size,
             self.q_proj = ColumnParallelLinear(hidden_size,
                                                self.q_size,
                                                self.q_size,
-                                               bias=False,
+                                               bias=bias,
                                                linear_method=linear_method)
                                                linear_method=linear_method)
             self.k_proj = ColumnParallelLinear(hidden_size,
             self.k_proj = ColumnParallelLinear(hidden_size,
                                                self.kv_size,
                                                self.kv_size,
-                                               bias=False,
+                                               bias=bias,
                                                linear_method=linear_method)
                                                linear_method=linear_method)
             self.v_proj = ColumnParallelLinear(hidden_size,
             self.v_proj = ColumnParallelLinear(hidden_size,
                                                self.kv_size,
                                                self.kv_size,
-                                               bias=False,
+                                               bias=bias,
                                                linear_method=linear_method)
                                                linear_method=linear_method)
         else:
         else:
             self.merge_weight = True
             self.merge_weight = True
@@ -157,13 +158,13 @@ class LlamaAttention(nn.Module):
                 self.head_dim,
                 self.head_dim,
                 self.total_num_heads,
                 self.total_num_heads,
                 self.total_num_kv_heads,
                 self.total_num_kv_heads,
-                bias=False,
+                bias=bias,
                 linear_method=linear_method,
                 linear_method=linear_method,
             )
             )
         self.o_proj = RowParallelLinear(
         self.o_proj = RowParallelLinear(
             self.total_num_heads * self.head_dim,
             self.total_num_heads * self.head_dim,
             hidden_size,
             hidden_size,
-            bias=False,
+            bias=bias,
             linear_method=linear_method,
             linear_method=linear_method,
         )
         )
 
 
@@ -180,7 +181,8 @@ class LlamaAttention(nn.Module):
         self.attn = PagedAttention(self.num_heads,
         self.attn = PagedAttention(self.num_heads,
                                    self.head_dim,
                                    self.head_dim,
                                    self.scaling,
                                    self.scaling,
-                                   num_kv_heads=self.num_kv_heads)
+                                   num_kv_heads=self.num_kv_heads,
+                                   sliding_window=sliding_window)
 
 
     def forward(
     def forward(
         self,
         self,
@@ -217,14 +219,18 @@ class LlamaDecoderLayer(nn.Module):
         rope_scaling = getattr(config, "rope_scaling", None)
         rope_scaling = getattr(config, "rope_scaling", None)
         max_position_embeddings = getattr(config, "max_position_embeddings",
         max_position_embeddings = getattr(config, "max_position_embeddings",
                                           8192)
                                           8192)
+        sliding_window = getattr(config, "sliding_window", None)
         self.self_attn = LlamaAttention(
         self.self_attn = LlamaAttention(
             hidden_size=self.hidden_size,
             hidden_size=self.hidden_size,
             num_heads=config.num_attention_heads,
             num_heads=config.num_attention_heads,
-            num_kv_heads=config.num_key_value_heads,
+            num_kv_heads=getattr(config, "num_key_value_heads",
+                                 config.num_attention_heads),
             rope_theta=rope_theta,
             rope_theta=rope_theta,
             rope_scaling=rope_scaling,
             rope_scaling=rope_scaling,
             max_position_embeddings=max_position_embeddings,
             max_position_embeddings=max_position_embeddings,
             linear_method=linear_method,
             linear_method=linear_method,
+            bias=getattr(config, "bias", False),
+            sliding_window=sliding_window,
         )
         )
         self.mlp = LlamaMLP(
         self.mlp = LlamaMLP(
             hidden_size=self.hidden_size,
             hidden_size=self.hidden_size,
@@ -316,7 +322,32 @@ class LlamaModel(nn.Module):
 
 
 
 
 class LlamaForCausalLM(nn.Module):
 class LlamaForCausalLM(nn.Module):
-    supports_lora = True
+    packed_modules_mapping = {
+        "qkv_proj": [
+            "q_proj",
+            "k_proj",
+            "v_proj",
+        ],
+        "gate_up_proj": [
+            "gate_proj",
+            "up_proj",
+        ],
+    }
+
+    # LoRA specific attributes
+    supported_lora_modules = [
+        "qkv_proj",
+        "o_proj",
+        "gate_up_proj",
+        "down_proj",
+        "embed_tokens",
+        "lm_head",
+    ]
+    embedding_modules = {
+        "embed_tokens": "input_embeddings",
+        "lm_head": "output_embeddings",
+    }
+    embedding_padding_modules = ["lm_head"]
 
 
     def __init__(
     def __init__(
         self,
         self,
@@ -328,20 +359,20 @@ class LlamaForCausalLM(nn.Module):
         self.config = config
         self.config = config
         self.linear_method = linear_method
         self.linear_method = linear_method
         self.model = LlamaModel(config, linear_method, lora_config=lora_config)
         self.model = LlamaModel(config, linear_method, lora_config=lora_config)
-        unpadded_vocab_size = config.vocab_size
+        self.unpadded_vocab_size = config.vocab_size
         if lora_config:
         if lora_config:
-            unpadded_vocab_size += lora_config.lora_extra_vocab_size
+            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
         self.lm_head = ParallelLMHead(
         self.lm_head = ParallelLMHead(
-            unpadded_vocab_size,
+            self.unpadded_vocab_size,
             config.hidden_size,
             config.hidden_size,
-            linear_method=linear_method,
             org_num_embeddings=config.vocab_size,
             org_num_embeddings=config.vocab_size,
+            linear_method=linear_method,
             padding_size=DEFAULT_VOCAB_PADDING_SIZE
             padding_size=DEFAULT_VOCAB_PADDING_SIZE
             # We need bigger padding if using lora for kernel
             # We need bigger padding if using lora for kernel
             # compatibility
             # compatibility
             if not lora_config else lora_config.lora_vocab_padding_size,
             if not lora_config else lora_config.lora_vocab_padding_size,
         )
         )
-        self.sampler = Sampler(unpadded_vocab_size, config.vocab_size)
+        self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size)
 
 
     def forward(
     def forward(
         self,
         self,

+ 62 - 18
aphrodite/modeling/models/mixtral.py

@@ -25,23 +25,23 @@
 from typing import List, Optional, Tuple
 from typing import List, Optional, Tuple
 
 
 import torch
 import torch
-
 from torch import nn
 from torch import nn
 from transformers import MixtralConfig
 from transformers import MixtralConfig
 
 
+from aphrodite.common.config import LoRAConfig
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.attention import PagedAttention
 from aphrodite.modeling.layers.attention import PagedAttention
 from aphrodite.modeling.layers.triton_kernel.fused_moe import fused_moe
 from aphrodite.modeling.layers.triton_kernel.fused_moe import fused_moe
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (LinearMethodBase,
 from aphrodite.modeling.layers.linear import (LinearMethodBase,
-                                              ReplicatedLinear,
                                               QKVParallelLinear,
                                               QKVParallelLinear,
+                                              ReplicatedLinear,
                                               RowParallelLinear,
                                               RowParallelLinear,
                                               ColumnParallelLinear)
                                               ColumnParallelLinear)
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    VocabParallelEmbedding, ParallelLMHead)
+    VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
 from aphrodite.modeling.megatron.communication_op import (
 from aphrodite.modeling.megatron.communication_op import (
     tensor_model_parallel_all_reduce)
     tensor_model_parallel_all_reduce)
 from aphrodite.modeling.megatron.parallel_state import (
 from aphrodite.modeling.megatron.parallel_state import (
@@ -58,6 +58,7 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
 class MixtralMoE(nn.Module):
 class MixtralMoE(nn.Module):
     """A tensor-parallel MoE implementation for Mixtral that shards each expert
     """A tensor-parallel MoE implementation for Mixtral that shards each expert
     across all ranks.
     across all ranks.
+
     Each expert's weights are sharded across all ranks and a fused MoE
     Each expert's weights are sharded across all ranks and a fused MoE
     kernel is used for the forward pass, and finally we reduce the outputs
     kernel is used for the forward pass, and finally we reduce the outputs
     across ranks.
     across ranks.
@@ -70,13 +71,14 @@ class MixtralMoE(nn.Module):
         hidden_size: int,
         hidden_size: int,
         intermediate_size: int,
         intermediate_size: int,
         params_dtype: Optional[torch.dtype] = None,
         params_dtype: Optional[torch.dtype] = None,
+        tp_size: Optional[int] = None,
     ):
     ):
         super().__init__()
         super().__init__()
-        tp_size = get_tensor_model_parallel_world_size()
+        self.tp_size = tp_size or get_tensor_model_parallel_world_size()
         self.num_total_experts = num_experts
         self.num_total_experts = num_experts
         self.top_k = top_k
         self.top_k = top_k
         self.hidden_size = hidden_size
         self.hidden_size = hidden_size
-        self.intermediate_size = intermediate_size // tp_size
+        self.intermediate_size = intermediate_size // self.tp_size
 
 
         if params_dtype is None:
         if params_dtype is None:
             params_dtype = torch.get_default_dtype()
             params_dtype = torch.get_default_dtype()
@@ -127,7 +129,6 @@ class MixtralMoE(nn.Module):
         hidden_states = hidden_states.view(-1, self.hidden_size)
         hidden_states = hidden_states.view(-1, self.hidden_size)
         # router_logits: (batch * sequence_length, n_experts)
         # router_logits: (batch * sequence_length, n_experts)
         router_logits, _ = self.gate(hidden_states)
         router_logits, _ = self.gate(hidden_states)
-
         final_hidden_states = fused_moe(hidden_states,
         final_hidden_states = fused_moe(hidden_states,
                                         self.ws,
                                         self.ws,
                                         self.w2s,
                                         self.w2s,
@@ -136,8 +137,9 @@ class MixtralMoE(nn.Module):
                                         renormalize=True,
                                         renormalize=True,
                                         inplace=True)
                                         inplace=True)
 
 
-        final_hidden_states = tensor_model_parallel_all_reduce(
-            final_hidden_states)
+        if self.tp_size > 1:
+            final_hidden_states = tensor_model_parallel_all_reduce(
+                final_hidden_states)
 
 
         return final_hidden_states.view(batch_size, sequence_length,
         return final_hidden_states.view(batch_size, sequence_length,
                                         hidden_size)
                                         hidden_size)
@@ -310,15 +312,20 @@ class MixtralModel(nn.Module):
         self,
         self,
         config: MixtralConfig,
         config: MixtralConfig,
         linear_method: Optional[LinearMethodBase] = None,
         linear_method: Optional[LinearMethodBase] = None,
+        lora_config: Optional[LoRAConfig] = None,
     ) -> None:
     ) -> None:
         super().__init__()
         super().__init__()
         self.padding_idx = config.pad_token_id
         self.padding_idx = config.pad_token_id
-        self.vocab_size = config.vocab_size
+        lora_vocab = (lora_config.lora_extra_vocab_size *
+                      (lora_config.max_loras or 1)) if lora_config else 0
+        self.vocab_size = config.vocab_size + lora_vocab
+        self.org_vocab_size = config.vocab_size
 
 
         self.embed_tokens = VocabParallelEmbedding(
         self.embed_tokens = VocabParallelEmbedding(
-            config.vocab_size,
+            self.vocab_size,
             config.hidden_size,
             config.hidden_size,
             linear_method=linear_method,
             linear_method=linear_method,
+            org_num_embeddings=config.vocab_size,
         )
         )
         self.layers = nn.ModuleList([
         self.layers = nn.ModuleList([
             MixtralDecoderLayer(config, linear_method=linear_method)
             MixtralDecoderLayer(config, linear_method=linear_method)
@@ -345,20 +352,53 @@ class MixtralModel(nn.Module):
 
 
 
 
 class MixtralForCausalLM(nn.Module):
 class MixtralForCausalLM(nn.Module):
+    packed_modules_mapping = {
+        "qkv_proj": [
+            "q_proj",
+            "k_proj",
+            "v_proj",
+        ],
+    }
+
+    # LoRA specific attributes
+    supported_lora_modules = [
+        "qkv_proj",
+        "o_proj",
+        "embed_tokens",
+        "lm_head",
+    ]
+    embedding_modules = {
+        "embed_tokens": "input_embeddings",
+        "lm_head": "output_embeddings",
+    }
+    embedding_padding_modules = ["lm_head"]
 
 
     def __init__(
     def __init__(
         self,
         self,
         config: MixtralConfig,
         config: MixtralConfig,
         linear_method: Optional[LinearMethodBase] = None,
         linear_method: Optional[LinearMethodBase] = None,
+        lora_config: Optional[LoRAConfig] = None,
     ) -> None:
     ) -> None:
         super().__init__()
         super().__init__()
         self.config = config
         self.config = config
         self.linear_method = linear_method
         self.linear_method = linear_method
-        self.model = MixtralModel(config, linear_method)
-        self.lm_head = ParallelLMHead(config.vocab_size,
-                                      config.hidden_size,
-                                      linear_method=linear_method)
-        self.sampler = Sampler(config.vocab_size)
+        self.model = MixtralModel(config,
+                                  linear_method,
+                                  lora_config=lora_config)
+        self.unpadded_vocab_size = config.vocab_size
+        if lora_config:
+            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
+        self.lm_head = ParallelLMHead(
+            self.unpadded_vocab_size,
+            config.hidden_size,
+            linear_method=linear_method,
+            org_num_embeddings=config.vocab_size,
+            padding_size=DEFAULT_VOCAB_PADDING_SIZE
+            # We need bigger padding if using lora for kernel
+            # compatibility
+            if not lora_config else lora_config.lora_vocab_padding_size,
+        )
+        self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size)
 
 
     def forward(
     def forward(
         self,
         self,
@@ -391,6 +431,10 @@ class MixtralForCausalLM(nn.Module):
             ("qkv_proj", "k_proj", "k"),
             ("qkv_proj", "k_proj", "k"),
             ("qkv_proj", "v_proj", "v"),
             ("qkv_proj", "v_proj", "v"),
         ]
         ]
+        if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
+        ):
+            stacked_params_mapping = []
+
         expert_params_mapping = [
         expert_params_mapping = [
             # (param_name, weight_name, expert_id)
             # (param_name, weight_name, expert_id)
             ("ws" if weight_name in ["w1", "w3"] else "w2s",
             ("ws" if weight_name in ["w1", "w3"] else "w2s",
@@ -398,18 +442,18 @@ class MixtralForCausalLM(nn.Module):
             for expert_id in range(self.config.num_local_experts)
             for expert_id in range(self.config.num_local_experts)
             for weight_name in ["w1", "w2", "w3"]
             for weight_name in ["w1", "w2", "w3"]
         ]
         ]
-        if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
-        ):
-            stacked_params_mapping = []
+
         params_dict = dict(self.named_parameters())
         params_dict = dict(self.named_parameters())
         for name, loaded_weight in hf_model_weights_iterator(
         for name, loaded_weight in hf_model_weights_iterator(
                 model_name_or_path,
                 model_name_or_path,
                 cache_dir,
                 cache_dir,
                 load_format,
                 load_format,
                 revision,
                 revision,
+                self.config,
                 fall_back_to_pt=False):
                 fall_back_to_pt=False):
             if "rotary_emb.inv_freq" in name:
             if "rotary_emb.inv_freq" in name:
                 continue
                 continue
+
             for (param_name, weight_name, shard_id) in stacked_params_mapping:
             for (param_name, weight_name, shard_id) in stacked_params_mapping:
                 if weight_name not in name:
                 if weight_name not in name:
                     continue
                     continue

+ 7 - 6
aphrodite/modeling/models/mixtral_quant.py

@@ -226,14 +226,12 @@ class MixtralAttention(nn.Module):
             bias=False,
             bias=False,
             linear_method=linear_method,
             linear_method=linear_method,
         )
         )
-        is_neox_style = True if linear_method is None or linear_method.quant_config.rope_style(
-        ) is None else linear_method.quant_config.rope_style()
         self.rotary_emb = get_rope(
         self.rotary_emb = get_rope(
             self.head_dim,
             self.head_dim,
             rotary_dim=self.head_dim,
             rotary_dim=self.head_dim,
             max_position=max_position,
             max_position=max_position,
             base=int(self.rope_theta),
             base=int(self.rope_theta),
-            is_neox_style=is_neox_style,
+            is_neox_style=True,
         )
         )
         self.attn = PagedAttention(
         self.attn = PagedAttention(
             self.num_heads,
             self.num_heads,
@@ -371,9 +369,11 @@ class MixtralForCausalLM(nn.Module):
         self.config = config
         self.config = config
         self.linear_method = linear_method
         self.linear_method = linear_method
         self.model = MixtralModel(config, linear_method)
         self.model = MixtralModel(config, linear_method)
-        self.lm_head = ParallelLMHead(config.vocab_size,
-                                      config.hidden_size,
-                                      linear_method=linear_method)
+        self.lm_head = ParallelLMHead(
+            config.vocab_size,
+            config.hidden_size,
+            linear_method=linear_method,
+        )
         self.sampler = Sampler(config.vocab_size)
         self.sampler = Sampler(config.vocab_size)
 
 
     def forward(
     def forward(
@@ -410,6 +410,7 @@ class MixtralForCausalLM(nn.Module):
         if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
         if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
         ):
         ):
             stacked_params_mapping = []
             stacked_params_mapping = []
+
         params_dict = dict(self.named_parameters())
         params_dict = dict(self.named_parameters())
         for name, loaded_weight in hf_model_weights_iterator(
         for name, loaded_weight in hf_model_weights_iterator(
                 model_name_or_path,
                 model_name_or_path,

+ 307 - 0
aphrodite/modeling/models/mpt.py

@@ -0,0 +1,307 @@
+# coding=utf-8
+# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
+import math
+from typing import List, Optional, Tuple
+
+import torch
+import torch.nn as nn
+
+from aphrodite.modeling.metadata import InputMetadata
+from aphrodite.modeling.layers.activation import get_act_fn
+from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
+                                              LinearMethodBase,
+                                              QKVParallelLinear,
+                                              RowParallelLinear)
+from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    VocabParallelEmbedding, ParallelLMHead)
+from aphrodite.modeling.megatron.parallel_state import (
+    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.modeling.hf_downloader import (default_weight_loader,
+                                              hf_model_weights_iterator)
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.transformers_utils.configs.mpt import MPTConfig
+
+KVCache = Tuple[torch.Tensor, torch.Tensor]
+
+
+def _get_alibi_slopes(
+    total_num_heads: int,
+    alibi_bias_max: int,
+) -> torch.Tensor:
+    next_power_of_2 = 2**math.ceil(math.log2(total_num_heads))
+    m = torch.arange(1, next_power_of_2 + 1, dtype=torch.float32)
+    m = m.mul(alibi_bias_max / next_power_of_2)
+    slopes = 1.0 / torch.pow(2, m)
+    if next_power_of_2 != total_num_heads:
+        slopes = torch.concat([slopes[1::2], slopes[::2]])[:total_num_heads]
+    return slopes
+
+
+class MPTAttention(nn.Module):
+
+    def __init__(
+        self,
+        config: MPTConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.d_model = config.d_model
+        self.total_num_heads = config.n_heads
+        self.head_dim = self.d_model // self.total_num_heads
+        self.clip_qkv = config.attn_config["clip_qkv"]
+        self.qk_ln = config.attn_config["qk_ln"]
+        self.alibi_bias_max = config.attn_config["alibi_bias_max"]
+        if "kv_n_heads" in config.attn_config:
+            self.total_num_kv_heads = config.attn_config['kv_n_heads']
+        else:
+            self.total_num_kv_heads = self.total_num_heads
+        assert not config.attn_config["prefix_lm"]
+        assert config.attn_config["alibi"]
+
+        # pylint: disable=invalid-name
+        self.Wqkv = QKVParallelLinear(
+            self.d_model,
+            self.d_model // self.total_num_heads,
+            self.total_num_heads,
+            self.total_num_kv_heads,
+            bias=not config.no_bias,
+            linear_method=linear_method,
+        )
+        if self.qk_ln:
+            self.q_ln = nn.LayerNorm(self.d_model)
+            self.k_ln = nn.LayerNorm(self.d_model)
+        self.out_proj = RowParallelLinear(
+            self.d_model,
+            self.d_model,
+            bias=not config.no_bias,
+            linear_method=linear_method,
+        )
+
+        tp_world_size = get_tensor_model_parallel_world_size()
+        assert self.total_num_heads % tp_world_size == 0
+        self.num_heads = self.total_num_heads // tp_world_size
+
+        if self.total_num_kv_heads >= tp_world_size:
+            # Number of KV heads is greater than TP size, so we partition
+            # the KV heads across multiple tensor parallel GPUs.
+            assert self.total_num_kv_heads % tp_world_size == 0
+        else:
+            # Number of KV heads is less than TP size, so we replicate
+            # the KV heads across multiple tensor parallel GPUs.
+            assert tp_world_size % self.total_num_kv_heads == 0
+        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
+        self.q_size = self.num_heads * self.head_dim
+        self.kv_size = self.num_kv_heads * self.head_dim
+        # Create the alibi slopes and slice them.
+        tp_rank = get_tensor_model_parallel_rank()
+        head_start = tp_rank * self.num_heads
+        head_end = (tp_rank + 1) * self.num_heads
+        alibi_slopes = _get_alibi_slopes(self.total_num_heads,
+                                         self.alibi_bias_max)
+        alibi_slopes = alibi_slopes[head_start:head_end].tolist()
+
+        self.head_dim = self.d_model // self.total_num_heads
+        scaling = self.head_dim**-0.5
+        self.attn = PagedAttention(self.num_heads,
+                                   self.head_dim,
+                                   scaling,
+                                   alibi_slopes=alibi_slopes,
+                                   num_kv_heads=self.num_kv_heads)
+
+    def forward(
+        self,
+        position_ids: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        del position_ids  # unused.
+        qkv, _ = self.Wqkv(hidden_states)
+        if self.clip_qkv is not None:
+            qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
+        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+        if self.qk_ln:
+            q = self.q_ln(q)
+            k = self.k_ln(k)
+        k_cache, v_cache = kv_cache
+        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
+        output, _ = self.out_proj(attn_output)
+        return output
+
+
+class MPTMLP(nn.Module):
+
+    def __init__(
+        self,
+        config: MPTConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        hidden_size = config.d_model
+        expansion_ratio = config.expansion_ratio
+        intermediate_size = expansion_ratio * hidden_size
+        self.up_proj = ColumnParallelLinear(
+            hidden_size,
+            intermediate_size,
+            bias=not config.no_bias,
+            linear_method=linear_method,
+        )
+        quant_config = getattr(linear_method, "quant_config", None)
+        self.act = get_act_fn("gelu", quant_config, intermediate_size)
+        self.down_proj = RowParallelLinear(
+            intermediate_size,
+            hidden_size,
+            bias=not config.no_bias,
+            linear_method=linear_method,
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x, _ = self.up_proj(x)
+        x = self.act(x)
+        x, _ = self.down_proj(x)
+        return x
+
+
+class MPTBlock(nn.Module):
+
+    def __init__(
+        self,
+        config: MPTConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        hidden_size = config.d_model
+        self.norm_1 = nn.LayerNorm(hidden_size)
+        self.attn = MPTAttention(config, linear_method)
+        self.norm_2 = nn.LayerNorm(hidden_size)
+        self.ffn = MPTMLP(config, linear_method)
+
+    def forward(
+        self,
+        position_ids: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        x = self.norm_1(hidden_states)
+        x = self.attn(
+            position_ids=position_ids,
+            hidden_states=x,
+            kv_cache=kv_cache,
+            input_metadata=input_metadata,
+        )
+        hidden_states = hidden_states + x
+        x = self.norm_2(hidden_states)
+        x = self.ffn(x)
+        hidden_states = hidden_states + x
+        return hidden_states
+
+
+class MPTModel(nn.Module):
+
+    def __init__(
+        self,
+        config: MPTConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        assert config.embedding_fraction == 1.0
+        assert config.norm_type == "low_precision_layernorm"
+
+        self.wte = VocabParallelEmbedding(config.vocab_size,
+                                          config.d_model,
+                                          linear_method=linear_method)
+        self.blocks = nn.ModuleList(
+            [MPTBlock(config, linear_method) for _ in range(config.n_layers)])
+        self.norm_f = nn.LayerNorm(config.d_model)
+        if config.no_bias:
+            for module in self.modules():
+                if hasattr(module, "bias") and isinstance(
+                        module.bias, nn.Parameter):
+                    # Remove the bias term in Linear and LayerNorm.
+                    module.register_parameter("bias", None)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        position_ids: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.wte(input_ids)
+        for i in range(len(self.blocks)):
+            block = self.blocks[i]
+            hidden_states = block(
+                position_ids,
+                hidden_states,
+                kv_caches[i],
+                input_metadata,
+            )
+        hidden_states = self.norm_f(hidden_states)
+        return hidden_states
+
+
+class MPTForCausalLM(nn.Module):
+
+    def __init__(
+        self,
+        config: MPTConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.config = config
+        assert config.tie_word_embeddings
+        self.linear_method = linear_method
+
+        self.transformer = MPTModel(config, linear_method)
+        # self.lm_head_weight = self.transformer.wte.weight
+        self.lm_head = ParallelLMHead(config.vocab_size,
+                                      config.hidden_size,
+                                      linear_method=linear_method)
+        self.sampler = Sampler(config.vocab_size)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.transformer(input_ids, positions, kv_caches,
+                                         input_metadata)
+        return hidden_states
+
+    def sample(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(self.lm_head(hidden_states),
+                                   sampling_metadata)
+        return next_tokens
+
+    def load_weights(self,
+                     model_name_or_path: str,
+                     cache_dir: Optional[str] = None,
+                     load_format: str = "auto",
+                     revision: Optional[str] = None):
+        params_dict = dict(self.named_parameters(remove_duplicate=False))
+        for name, loaded_weight in hf_model_weights_iterator(
+                model_name_or_path, cache_dir, load_format, revision,
+                self.config):
+            # Skip loading extra bias for GPTQ models.
+            if name.endswith(".bias") and name not in params_dict:
+                continue
+            if "wte.weight" in name:
+                # Copy word embedding to lm_head
+                lm_head_param = params_dict["lm_head.weight"]
+                weight_loader = getattr(lm_head_param, "weight_loader",
+                                        default_weight_loader)
+                weight_loader(lm_head_param, loaded_weight)
+            param = params_dict[name]
+            weight_loader = getattr(param, "weight_loader",
+                                    default_weight_loader)
+            weight_loader(param, loaded_weight)

+ 377 - 0
aphrodite/modeling/models/olmo.py

@@ -0,0 +1,377 @@
+# coding=utf-8
+# Adapted from
+# https://github.com/allenai/OLMo/blob/v0.2.4/olmo/model.py and
+# https://github.com/allenai/OLMo/blob/v0.2.4/hf_olmo/modeling_olmo.py
+# Copyright 2023 The PygmalionAI team.
+# Copyright 2023 The vLLM team.
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+#
+# BSD 3-Clause License
+#
+# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+"""Inference-only OLMo model compatible with HuggingFace weights."""
+from typing import List, Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from aphrodite.modeling.metadata import InputMetadata
+from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.linear import (
+    ColumnParallelLinear,
+    LinearMethodBase,
+    QKVParallelLinear,
+    RowParallelLinear,
+)
+from aphrodite.modeling.layers.rotary_embedding import get_rope
+from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    VocabParallelEmbedding, ParallelLMHead)
+from aphrodite.modeling.megatron.parallel_state import (
+    get_tensor_model_parallel_world_size, )
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.modeling.hf_downloader import (
+    default_weight_loader,
+    hf_model_weights_iterator,
+)
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.transformers_utils.configs.olmo import OLMoConfig
+
+KVCache = Tuple[torch.Tensor, torch.Tensor]
+
+
+class SwiGLU(nn.Module):
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x, gate = x.chunk(2, dim=-1)
+        return F.silu(gate) * x
+
+    @property
+    def output_multiplier(self) -> float:
+        return 0.5
+
+
+class OlmoAttention(nn.Module):
+    """
+    This is the attention block where the output is computed as ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
+    (plus another skip connection).
+    """
+
+    def __init__(
+        self,
+        config: OLMoConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.config = config
+        self.hidden_size = config.d_model
+        assert config.d_model % config.n_heads == 0
+        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
+        )
+        self.total_num_heads = self.config.n_heads
+        assert self.total_num_heads % tensor_model_parallel_world_size == 0
+        self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
+        self.head_dim = self.hidden_size // self.total_num_heads
+
+        # Layer norms.
+        self.attn_norm = nn.LayerNorm(config.d_model,
+                                      elementwise_affine=False,
+                                      bias=False)
+        # Attention input projection. Projects x -> (q, k, v)
+        self.att_proj = QKVParallelLinear(
+            config.d_model,
+            self.head_dim,
+            self.total_num_heads,
+            bias=config.include_bias,
+            linear_method=linear_method,
+        )
+
+        # Rotary embeddings.
+        if self.config.rope:
+            rope_theta = getattr(config, "rope_theta", 10000)
+            max_position_embeddings = getattr(config,
+                                              "max_position_embeddings", 8192)
+            self.rotary_emb = get_rope(
+                self.head_dim,
+                rotary_dim=self.head_dim,
+                max_position=max_position_embeddings,
+                base=rope_theta,
+            )
+        self.scaling = self.head_dim**-0.5
+        self.attn = PagedAttention(self.num_heads,
+                                   self.head_dim,
+                                   scale=self.scaling)
+
+        # Attention output projection.
+        self.attn_out = RowParallelLinear(
+            config.d_model,
+            config.d_model,
+            bias=config.include_bias,
+            linear_method=linear_method,
+        )
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.attn_norm(hidden_states)
+        qkv, _ = self.att_proj(hidden_states)
+        q, k, v = qkv.chunk(chunks=3, dim=-1)
+        if self.config.rope:
+            q, k = self.rotary_emb(positions, q, k)
+        k_cache, v_cache = kv_cache
+        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
+        output, _ = self.attn_out(attn_output)
+        return output
+
+
+class OlmoMLP(nn.Module):
+    """
+    This is the MLP block where the output is computed as ``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
+    (plus another skip connection).
+    """
+
+    def __init__(
+        self,
+        config: OLMoConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.config = config
+        self.hidden_size = (config.mlp_hidden_size if config.mlp_hidden_size
+                            is not None else config.mlp_ratio * config.d_model)
+
+        # Layer norms.
+        self.ff_norm = nn.LayerNorm(config.d_model,
+                                    elementwise_affine=False,
+                                    bias=False)
+
+        # Feed-forward input projection.
+        self.ff_proj = ColumnParallelLinear(
+            config.d_model,
+            self.hidden_size,
+            bias=config.include_bias,
+            linear_method=linear_method,
+        )
+
+        # Activation function.
+        # self.act = SiluAndMul()
+        # self.act.output_multiplier = 0.5
+        self.act = SwiGLU()
+        assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
+
+        # Feed-forward output projection.
+        self.ff_out = RowParallelLinear(
+            int(self.act.output_multiplier * self.hidden_size),
+            config.d_model,
+            bias=config.include_bias,
+            linear_method=linear_method,
+        )
+
+    def forward(
+        self,
+        x: torch.Tensor,
+    ) -> torch.Tensor:
+        # Add feed-forward projection.
+        # shape: (batch_size, seq_len, d_model)
+        og_x = x
+        x = self.ff_norm(x)
+        x, _ = self.ff_proj(x)
+        x = self.act(x)
+        x, _ = self.ff_out(x)
+        x = og_x + x
+
+        return x
+
+
+class OlmoBlock(nn.Module):
+    """
+    This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
+    (plus another skip connection).
+    """
+
+    def __init__(self,
+                 config: OLMoConfig,
+                 linear_method: Optional[LinearMethodBase] = None):
+        super().__init__()
+        # Attention block.
+        self.attn = OlmoAttention(config, linear_method)
+
+        # MLP block.
+        self.mlp = OlmoMLP(config, linear_method)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
+        # Attention block.
+        og_x = hidden_states
+        x = self.attn(positions, hidden_states, kv_cache, input_metadata)
+        x = x + og_x
+
+        # MLP block.
+        hidden_states = self.mlp(x)
+        return hidden_states
+
+
+class OlmoModel(nn.Module):
+
+    def __init__(self,
+                 config: OLMoConfig,
+                 linear_method: Optional[LinearMethodBase] = None):
+        super().__init__()
+        self.config = config
+
+        self.transformer = nn.ModuleDict(
+            dict(wte=VocabParallelEmbedding(
+                config.embedding_size or config.vocab_size,
+                config.d_model,
+                linear_method=linear_method,
+            ),
+                 ln_f=nn.LayerNorm(config.d_model,
+                                   elementwise_affine=False,
+                                   bias=False),
+                 ff_out=ParallelLMHead(
+                     config.embedding_size or config.vocab_size,
+                     config.d_model,
+                     bias=config.include_bias,
+                     linear_method=linear_method,
+                 )))
+
+        blocks = [
+            OlmoBlock(config, linear_method) for i in range(config.n_layers)
+        ]
+        if self.config.block_group_size > 1:
+            raise NotImplementedError("Block group size > 1 not supported yet")
+        else:
+            self.transformer.update({"blocks": nn.ModuleList(blocks)})
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        """
+        :param input_ids: A tensor of shape `(batch_size, seq_len)`.
+        """
+        # Get embeddings of input.
+        # shape: (batch_size, seq_len, d_model)
+        x = self.transformer.wte(input_ids)  # type: ignore
+
+        # Apply blocks one-by-one.
+        for block_idx, block in enumerate(self.transformer.blocks):
+            # shape: (batch_size, seq_len, d_model)
+            x = block(
+                positions,
+                x,
+                kv_caches[block_idx],
+                input_metadata,
+            )
+
+        # Apply final layer norm.
+        # shape: (batch_size, seq_len or 1, d_model)
+        x = self.transformer.ln_f(x)  # type: ignore
+        return x
+
+
+class OLMoForCausalLM(nn.Module):
+    """
+    Extremely barebones HF model wrapper.
+    """
+
+    def __init__(self,
+                 config: OLMoConfig,
+                 linear_method: Optional[LinearMethodBase] = None):
+        super().__init__()
+        self.config = config
+        self.linear_method = linear_method
+        self.model = OlmoModel(config, linear_method)
+        self.sampler = Sampler(config.vocab_size)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.model(
+            input_ids=input_ids,
+            positions=positions,
+            kv_caches=kv_caches,
+            input_metadata=input_metadata,
+        )
+        return hidden_states
+
+    def sample(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(
+            self.model.transformer.ff_out(hidden_states), sampling_metadata)
+        return next_tokens
+
+    def load_weights(
+        self,
+        model_name_or_path: str,
+        cache_dir: Optional[str] = None,
+        load_format: str = "auto",
+        revision: Optional[str] = None,
+    ):
+        params_dict = dict(self.named_parameters(remove_duplicate=False))
+        for name, loaded_weight in hf_model_weights_iterator(
+                model_name_or_path, cache_dir, load_format, revision):
+            if "wte.weight" in name and self.config.weight_tying:
+                # Copy word embedding to lm_head
+                lm_head_param = params_dict["model.transformer.ff_out.weight"]
+                weight_loader = getattr(lm_head_param, "weight_loader",
+                                        default_weight_loader)
+                weight_loader(lm_head_param, loaded_weight)
+            # attention
+            if ".att" in name:
+                name = name.replace(".att", ".attn.att")
+            # mlp
+            if ".ff" in name and "transformer.ff_out" not in name:
+                name = name.replace(".ff", ".mlp.ff")
+            # there is no bias in olmo
+            param = params_dict[name]
+            weight_loader = getattr(param, "weight_loader",
+                                    default_weight_loader)
+            weight_loader(param, loaded_weight)

+ 0 - 1
aphrodite/modeling/models/phi.py

@@ -287,7 +287,6 @@ class PhiForCausalLM(nn.Module):
         hidden_states: torch.Tensor,
         hidden_states: torch.Tensor,
         sampling_metadata: SamplingMetadata,
         sampling_metadata: SamplingMetadata,
     ) -> Optional[SamplerOutput]:
     ) -> Optional[SamplerOutput]:
-        head = self.lm_head  # pylint: disable=unused-variable
         next_tokens = self.sampler(self.lm_head(hidden_states),
         next_tokens = self.sampler(self.lm_head(hidden_states),
                                    sampling_metadata)
                                    sampling_metadata)
         return next_tokens
         return next_tokens

+ 315 - 0
aphrodite/modeling/models/qwen.py

@@ -0,0 +1,315 @@
+# coding=utf-8
+# Adapted from
+# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
+# Copyright (c) Alibaba Cloud.
+# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
+"""Inference-only QWen model compatible with HuggingFace weights."""
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+from torch import nn
+
+from aphrodite.modeling.metadata import InputMetadata
+from aphrodite.modeling.layers.activation import SiluAndMul
+from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.layernorm import RMSNorm
+from aphrodite.modeling.layers.linear import (LinearMethodBase,
+                                              MergedColumnParallelLinear,
+                                              QKVParallelLinear,
+                                              RowParallelLinear,
+                                              ColumnParallelLinear)
+from aphrodite.modeling.layers.rotary_embedding import get_rope
+from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    VocabParallelEmbedding, ParallelLMHead)
+from aphrodite.modeling.megatron.parallel_state import (
+    get_tensor_model_parallel_world_size)
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.modeling.hf_downloader import (default_weight_loader,
+                                              hf_model_weights_iterator)
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.transformers_utils.configs.qwen import QWenConfig
+
+KVCache = Tuple[torch.Tensor, torch.Tensor]
+
+
+class QWenMLP(nn.Module):
+
+    def __init__(
+        self,
+        hidden_size: int,
+        intermediate_size: int,
+        hidden_act: str = "silu",
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        if linear_method is not None and not linear_method.quant_config.merge_weight(
+        ):
+            self.merge_weight = False
+            self.w2 = ColumnParallelLinear(hidden_size,
+                                           intermediate_size,
+                                           bias=False,
+                                           linear_method=linear_method)
+            self.w1 = ColumnParallelLinear(hidden_size,
+                                           intermediate_size,
+                                           bias=False,
+                                           linear_method=linear_method)
+        else:
+            self.merge_weight = True
+            self.gate_up_proj = MergedColumnParallelLinear(
+                hidden_size, [intermediate_size] * 2,
+                bias=False,
+                linear_method=linear_method)
+        self.c_proj = RowParallelLinear(intermediate_size,
+                                        hidden_size,
+                                        bias=False,
+                                        linear_method=linear_method)
+        if hidden_act != "silu":
+            raise ValueError(f"Unsupported activation: {hidden_act}. "
+                             "Only silu is supported for now.")
+        self.act_fn = SiluAndMul()
+
+    def forward(self, x):
+        if self.merge_weight:
+            gate_up, _ = self.gate_up_proj(x)
+        else:
+            up, _ = self.w1(x)
+            gate, _ = self.w2(x)
+            gate_up = torch.cat([gate, up], dim=-1)
+        x = self.act_fn(gate_up)
+        x, _ = self.c_proj(x)
+        return x
+
+
+class QWenAttention(nn.Module):
+
+    def __init__(
+        self,
+        hidden_size: int,
+        num_heads: int,
+        max_position_embeddings: int,
+        rope_theta: float = 10000,
+        rope_scaling: Optional[Dict[str, Any]] = None,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.hidden_size = hidden_size
+        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
+        )
+        self.total_num_heads = num_heads
+        assert self.total_num_heads % tensor_model_parallel_world_size == 0
+        self.num_heads = (self.total_num_heads //
+                          tensor_model_parallel_world_size)
+        self.head_dim = hidden_size // self.total_num_heads
+        self.c_attn = QKVParallelLinear(
+            hidden_size,
+            self.head_dim,
+            self.total_num_heads,
+            bias=True,
+            linear_method=linear_method,
+        )
+        self.c_proj = RowParallelLinear(
+            self.total_num_heads * self.head_dim,
+            hidden_size,
+            bias=False,
+            linear_method=linear_method,
+        )
+        self.scaling = self.head_dim**-0.5
+
+        is_neox_style = True if linear_method is None or linear_method.quant_config.rope_style(
+        ) is None else linear_method.quant_config.rope_style()
+        self.rotary_emb = get_rope(
+            self.head_dim,
+            rotary_dim=self.head_dim,
+            max_position=max_position_embeddings,
+            base=rope_theta,
+            rope_scaling=rope_scaling,
+            is_neox_style=is_neox_style,
+        )
+        self.attn = PagedAttention(self.num_heads, self.head_dim, self.scaling)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        qkv, _ = self.c_attn(hidden_states)
+        q, k, v = qkv.chunk(chunks=3, dim=-1)
+        q, k = self.rotary_emb(positions, q, k)
+        k_cache, v_cache = kv_cache
+        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
+
+        output, _ = self.c_proj(attn_output)
+        return output
+
+
+class QWenBlock(nn.Module):
+
+    def __init__(
+        self,
+        config: QWenConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
+
+        rope_theta = getattr(config, "rope_theta", 10000)
+        rope_scaling = getattr(config, "rope_scaling", None)
+        self.attn = QWenAttention(config.hidden_size,
+                                  config.num_attention_heads,
+                                  config.max_position_embeddings,
+                                  rope_theta=rope_theta,
+                                  rope_scaling=rope_scaling,
+                                  linear_method=linear_method)
+
+        self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
+
+        self.mlp = QWenMLP(config.hidden_size,
+                           config.intermediate_size // 2,
+                           linear_method=linear_method)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+        residual: Optional[torch.Tensor],
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        # Self Attention
+        if residual is None:
+            residual = hidden_states
+            hidden_states = self.ln_1(hidden_states)
+        else:
+            hidden_states, residual = self.ln_1(hidden_states, residual)
+        hidden_states = self.attn(
+            positions=positions,
+            hidden_states=hidden_states,
+            kv_cache=kv_cache,
+            input_metadata=input_metadata,
+        )
+
+        # Fully Connected
+        hidden_states, residual = self.ln_2(hidden_states, residual)
+        hidden_states = self.mlp(hidden_states)
+        return hidden_states, residual
+
+
+class QWenModel(nn.Module):
+
+    def __init__(
+        self,
+        config: QWenConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.config = config
+        self.vocab_size = config.vocab_size
+
+        self.wte = VocabParallelEmbedding(config.vocab_size,
+                                          config.hidden_size,
+                                          linear_method=linear_method)
+        self.h = nn.ModuleList([
+            QWenBlock(config, linear_method)
+            for _ in range(config.num_hidden_layers)
+        ])
+        self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.wte(input_ids)
+        residual = None
+        for i in range(len(self.h)):
+            layer = self.h[i]
+            hidden_states, residual = layer(
+                positions,
+                hidden_states,
+                kv_caches[i],
+                input_metadata,
+                residual,
+            )
+        hidden_states, _ = self.ln_f(hidden_states, residual)
+        return hidden_states
+
+
+class QWenLMHeadModel(nn.Module):
+
+    def __init__(
+        self,
+        config: QWenConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.config = config
+        self.linear_method = linear_method
+        self.transformer = QWenModel(config, linear_method)
+        self.lm_head = ParallelLMHead(config.vocab_size,
+                                      config.hidden_size,
+                                      linear_method=linear_method)
+        self.sampler = Sampler(config.vocab_size)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.transformer(input_ids, positions, kv_caches,
+                                         input_metadata)
+        return hidden_states
+
+    def sample(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(self.lm_head(hidden_states),
+                                   sampling_metadata)
+        return next_tokens
+
+    def load_weights(self,
+                     model_name_or_path: str,
+                     cache_dir: Optional[str] = None,
+                     load_format: str = "auto",
+                     revision: Optional[str] = None):
+        stacked_params_mapping = [
+            # (param_name, shard_name, shard_id)
+            ("gate_up_proj", "w2", 0),
+            ("gate_up_proj", "w1", 1),
+        ]
+        if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
+        ):
+            stacked_params_mapping = []
+        params_dict = dict(self.named_parameters())
+        for name, loaded_weight in hf_model_weights_iterator(
+                model_name_or_path, cache_dir, load_format, revision,
+                self.config):
+            if "rotary_emb.inv_freq" in name:
+                continue
+            for (param_name, weight_name, shard_id) in stacked_params_mapping:
+                if weight_name not in name:
+                    continue
+                name = name.replace(weight_name, param_name)
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+                param = params_dict[name]
+                weight_loader = param.weight_loader
+                weight_loader(param, loaded_weight, shard_id)
+                break
+            else:
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+                param = params_dict[name]
+                weight_loader = getattr(param, "weight_loader",
+                                        default_weight_loader)
+                weight_loader(param, loaded_weight)

+ 35 - 52
aphrodite/modeling/models/mistral.py → aphrodite/modeling/models/qwen2.py

@@ -1,7 +1,8 @@
 # coding=utf-8
 # coding=utf-8
 # Adapted from
 # Adapted from
-# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
+# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
 # Copyright 2023 The PygmalionAI team.
 # Copyright 2023 The PygmalionAI team.
+# Copyright 2024 The Qwen team.
 # Copyright 2023 The vLLM team.
 # Copyright 2023 The vLLM team.
 # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
 # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
 #
 #
@@ -21,38 +22,37 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
-"""Inference-only Mistral model compatible with HuggingFace weights."""
+"""Inference-only Qwen2 model compatible with HuggingFace weights."""
 from typing import List, Optional, Tuple
 from typing import List, Optional, Tuple
 
 
 import torch
 import torch
 from torch import nn
 from torch import nn
-from transformers import MistralConfig
+from transformers import Qwen2Config
 
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.attention import PagedAttention
 from aphrodite.modeling.layers.attention import PagedAttention
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (LinearMethodBase,
 from aphrodite.modeling.layers.linear import (LinearMethodBase,
+                                              ColumnParallelLinear,
                                               MergedColumnParallelLinear,
                                               MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               QKVParallelLinear,
-                                              RowParallelLinear,
-                                              ColumnParallelLinear)
+                                              RowParallelLinear)
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
+    VocabParallelEmbedding, ParallelLMHead)
 from aphrodite.modeling.megatron.parallel_state import (
 from aphrodite.modeling.megatron.parallel_state import (
     get_tensor_model_parallel_world_size)
     get_tensor_model_parallel_world_size)
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.hf_downloader import (default_weight_loader,
 from aphrodite.modeling.hf_downloader import (default_weight_loader,
                                               hf_model_weights_iterator)
                                               hf_model_weights_iterator)
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.common.sequence import SamplerOutput
-from aphrodite.common.config import LoRAConfig
 
 
 KVCache = Tuple[torch.Tensor, torch.Tensor]
 KVCache = Tuple[torch.Tensor, torch.Tensor]
 
 
 
 
-class MistralMLP(nn.Module):
+class Qwen2MLP(nn.Module):
 
 
     def __init__(
     def __init__(
         self,
         self,
@@ -100,7 +100,7 @@ class MistralMLP(nn.Module):
         return x
         return x
 
 
 
 
-class MistralAttention(nn.Module):
+class Qwen2Attention(nn.Module):
 
 
     def __init__(self,
     def __init__(self,
                  hidden_size: int,
                  hidden_size: int,
@@ -108,6 +108,7 @@ class MistralAttention(nn.Module):
                  num_kv_heads: int,
                  num_kv_heads: int,
                  max_position: int = 4096 * 32,
                  max_position: int = 4096 * 32,
                  rope_theta: float = 10000,
                  rope_theta: float = 10000,
+                 use_sliding_window: bool = False,
                  linear_method: Optional[LinearMethodBase] = None,
                  linear_method: Optional[LinearMethodBase] = None,
                  sliding_window: Optional[int] = None) -> None:
                  sliding_window: Optional[int] = None) -> None:
         super().__init__()
         super().__init__()
@@ -131,22 +132,22 @@ class MistralAttention(nn.Module):
         self.kv_size = self.num_kv_heads * self.head_dim
         self.kv_size = self.num_kv_heads * self.head_dim
         self.scaling = self.head_dim**-0.5
         self.scaling = self.head_dim**-0.5
         self.rope_theta = rope_theta
         self.rope_theta = rope_theta
-        self.sliding_window = sliding_window
+        self.sliding_window = sliding_window if use_sliding_window else None
 
 
         if linear_method is not None and not linear_method.quant_config.merge_weight(
         if linear_method is not None and not linear_method.quant_config.merge_weight(
         ):
         ):
             self.merge_weight = False
             self.merge_weight = False
             self.q_proj = ColumnParallelLinear(hidden_size,
             self.q_proj = ColumnParallelLinear(hidden_size,
                                                self.q_size,
                                                self.q_size,
-                                               bias=False,
+                                               bias=True,
                                                linear_method=linear_method)
                                                linear_method=linear_method)
             self.k_proj = ColumnParallelLinear(hidden_size,
             self.k_proj = ColumnParallelLinear(hidden_size,
                                                self.kv_size,
                                                self.kv_size,
-                                               bias=False,
+                                               bias=True,
                                                linear_method=linear_method)
                                                linear_method=linear_method)
             self.v_proj = ColumnParallelLinear(hidden_size,
             self.v_proj = ColumnParallelLinear(hidden_size,
                                                self.kv_size,
                                                self.kv_size,
-                                               bias=False,
+                                               bias=True,
                                                linear_method=linear_method)
                                                linear_method=linear_method)
         else:
         else:
             self.merge_weight = True
             self.merge_weight = True
@@ -155,7 +156,7 @@ class MistralAttention(nn.Module):
                 self.head_dim,
                 self.head_dim,
                 self.total_num_heads,
                 self.total_num_heads,
                 self.total_num_kv_heads,
                 self.total_num_kv_heads,
-                bias=False,
+                bias=True,
                 linear_method=linear_method,
                 linear_method=linear_method,
             )
             )
         self.o_proj = RowParallelLinear(
         self.o_proj = RowParallelLinear(
@@ -165,14 +166,11 @@ class MistralAttention(nn.Module):
             linear_method=linear_method,
             linear_method=linear_method,
         )
         )
 
 
-        is_neox_style = True if linear_method is None or linear_method.quant_config.rope_style(
-        ) is None else linear_method.quant_config.rope_style()
         self.rotary_emb = get_rope(
         self.rotary_emb = get_rope(
             self.head_dim,
             self.head_dim,
             rotary_dim=self.head_dim,
             rotary_dim=self.head_dim,
             max_position=max_position,
             max_position=max_position,
             base=self.rope_theta,
             base=self.rope_theta,
-            is_neox_style=is_neox_style,
         )
         )
         self.attn = PagedAttention(self.num_heads,
         self.attn = PagedAttention(self.num_heads,
                                    self.head_dim,
                                    self.head_dim,
@@ -202,26 +200,29 @@ class MistralAttention(nn.Module):
         return output
         return output
 
 
 
 
-class MistralDecoderLayer(nn.Module):
+class Qwen2DecoderLayer(nn.Module):
 
 
     def __init__(
     def __init__(
         self,
         self,
-        config: MistralConfig,
+        config: Qwen2Config,
+        layer_idx: int,
         linear_method: Optional[LinearMethodBase] = None,
         linear_method: Optional[LinearMethodBase] = None,
     ) -> None:
     ) -> None:
         super().__init__()
         super().__init__()
         self.hidden_size = config.hidden_size
         self.hidden_size = config.hidden_size
         # Requires transformers > 4.32.0
         # Requires transformers > 4.32.0
-        rope_theta = getattr(config, "rope_theta", 10000)
-        self.self_attn = MistralAttention(
+        rope_theta = getattr(config, "rope_theta", 1000000)
+        use_sliding_window = config.use_sliding_window and layer_idx < config.max_window_layers
+        self.self_attn = Qwen2Attention(
             hidden_size=self.hidden_size,
             hidden_size=self.hidden_size,
             num_heads=config.num_attention_heads,
             num_heads=config.num_attention_heads,
             max_position=config.max_position_embeddings,
             max_position=config.max_position_embeddings,
             num_kv_heads=config.num_key_value_heads,
             num_kv_heads=config.num_key_value_heads,
             rope_theta=rope_theta,
             rope_theta=rope_theta,
+            use_sliding_window=use_sliding_window,
             linear_method=linear_method,
             linear_method=linear_method,
             sliding_window=config.sliding_window)
             sliding_window=config.sliding_window)
-        self.mlp = MistralMLP(
+        self.mlp = Qwen2MLP(
             hidden_size=self.hidden_size,
             hidden_size=self.hidden_size,
             intermediate_size=config.intermediate_size,
             intermediate_size=config.intermediate_size,
             hidden_act=config.hidden_act,
             hidden_act=config.hidden_act,
@@ -261,31 +262,26 @@ class MistralDecoderLayer(nn.Module):
         return hidden_states, residual
         return hidden_states, residual
 
 
 
 
-class MistralModel(nn.Module):
+class Qwen2Model(nn.Module):
 
 
     def __init__(
     def __init__(
         self,
         self,
-        config: MistralConfig,
+        config: Qwen2Config,
         linear_method: Optional[LinearMethodBase] = None,
         linear_method: Optional[LinearMethodBase] = None,
-        lora_config: Optional[LoRAConfig] = None,
     ) -> None:
     ) -> None:
         super().__init__()
         super().__init__()
         self.config = config
         self.config = config
         self.padding_idx = config.pad_token_id
         self.padding_idx = config.pad_token_id
-        lora_vocab = (lora_config.lora_extra_vocab_size *
-                      (lora_config.max_loras or 1)) if lora_config else 0
-        self.vocab_size = config.vocab_size + lora_vocab
-        self.org_vocab_size = config.vocab_size
+        self.vocab_size = config.vocab_size
 
 
         self.embed_tokens = VocabParallelEmbedding(
         self.embed_tokens = VocabParallelEmbedding(
-            self.vocab_size,
+            config.vocab_size,
             config.hidden_size,
             config.hidden_size,
             linear_method=linear_method,
             linear_method=linear_method,
-            org_num_embeddings=config.vocab_size,
         )
         )
         self.layers = nn.ModuleList([
         self.layers = nn.ModuleList([
-            MistralDecoderLayer(config, linear_method)
-            for _ in range(config.num_hidden_layers)
+            Qwen2DecoderLayer(config, layer_idx, linear_method)
+            for layer_idx in range(config.num_hidden_layers)
         ])
         ])
         self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
         self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
 
@@ -311,35 +307,23 @@ class MistralModel(nn.Module):
         return hidden_states
         return hidden_states
 
 
 
 
-class MistralForCausalLM(nn.Module):
-    supports_lora = True
+class Qwen2ForCausalLM(nn.Module):
 
 
     def __init__(
     def __init__(
         self,
         self,
-        config: MistralConfig,
+        config: Qwen2Config,
         linear_method: Optional[LinearMethodBase] = None,
         linear_method: Optional[LinearMethodBase] = None,
-        lora_config: Optional[LoRAConfig] = None,
     ) -> None:
     ) -> None:
         super().__init__()
         super().__init__()
         self.config = config
         self.config = config
         self.linear_method = linear_method
         self.linear_method = linear_method
-        self.model = MistralModel(config,
-                                  linear_method,
-                                  lora_config=lora_config)
-        unpadded_vocab_size = config.vocab_size
-        if lora_config:
-            unpadded_vocab_size += lora_config.lora_extra_vocab_size
+        self.model = Qwen2Model(config, linear_method)
         self.lm_head = ParallelLMHead(
         self.lm_head = ParallelLMHead(
-            unpadded_vocab_size,
+            config.vocab_size,
             config.hidden_size,
             config.hidden_size,
             linear_method=linear_method,
             linear_method=linear_method,
-            org_num_embeddings=config.vocab_size,
-            padding_size=DEFAULT_VOCAB_PADDING_SIZE
-            # We need bigger padding if using lora for kernel
-            # compatibility
-            if not lora_config else lora_config.lora_vocab_padding_size,
         )
         )
-        self.sampler = Sampler(unpadded_vocab_size, config.vocab_size)
+        self.sampler = Sampler(config.vocab_size)
 
 
     def forward(
     def forward(
         self,
         self,
@@ -379,8 +363,7 @@ class MistralForCausalLM(nn.Module):
             stacked_params_mapping = []
             stacked_params_mapping = []
         params_dict = dict(self.named_parameters())
         params_dict = dict(self.named_parameters())
         for name, loaded_weight in hf_model_weights_iterator(
         for name, loaded_weight in hf_model_weights_iterator(
-                model_name_or_path, cache_dir, load_format, revision,
-                self.config):
+                model_name_or_path, cache_dir, load_format, revision):
             if "rotary_emb.inv_freq" in name:
             if "rotary_emb.inv_freq" in name:
                 continue
                 continue
             for (param_name, weight_name, shard_id) in stacked_params_mapping:
             for (param_name, weight_name, shard_id) in stacked_params_mapping:

+ 352 - 0
aphrodite/modeling/models/stablelm.py

@@ -0,0 +1,352 @@
+# coding=utf-8
+# Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# This code is based off the following work:
+# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/modeling_stablelm_epoch.py
+# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
+"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM) model compatible with HuggingFace weights."""
+from typing import List, Optional, Tuple
+
+import torch
+from torch import nn
+from transformers import PretrainedConfig
+
+from aphrodite.modeling.metadata import InputMetadata
+from aphrodite.modeling.layers.activation import SiluAndMul
+from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.linear import (LinearMethodBase,
+                                              MergedColumnParallelLinear,
+                                              QKVParallelLinear,
+                                              RowParallelLinear,
+                                              ColumnParallelLinear)
+from aphrodite.modeling.layers.rotary_embedding import get_rope
+from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    VocabParallelEmbedding, ParallelLMHead)
+from aphrodite.modeling.megatron.parallel_state import (
+    get_tensor_model_parallel_world_size)
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.modeling.hf_downloader import (default_weight_loader,
+                                              hf_model_weights_iterator)
+from aphrodite.common.sequence import SamplerOutput
+
+KVCache = Tuple[torch.Tensor, torch.Tensor]
+
+
+class StablelmMLP(nn.Module):
+
+    def __init__(self,
+                 config: PretrainedConfig,
+                 linear_method: Optional[LinearMethodBase] = None) -> None:
+        super().__init__()
+        self.config = config
+        self.hidden_size = config.hidden_size
+        self.intermediate_size = config.intermediate_size
+        if linear_method is not None and not linear_method.quant_config.merge_weight(
+        ):
+            self.merge_weight = False
+            self.gate_proj = ColumnParallelLinear(config.hidden_size,
+                                                  config.intermediate_size,
+                                                  bias=False,
+                                                  linear_method=linear_method)
+            self.up_proj = ColumnParallelLinear(config.hidden_size,
+                                                config.intermediate_size,
+                                                bias=False,
+                                                linear_method=linear_method)
+        else:
+            self.merge_weight = True
+            self.gate_up_proj = MergedColumnParallelLinear(
+                config.hidden_size, [config.intermediate_size] * 2,
+                bias=False,
+                linear_method=linear_method)
+        self.down_proj = RowParallelLinear(config.intermediate_size,
+                                           config.hidden_size,
+                                           bias=False)
+        self.act_fn = SiluAndMul()
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        if self.merge_weight:
+            gate_up, _ = self.gate_up_proj(x)
+        else:
+            up, _ = self.up_proj(x)
+            gate, _ = self.gate_proj(x)
+            gate_up = torch.cat([gate, up], dim=-1)
+        x = self.act_fn(gate_up)
+        x, _ = self.down_proj(x)
+        return x
+
+
+class StablelmAttention(nn.Module):
+
+    def __init__(self,
+                 config: PretrainedConfig,
+                 linear_method: Optional[LinearMethodBase] = None) -> None:
+        super().__init__()
+        self.config = config
+        self.hidden_size = config.hidden_size
+        tp_size = get_tensor_model_parallel_world_size()
+        self.total_num_heads = config.num_attention_heads
+        self.num_heads = self.total_num_heads // tp_size
+
+        self.total_num_key_value_heads = config.num_key_value_heads
+        if self.total_num_key_value_heads >= tp_size:
+            # Number of KV heads is greater than TP size, so we partition
+            # the KV heads across multiple tensor parallel GPUs.
+            assert self.total_num_key_value_heads % tp_size == 0
+        else:
+            # Number of KV heads is less than TP size, so we replicate
+            # the KV heads across multiple tensor parallel GPUs.
+            assert tp_size % self.total_num_key_value_heads == 0
+        self.num_key_value_heads = max(
+            1, self.total_num_key_value_heads // tp_size)
+        self.head_dim = self.hidden_size // self.total_num_heads
+        self.max_position_embeddings = config.max_position_embeddings
+        self.rotary_ndims = int(self.head_dim * self.config.rope_pct)
+        self.scaling = self.head_dim**-0.5
+        self.q_size = self.num_heads * self.head_dim
+        self.kv_size = self.num_key_value_heads * self.head_dim
+        self.qkv_bias = getattr(config, "use_qkv_bias", False)
+        if (self.head_dim * self.num_heads * tp_size) != self.hidden_size:
+            raise ValueError(
+                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+                f" and `num_heads`: {self.num_heads}).")
+
+        if linear_method is not None and not linear_method.quant_config.merge_weight(
+        ):
+            self.merge_weight = False
+            self.q_proj = ColumnParallelLinear(self.hidden_size,
+                                               self.q_size,
+                                               bias=self.qkv_bias,
+                                               linear_method=linear_method)
+            self.k_proj = ColumnParallelLinear(self.hidden_size,
+                                               self.kv_size,
+                                               bias=self.qkv_bias,
+                                               linear_method=linear_method)
+            self.v_proj = ColumnParallelLinear(self.hidden_size,
+                                               self.kv_size,
+                                               bias=self.qkv_bias,
+                                               linear_method=linear_method)
+        else:
+            self.merge_weight = True
+            self.qkv_proj = QKVParallelLinear(
+                self.hidden_size,
+                self.head_dim,
+                self.total_num_heads,
+                self.total_num_key_value_heads,
+                self.qkv_bias,
+                linear_method=linear_method,
+            )
+        self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
+                                        self.hidden_size,
+                                        bias=False,
+                                        linear_method=linear_method)
+        self.rotary_ndims = int(self.head_dim * self.config.rope_pct)
+        is_neox_style = True if linear_method is None or linear_method.quant_config.rope_style(
+        ) is None else linear_method.quant_config.rope_style()
+        self.rotary_emb = get_rope(
+            self.head_dim,
+            rotary_dim=self.rotary_ndims,
+            max_position=self.config.max_position_embeddings,
+            base=self.config.rope_theta,
+            is_neox_style=is_neox_style,
+        )
+        self.attn = PagedAttention(self.num_heads,
+                                   self.head_dim,
+                                   self.scaling,
+                                   num_kv_heads=self.num_key_value_heads)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        if self.merge_weight:
+            qkv, _ = self.qkv_proj(hidden_states)
+            q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
+                                dim=-1)
+        else:
+            q, _ = self.q_proj(hidden_states)
+            k, _ = self.k_proj(hidden_states)
+            v, _ = self.v_proj(hidden_states)
+        q, k = self.rotary_emb(positions, q, k)
+        k_cache, v_cache = kv_cache
+        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
+        output, _ = self.o_proj(attn_output)
+        return output
+
+
+class StablelmDecoderLayer(nn.Module):
+
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ) -> None:
+        super().__init__()
+        self.self_attn = StablelmAttention(config)
+        self.mlp = StablelmMLP(config, linear_method)
+        self.input_layernorm = nn.LayerNorm(config.hidden_size,
+                                            eps=config.norm_eps)
+        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
+                                                     eps=config.norm_eps)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        # Self Attention
+        residual = hidden_states
+        hidden_states = self.input_layernorm(hidden_states)
+        hidden_states = self.self_attn(
+            positions=positions,
+            hidden_states=hidden_states,
+            kv_cache=kv_cache,
+            input_metadata=input_metadata,
+        )
+        hidden_states = residual + hidden_states
+
+        # Fully Connected
+        residual = hidden_states
+        hidden_states = self.post_attention_layernorm(hidden_states)
+        hidden_states = self.mlp(hidden_states)
+        hidden_states = residual + hidden_states
+
+        return hidden_states, residual
+
+
+class StableLMEpochModel(nn.Module):
+
+    def __init__(self,
+                 config: PretrainedConfig,
+                 linear_method: Optional[LinearMethodBase] = None) -> None:
+        super().__init__()
+        # self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
+        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
+                                                   config.hidden_size,
+                                                   linear_method=linear_method)
+        self.layers = nn.ModuleList([
+            StablelmDecoderLayer(config, linear_method)
+            for _ in range(config.num_hidden_layers)
+        ])
+        self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.embed_tokens(input_ids)
+        for i in range(len(self.layers)):
+            layer = self.layers[i]
+            # pylint: disable=unused-variable
+            hidden_states, residual = layer(
+                positions,
+                hidden_states,
+                kv_caches[i],
+                input_metadata,
+            )
+        hidden_states = self.norm(hidden_states)
+        return hidden_states
+
+
+class StablelmForCausalLM(nn.Module):
+
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ) -> None:
+        super().__init__()
+        self.config = config
+        self.linear_method = linear_method
+        self.model = StableLMEpochModel(config, linear_method)
+        self.lm_head = ParallelLMHead(config.vocab_size,
+                                      config.hidden_size,
+                                      linear_method=linear_method)
+        self.sampler = Sampler(config.vocab_size)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.model(input_ids, positions, kv_caches,
+                                   input_metadata)
+        return hidden_states
+
+    def sample(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(self.lm_head(hidden_states),
+                                   sampling_metadata)
+        return next_tokens
+
+    def load_weights(self,
+                     model_name_or_path: str,
+                     cache_dir: Optional[str] = None,
+                     load_format: str = "auto",
+                     revision: Optional[str] = None):
+        stacked_params_mapping = [
+            # (param_name, shard_name, shard_id)
+            ("qkv_proj", "q_proj", "q"),
+            ("qkv_proj", "k_proj", "k"),
+            ("qkv_proj", "v_proj", "v"),
+            ("gate_up_proj", "gate_proj", 0),
+            ("gate_up_proj", "up_proj", 1),
+        ]
+        if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
+        ):
+            stacked_params_mapping = []
+        params_dict = dict(self.named_parameters())
+        for name, loaded_weight in hf_model_weights_iterator(
+                model_name_or_path, cache_dir, load_format, revision,
+                self.config):
+            if "rotary_emb.inv_freq" in name:
+                continue
+            if ("rotary_emb.cos_cached" in name
+                    or "rotary_emb.sin_cached" in name):
+                # Models trained using ColossalAI may include these tensors in
+                # the checkpoint. Skip them.
+                continue
+            for (param_name, weight_name, shard_id) in stacked_params_mapping:
+                if weight_name not in name:
+                    continue
+                name = name.replace(weight_name, param_name)
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+                param = params_dict[name]
+                weight_loader = param.weight_loader
+                weight_loader(param, loaded_weight, shard_id)
+                break
+            else:
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+                param = params_dict[name]
+                weight_loader = getattr(param, "weight_loader",
+                                        default_weight_loader)
+                weight_loader(param, loaded_weight)

+ 37 - 24
aphrodite/transformers_utils/config.py

@@ -3,37 +3,44 @@ from typing import Optional
 import gguf
 import gguf
 from transformers import AutoConfig, PretrainedConfig
 from transformers import AutoConfig, PretrainedConfig
 from transformers.models.auto.configuration_auto import CONFIG_MAPPING
 from transformers.models.auto.configuration_auto import CONFIG_MAPPING
-from aphrodite.transformers_utils.configs import YiConfig, QWenConfig
+
+from aphrodite.transformers_utils.configs import (BaiChuanConfig,
+                                                  ChatGLMConfig, MPTConfig,
+                                                  QWenConfig, RWConfig)
 
 
 _CONFIG_REGISTRY = {
 _CONFIG_REGISTRY = {
+    "baichuan": BaiChuanConfig,
+    "chatglm": ChatGLMConfig,
+    "mpt": MPTConfig,
     "qwen": QWenConfig,
     "qwen": QWenConfig,
-    "yi": YiConfig,
+    "RefinedWeb": RWConfig,  # For tiiuae/falcon-40b(-instruct)
+    "RefinedWebModel": RWConfig,  # For tiiuae/falcon-7b(-instruct)
 }
 }
 
 
 
 
 def extract_gguf_config(checkpoint):
 def extract_gguf_config(checkpoint):
     result = gguf.GGUFReader(checkpoint)
     result = gguf.GGUFReader(checkpoint)
-    architecture = result.fields["general.architecture"]
+    architecture = result.fields['general.architecture']
     architecture = str(bytes(architecture.parts[architecture.data[0]]),
     architecture = str(bytes(architecture.parts[architecture.data[0]]),
-                       encoding="utf-8")
+                       encoding='utf-8')
     # Only support llama so far
     # Only support llama so far
     if architecture != "llama":
     if architecture != "llama":
         raise RuntimeError(f"Unsupported architecture {architecture}")
         raise RuntimeError(f"Unsupported architecture {architecture}")
 
 
     # write config
     # write config
-    vocab_size = len(result.fields["tokenizer.ggml.token_type"].data)
-    context_length = int(result.fields["llama.context_length"].parts[-1])
-    n_layer = int(result.fields["llama.block_count"].parts[-1])
-    n_head = int(result.fields["llama.attention.head_count"].parts[-1])
+    vocab_size = len(result.fields['tokenizer.ggml.token_type'].data)
+    context_length = int(result.fields['llama.context_length'].parts[-1])
+    n_layer = int(result.fields['llama.block_count'].parts[-1])
+    n_head = int(result.fields['llama.attention.head_count'].parts[-1])
     n_local_heads = int(
     n_local_heads = int(
-        result.fields["llama.attention.head_count_kv"].parts[-1])
+        result.fields['llama.attention.head_count_kv'].parts[-1])
     intermediate_size = int(
     intermediate_size = int(
-        result.fields["llama.feed_forward_length"].parts[-1])
+        result.fields['llama.feed_forward_length'].parts[-1])
     norm_eps = float(
     norm_eps = float(
-        result.fields["llama.attention.layer_norm_rms_epsilon"].parts[-1])
-    dim = int(result.fields["llama.embedding_length"].parts[-1])
+        result.fields['llama.attention.layer_norm_rms_epsilon'].parts[-1])
+    dim = int(result.fields['llama.embedding_length'].parts[-1])
     arch = "MixtralForCausalLM"
     arch = "MixtralForCausalLM"
-    if "llama.expert_count" in result.fields:
+    if 'llama.expert_count' in result.fields:
         arch = "MixtralForCausalLM"
         arch = "MixtralForCausalLM"
         name = "mixtral"
         name = "mixtral"
     else:
     else:
@@ -55,14 +62,14 @@ def extract_gguf_config(checkpoint):
         "torch_dtype": "float16",
         "torch_dtype": "float16",
         "vocab_size": vocab_size
         "vocab_size": vocab_size
     }
     }
-    if "llama.rope.freq_base" in result.fields:
-        model_config["rope_theta"] = float(
-            result.fields["llama.rope.freq_base"].parts[-1])
-    if "llama.expert_count" in result.fields:
-        model_config["num_local_experts"] = int(
-            result.fields["llama.expert_count"].parts[-1])
-        model_config["num_experts_per_tok"] = int(
-            result.fields["llama.expert_used_count"].parts[-1])
+    if 'llama.rope.freq_base' in result.fields:
+        model_config['rope_theta'] = float(
+            result.fields['llama.rope.freq_base'].parts[-1])
+    if 'llama.expert_count' in result.fields:
+        model_config['num_local_experts'] = int(
+            result.fields['llama.expert_count'].parts[-1])
+        model_config['num_experts_per_tok'] = int(
+            result.fields['llama.expert_used_count'].parts[-1])
     if name in _CONFIG_REGISTRY:
     if name in _CONFIG_REGISTRY:
         config_class = _CONFIG_REGISTRY[name]
         config_class = _CONFIG_REGISTRY[name]
     else:
     else:
@@ -73,12 +80,16 @@ def extract_gguf_config(checkpoint):
 
 
 def get_config(model: str,
 def get_config(model: str,
                trust_remote_code: bool,
                trust_remote_code: bool,
-               revision: Optional[str] = None) -> PretrainedConfig:
+               revision: Optional[str] = None,
+               code_revision: Optional[str] = None) -> PretrainedConfig:
     if model.endswith("gguf"):
     if model.endswith("gguf"):
         return extract_gguf_config(model)
         return extract_gguf_config(model)
     try:
     try:
         config = AutoConfig.from_pretrained(
         config = AutoConfig.from_pretrained(
-            model, trust_remote_code=trust_remote_code, revision=revision)
+            model,
+            trust_remote_code=trust_remote_code,
+            revision=revision,
+            code_revision=code_revision)
     except ValueError as e:
     except ValueError as e:
         if (not trust_remote_code and
         if (not trust_remote_code and
                 "requires you to execute the configuration file" in str(e)):
                 "requires you to execute the configuration file" in str(e)):
@@ -92,5 +103,7 @@ def get_config(model: str,
             raise e
             raise e
     if config.model_type in _CONFIG_REGISTRY:
     if config.model_type in _CONFIG_REGISTRY:
         config_class = _CONFIG_REGISTRY[config.model_type]
         config_class = _CONFIG_REGISTRY[config.model_type]
-        config = config_class.from_pretrained(model, revision=revision)
+        config = config_class.from_pretrained(model,
+                                              revision=revision,
+                                              code_revision=code_revision)
     return config
     return config

+ 13 - 2
aphrodite/transformers_utils/configs/__init__.py

@@ -1,7 +1,18 @@
+from aphrodite.transformers_utils.configs.baichuan import BaiChuanConfig
+from aphrodite.transformers_utils.configs.chatglm import ChatGLMConfig
+from aphrodite.transformers_utils.configs.mpt import MPTConfig
+from aphrodite.transformers_utils.configs.olmo import OLMoConfig
 from aphrodite.transformers_utils.configs.qwen import QWenConfig
 from aphrodite.transformers_utils.configs.qwen import QWenConfig
-from aphrodite.transformers_utils.configs.yi import YiConfig
+# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
+# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
+# `FalconConfig` class from the official HuggingFace transformers library.
+from aphrodite.transformers_utils.configs.falcon import RWConfig
 
 
 __all__ = [
 __all__ = [
+    "BaiChuanConfig",
+    "ChatGLMConfig",
+    "MPTConfig",
+    "OLMoConfig",
     "QWenConfig",
     "QWenConfig",
-    "YiConfig",
+    "RWConfig",
 ]
 ]

+ 62 - 0
aphrodite/transformers_utils/configs/baichuan.py

@@ -0,0 +1,62 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from transformers.configuration_utils import PretrainedConfig
+
+
+class BaiChuanConfig(PretrainedConfig):
+    model_type = "baichuan"
+    keys_to_ignore_at_inference = ["past_key_values"]
+
+    def __init__(
+        self,
+        vocab_size=64000,
+        hidden_size=4096,
+        intermediate_size=11008,
+        num_hidden_layers=32,
+        num_attention_heads=32,
+        hidden_act="silu",
+        max_position_embeddings=4096,
+        initializer_range=0.02,
+        rms_norm_eps=1e-6,
+        use_cache=True,
+        pad_token_id=0,
+        bos_token_id=1,
+        eos_token_id=2,
+        tie_word_embeddings=False,
+        **kwargs,
+    ):
+        self.vocab_size = vocab_size
+        self.max_position_embeddings = max_position_embeddings
+        self.hidden_size = hidden_size
+        self.intermediate_size = intermediate_size
+        self.num_hidden_layers = num_hidden_layers
+        self.num_attention_heads = num_attention_heads
+        self.hidden_act = hidden_act
+        self.initializer_range = initializer_range
+        self.rms_norm_eps = rms_norm_eps
+        self.use_cache = use_cache
+        super().__init__(
+            pad_token_id=pad_token_id,
+            bos_token_id=bos_token_id,
+            eos_token_id=eos_token_id,
+            tie_word_embeddings=tie_word_embeddings,
+            **kwargs,
+        )

+ 68 - 0
aphrodite/transformers_utils/configs/chatglm.py

@@ -0,0 +1,68 @@
+# coding=utf-8
+# Adapted from
+# https://github.com/THUDM/ChatGLM2-6B
+from transformers import PretrainedConfig
+
+
+class ChatGLMConfig(PretrainedConfig):
+    model_type = "chatglm"
+    attribute_map = {
+        "num_hidden_layers": "num_layers",
+        "n_head_kv": "multi_query_group_num",
+    }
+
+    def __init__(self,
+                 num_layers=28,
+                 padded_vocab_size=65024,
+                 hidden_size=4096,
+                 ffn_hidden_size=13696,
+                 kv_channels=128,
+                 num_attention_heads=32,
+                 seq_length=2048,
+                 hidden_dropout=0.0,
+                 attention_dropout=0.0,
+                 layernorm_epsilon=1e-5,
+                 rmsnorm=True,
+                 apply_residual_connection_post_layernorm=False,
+                 post_layer_norm=True,
+                 add_bias_linear=False,
+                 add_qkv_bias=False,
+                 interleaved_qkv=False,
+                 bias_dropout_fusion=True,
+                 multi_query_attention=False,
+                 multi_query_group_num=1,
+                 apply_query_key_layer_scaling=True,
+                 attention_softmax_in_fp32=True,
+                 fp32_residual_connection=False,
+                 quantization_bit=0,
+                 pre_seq_len=None,
+                 prefix_projection=False,
+                 **kwargs):
+        self.num_layers = num_layers
+        self.vocab_size = padded_vocab_size
+        self.padded_vocab_size = padded_vocab_size
+        self.hidden_size = hidden_size
+        self.ffn_hidden_size = ffn_hidden_size
+        self.kv_channels = kv_channels
+        self.num_attention_heads = num_attention_heads
+        self.seq_length = seq_length
+        self.hidden_dropout = hidden_dropout
+        self.attention_dropout = attention_dropout
+        self.layernorm_epsilon = layernorm_epsilon
+        self.rmsnorm = rmsnorm
+        self.apply_residual_connection_post_layernorm = (
+            apply_residual_connection_post_layernorm)
+        self.post_layer_norm = post_layer_norm
+        self.add_bias_linear = add_bias_linear
+        self.add_qkv_bias = add_qkv_bias
+        self.bias_dropout_fusion = bias_dropout_fusion
+        self.multi_query_attention = multi_query_attention
+        self.multi_query_group_num = multi_query_group_num
+        self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
+        self.attention_softmax_in_fp32 = attention_softmax_in_fp32
+        self.fp32_residual_connection = fp32_residual_connection
+        self.quantization_bit = quantization_bit
+        self.pre_seq_len = pre_seq_len
+        self.prefix_projection = prefix_projection
+        self.interleaved_qkv = interleaved_qkv
+        super().__init__(**kwargs)

+ 88 - 0
aphrodite/transformers_utils/configs/falcon.py

@@ -0,0 +1,88 @@
+# Adapted from
+# https://huggingface.co/tiiuae/falcon-7b/blob/main/configuration_RW.py
+# Copyright 2023 The PygmalionAI team.
+# Copyright 2023 The vLLM team.
+# Copyright 2022 the Big Science Workshop and HuggingFace Inc. team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Falcon configuration"""
+from transformers.configuration_utils import PretrainedConfig
+
+
+class RWConfig(PretrainedConfig):
+    model_type = "falcon"
+    keys_to_ignore_at_inference = ["past_key_values"]
+    attribute_map = {
+        "num_hidden_layers": "n_layer",
+        "num_attention_heads": "n_head",
+        "num_kv_heads": "n_head_kv",
+    }
+
+    def __init__(
+        self,
+        vocab_size=250880,
+        hidden_size=64,
+        n_layer=2,
+        n_head=8,
+        layer_norm_epsilon=1e-5,
+        initializer_range=0.02,
+        use_cache=True,
+        bos_token_id=1,
+        eos_token_id=2,
+        hidden_dropout=0.0,
+        attention_dropout=0.0,
+        multi_query=True,
+        n_head_kv=None,
+        alibi=False,
+        bias=False,
+        parallel_attn=False,
+        new_decoder_architecture=False,
+        **kwargs,
+    ) -> None:
+        self.vocab_size = vocab_size
+        # Backward compatibility with n_embed kwarg
+        n_embed = kwargs.pop("n_embed", None)
+        self.hidden_size = hidden_size if n_embed is None else n_embed
+        self.n_layer = n_layer
+        self.n_head = n_head
+        self.layer_norm_epsilon = layer_norm_epsilon
+        self.initializer_range = initializer_range
+        self.use_cache = use_cache
+        self.hidden_dropout = hidden_dropout
+        self.attention_dropout = attention_dropout
+
+        self.bos_token_id = bos_token_id
+        self.eos_token_id = eos_token_id
+        self.multi_query = multi_query
+        self.n_head_kv = 1 if n_head_kv is None else n_head_kv
+        self.alibi = alibi
+        self.bias = bias
+        self.parallel_attn = parallel_attn
+        self.new_decoder_architecture = new_decoder_architecture
+
+        if self.hidden_size == 8192:
+            # Hack for falcon-40b
+            self.new_decoder_architecture = True
+
+        super().__init__(bos_token_id=bos_token_id,
+                         eos_token_id=eos_token_id,
+                         **kwargs)
+
+    @property
+    def head_dim(self):
+        return self.hidden_size // self.n_head
+
+    @property
+    def rotary(self):
+        return not self.alibi

+ 233 - 0
aphrodite/transformers_utils/configs/mpt.py

@@ -0,0 +1,233 @@
+# coding=utf-8
+# Copied from
+# https://huggingface.co/mosaicml/mpt-7b/blob/main/configuration_mpt.py
+"""A HuggingFace-style model configuration."""
+import warnings
+from typing import Any, Dict, Optional, Union
+
+from transformers import PretrainedConfig
+
+attn_config_defaults: Dict = {
+    'attn_type': 'multihead_attention',
+    'attn_pdrop': 0.0,
+    'attn_impl': 'triton',
+    'qk_ln': False,
+    'clip_qkv': None,
+    'softmax_scale': None,
+    'prefix_lm': False,
+    'attn_uses_sequence_id': False,
+    'alibi': False,
+    'alibi_bias_max': 8
+}
+ffn_config_defaults: Dict = {'ffn_type': 'mptmlp'}
+init_config_defaults: Dict = {
+    'name': 'kaiming_normal_',
+    'fan_mode': 'fan_in',
+    'init_nonlinearity': 'relu',
+    'init_div_is_residual': True,
+    'emb_init_std': None,
+    'emb_init_uniform_lim': None,
+    'init_std': None,
+    'init_gain': 0.0
+}
+
+
+class MPTConfig(PretrainedConfig):
+    model_type = 'mpt'
+    attribute_map = {
+        'num_attention_heads': 'n_heads',
+        'hidden_size': 'd_model',
+        'num_hidden_layers': 'n_layers',
+    }
+
+    # pylint: disable=dangerous-default-value
+    def __init__(self,
+                 d_model: int = 2048,
+                 n_heads: int = 16,
+                 n_layers: int = 24,
+                 expansion_ratio: int = 4,
+                 max_seq_len: int = 2048,
+                 vocab_size: int = 50368,
+                 resid_pdrop: float = 0.0,
+                 emb_pdrop: float = 0.0,
+                 learned_pos_emb: bool = True,
+                 attn_config: Dict = attn_config_defaults,
+                 ffn_config: Dict = ffn_config_defaults,
+                 init_device: str = 'cpu',
+                 logit_scale: Optional[Union[float, str]] = None,
+                 no_bias: bool = False,
+                 embedding_fraction: float = 1.0,
+                 norm_type: str = 'low_precision_layernorm',
+                 use_cache: bool = False,
+                 init_config: Dict = init_config_defaults,
+                 fc_type: str = 'torch',
+                 verbose: Optional[int] = None,
+                 **kwargs: Any):
+        """The MPT configuration class.
+        Args:
+            d_model (int): The size of the embedding dimension of the model.
+            n_heads (int): The number of attention heads.
+            n_layers (int): The number of layers in the model.
+            expansion_ratio (int): The ratio of the up/down scale in the ffn.
+            max_seq_len (int): The maximum sequence length of the model.
+            vocab_size (int): The size of the vocabulary.
+            resid_pdrop (float): The dropout probability applied to the attention output before combining with residual.
+            emb_pdrop (float): The dropout probability for the embedding layer.
+            learned_pos_emb (bool): Whether to use learned positional embeddings
+            attn_config (Dict): A dictionary used to configure the model's attention module:
+                attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention, grouped_query_attention
+                attn_pdrop (float): The dropout probability for the attention layers.
+                attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'.
+                qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
+                clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to
+                    this value.
+                softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,
+                    use the default scale of ``1/sqrt(d_keys)``.
+                prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an
+                    extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix
+                    can attend to one another bi-directionally. Tokens outside the prefix use causal attention.
+                attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id.
+                    When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates
+                    which sub-sequence each token belongs to.
+                    Defaults to ``False`` meaning any provided `sequence_id` will be ignored.
+                alibi (bool): Whether to use the alibi bias instead of position embeddings.
+                alibi_bias_max (int): The maximum value of the alibi bias.
+                kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads.
+            ffn_config (Dict): A dictionary used to configure the model's ffn module:
+                ffn_type (str): type of ffn to use. Options: mptmlp, te_ln_mlp
+            init_device (str): The device to use for parameter initialization.
+            logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.
+            no_bias (bool): Whether to use bias in all layers.
+            verbose (int): The verbosity level. 0 is silent.
+            embedding_fraction (float): The fraction to scale the gradients of the embedding layer by.
+            norm_type (str): choose type of norm to use
+            use_cache (bool): Whether or not the model should return the last key/values attentions
+            init_config (Dict): A dictionary used to configure the model initialization:
+                init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_',
+                    'kaiming_uniform_', 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or
+                    'xavier_normal_'. These mimic the parameter initialization methods in PyTorch.
+                init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True.
+                emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer.
+                emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution
+                    used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``.
+                init_std (float): The standard deviation of the normal distribution used to initialize the model,
+                    if using the baseline_ parameter initialization scheme.
+                init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes.
+                fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes.
+                init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes.
+                ---
+                See llmfoundry.models.utils.param_init_fns.py for info on other param init config options
+            fc_type (str): choose fc layer implementation. Options: torch and te. te layers support fp8 when using H100 GPUs.
+        """
+        self.d_model = d_model
+        self.n_heads = n_heads
+        self.n_layers = n_layers
+        self.expansion_ratio = expansion_ratio
+        self.max_seq_len = max_seq_len
+        self.vocab_size = vocab_size
+        self.resid_pdrop = resid_pdrop
+        self.emb_pdrop = emb_pdrop
+        self.learned_pos_emb = learned_pos_emb
+        self.attn_config = attn_config
+        self.ffn_config = ffn_config
+        self.init_device = init_device
+        self.logit_scale = logit_scale
+        self.no_bias = no_bias
+        self.embedding_fraction = embedding_fraction
+        self.norm_type = norm_type
+        self.use_cache = use_cache
+        self.init_config = init_config
+        self.fc_type = fc_type
+        if verbose is not None:
+            warnings.warn(DeprecationWarning(
+                'verbose argument for MPTConfig is now ignored and will be removed. Use python_log_level instead.'
+            ),
+                          stacklevel=2)
+        if 'name' in kwargs:
+            del kwargs['name']
+        if 'loss_fn' in kwargs:
+            del kwargs['loss_fn']
+        if self.attn_config.get('alibi', False):
+            self.learned_pos_emb = False
+            warnings.warn(
+                f'alibi is turned on, setting `learned_pos_emb` to {self.learned_pos_emb}`',
+                stacklevel=2)
+        super().__init__(**kwargs)
+        self._validate_config()
+
+    def _set_config_defaults(
+            self, config: Dict[str, Any],
+            config_defaults: Dict[str, Any]) -> Dict[str, Any]:
+        for (k, v) in config_defaults.items():
+            if k not in config:
+                config[k] = v
+        return config
+
+    def _validate_config(self) -> None:
+        self.attn_config = self._set_config_defaults(self.attn_config,
+                                                     attn_config_defaults)
+        self.ffn_config = self._set_config_defaults(self.ffn_config,
+                                                    ffn_config_defaults)
+        self.init_config = self._set_config_defaults(self.init_config,
+                                                     init_config_defaults)
+        if self.d_model % self.n_heads != 0:
+            raise ValueError('d_model must be divisible by n_heads')
+        if any((
+                prob < 0 or prob > 1 for prob in
+            [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop]
+        )):
+            raise ValueError(
+                "self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1"  # pylint: disable=line-too-long
+            )
+        if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']:
+            raise ValueError(
+                f"Unknown attn_impl={self.attn_config['attn_impl']}")
+        if self.attn_config['prefix_lm'] and self.attn_config[
+                'attn_impl'] not in ['torch', 'triton']:
+            raise NotImplementedError(
+                'prefix_lm only implemented with torch and triton attention.')
+        if self.attn_config['alibi'] and self.attn_config['attn_impl'] not in [
+                'torch', 'triton'
+        ]:
+            raise NotImplementedError(
+                'alibi only implemented with torch and triton attention.')
+        if self.attn_config['attn_uses_sequence_id'] and self.attn_config[
+                'attn_impl'] not in ['torch', 'triton']:
+            raise NotImplementedError(
+                'attn_uses_sequence_id only implemented with torch and triton attention.'  # pylint: disable=line-too-long
+            )
+        if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
+            raise ValueError(
+                'model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!'  # pylint: disable=line-too-long
+            )
+        if isinstance(self.logit_scale,
+                      str) and self.logit_scale != 'inv_sqrt_d_model':
+            raise ValueError(
+                f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."  # pylint: disable=line-too-long
+            )
+        if self.init_config.get('name', None) is None:
+            raise ValueError(
+                f"self.init_config={self.init_config!r} 'name' needs to be set."
+            )
+        if not self.learned_pos_emb and (not self.attn_config['alibi']):
+            warnings.warn(
+                'Positional information not being provided to the model.',
+                stacklevel=2)
+        if self.fc_type == 'te' or self.ffn_config['ffn_type'] == 'te_ln_mlp':
+            try:
+                # pylint: disable=import-outside-toplevel
+                import transformer_engine.pytorch as te
+                del te
+            except Exception as exc:
+                raise ImportError(
+                    # pylint: disable=line-too-long
+                    'TransformerEngine import fail. `fc_type: te` requires TransformerEngine be installed. '
+                    +
+                    'The required version of transformer_engine also requires FlashAttention v1.0.6 is installed:\n'
+                    + 'pip install flash-attn==1.0.6 --no-build-isolation \n' +
+                    'pip install git+https://github.com/NVIDIA/TransformerEngine.git@144e4888b2cdd60bd52e706d5b7a79cb9c1a7156'
+                ) from exc
+        if self.ffn_config['ffn_type'] == 'mptmlp':
+            self.ffn_config['fc_type'] = self.fc_type
+        elif self.ffn_config['ffn_type'] == 'te_ln_mlp':
+            self.ffn_config['bias'] = not self.no_bias

+ 72 - 0
aphrodite/transformers_utils/configs/olmo.py

@@ -0,0 +1,72 @@
+# coding=utf-8
+# adapted from https://github.com/allenai/OLMo/blob/v0.2.4/hf_olmo/configuration_olmo.py
+"""OLMo configuration"""
+from transformers import PretrainedConfig
+
+
+class OLMoConfig(PretrainedConfig):
+    model_type = 'olmo'
+    attribute_map = {
+        'num_attention_heads': 'n_heads',
+        'hidden_size': 'd_model',
+        'num_hidden_layers': 'n_layers',
+    }
+
+    # Note that the defaults for these attributes are equivalent to the base GPT2 model.
+    def __init__(
+        self,
+        d_model=768,
+        n_heads=12,
+        n_layers=12,
+        mlp_ratio=4,
+        mlp_hidden_size=None,
+        activation_type="swiglu",
+        block_type="sequential",
+        block_group_size=1,
+        alibi=False,
+        alibi_bias_max=8.0,
+        rope=False,
+        rope_full_precision=True,
+        multi_query_attention=False,
+        attention_layer_norm=False,
+        layer_norm_type="default",
+        layer_norm_with_affine=True,
+        attention_layer_norm_with_affine=True,
+        max_sequence_length=1024,
+        include_bias=True,
+        bias_for_layer_norm=None,
+        scale_logits=False,
+        vocab_size=50257,
+        embedding_size=50304,
+        weight_tying=True,
+        eos_token_id=50256,
+        pad_token_id=50256,
+        **kwargs,
+    ):
+        self.d_model = d_model
+        self.n_heads = n_heads
+        self.n_layers = n_layers
+        self.mlp_ratio = mlp_ratio
+        self.mlp_hidden_size = mlp_hidden_size
+        self.activation_type = activation_type
+        self.block_type = block_type
+        self.block_group_size = block_group_size
+        self.alibi = alibi
+        self.alibi_bias_max = alibi_bias_max
+        self.rope = rope
+        self.rope_full_precision = rope_full_precision
+        self.multi_query_attention = multi_query_attention
+        self.attention_layer_norm = attention_layer_norm
+        self.layer_norm_type = layer_norm_type
+        self.layer_norm_with_affine = layer_norm_with_affine
+        self.attention_layer_norm_with_affine = attention_layer_norm_with_affine
+        self.max_sequence_length = max_sequence_length
+        self.include_bias = include_bias
+        self.bias_for_layer_norm = bias_for_layer_norm
+        self.scale_logits = scale_logits
+        self.vocab_size = vocab_size
+        self.embedding_size = embedding_size
+        self.weight_tying = weight_tying
+        self.eos_token_id = eos_token_id
+        self.pad_token_id = pad_token_id
+        super().__init__(**kwargs)

+ 49 - 52
aphrodite/transformers_utils/tokenizer.py

@@ -1,6 +1,6 @@
 import os
 import os
 import tempfile
 import tempfile
-from typing import List, Tuple, Union, Optional
+from typing import List, Optional, Tuple, Union
 
 
 import gguf
 import gguf
 from transformers import (AutoTokenizer, PreTrainedTokenizer,
 from transformers import (AutoTokenizer, PreTrainedTokenizer,
@@ -10,6 +10,7 @@ from transformers.convert_slow_tokenizer import import_protobuf
 from aphrodite.common.logger import init_logger
 from aphrodite.common.logger import init_logger
 from aphrodite.lora.request import LoRARequest
 from aphrodite.lora.request import LoRARequest
 from aphrodite.common.utils import make_async, LRUCache
 from aphrodite.common.utils import make_async, LRUCache
+from aphrodite.transformers_utils.tokenizers import BaichuanTokenizer
 
 
 logger = init_logger(__name__)
 logger = init_logger(__name__)
 
 
@@ -19,90 +20,75 @@ def convert_gguf_to_tokenizer(checkpoint):
     # write vocab
     # write vocab
     sentencepiece_model_pb2 = import_protobuf()
     sentencepiece_model_pb2 = import_protobuf()
     vocab = sentencepiece_model_pb2.ModelProto()
     vocab = sentencepiece_model_pb2.ModelProto()
-    vocab_size = len(result.fields["tokenizer.ggml.token_type"].data)
+    vocab_size = len(result.fields['tokenizer.ggml.token_type'].data)
     vocab.trainer_spec.model_type = 2  # BPE
     vocab.trainer_spec.model_type = 2  # BPE
     vocab.trainer_spec.vocab_size = vocab_size
     vocab.trainer_spec.vocab_size = vocab_size
     vocab.trainer_spec.byte_fallback = True
     vocab.trainer_spec.byte_fallback = True
     vocab.normalizer_spec.remove_extra_whitespaces = False
     vocab.normalizer_spec.remove_extra_whitespaces = False
-    tokens = result.fields["tokenizer.ggml.tokens"]
-    scores = result.fields["tokenizer.ggml.scores"]
-    types = result.fields["tokenizer.ggml.token_type"]
+    tokens = result.fields['tokenizer.ggml.tokens']
+    scores = result.fields['tokenizer.ggml.scores']
+    types = result.fields['tokenizer.ggml.token_type']
     for i in range(vocab_size):
     for i in range(vocab_size):
         new_token = vocab.SentencePiece()
         new_token = vocab.SentencePiece()
         new_token.piece = str(bytes(tokens.parts[tokens.data[i]]),
         new_token.piece = str(bytes(tokens.parts[tokens.data[i]]),
-                              encoding="utf-8")
+                              encoding='utf-8')
         new_token.score = scores.parts[scores.data[i]]
         new_token.score = scores.parts[scores.data[i]]
         # llama.cpp tokentype is the same with sentencepiece token type
         # llama.cpp tokentype is the same with sentencepiece token type
         new_token.type = int(types.parts[types.data[i]])
         new_token.type = int(types.parts[types.data[i]])
         vocab.pieces.append(new_token)
         vocab.pieces.append(new_token)
-    with tempfile.NamedTemporaryFile(mode="wb", delete=False) as temp_file:
+    with tempfile.NamedTemporaryFile(mode='wb', delete=False) as temp_file:
         temp_file.write(vocab.SerializeToString())
         temp_file.write(vocab.SerializeToString())
         temp_file_filename = temp_file.name
         temp_file_filename = temp_file.name
     tokenizer_args = {"vocab_file": temp_file_filename}
     tokenizer_args = {"vocab_file": temp_file_filename}
 
 
-    if "tokenizer.ggml.bos_token_id" in result.fields:
+    if 'tokenizer.ggml.bos_token_id' in result.fields:
         tokenizer_args["bos_token"] = vocab.pieces[int(
         tokenizer_args["bos_token"] = vocab.pieces[int(
-            result.fields["tokenizer.ggml.bos_token_id"].parts[-1])].piece
-    if "tokenizer.ggml.eos_token_id" in result.fields:
+            result.fields['tokenizer.ggml.bos_token_id'].parts[-1])].piece
+    if 'tokenizer.ggml.eos_token_id' in result.fields:
         tokenizer_args["eos_token"] = vocab.pieces[int(
         tokenizer_args["eos_token"] = vocab.pieces[int(
-            result.fields["tokenizer.ggml.eos_token_id"].parts[-1])].piece
-    if "tokenizer.ggml.padding_token_id" in result.fields:
+            result.fields['tokenizer.ggml.eos_token_id'].parts[-1])].piece
+    if 'tokenizer.ggml.padding_token_id' in result.fields:
         tokenizer_args["pad_token"] = vocab.pieces[int(
         tokenizer_args["pad_token"] = vocab.pieces[int(
-            result.fields["tokenizer.ggml.padding_token_id"].parts[-1])].piece
-    if "tokenizer.ggml.unknown_token_id" in result.fields:
+            result.fields['tokenizer.ggml.padding_token_id'].parts[-1])].piece
+    if 'tokenizer.ggml.unknown_token_id' in result.fields:
         tokenizer_args["unk_token"] = vocab.pieces[int(
         tokenizer_args["unk_token"] = vocab.pieces[int(
-            result.fields["tokenizer.ggml.unknown_token_id"].parts[-1])].piece
-    if "tokenizer.ggml.add_bos_token" in result.fields:
+            result.fields['tokenizer.ggml.unknown_token_id'].parts[-1])].piece
+    if 'tokenizer.ggml.add_bos_token' in result.fields:
         tokenizer_args["add_bos_token"] = bool(
         tokenizer_args["add_bos_token"] = bool(
-            result.fields["tokenizer.ggml.add_bos_token"].parts[-1])
-    if "tokenizer.ggml.add_eos_token" in result.fields:
+            result.fields['tokenizer.ggml.add_bos_token'].parts[-1])
+    if 'tokenizer.ggml.add_eos_token' in result.fields:
         tokenizer_args["add_eos_token"] = bool(
         tokenizer_args["add_eos_token"] = bool(
-            result.fields["tokenizer.ggml.add_eos_token"].parts[-1])
-    tokenizer = LlamaTokenizer(**tokenizer_args, legacy=False)
+            result.fields['tokenizer.ggml.add_eos_token'].parts[-1])
+    tokenizer = LlamaTokenizer(**tokenizer_args)
     os.unlink(temp_file_filename)
     os.unlink(temp_file_filename)
     return tokenizer
     return tokenizer
 
 
 
 
-# A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file.
-_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
-
-
 def get_tokenizer(
 def get_tokenizer(
     tokenizer_name: str,
     tokenizer_name: str,
     *args,
     *args,
     tokenizer_mode: str = "auto",
     tokenizer_mode: str = "auto",
     trust_remote_code: bool = False,
     trust_remote_code: bool = False,
+    tokenizer_revision: Optional[str] = None,
     **kwargs,
     **kwargs,
 ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
 ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
     """Gets a tokenizer for the given model name via Huggingface."""
     """Gets a tokenizer for the given model name via Huggingface."""
     if tokenizer_name.endswith("gguf"):
     if tokenizer_name.endswith("gguf"):
         return convert_gguf_to_tokenizer(tokenizer_name)
         return convert_gguf_to_tokenizer(tokenizer_name)
+
     if tokenizer_mode == "slow":
     if tokenizer_mode == "slow":
         if kwargs.get("use_fast", False):
         if kwargs.get("use_fast", False):
             raise ValueError(
             raise ValueError(
                 "Cannot use the fast tokenizer in slow tokenizer mode.")
                 "Cannot use the fast tokenizer in slow tokenizer mode.")
         kwargs["use_fast"] = False
         kwargs["use_fast"] = False
 
 
-    if ("llama" in tokenizer_name.lower() and kwargs.get("use_fast", True)
-            and tokenizer_name != _FAST_LLAMA_TOKENIZER):
-        logger.info(
-            "For some LLaMA V1 models, initializing the fast tokenizer may "
-            "take a long time. To reduce the initialization time, consider "
-            f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
-            "tokenizer.")
     try:
     try:
         tokenizer = AutoTokenizer.from_pretrained(
         tokenizer = AutoTokenizer.from_pretrained(
             tokenizer_name,
             tokenizer_name,
             *args,
             *args,
             trust_remote_code=trust_remote_code,
             trust_remote_code=trust_remote_code,
+            tokenizer_revision=tokenizer_revision,
             **kwargs)
             **kwargs)
-    except TypeError as e:
-        # The LLaMA tokenizer causes a protobuf error in some environments.
-        err_msg = (
-            "Failed to load the tokenizer. If you are using a LLaMA V1 model "
-            f"consider using '{_FAST_LLAMA_TOKENIZER}' instead of the "
-            "original tokenizer.")
-        raise RuntimeError(err_msg) from e
     except ValueError as e:
     except ValueError as e:
         # If the error pertains to the tokenizer class not existing or not
         # If the error pertains to the tokenizer class not existing or not
         # currently being imported, suggest using the --trust-remote-code flag.
         # currently being imported, suggest using the --trust-remote-code flag.
@@ -117,6 +103,18 @@ def get_tokenizer(
             raise RuntimeError(err_msg) from e
             raise RuntimeError(err_msg) from e
         else:
         else:
             raise e
             raise e
+    except AttributeError as e:
+        if "BaichuanTokenizer" in str(e):
+            # This is for the error "'BaichuanTokenizer' object has no
+            # attribute 'sp_model'".
+            tokenizer = BaichuanTokenizer.from_pretrained(
+                tokenizer_name,
+                *args,
+                trust_remote_code=trust_remote_code,
+                tokenizer_revision=tokenizer_revision,
+                **kwargs)
+        else:
+            raise e
 
 
     if not isinstance(tokenizer, PreTrainedTokenizerFast):
     if not isinstance(tokenizer, PreTrainedTokenizerFast):
         logger.warning(
         logger.warning(
@@ -161,21 +159,18 @@ class TokenizerGroup:
         else:
         else:
             self.lora_tokenizers = None
             self.lora_tokenizers = None
 
 
-    def encode(
-        self,
-        prompt: str,
-        request_id: Optional[str] = None,  # pylint: disable=unused-argument
-        lora_request: Optional[LoRARequest] = None
-    ) -> List[int]:
+    def encode(self,
+               prompt: str,
+               request_id: Optional[str] = None,
+               lora_request: Optional[LoRARequest] = None) -> List[int]:
         tokenizer = self.get_lora_tokenizer(lora_request)
         tokenizer = self.get_lora_tokenizer(lora_request)
         return tokenizer.encode(prompt)
         return tokenizer.encode(prompt)
 
 
     async def encode_async(
     async def encode_async(
-        self,
-        prompt: str,
-        request_id: Optional[str] = None,  # pylint: disable=unused-argument
-        lora_request: Optional[LoRARequest] = None
-    ) -> List[int]:
+            self,
+            prompt: str,
+            request_id: Optional[str] = None,
+            lora_request: Optional[LoRARequest] = None) -> List[int]:
         tokenizer = await self.get_lora_tokenizer_async(lora_request)
         tokenizer = await self.get_lora_tokenizer_async(lora_request)
         return tokenizer.encode(prompt)
         return tokenizer.encode(prompt)
 
 
@@ -262,7 +257,7 @@ def detokenize_incrementally(
         # tokenizers (bigger = more conservative).
         # tokenizers (bigger = more conservative).
         # Subtract 1 extra to account for the generated token.
         # Subtract 1 extra to account for the generated token.
         prefix_offset = max(len(output_tokens) - 6, 0)
         prefix_offset = max(len(output_tokens) - 6, 0)
-        # If the first new token is a special token we can't skip 1 extra token
+        # If the first new token is a special token, we can't skip 1 extra token
         if skip_special_tokens and new_token_id in tokenizer.all_special_ids:
         if skip_special_tokens and new_token_id in tokenizer.all_special_ids:
             read_offset = max(len(output_tokens), 0)
             read_offset = max(len(output_tokens), 0)
         else:
         else:
@@ -286,12 +281,14 @@ def detokenize_incrementally(
             tokenizer,
             tokenizer,
             output_tokens[prefix_offset:read_offset],
             output_tokens[prefix_offset:read_offset],
             skip_special_tokens=skip_special_tokens,
             skip_special_tokens=skip_special_tokens,
-            spaces_between_special_tokens=spaces_between_special_tokens)
+            spaces_between_special_tokens=spaces_between_special_tokens,
+        )
         new_text = _convert_tokens_to_string_with_added_encoders(
         new_text = _convert_tokens_to_string_with_added_encoders(
             tokenizer,
             tokenizer,
             output_tokens[prefix_offset:],
             output_tokens[prefix_offset:],
             skip_special_tokens=skip_special_tokens,
             skip_special_tokens=skip_special_tokens,
-            spaces_between_special_tokens=spaces_between_special_tokens)
+            spaces_between_special_tokens=spaces_between_special_tokens,
+        )
 
 
     if len(new_text) > len(prefix_text) and not new_text.endswith("�"):
     if len(new_text) > len(prefix_text) and not new_text.endswith("�"):
         # utf-8 char at the end means it's a potential unfinished byte sequence
         # utf-8 char at the end means it's a potential unfinished byte sequence

+ 5 - 0
aphrodite/transformers_utils/tokenizers/__init__.py

@@ -0,0 +1,5 @@
+from aphrodite.transformers_utils.tokenizers.baichuan import BaichuanTokenizer
+
+__all__ = [
+    "BaichuanTokenizer",
+]

+ 261 - 0
aphrodite/transformers_utils/tokenizers/baichuan.py

@@ -0,0 +1,261 @@
+# yapf: disable
+# Adapted from
+# https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/8f6e343d545c503b91429582231d1d354dac2740/tokenization_baichuan.py
+# Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved.
+
+import os
+from shutil import copyfile
+from typing import Any, Dict, List, Optional, Tuple
+
+import sentencepiece as spm
+from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
+from transformers.utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
+
+PRETRAINED_VOCAB_FILES_MAP = {
+    "vocab_file": {},
+    "tokenizer_file": {},
+}
+PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {}
+
+
+class BaichuanTokenizer(PreTrainedTokenizer):
+    """
+    Construct a Baichuan tokenizer. Based on byte-level Byte-Pair-Encoding.
+
+    Args:
+        vocab_file (`str`):
+            Path to the vocabulary file.
+    """
+
+    vocab_files_names = VOCAB_FILES_NAMES
+    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
+    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
+    model_input_names = ["input_ids", "attention_mask"]
+
+    def __init__(
+        self,
+        vocab_file,
+        unk_token="<unk>",
+        bos_token="<s>",
+        eos_token="</s>",
+        pad_token=None,
+        sp_model_kwargs: Optional[Dict[str, Any]] = None,
+        add_bos_token=True,
+        add_eos_token=False,
+        clean_up_tokenization_spaces=False,
+        **kwargs,
+    ):
+        self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
+        bos_token = (
+            AddedToken(bos_token, lstrip=False, rstrip=False)
+            if isinstance(bos_token, str)
+            else bos_token
+        )
+        eos_token = (
+            AddedToken(eos_token, lstrip=False, rstrip=False)
+            if isinstance(eos_token, str)
+            else eos_token
+        )
+        unk_token = (
+            AddedToken(unk_token, lstrip=False, rstrip=False)
+            if isinstance(unk_token, str)
+            else unk_token
+        )
+        pad_token = (
+            AddedToken(pad_token, lstrip=False, rstrip=False)
+            if isinstance(pad_token, str)
+            else pad_token
+        )
+        self.vocab_file = vocab_file
+        self.add_bos_token = add_bos_token
+        self.add_eos_token = add_eos_token
+        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+        self.sp_model.Load(vocab_file)
+        super().__init__(
+            bos_token=bos_token,
+            eos_token=eos_token,
+            unk_token=unk_token,
+            pad_token=pad_token,
+            add_bos_token=add_bos_token,
+            add_eos_token=add_eos_token,
+            sp_model_kwargs=self.sp_model_kwargs,
+            clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+            **kwargs,
+        )
+
+    def __getstate__(self):
+        state = self.__dict__.copy()
+        state["sp_model"] = None
+        return state
+
+    def __setstate__(self, d):
+        self.__dict__ = d
+        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+        self.sp_model.Load(self.vocab_file)
+
+    @property
+    def vocab_size(self):
+        """Returns vocab size"""
+        return self.sp_model.get_piece_size()
+
+    def get_vocab(self):
+        """Returns vocab as a dict"""
+        vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
+        vocab.update(self.added_tokens_encoder)
+        return vocab
+
+    def _tokenize(self, text):
+        """Returns a tokenized string."""
+        return self.sp_model.encode(text, out_type=str)
+
+    def _convert_token_to_id(self, token):
+        """Converts a token (str) in an id using the vocab."""
+        return self.sp_model.piece_to_id(token)
+
+    def _convert_id_to_token(self, index):
+        """Converts an index (integer) in a token (str) using the vocab."""
+        token = self.sp_model.IdToPiece(index)
+        return token
+
+    def convert_tokens_to_string(self, tokens):
+        """Converts a sequence of tokens (string) in a single string."""
+        current_sub_tokens = []
+        out_string = ""
+        prev_is_special = False
+        for i, token in enumerate(tokens):
+            # make sure that special tokens are not decoded using sentencepiece model
+            if token in self.all_special_tokens:
+                if not prev_is_special and i != 0:
+                    out_string += " "
+                out_string += self.sp_model.decode(current_sub_tokens) + token
+                prev_is_special = True
+                current_sub_tokens = []
+            else:
+                current_sub_tokens.append(token)
+                prev_is_special = False
+        out_string += self.sp_model.decode(current_sub_tokens)
+        return out_string
+
+    def save_vocabulary(
+        self, save_directory, filename_prefix: Optional[str] = None
+    ) -> Tuple[str]:
+        """
+        Save the vocabulary and special tokens file to a directory.
+
+        Args:
+            save_directory (`str`):
+                The directory in which to save the vocabulary.
+
+        Returns:
+            `Tuple(str)`: Paths to the files saved.
+        """
+        if not os.path.isdir(save_directory):
+            logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+            return
+        out_vocab_file = os.path.join(
+            save_directory,
+            (filename_prefix + "-" if filename_prefix else "")
+            + VOCAB_FILES_NAMES["vocab_file"],
+        )
+
+        if os.path.abspath(self.vocab_file) != os.path.abspath(
+            out_vocab_file
+        ) and os.path.isfile(self.vocab_file):
+            copyfile(self.vocab_file, out_vocab_file)
+        elif not os.path.isfile(self.vocab_file):
+            with open(out_vocab_file, "wb") as fi:
+                content_spiece_model = self.sp_model.serialized_model_proto()
+                fi.write(content_spiece_model)
+
+        return (out_vocab_file,)
+
+    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
+        bos_token_id = [self.bos_token_id] if self.add_bos_token else []
+        eos_token_id = [self.eos_token_id] if self.add_eos_token else []
+
+        output = bos_token_id + token_ids_0 + eos_token_id
+
+        if token_ids_1 is not None:
+            output = output + bos_token_id + token_ids_1 + eos_token_id
+
+        return output
+
+    def get_special_tokens_mask(
+        self,
+        token_ids_0: List[int],
+        token_ids_1: Optional[List[int]] = None,
+        already_has_special_tokens: bool = False,
+    ) -> List[int]:
+        """
+        Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+        special tokens using the tokenizer `prepare_for_model` method.
+
+        Args:
+            token_ids_0 (`List[int]`):
+                List of IDs.
+            token_ids_1 (`List[int]`, *optional*):
+                Optional second list of IDs for sequence pairs.
+            already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+                Whether or not the token list is already formatted with special tokens for the model.
+
+        Returns:
+            `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+        """
+        if already_has_special_tokens:
+            return super().get_special_tokens_mask(
+                token_ids_0=token_ids_0,
+                token_ids_1=token_ids_1,
+                already_has_special_tokens=True,
+            )
+
+        bos_token_id = [1] if self.add_bos_token else []
+        eos_token_id = [1] if self.add_eos_token else []
+
+        if token_ids_1 is None:
+            return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
+        return (
+            bos_token_id
+            + ([0] * len(token_ids_0))
+            + eos_token_id
+            + bos_token_id
+            + ([0] * len(token_ids_1))
+            + eos_token_id
+        )
+
+    def create_token_type_ids_from_sequences(
+        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+    ) -> List[int]:
+        """
+        Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
+        sequence pair mask has the following format:
+
+        ```
+        0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
+        | first sequence    | second sequence |
+        ```
+
+        if token_ids_1 is None, only returns the first portion of the mask (0s).
+
+        Args:
+            token_ids_0 (`List[int]`):
+                List of ids.
+            token_ids_1 (`List[int]`, *optional*):
+                Optional second list of IDs for sequence pairs.
+
+        Returns:
+            `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
+        """
+        bos_token_id = [self.bos_token_id] if self.add_bos_token else []
+        eos_token_id = [self.eos_token_id] if self.add_eos_token else []
+
+        output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
+
+        if token_ids_1 is not None:
+            output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
+
+        return output

+ 48 - 25
kernels/activation_kernels.cu

@@ -2,19 +2,16 @@
 #include <torch/extension.h>
 #include <torch/extension.h>
 #include <c10/cuda/CUDAGuard.h>
 #include <c10/cuda/CUDAGuard.h>
 
 
+#include <cmath>
+
 #include "cuda_compat.h"
 #include "cuda_compat.h"
 #include "dispatch_utils.h"
 #include "dispatch_utils.h"
 
 
 namespace aphrodite {
 namespace aphrodite {
 
 
-template<typename T>
-__device__ __forceinline__ T silu(const T& x) {
-  // x * sigmoid(x)
-  return (T) (((float) x) / (1.0f + expf((float) -x)));
-}
-
-template<typename scalar_t>
-__global__ void silu_and_mul_kernel(
+// Activation and gating kernel template.
+template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
+__global__ void act_and_mul_kernel(
   scalar_t* __restrict__ out,               // [..., d]
   scalar_t* __restrict__ out,               // [..., d]
   const scalar_t* __restrict__ input,       // [..., 2, d]
   const scalar_t* __restrict__ input,       // [..., 2, d]
   const int d) {
   const int d) {
@@ -22,32 +19,58 @@ __global__ void silu_and_mul_kernel(
   for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
   for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
     const scalar_t x = APHRODITE_LDG(&input[token_idx * 2 * d + idx]);
     const scalar_t x = APHRODITE_LDG(&input[token_idx * 2 * d + idx]);
     const scalar_t y = APHRODITE_LDG(&input[token_idx * 2 * d + d + idx]);
     const scalar_t y = APHRODITE_LDG(&input[token_idx * 2 * d + d + idx]);
-    out[token_idx * d + idx] = silu(x) * y;
+    out[token_idx * d + idx] = ACT_FN(x) * y;
   }
   }
 }
 }
 
 
+template<typename T>
+__device__ __forceinline__ T silu_kernel(const T& x) {
+  // x * sigmoid(x)
+  return (T) (((float) x) / (1.0f + expf((float) -x)));
+}
+
+template<typename T>
+__device__ __forceinline__ T gelu_kernel(const T& x) {
+  // Equivalent to PyTorch GELU with 'none' approximation.
+  // Refer to:
+  // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L38
+  const float f = (float) x;
+  constexpr float ALPHA = M_SQRT1_2;
+  return (T) (f * 0.5f * (1.0f + ::erf(f * ALPHA)));
+}
+
 } // namespace aphrodite
 } // namespace aphrodite
 
 
+// Launch activation and gating kernel.
+#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL)                                                 \
+  int d = input.size(-1) / 2;                                                                 \
+  int64_t num_tokens = input.numel() / input.size(-1);                                        \
+  dim3 grid(num_tokens);                                                                      \
+  dim3 block(std::min(d, 1024));                                                              \
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));                           \
+  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();                               \
+  APHRODITE_DISPATCH_FLOATING_TYPES(                                                          \
+    input.scalar_type(),                                                                      \
+    "act_and_mul_kernel",                                                                     \
+    [&] {                                                                                     \
+      aphrodite::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>(  \
+        out.data_ptr<scalar_t>(),                                                             \
+        input.data_ptr<scalar_t>(),                                                           \
+        d);                                                                                   \
+    });
+
 void silu_and_mul(
 void silu_and_mul(
   torch::Tensor& out,      // [..., d]
   torch::Tensor& out,      // [..., d]
   torch::Tensor& input)    // [..., 2 * d]
   torch::Tensor& input)    // [..., 2 * d]
 {
 {
-  int64_t num_tokens = input.numel() / input.size(-1);
-  int d = input.size(-1) / 2;
-
-  dim3 grid(num_tokens);
-  dim3 block(std::min(d, 1024));
-  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
-  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
-  APHRODITE_DISPATCH_FLOATING_TYPES(
-    input.scalar_type(),
-    "silu_and_mul_kernel",
-    [&] {
-      aphrodite::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
-        out.data_ptr<scalar_t>(),
-        input.data_ptr<scalar_t>(),
-        d);
-    });
+  LAUNCH_ACTIVATION_GATE_KERNEL(aphrodite::silu_kernel);
+}
+
+void gelu_and_mul(
+  torch::Tensor& out,      // [..., d]
+  torch::Tensor& input)    // [..., 2 * d]
+{
+  LAUNCH_ACTIVATION_GATE_KERNEL(aphrodite::gelu_kernel);
 }
 }
 
 
 namespace aphrodite {
 namespace aphrodite {

+ 4 - 0
kernels/ops.h

@@ -57,6 +57,10 @@ void silu_and_mul(
   torch::Tensor& out,
   torch::Tensor& out,
   torch::Tensor& input);
   torch::Tensor& input);
 
 
+void gelu_and_mul(
+  torch::Tensor& out,
+  torch::Tensor& input);
+
 void gelu_new(
 void gelu_new(
   torch::Tensor& out,
   torch::Tensor& out,
   torch::Tensor& input);
   torch::Tensor& input);

+ 4 - 0
kernels/pybind.cpp

@@ -22,6 +22,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
     "silu_and_mul",
     "silu_and_mul",
     &silu_and_mul,
     &silu_and_mul,
     "Activation function used in SwiGLU.");
     "Activation function used in SwiGLU.");
+  ops.def(
+    "gelu_and_mul",
+    &gelu_and_mul,
+    "Activation function used in GeGLU.");
   ops.def(
   ops.def(
     "gelu_new",
     "gelu_new",
     &gelu_new,
     &gelu_new,