Ver Fonte

feat: refactor modeling logic and support more models (#274)

* feat: refactor llama models and add new models

* add geglu kernels

* add geglu activation layer

* formatting
AlpinDale há 1 ano atrás
pai
commit
e31c6f0b45
38 ficheiros alterados com 5141 adições e 363 exclusões
  1. 1 0
      .pylintrc
  2. 21 0
      aphrodite/modeling/layers/activation.py
  3. 25 0
      aphrodite/modeling/layers/layernorm.py
  4. 36 10
      aphrodite/modeling/models/__init__.py
  5. 412 0
      aphrodite/modeling/models/baichuan.py
  6. 340 0
      aphrodite/modeling/models/bloom.py
  7. 378 0
      aphrodite/modeling/models/chatglm.py
  8. 6 2
      aphrodite/modeling/models/decilm.py
  9. 23 72
      aphrodite/modeling/models/deepseek.py
  10. 446 0
      aphrodite/modeling/models/falcon.py
  11. 65 62
      aphrodite/modeling/models/gemma.py
  12. 286 0
      aphrodite/modeling/models/gpt2.py
  13. 292 0
      aphrodite/modeling/models/gpt_bigcode.py
  14. 8 13
      aphrodite/modeling/models/gpt_j.py
  15. 6 10
      aphrodite/modeling/models/gpt_neox.py
  16. 352 0
      aphrodite/modeling/models/internlm2.py
  17. 45 14
      aphrodite/modeling/models/llama.py
  18. 62 18
      aphrodite/modeling/models/mixtral.py
  19. 7 6
      aphrodite/modeling/models/mixtral_quant.py
  20. 307 0
      aphrodite/modeling/models/mpt.py
  21. 377 0
      aphrodite/modeling/models/olmo.py
  22. 0 1
      aphrodite/modeling/models/phi.py
  23. 315 0
      aphrodite/modeling/models/qwen.py
  24. 35 52
      aphrodite/modeling/models/qwen2.py
  25. 352 0
      aphrodite/modeling/models/stablelm.py
  26. 37 24
      aphrodite/transformers_utils/config.py
  27. 13 2
      aphrodite/transformers_utils/configs/__init__.py
  28. 62 0
      aphrodite/transformers_utils/configs/baichuan.py
  29. 68 0
      aphrodite/transformers_utils/configs/chatglm.py
  30. 88 0
      aphrodite/transformers_utils/configs/falcon.py
  31. 233 0
      aphrodite/transformers_utils/configs/mpt.py
  32. 72 0
      aphrodite/transformers_utils/configs/olmo.py
  33. 49 52
      aphrodite/transformers_utils/tokenizer.py
  34. 5 0
      aphrodite/transformers_utils/tokenizers/__init__.py
  35. 261 0
      aphrodite/transformers_utils/tokenizers/baichuan.py
  36. 48 25
      kernels/activation_kernels.cu
  37. 4 0
      kernels/ops.h
  38. 4 0
      kernels/pybind.cpp

+ 1 - 0
.pylintrc

@@ -61,6 +61,7 @@ disable=abstract-method,
         c-extension-no-member,
         consider-using-enumerate,
         cmp-builtin,
+        inconsistent-quotes,
         cmp-method,
         coerce-builtin,
         coerce-method,

+ 21 - 0
aphrodite/modeling/layers/activation.py

@@ -37,6 +37,27 @@ class SiluAndMul(nn.Module):
         return out
 
 
+class GeluAndMul(nn.Module):
+    """An activation function for GeGLU.
+    The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
+    Shapes:
+        x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
+        return: (batch_size, seq_len, d) or (num_tokens, d)
+    """
+
+    def _forward(self, x: torch.Tensor) -> torch.Tensor:
+        """PyTorch-native implementation equivalent to forward()."""
+        d = x.shape[-1] // 2
+        return F.gelu(x[..., :d]) * x[..., d:]
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        d = x.shape[-1] // 2
+        output_shape = (x.shape[:-1] + (d, ))
+        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
+        ops.gelu_and_mul(out, x)
+        return out
+
+
 class NewGELU(nn.Module):
 
     def _forward(self, x: torch.Tensor) -> torch.Tensor:

+ 25 - 0
aphrodite/modeling/layers/layernorm.py

@@ -7,6 +7,31 @@ import torch.nn as nn
 from aphrodite._C import ops
 
 
+class LayerNorm(nn.LayerNorm):
+
+    def __init__(
+        self,
+        hidden_size: int,
+        eps: float = 1e-6,
+    ) -> None:
+        super().__init__(hidden_size, eps=eps)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        residual: Optional[torch.Tensor] = None,
+    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+        """normalization."""
+        if residual is not None:
+            x = x + residual
+            residual = x
+        x = super().forward(x)
+        if residual is None:
+            return x
+        else:
+            return x, residual
+
+
 class RMSNorm(nn.Module):
     """Root mean square normalization.
 

+ 36 - 10
aphrodite/modeling/models/__init__.py

@@ -8,30 +8,53 @@ from aphrodite.common.utils import is_hip
 
 logger = init_logger(__name__)
 
-# Architecture -> (module, class)
+# Architecture -> (module, class).
 _MODELS = {
+    "AquilaModel": ("llama", "LlamaForCausalLM"),
+    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
+    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),  # baichuan-7b
+    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),  # baichuan-13b
+    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
+    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
+    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
     "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
     "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
+    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
+    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
+    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
+    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
     "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
     "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
+    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
+    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
     "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
+    # For decapoda-research/llama-*
     "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
-    "MistralForCausalLM": ("mistral", "MistralForCausalLM"),
+    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
     "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
     "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
-    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
-    "YiForCausalLM": ("yi", "YiForCausalLM"),
+    # transformers's mpt class has lower case
+    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
+    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
+    "OLMoForCausalLM": ("olmo", "OLMoForCausalLM"),
     "OPTForCausalLM": ("opt", "OPTForCausalLM"),
+    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
+    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
+    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
+    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
+    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
 }
 
-# Models not supported by ROCm
+# Models not supported by ROCm.
 _ROCM_UNSUPPORTED_MODELS = []
 
 # Models partially supported by ROCm.
-# Architecture -> Reason
+# Architecture -> Reason.
 _ROCM_PARTIALLY_SUPPORTED_MODELS = {
+    "Qwen2ForCausalLM":
+    "Sliding window attention is not yet supported in ROCm's flash attention",
     "MistralForCausalLM":
-    "Sliding window attention is not yet supported in ROCM's flash attention.",
+    "Sliding window attention is not yet supported in ROCm's flash attention",
     "MixtralForCausalLM":
     "Sliding window attention is not yet supported in ROCm's flash attention",
 }
@@ -45,8 +68,9 @@ class ModelRegistry:
             return None
         if is_hip():
             if model_arch in _ROCM_UNSUPPORTED_MODELS:
-                raise ValueError(f"Model architecture {model_arch} is not "
-                                 "supported in ROCm for now.")
+                raise ValueError(
+                    f"Model architecture {model_arch} is not supported by "
+                    "ROCm for now.")
             if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
                 logger.warning(
                     f"Model architecture {model_arch} is partially supported "
@@ -62,4 +86,6 @@ class ModelRegistry:
         return list(_MODELS.keys())
 
 
-__all__ = ["ModelRegistry"]
+__all__ = [
+    "ModelRegistry",
+]

+ 412 - 0
aphrodite/modeling/models/baichuan.py

@@ -0,0 +1,412 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only BaiChuan model compatible with HuggingFace weights."""
+import math
+from typing import List, Optional, Tuple
+
+import torch
+from torch import nn
+
+from aphrodite.modeling.metadata import InputMetadata
+from aphrodite.modeling.layers.activation import SiluAndMul
+from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.layernorm import RMSNorm
+from aphrodite.modeling.layers.linear import (LinearMethodBase,
+                                              MergedColumnParallelLinear,
+                                              QKVParallelLinear,
+                                              RowParallelLinear,
+                                              ColumnParallelLinear)
+from aphrodite.modeling.layers.rotary_embedding import get_rope
+from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    VocabParallelEmbedding, ParallelLMHead)
+from aphrodite.modeling.megatron.parallel_state import (
+    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.modeling.hf_downloader import (default_weight_loader,
+                                              hf_model_weights_iterator)
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.transformers_utils.configs.baichuan import BaiChuanConfig
+
+KVCache = Tuple[torch.Tensor, torch.Tensor]
+
+
+def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
+    closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
+    base = torch.tensor(
+        2**(-(2**-(math.log2(closest_power_of_2) - 3))),
+        dtype=torch.float32,
+    )
+    powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
+    slopes = torch.pow(base, powers)
+
+    if closest_power_of_2 != total_num_heads:
+        extra_base = torch.tensor(
+            2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
+            dtype=torch.float32,
+        )
+        num_remaining_heads = min(closest_power_of_2,
+                                  total_num_heads - closest_power_of_2)
+        extra_powers = torch.arange(start=1,
+                                    end=1 + 2 * num_remaining_heads,
+                                    step=2,
+                                    dtype=torch.int32)
+        slopes = torch.cat(
+            [slopes, torch.pow(extra_base, extra_powers)], dim=0)
+    return slopes
+
+
+class BaiChuanMLP(nn.Module):
+
+    def __init__(
+        self,
+        hidden_size: int,
+        intermediate_size: int,
+        hidden_act: str,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        if linear_method is not None and not linear_method.quant_config.merge_weight(
+        ):
+            self.merge_weight = False
+            self.gate_proj = ColumnParallelLinear(hidden_size,
+                                                  intermediate_size,
+                                                  bias=False,
+                                                  linear_method=linear_method)
+            self.up_proj = ColumnParallelLinear(hidden_size,
+                                                intermediate_size,
+                                                bias=False,
+                                                linear_method=linear_method)
+        else:
+            self.merge_weight = True
+            self.gate_up_proj = MergedColumnParallelLinear(
+                hidden_size, [intermediate_size] * 2,
+                bias=False,
+                linear_method=linear_method)
+        self.down_proj = RowParallelLinear(intermediate_size,
+                                           hidden_size,
+                                           bias=False,
+                                           linear_method=linear_method)
+        if hidden_act != "silu":
+            raise ValueError(f"Unsupported activation: {hidden_act}. "
+                             "Only silu is supported for now.")
+        self.act_fn = SiluAndMul()
+
+    def forward(self, x):
+        if self.merge_weight:
+            gate_up, _ = self.gate_up_proj(x)
+        else:
+            up, _ = self.up_proj(x)
+            gate, _ = self.gate_proj(x)
+            gate_up = torch.cat([gate, up], dim=-1)
+        x = self.act_fn(gate_up)
+        x, _ = self.down_proj(x)
+        return x
+
+
+class BaiChuanAttention(nn.Module):
+    """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+    def __init__(
+        self,
+        hidden_size: int,
+        num_heads: int,
+        position_embedding: str,
+        rope_theta: float = 10000,
+        max_position_embeddings: int = 8192,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.hidden_size = hidden_size
+        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
+        )
+        self.total_num_heads = num_heads
+        assert self.total_num_heads % tensor_model_parallel_world_size == 0
+        self.num_heads = (self.total_num_heads //
+                          tensor_model_parallel_world_size)
+        self.head_dim = hidden_size // self.total_num_heads
+        self.postion_embedding = position_embedding
+        self.rope_theta = rope_theta
+        self.max_position_embeddings = max_position_embeddings
+
+        # pylint: disable=invalid-name
+        self.W_pack = QKVParallelLinear(
+            hidden_size,
+            self.head_dim,
+            self.total_num_heads,
+            self.total_num_heads,
+            bias=False,
+            linear_method=linear_method,
+        )
+        self.o_proj = RowParallelLinear(
+            self.total_num_heads * self.head_dim,
+            hidden_size,
+            bias=False,
+            linear_method=linear_method,
+        )
+        # Create the alibi slopes and slice them.
+        if self.postion_embedding == "ALIBI":
+            tp_rank = get_tensor_model_parallel_rank()
+            head_start = tp_rank * self.num_heads
+            head_end = (tp_rank + 1) * self.num_heads
+            alibi_slopes = _get_alibi_slopes(self.total_num_heads)
+            alibi_slopes = alibi_slopes[head_start:head_end].tolist()
+
+            scaling = self.head_dim**-0.5
+            self.attn = PagedAttention(self.num_heads,
+                                       self.head_dim,
+                                       scaling,
+                                       alibi_slopes=alibi_slopes)
+        else:
+            is_neox_style = True if linear_method is None or linear_method.quant_config.rope_style(
+            ) is None else linear_method.quant_config.rope_style()
+            self.rotary_emb = get_rope(
+                self.head_dim,
+                rotary_dim=self.head_dim,
+                max_position=self.max_position_embeddings,
+                base=self.rope_theta,
+                is_neox_style=is_neox_style,
+            )
+            self.scaling = self.head_dim**-0.5
+            self.attn = PagedAttention(self.num_heads, self.head_dim,
+                                       self.scaling)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        qkv, _ = self.W_pack(hidden_states)
+        q, k, v = qkv.chunk(chunks=3, dim=-1)
+        if self.postion_embedding != "ALIBI":
+            q, k = self.rotary_emb(positions, q, k)
+        k_cache, v_cache = kv_cache
+        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
+        output, _ = self.o_proj(attn_output)
+        return output
+
+
+class BaiChuanDecoderLayer(nn.Module):
+
+    def __init__(self,
+                 config: BaiChuanConfig,
+                 position_embedding: str,
+                 linear_method: Optional[LinearMethodBase] = None):
+        super().__init__()
+        self.hidden_size = config.hidden_size
+        rope_theta = getattr(config, "rope_theta", 10000)
+        max_position_embeddings = getattr(config, "max_position_embeddings",
+                                          8192)
+        self.self_attn = BaiChuanAttention(
+            hidden_size=self.hidden_size,
+            num_heads=config.num_attention_heads,
+            position_embedding=position_embedding,
+            rope_theta=rope_theta,
+            max_position_embeddings=max_position_embeddings,
+            linear_method=linear_method,
+        )
+        self.mlp = BaiChuanMLP(
+            hidden_size=self.hidden_size,
+            intermediate_size=config.intermediate_size,
+            hidden_act=config.hidden_act,
+            linear_method=linear_method,
+        )
+        self.input_layernorm = RMSNorm(config.hidden_size,
+                                       eps=config.rms_norm_eps)
+        self.post_attention_layernorm = RMSNorm(config.hidden_size,
+                                                eps=config.rms_norm_eps)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+        residual: Optional[torch.Tensor],
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        # Self Attention
+        if residual is None:
+            residual = hidden_states
+            hidden_states = self.input_layernorm(hidden_states)
+        else:
+            hidden_states, residual = self.input_layernorm(
+                hidden_states, residual)
+        hidden_states = self.self_attn(
+            positions=positions,
+            hidden_states=hidden_states,
+            kv_cache=kv_cache,
+            input_metadata=input_metadata,
+        )
+
+        # Fully Connected
+        hidden_states, residual = self.post_attention_layernorm(
+            hidden_states, residual)
+        hidden_states = self.mlp(hidden_states)
+        return hidden_states, residual
+
+
+class BaiChuanModel(nn.Module):
+
+    def __init__(self,
+                 config: BaiChuanConfig,
+                 position_embedding: str,
+                 linear_method: Optional[LinearMethodBase] = None):
+        super().__init__()
+        self.config = config
+        self.padding_idx = config.pad_token_id
+        self.vocab_size = config.vocab_size
+
+        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
+                                                   config.hidden_size,
+                                                   linear_method=linear_method)
+        self.layers = nn.ModuleList([
+            BaiChuanDecoderLayer(config, position_embedding, linear_method)
+            for _ in range(config.num_hidden_layers)
+        ])
+        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.embed_tokens(input_ids)
+        residual = None
+        for i in range(len(self.layers)):
+            layer = self.layers[i]
+            hidden_states, residual = layer(
+                positions,
+                hidden_states,
+                kv_caches[i],
+                input_metadata,
+                residual,
+            )
+        hidden_states, _ = self.norm(hidden_states, residual)
+        return hidden_states
+
+
+class BaiChuanBaseForCausalLM(nn.Module):
+
+    def __init__(self,
+                 config,
+                 position_embedding: str,
+                 linear_method: Optional[LinearMethodBase] = None):
+        super().__init__()
+        self.config = config
+        self.linear_method = linear_method
+        self.model = BaiChuanModel(config, position_embedding, linear_method)
+        self.lm_head = ParallelLMHead(config.vocab_size,
+                                      config.hidden_size,
+                                      linear_method=linear_method)
+        self.sampler = Sampler(config.vocab_size)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.model(input_ids, positions, kv_caches,
+                                   input_metadata)
+        return hidden_states
+
+    def sample(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(self.lm_head(hidden_states),
+                                   sampling_metadata)
+        return next_tokens
+
+    def load_weights(self,
+                     model_name_or_path: str,
+                     cache_dir: Optional[str] = None,
+                     load_format: str = "auto",
+                     revision: Optional[str] = None):
+        stacked_params_mapping = [
+            # (param_name, shard_name, shard_id)
+            ("gate_up_proj", "gate_proj", 0),
+            ("gate_up_proj", "up_proj", 1),
+        ]
+        if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
+        ):
+            stacked_params_mapping = []
+        params_dict = dict(self.named_parameters())
+        for name, loaded_weight in hf_model_weights_iterator(
+                model_name_or_path, cache_dir, load_format, revision,
+                self.config):
+            if "rotary_emb.inv_freq" in name:
+                continue
+            if name == "lm_head.weight":
+                # Unlike Baichuan, Baichuan2 normalizes the head weights. Refer to:
+                # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
+                # Distinguish between Baichuan and Baichuan2 by checking the
+                # vocab size.
+                is_baichuan2 = self.config.vocab_size == 125696
+                if is_baichuan2:
+                    loaded_weight = torch.nn.functional.normalize(
+                        loaded_weight)
+
+            for (param_name, weight_name, shard_id) in stacked_params_mapping:
+                if weight_name not in name:
+                    continue
+                name = name.replace(weight_name, param_name)
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+                param = params_dict[name]
+                weight_loader = param.weight_loader
+                weight_loader(param, loaded_weight, shard_id)
+                break
+            else:
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+                param = params_dict[name]
+                weight_loader = getattr(param, "weight_loader",
+                                        default_weight_loader)
+                weight_loader(param, loaded_weight)
+
+
+class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
+    """Baichuan 13B and Baichuan2 7B/13B."""
+
+    def __init__(self,
+                 config,
+                 linear_method: Optional[LinearMethodBase] = None):
+        if config.hidden_size == 4096:  # baichuan2 7b
+            super().__init__(config, "ROPE", linear_method)
+        else:  # baichuan 13b, baichuan2 13b
+            super().__init__(config, "ALIBI", linear_method)
+
+
+class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
+    """Baichuan 7B."""
+
+    def __init__(self,
+                 config,
+                 linear_method: Optional[LinearMethodBase] = None):
+        super().__init__(config, "ROPE", linear_method)

+ 340 - 0
aphrodite/modeling/models/bloom.py

@@ -0,0 +1,340 @@
+# coding=utf-8
+# Adapted from
+# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py
+# Copyright 2023 The PygmalionAI team.
+# Copyright 2023 The CacheFlow team.
+# Copyright 2022 HuggingFace Inc. team and BigScience workshop.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only BLOOM model compatible with HuggingFace weights."""
+import math
+from typing import List, Optional, Tuple
+
+import torch
+from torch import nn
+from transformers import BloomConfig
+
+from aphrodite.modeling.metadata import InputMetadata
+from aphrodite.modeling.layers.activation import get_act_fn
+from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
+                                              LinearMethodBase,
+                                              QKVParallelLinear,
+                                              RowParallelLinear)
+from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    VocabParallelEmbedding, ParallelLMHead)
+from aphrodite.modeling.megatron.parallel_state import (
+    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.modeling.hf_downloader import (default_weight_loader,
+                                              hf_model_weights_iterator)
+from aphrodite.common.sequence import SamplerOutput
+
+KVCache = Tuple[torch.Tensor, torch.Tensor]
+
+
+def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
+    closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
+    base = torch.tensor(
+        2**(-(2**-(math.log2(closest_power_of_2) - 3))),
+        dtype=torch.float32,
+    )
+    powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
+    slopes = torch.pow(base, powers)
+
+    if closest_power_of_2 != total_num_heads:
+        extra_base = torch.tensor(
+            2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
+            dtype=torch.float32,
+        )
+        num_remaining_heads = min(closest_power_of_2,
+                                  total_num_heads - closest_power_of_2)
+        extra_powers = torch.arange(start=1,
+                                    end=1 + 2 * num_remaining_heads,
+                                    step=2,
+                                    dtype=torch.int32)
+        slopes = torch.cat(
+            [slopes, torch.pow(extra_base, extra_powers)], dim=0)
+    return slopes
+
+
+class BloomAttention(nn.Module):
+
+    def __init__(
+        self,
+        config: BloomConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.hidden_size = config.hidden_size
+        self.total_num_heads = config.n_head
+        self.head_dim = self.hidden_size // self.total_num_heads
+        assert self.head_dim * self.total_num_heads == self.hidden_size
+
+        tp_world_size = get_tensor_model_parallel_world_size()
+        assert self.total_num_heads % tp_world_size == 0
+        self.num_heads = self.total_num_heads // tp_world_size
+
+        self.query_key_value = QKVParallelLinear(
+            self.hidden_size,
+            self.head_dim,
+            self.total_num_heads,
+            bias=True,
+            linear_method=linear_method,
+        )
+        self.dense = RowParallelLinear(
+            self.hidden_size,
+            self.hidden_size,
+            bias=True,
+            linear_method=linear_method,
+        )
+
+        # Create the alibi slopes and slice them.
+        tp_rank = get_tensor_model_parallel_rank()
+        head_start = tp_rank * self.num_heads
+        head_end = (tp_rank + 1) * self.num_heads
+        alibi_slopes = _get_alibi_slopes(self.total_num_heads)
+        alibi_slopes = alibi_slopes[head_start:head_end].tolist()
+
+        scaling = self.head_dim**-0.5
+        self.attn = PagedAttention(self.num_heads,
+                                   self.head_dim,
+                                   scaling,
+                                   alibi_slopes=alibi_slopes)
+
+    def forward(
+        self,
+        position_ids: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        del position_ids  # Unused.
+        qkv, _ = self.query_key_value(hidden_states)
+        q, k, v = qkv.chunk(chunks=3, dim=-1)
+        k_cache, v_cache = kv_cache
+        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
+        output, _ = self.dense(attn_output)
+        return output
+
+
+class BloomMLP(nn.Module):
+
+    def __init__(
+        self,
+        config: BloomConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        hidden_size = config.hidden_size
+        self.dense_h_to_4h = ColumnParallelLinear(
+            hidden_size,
+            4 * hidden_size,
+            linear_method=linear_method,
+        )
+        quant_config = getattr(linear_method, "quant_config", None)
+        self.gelu_impl = get_act_fn("gelu", quant_config, 4 * hidden_size)
+        self.dense_4h_to_h = RowParallelLinear(
+            4 * hidden_size,
+            hidden_size,
+            linear_method=linear_method,
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x, _ = self.dense_h_to_4h(x)
+        x = self.gelu_impl(x)
+        x, _ = self.dense_4h_to_h(x)
+        return x
+
+
+class BloomBlock(nn.Module):
+
+    def __init__(
+        self,
+        config: BloomConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        hidden_size = config.hidden_size
+
+        self.input_layernorm = nn.LayerNorm(hidden_size,
+                                            eps=config.layer_norm_epsilon)
+        self.self_attention = BloomAttention(config, linear_method)
+        self.post_attention_layernorm = nn.LayerNorm(
+            hidden_size, eps=config.layer_norm_epsilon)
+        self.mlp = BloomMLP(config, linear_method)
+        self.apply_residual_connection_post_layernorm = (
+            config.apply_residual_connection_post_layernorm)
+
+    def forward(
+        self,
+        position_ids: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        # Layer norm at the beginning of the transformer layer.
+        layernorm_output = self.input_layernorm(hidden_states)
+
+        # Layer norm post the self attention.
+        if self.apply_residual_connection_post_layernorm:
+            residual = layernorm_output
+        else:
+            residual = hidden_states
+
+        # Self attention.
+        attention_output = self.self_attention(
+            position_ids=position_ids,
+            hidden_states=layernorm_output,
+            kv_cache=kv_cache,
+            input_metadata=input_metadata,
+        )
+        attention_output = attention_output + residual
+        layernorm_output = self.post_attention_layernorm(attention_output)
+
+        # Get residual
+        if self.apply_residual_connection_post_layernorm:
+            residual = layernorm_output
+        else:
+            residual = attention_output
+
+        # MLP.
+        output = self.mlp(layernorm_output) + residual
+        return output
+
+
+class BloomModel(nn.Module):
+
+    def __init__(
+        self,
+        config: BloomConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.embed_dim = config.hidden_size
+
+        # Embedding + LN Embedding
+        self.word_embeddings = VocabParallelEmbedding(
+            config.vocab_size, self.embed_dim, linear_method=linear_method)
+        self.word_embeddings_layernorm = nn.LayerNorm(
+            self.embed_dim, eps=config.layer_norm_epsilon)
+
+        # Transformer blocks
+        self.h = nn.ModuleList([
+            BloomBlock(config, linear_method)
+            for _ in range(config.num_hidden_layers)
+        ])
+
+        # Final Layer Norm
+        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        position_ids: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.word_embeddings(input_ids)
+        hidden_states = self.word_embeddings_layernorm(hidden_states)
+        for i in range(len(self.h)):
+            layer = self.h[i]
+            hidden_states = layer(
+                position_ids,
+                hidden_states,
+                kv_caches[i],
+                input_metadata,
+            )
+        hidden_states = self.ln_f(hidden_states)
+        return hidden_states
+
+
+class BloomForCausalLM(nn.Module):
+
+    def __init__(
+        self,
+        config: BloomConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.config = config
+        self.linear_method = linear_method
+        self.transformer = BloomModel(config, linear_method)
+        # self.lm_head_weight = self.transformer.word_embeddings.weight
+        self.lm_head = ParallelLMHead(config.vocab_size,
+                                      config.hidden_size,
+                                      linear_method=linear_method)
+        self.sampler = Sampler(config.vocab_size)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.transformer(input_ids, positions, kv_caches,
+                                         input_metadata)
+        return hidden_states
+
+    def sample(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(self.lm_head(hidden_states),
+                                   sampling_metadata)
+        return next_tokens
+
+    def load_weights(self,
+                     model_name_or_path: str,
+                     cache_dir: Optional[str] = None,
+                     load_format: str = "auto",
+                     revision: Optional[str] = None):
+        params_dict = dict(self.named_parameters(remove_duplicate=False))
+        for name, loaded_weight in hf_model_weights_iterator(
+                model_name_or_path, cache_dir, load_format, revision,
+                self.config):
+            if name == "lm_head.weight":
+                continue
+            if not name.startswith("transformer."):
+                name = "transformer." + name
+            param = params_dict[name]
+
+            if "word_embeddings.weight" in name:
+                # Copy word embedding to lm_head
+                lm_head_param = params_dict["lm_head.weight"]
+                weight_loader = getattr(lm_head_param, "weight_loader",
+                                        default_weight_loader)
+                weight_loader(lm_head_param, loaded_weight)
+
+            if "query_key_value" in name:
+                # NOTE: BLOOM's fused QKV's output_dim has the shape of
+                # (num_heads * 3 * head_size), while the
+                # required shape is (3 * num_heads * head_size).
+                # Thus, we need weight conversion.
+                output_dim = getattr(param, "output_dim", None)
+                num_heads = self.config.num_attention_heads
+                if output_dim is not None:
+                    loaded_weight_shape = loaded_weight.shape
+                    loaded_weight = loaded_weight.view(
+                        loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
+                        loaded_weight_shape[output_dim + 1:])
+                    loaded_weight = loaded_weight.transpose(
+                        output_dim, output_dim + 1)
+                    loaded_weight = loaded_weight.reshape(loaded_weight_shape)
+
+            weight_loader = getattr(param, "weight_loader",
+                                    default_weight_loader)
+            weight_loader(param, loaded_weight)

+ 378 - 0
aphrodite/modeling/models/chatglm.py

@@ -0,0 +1,378 @@
+# coding=utf-8
+# Adapted from
+# https://github.com/THUDM/ChatGLM2-6B
+"""Inference-only ChatGLM model compatible with THUDM weights."""
+from typing import List, Optional, Tuple
+
+import torch
+from torch import nn
+from torch.nn import LayerNorm
+
+from aphrodite.modeling.metadata import InputMetadata
+from aphrodite.modeling.layers.activation import SiluAndMul
+from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.layernorm import RMSNorm
+from aphrodite.modeling.layers.linear import (LinearMethodBase,
+                                              MergedColumnParallelLinear,
+                                              QKVParallelLinear,
+                                              RowParallelLinear)
+from aphrodite.modeling.layers.rotary_embedding import get_rope
+from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    VocabParallelEmbedding, ParallelLMHead)
+from aphrodite.modeling.megatron.parallel_state import (
+    get_tensor_model_parallel_world_size)
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.modeling.hf_downloader import (default_weight_loader,
+                                              hf_model_weights_iterator)
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.transformers_utils.configs import ChatGLMConfig
+
+KVCache = Tuple[torch.Tensor, torch.Tensor]
+
+
+class GLMAttention(nn.Module):
+
+    def __init__(
+        self,
+        config,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.hidden_size = config.hidden_size
+        tp_size = get_tensor_model_parallel_world_size()
+        self.total_num_heads = config.num_attention_heads
+        assert self.total_num_heads % tp_size == 0
+        self.num_heads = self.total_num_heads // tp_size
+        self.multi_query_attention = config.multi_query_attention
+        self.total_num_kv_heads = (config.multi_query_group_num
+                                   if config.multi_query_attention else
+                                   config.num_attention_heads)
+        if self.total_num_kv_heads >= tp_size:
+            # Number of KV heads is greater than TP size, so we partition
+            # the KV heads across multiple tensor parallel GPUs.
+            assert self.total_num_kv_heads % tp_size == 0
+        else:
+            # Number of KV heads is less than TP size, so we replicate
+            # the KV heads across multiple tensor parallel GPUs.
+            assert tp_size % self.total_num_kv_heads == 0
+        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
+        self.head_dim = config.hidden_size // self.total_num_heads
+        self.q_size = self.num_heads * self.head_dim
+        self.kv_size = self.num_kv_heads * self.head_dim
+        self.scaling = self.head_dim**-0.5
+
+        self.query_key_value = QKVParallelLinear(
+            self.hidden_size,
+            self.head_dim,
+            self.total_num_heads,
+            self.total_num_kv_heads,
+            bias=config.add_bias_linear or config.add_qkv_bias,
+            linear_method=linear_method,
+        )
+        self.dense = RowParallelLinear(
+            self.total_num_heads * self.head_dim,
+            config.hidden_size,
+            bias=config.add_bias_linear,
+            linear_method=linear_method,
+        )
+
+        # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
+        rope_ratio = getattr(config, "rope_ratio", 1.0)
+        max_positions = getattr(config, "seq_length", 8192)
+        self.rotary_emb = get_rope(
+            self.head_dim,
+            rotary_dim=self.head_dim // 2,
+            max_position=max_positions,
+            base=10000 * rope_ratio,
+            is_neox_style=False,
+        )
+        self.attn = PagedAttention(
+            self.num_heads,
+            self.head_dim,
+            self.scaling,
+            num_kv_heads=self.num_kv_heads,
+        )
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        position_ids: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        qkv, _ = self.query_key_value(hidden_states)
+        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+        q, k = self.rotary_emb(position_ids, q, k)
+        key_cache, value_cache = kv_cache
+        context_layer = self.attn(
+            q,
+            k,
+            v,
+            key_cache,
+            value_cache,
+            input_metadata,
+        )
+        attn_output, _ = self.dense(context_layer)
+        return attn_output
+
+
+class GLMMLP(nn.Module):
+    """MLP.
+
+    MLP will take the input with h hidden state, project it to 4*h
+    hidden dimension, perform nonlinear transformation, and project the
+    state back into h hidden dimension.
+    """
+
+    def __init__(
+        self,
+        config,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+
+        self.add_bias = config.add_bias_linear
+
+        # Project to 4h.
+        self.dense_h_to_4h = MergedColumnParallelLinear(
+            config.hidden_size,
+            [config.ffn_hidden_size] * 2,
+            bias=config.add_bias_linear,
+            linear_method=linear_method,
+        )
+
+        self.activation_func = SiluAndMul()
+
+        # Project back to h.
+        self.dense_4h_to_h = RowParallelLinear(
+            config.ffn_hidden_size,
+            config.hidden_size,
+            bias=config.add_bias_linear,
+            linear_method=linear_method,
+        )
+
+    def forward(self, hidden_states):
+        # [s, b, 4hp]
+        intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
+        intermediate_parallel = self.activation_func(intermediate_parallel)
+        # [s, b, h]
+        output, _ = self.dense_4h_to_h(intermediate_parallel)
+        return output
+
+
+class GLMBlock(nn.Module):
+    """A single transformer layer.
+
+    Transformer layer takes input with size [s, b, h] and returns an
+    output of the same size.
+    """
+
+    def __init__(
+        self,
+        config,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.apply_residual_connection_post_layernorm = (
+            config.apply_residual_connection_post_layernorm)
+
+        self.fp32_residual_connection = config.fp32_residual_connection
+
+        layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
+        # Layernorm on the input data.
+        self.input_layernorm = layer_norm_func(config.hidden_size,
+                                               eps=config.layernorm_epsilon)
+
+        # Self attention.
+        self.self_attention = GLMAttention(config, linear_method)
+        self.hidden_dropout = config.hidden_dropout
+
+        # Layernorm on the attention output
+        self.post_attention_layernorm = layer_norm_func(
+            config.hidden_size, eps=config.layernorm_epsilon)
+
+        # MLP
+        self.mlp = GLMMLP(config, linear_method)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        position_ids: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        # hidden_states: [num_tokens, h]
+        # Layer norm at the beginning of the transformer layer.
+        layernorm_output = self.input_layernorm(hidden_states)
+        # Self attention.
+        attention_output = self.self_attention(
+            hidden_states=layernorm_output,
+            position_ids=position_ids,
+            kv_cache=kv_cache,
+            input_metadata=input_metadata,
+        )
+
+        # Residual connection.
+        if self.apply_residual_connection_post_layernorm:
+            residual = layernorm_output
+        else:
+            residual = hidden_states
+
+        layernorm_input = residual + attention_output
+
+        # Layer norm post the self attention.
+        layernorm_output = self.post_attention_layernorm(layernorm_input)
+
+        # Second residual connection.
+        if self.apply_residual_connection_post_layernorm:
+            residual = layernorm_output
+        else:
+            residual = layernorm_input
+
+        output = self.mlp(layernorm_output) + residual
+
+        return output
+
+
+class GLMTransformer(nn.Module):
+    """Transformer class."""
+
+    def __init__(
+        self,
+        config,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.post_layer_norm = config.post_layer_norm
+
+        # Number of layers.
+        self.num_layers = config.num_layers
+
+        # Transformer layers.
+        self.layers = nn.ModuleList(
+            [GLMBlock(config, linear_method) for i in range(self.num_layers)])
+
+        if self.post_layer_norm:
+            layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
+            # Final layer norm before output.
+            self.final_layernorm = layer_norm_func(
+                config.hidden_size, eps=config.layernorm_epsilon)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        position_ids: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        for i in range(self.num_layers):
+            layer = self.layers[i]
+            hidden_states = layer(
+                hidden_states=hidden_states,
+                position_ids=position_ids,
+                kv_cache=kv_caches[i],
+                input_metadata=input_metadata,
+            )
+        # Final layer norm.
+        if self.post_layer_norm:
+            hidden_states = self.final_layernorm(hidden_states)
+
+        return hidden_states
+
+
+class ChatGLMModel(nn.Module):
+
+    def __init__(
+        self,
+        config,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+
+        self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
+                                                config.hidden_size,
+                                                linear_method=linear_method)
+
+        self.num_layers = config.num_layers
+        self.multi_query_group_num = config.multi_query_group_num
+        self.kv_channels = config.kv_channels
+        self.encoder = GLMTransformer(config, linear_method)
+
+        self.output_layer = ParallelLMHead(config.padded_vocab_size,
+                                           config.hidden_size,
+                                           linear_method=linear_method)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        position_ids: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        inputs_embeds = self.embedding(input_ids)
+
+        # Run encoder.
+        hidden_states = self.encoder(
+            hidden_states=inputs_embeds,
+            position_ids=position_ids,
+            kv_caches=kv_caches,
+            input_metadata=input_metadata,
+        )
+        return hidden_states
+
+
+class ChatGLMForCausalLM(nn.Module):
+
+    def __init__(
+        self,
+        config: ChatGLMConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.config: ChatGLMConfig = config
+        self.linear_method = linear_method
+        self.transformer = ChatGLMModel(config, linear_method)
+        # self.lm_head_weight = self.transformer.output_layer.weight
+        self.sampler = Sampler(config.padded_vocab_size)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.transformer(input_ids, positions, kv_caches,
+                                         input_metadata)
+        return hidden_states
+
+    def sample(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(
+            self.transformer.output_layer(hidden_states), sampling_metadata)
+        return next_tokens
+
+    def load_weights(self,
+                     model_name_or_path: str,
+                     cache_dir: Optional[str] = None,
+                     load_format: str = "auto",
+                     revision: Optional[str] = None):
+        params_dict = dict(self.named_parameters(remove_duplicate=False))
+        for name, loaded_weight in hf_model_weights_iterator(
+                model_name_or_path, cache_dir, load_format, revision,
+                self.config):
+            if "rotary_pos_emb.inv_freq" in name:
+                continue
+            if "word_embeddings" in name:
+                name = name.replace(".word_embeddings", "")
+            # Skip loading extra bias for GPTQ models.
+            if name.endswith(".bias") and name not in params_dict:
+                continue
+            param = params_dict[name]
+            weight_loader = getattr(param, "weight_loader",
+                                    default_weight_loader)
+            weight_loader(param, loaded_weight)

+ 6 - 2
aphrodite/modeling/models/decilm.py

@@ -29,6 +29,7 @@ from typing import Optional
 import torch
 from transformers import PretrainedConfig
 
+from aphrodite.common.config import LoRAConfig
 from aphrodite.modeling.layers.linear import LinearMethodBase
 from aphrodite.modeling.models.llama import LlamaForCausalLM
 from aphrodite.modeling.hf_downloader import (default_weight_loader,
@@ -41,7 +42,7 @@ class DeciLMForCausalLM(LlamaForCausalLM):
     Based on the llama executor.
 
     The main difference is that DeciLM uses Variable Grouped Query Attention.
-    The constant number of GQA heads in the decoder is overriden with a value
+    The constant number of GQA heads in the decoder is overridden with a value
     per layer.
 
     Usually, in the HuggingFace implementation, instead of
@@ -57,10 +58,13 @@ class DeciLMForCausalLM(LlamaForCausalLM):
         self,
         config: Optional[PretrainedConfig] = None,
         linear_method: Optional[LinearMethodBase] = None,
+        lora_config: Optional[LoRAConfig] = None,
     ) -> None:
         config.num_key_value_heads = max(config.num_key_value_heads_per_layer)
         delattr(config, "num_key_value_heads_per_layer")
-        super().__init__(config=config, linear_method=linear_method)
+        super().__init__(config=config,
+                         linear_method=linear_method,
+                         lora_config=lora_config)
 
     def load_weights(self,
                      model_name_or_path: str,

+ 23 - 72
aphrodite/modeling/models/deepseek.py

@@ -28,14 +28,16 @@ import torch
 from torch import nn
 from transformers import PretrainedConfig
 
-from aphrodite.modeling.metadata import InputMetadata
+from aphrodite.modeling.megatron import InputMetadata
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.attention import PagedAttention
 from aphrodite.modeling.layers.triton_kernel.fused_moe import fused_moe
 from aphrodite.modeling.layers.layernorm import RMSNorm
-from aphrodite.modeling.layers.linear import (
-    LinearMethodBase, MergedColumnParallelLinear, ReplicatedLinear,
-    QKVParallelLinear, RowParallelLinear, ColumnParallelLinear)
+from aphrodite.modeling.layers.linear import (LinearMethodBase,
+                                              MergedColumnParallelLinear,
+                                              ReplicatedLinear,
+                                              QKVParallelLinear,
+                                              RowParallelLinear)
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
@@ -63,23 +65,10 @@ class DeepseekMLP(nn.Module):
         reduce_results: bool = True,
     ) -> None:
         super().__init__()
-        if linear_method is not None and not linear_method.quant_config.merge_weight(
-        ):
-            self.merge_weight = False
-            self.gate_proj = ColumnParallelLinear(hidden_size,
-                                                  intermediate_size,
-                                                  bias=False,
-                                                  linear_method=linear_method)
-            self.up_proj = ColumnParallelLinear(hidden_size,
-                                                intermediate_size,
-                                                bias=False,
-                                                linear_method=linear_method)
-        else:
-            self.merge_weight = True
-            self.gate_up_proj = MergedColumnParallelLinear(
-                hidden_size, [intermediate_size] * 2,
-                bias=False,
-                linear_method=linear_method)
+        self.gate_up_proj = MergedColumnParallelLinear(
+            hidden_size, [intermediate_size] * 2,
+            bias=False,
+            linear_method=linear_method)
         self.down_proj = RowParallelLinear(intermediate_size,
                                            hidden_size,
                                            bias=False,
@@ -91,12 +80,7 @@ class DeepseekMLP(nn.Module):
         self.act_fn = SiluAndMul()
 
     def forward(self, x):
-        if self.merge_weight:
-            gate_up, _ = self.gate_up_proj(x)
-        else:
-            up, _ = self.up_proj(x)
-            gate, _ = self.gate_proj(x)
-            gate_up = torch.cat([gate, up], dim=-1)
+        gate_up, _ = self.gate_up_proj(x)
         x = self.act_fn(gate_up)
         x, _ = self.down_proj(x)
         return x
@@ -171,7 +155,6 @@ class DeepseekMoE(nn.Module):
             shared_output = self.shared_experts(hidden_states)
         # router_logits: (batch * sequence_length, n_experts)
         router_logits, _ = self.gate(hidden_states)
-
         final_hidden_states = fused_moe(hidden_states,
                                         self.w1,
                                         self.w2,
@@ -224,31 +207,14 @@ class DeepseekAttention(nn.Module):
         self.rope_theta = rope_theta
         self.max_position_embeddings = max_position_embeddings
 
-        if linear_method is not None and not linear_method.quant_config.merge_weight(
-        ):
-            self.merge_weight = False
-            self.q_proj = ColumnParallelLinear(hidden_size,
-                                               self.q_size,
-                                               bias=False,
-                                               linear_method=linear_method)
-            self.k_proj = ColumnParallelLinear(hidden_size,
-                                               self.kv_size,
-                                               bias=False,
-                                               linear_method=linear_method)
-            self.v_proj = ColumnParallelLinear(hidden_size,
-                                               self.kv_size,
-                                               bias=False,
-                                               linear_method=linear_method)
-        else:
-            self.merge_weight = True
-            self.qkv_proj = QKVParallelLinear(
-                hidden_size,
-                self.head_dim,
-                self.total_num_heads,
-                self.total_num_kv_heads,
-                bias=False,
-                linear_method=linear_method,
-            )
+        self.qkv_proj = QKVParallelLinear(
+            hidden_size,
+            self.head_dim,
+            self.total_num_heads,
+            self.total_num_kv_heads,
+            bias=False,
+            linear_method=linear_method,
+        )
 
         self.o_proj = RowParallelLinear(
             self.total_num_heads * self.head_dim,
@@ -257,15 +223,12 @@ class DeepseekAttention(nn.Module):
             linear_method=linear_method,
         )
 
-        is_neox_style = True if linear_method is None or linear_method.quant_config.rope_style(
-        ) is None else linear_method.quant_config.rope_style()
         self.rotary_emb = get_rope(
             self.head_dim,
             rotary_dim=self.head_dim,
             max_position=max_position_embeddings,
             base=rope_theta,
             rope_scaling=rope_scaling,
-            is_neox_style=is_neox_style,
         )
         self.attn = PagedAttention(self.num_heads,
                                    self.head_dim,
@@ -279,14 +242,8 @@ class DeepseekAttention(nn.Module):
         kv_cache: KVCache,
         input_metadata: InputMetadata,
     ) -> torch.Tensor:
-        if self.merge_weight:
-            qkv, _ = self.qkv_proj(hidden_states)
-            q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
-                                dim=-1)
-        else:
-            q, _ = self.q_proj(hidden_states)
-            k, _ = self.k_proj(hidden_states)
-            v, _ = self.v_proj(hidden_states)
+        qkv, _ = self.qkv_proj(hidden_states)
+        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
         q, k = self.rotary_emb(positions, q, k)
         k_cache, v_cache = kv_cache
         attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
@@ -375,7 +332,6 @@ class DeepseekModel(nn.Module):
         self.embed_tokens = VocabParallelEmbedding(
             config.vocab_size,
             config.hidden_size,
-            linear_method=linear_method,
         )
         self.layers = nn.ModuleList([
             DeepseekDecoderLayer(config,
@@ -414,9 +370,7 @@ class DeepseekForCausalLM(nn.Module):
         self.config = config
         self.linear_method = linear_method
         self.model = DeepseekModel(config, linear_method)
-        self.lm_head = ParallelLMHead(config.vocab_size,
-                                      config.hidden_size,
-                                      linear_method=linear_method)
+        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
         self.sampler = Sampler(config.vocab_size)
 
     def forward(
@@ -452,16 +406,13 @@ class DeepseekForCausalLM(nn.Module):
             ("gate_up_proj", "gate_proj", 0),
             ("gate_up_proj", "up_proj", 1),
         ]
-        if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
-        ):
-            stacked_params_mapping = []
+
         params_dict = dict(self.named_parameters())
         for name, loaded_weight in hf_model_weights_iterator(
                 model_name_or_path,
                 cache_dir,
                 load_format,
                 revision,
-                self.config,
                 fall_back_to_pt=False):
             if "rotary_emb.inv_freq" in name:
                 continue

+ 446 - 0
aphrodite/modeling/models/falcon.py

@@ -0,0 +1,446 @@
+# coding=utf-8
+# Adapted from
+# https://github.com/huggingface/transformers/blob/a5cc30d72ae2dc19af534e4b35c986cc28db1275/src/transformers/models/falcon/modeling_falcon.py
+# Copyright 2023 The PygmalionAI team.
+# Copyright 2023 The vLLM team.
+# Copyright 2023 the Falcon authors and HuggingFace Inc. team.  All rights
+# reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Falcon model."""
+
+import math
+from typing import List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+from torch.nn import LayerNorm
+from transformers import FalconConfig as HF_FalconConfig
+
+from aphrodite.modeling.metadata import InputMetadata
+from aphrodite.modeling.layers.activation import get_act_fn
+from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
+                                              LinearMethodBase,
+                                              QKVParallelLinear,
+                                              RowParallelLinear)
+from aphrodite.modeling.layers.rotary_embedding import get_rope
+from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    VocabParallelEmbedding, ParallelLMHead)
+from aphrodite.modeling.megatron.communication_op import (
+    tensor_model_parallel_all_reduce)
+from aphrodite.modeling.megatron.parallel_state import (
+    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.modeling.hf_downloader import (default_weight_loader,
+                                              hf_model_weights_iterator)
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.transformers_utils.configs import RWConfig
+
+KVCache = Tuple[torch.Tensor, torch.Tensor]
+FalconConfig = Union[HF_FalconConfig, RWConfig]
+
+
+def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
+    closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
+    base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))),
+                        dtype=torch.float32)
+    powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
+    slopes = torch.pow(base, powers)
+
+    if closest_power_of_2 != total_num_heads:
+        extra_base = torch.tensor(
+            2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
+            dtype=torch.float32)
+        num_remaining_heads = min(closest_power_of_2,
+                                  total_num_heads - closest_power_of_2)
+        extra_powers = torch.arange(1,
+                                    1 + 2 * num_remaining_heads,
+                                    2,
+                                    dtype=torch.int32)
+        slopes = torch.cat(
+            [slopes, torch.pow(extra_base, extra_powers)], dim=0)
+
+    return slopes
+
+
+class FalconAttention(nn.Module):
+
+    def __init__(
+        self,
+        config: FalconConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+
+        self.hidden_size = config.hidden_size
+        tp_size = get_tensor_model_parallel_world_size()
+
+        self.total_num_heads = config.num_attention_heads
+        assert self.total_num_heads % tp_size == 0
+        self.num_heads = self.total_num_heads // tp_size
+        self.head_dim = self.hidden_size // self.total_num_heads
+        assert self.head_dim * self.total_num_heads == self.hidden_size
+
+        self.new_decoder_architecture = config.new_decoder_architecture
+        self.multi_query = config.multi_query
+
+        if self.new_decoder_architecture:
+            self.total_num_kv_heads = config.num_kv_heads
+        elif self.multi_query:
+            self.total_num_kv_heads = 1
+        else:
+            self.total_num_kv_heads = self.total_num_heads
+        if self.total_num_kv_heads >= tp_size:
+            # Number of KV heads is greater than TP size, so we partition
+            # the KV heads across multiple tensor parallel GPUs.
+            assert self.total_num_kv_heads % tp_size == 0
+        else:
+            # Number of KV heads is less than TP size, so we replicate
+            # the KV heads across multiple tensor parallel GPUs.
+            assert tp_size % self.total_num_kv_heads == 0
+        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
+
+        self.query_key_value = QKVParallelLinear(
+            self.hidden_size,
+            self.head_dim,
+            self.total_num_heads,
+            self.total_num_kv_heads,
+            bias=config.bias,
+            skip_bias_add=True,
+            linear_method=linear_method,
+        )
+        self.q_size = self.num_heads * self.head_dim
+        self.kv_size = self.num_kv_heads * self.head_dim
+
+        # Layer-wise attention scaling
+        self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
+        self.reduce_row_parallel_results = not (config.new_decoder_architecture
+                                                or config.parallel_attn)
+        self.dense = RowParallelLinear(
+            self.hidden_size,
+            self.hidden_size,
+            bias=config.bias,
+            skip_bias_add=True,
+            linear_method=linear_method,
+            reduce_results=self.reduce_row_parallel_results)
+
+        self.use_rotary = config.rotary
+        self.use_alibi = config.alibi
+        assert not (self.use_rotary and self.use_alibi), (
+            "Rotary and alibi are mutually exclusive.")
+
+        if self.use_rotary:
+            rope_theta = getattr(config, "rope_theta", 10000)
+            max_position_embeddings = getattr(config,
+                                              "max_position_embeddings", 8192)
+            self.rotary_emb = get_rope(
+                self.head_dim,
+                rotary_dim=self.head_dim,
+                max_position=max_position_embeddings,
+                base=rope_theta,
+            )
+            self.attn = PagedAttention(self.num_heads,
+                                       self.head_dim,
+                                       self.inv_norm_factor,
+                                       num_kv_heads=self.num_kv_heads)
+        elif self.use_alibi:
+            tp_rank = get_tensor_model_parallel_rank()
+            head_start = tp_rank * self.num_heads
+            head_end = (tp_rank + 1) * self.num_heads
+            alibi_slopes = (_get_alibi_slopes(self.total_num_heads) *
+                            self.inv_norm_factor)
+            alibi_slopes = alibi_slopes[head_start:head_end].tolist()
+            self.attn = PagedAttention(self.num_heads,
+                                       self.head_dim,
+                                       self.inv_norm_factor,
+                                       num_kv_heads=self.num_kv_heads,
+                                       alibi_slopes=alibi_slopes)
+        else:
+            self.attn = PagedAttention(self.num_heads,
+                                       self.head_dim,
+                                       scale=self.inv_norm_factor,
+                                       num_kv_heads=self.num_kv_heads)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        qkv, bias = self.query_key_value(hidden_states)
+        if bias is not None:
+            qkv += bias
+        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+        if self.use_rotary:
+            q, k = self.rotary_emb(positions, q, k)
+        k_cache, v_cache = kv_cache
+        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
+        attn_output, bias = self.dense(attn_output)
+        return attn_output, bias
+
+
+class FalconMLP(nn.Module):
+
+    def __init__(
+        self,
+        config: FalconConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        hidden_size = config.hidden_size
+
+        self.dense_h_to_4h = ColumnParallelLinear(hidden_size,
+                                                  4 * hidden_size,
+                                                  bias=config.bias,
+                                                  skip_bias_add=True,
+                                                  linear_method=linear_method)
+        quant_config = getattr(linear_method, "quant_config", None)
+        self.act = get_act_fn("gelu", quant_config, 4 * hidden_size)
+        self.reduce_row_parallel_results = not (config.new_decoder_architecture
+                                                or config.parallel_attn)
+        self.dense_4h_to_h = RowParallelLinear(
+            4 * hidden_size,
+            hidden_size,
+            bias=config.bias,
+            skip_bias_add=True,
+            reduce_results=self.reduce_row_parallel_results,
+            linear_method=linear_method)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        # NOTE(zhuohan): Following huggingface, we do not fuse bias add here.
+        x, bias = self.dense_h_to_4h(x)
+        if bias is not None:
+            x += bias
+        x = self.act(x)
+        x, bias = self.dense_4h_to_h(x)
+        return x, bias
+
+
+class FalconDecoderLayer(nn.Module):
+
+    def __init__(
+        self,
+        config: FalconConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        hidden_size = config.hidden_size
+        self.num_heads = config.num_attention_heads
+        self.self_attention = FalconAttention(config, linear_method)
+        self.mlp = FalconMLP(config, linear_method)
+        self.config = config
+
+        if config.new_decoder_architecture:
+            # The layer norm before self-attention
+            self.ln_attn = LayerNorm(hidden_size,
+                                     eps=config.layer_norm_epsilon)
+            # The layer norm before the MLP
+            self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+        else:
+            self.input_layernorm = LayerNorm(hidden_size,
+                                             eps=config.layer_norm_epsilon)
+            if not config.parallel_attn:
+                self.post_attention_layernorm = LayerNorm(
+                    hidden_size, eps=config.layer_norm_epsilon)
+
+        self.reduce_row_parallel_results = not (config.new_decoder_architecture
+                                                or config.parallel_attn)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        residual = hidden_states
+
+        if self.config.new_decoder_architecture:
+            attention_layernorm_out = self.ln_attn(hidden_states)
+            mlp_layernorm_out = self.ln_mlp(hidden_states)
+        else:
+            attention_layernorm_out = self.input_layernorm(hidden_states)
+
+        # Self attention.
+        attention_output, attention_bias = self.self_attention(
+            positions=positions,
+            hidden_states=attention_layernorm_out,
+            kv_cache=kv_cache,
+            input_metadata=input_metadata,
+        )
+        if self.reduce_row_parallel_results and attention_bias is not None:
+            attention_output += attention_bias
+
+        if not self.config.new_decoder_architecture:
+            if self.config.parallel_attn:
+                mlp_layernorm_out = attention_layernorm_out
+            else:
+                residual += attention_output
+                mlp_layernorm_out = self.post_attention_layernorm(residual)
+
+        # MLP.
+        mlp_output, mlp_bias = self.mlp(mlp_layernorm_out)
+        if self.reduce_row_parallel_results and mlp_bias is not None:
+            mlp_output += mlp_bias
+
+        if not self.reduce_row_parallel_results:
+            # When MLP and Attention layers are parallel, we can use
+            # only one all-reduce operator to reduce the results from
+            # both MLP and Attention layers.
+            mlp_output += attention_output
+            mlp_output = tensor_model_parallel_all_reduce(mlp_output)
+            if attention_bias is not None:
+                mlp_output += attention_bias
+            if mlp_bias is not None:
+                mlp_output += mlp_bias
+
+        output = mlp_output + residual
+        return output
+
+
+class FalconModel(nn.Module):
+
+    def __init__(
+        self,
+        config: FalconConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.config = config
+        self.embed_dim = config.hidden_size
+        self.num_heads = config.num_attention_heads
+        self.use_alibi = config.alibi
+
+        # Embedding + LN Embedding
+        self.word_embeddings = VocabParallelEmbedding(
+            config.vocab_size, self.embed_dim, linear_method=linear_method)
+
+        # Transformer blocks
+        self.h = nn.ModuleList([
+            FalconDecoderLayer(config, linear_method)
+            for _ in range(config.num_hidden_layers)
+        ])
+
+        # Final Layer Norm
+        self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
+
+    def forward(
+        self,
+        input_ids: torch.LongTensor,
+        positions: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.word_embeddings(input_ids)
+        for i in range(len(self.h)):
+            layer = self.h[i]
+            hidden_states = layer(
+                positions,
+                hidden_states,
+                kv_caches[i],
+                input_metadata,
+            )
+        hidden_states = self.ln_f(hidden_states)
+        return hidden_states
+
+
+class FalconForCausalLM(nn.Module):
+
+    def __init__(
+        self,
+        config: FalconConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.config = config
+        self.linear_method = linear_method
+        self.transformer = FalconModel(config, linear_method)
+        self.lm_head = ParallelLMHead(config.vocab_size,
+                                      config.hidden_size,
+                                      linear_method=linear_method)
+        self.sampler = Sampler(config.vocab_size)
+
+    def forward(
+        self,
+        input_ids: torch.LongTensor,
+        positions: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.transformer(
+            input_ids,
+            positions,
+            kv_caches,
+            input_metadata,
+        )
+        return hidden_states
+
+    def sample(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(self.lm_head(hidden_states),
+                                   sampling_metadata)
+        return next_tokens
+
+    def load_weights(self,
+                     model_name_or_path: str,
+                     cache_dir: Optional[str] = None,
+                     load_format: str = "auto",
+                     revision: Optional[str] = None):
+        total_num_heads = self.config.num_attention_heads
+        if self.config.new_decoder_architecture:
+            total_num_kv_heads = self.config.num_kv_heads
+        elif self.config.multi_query:
+            total_num_kv_heads = 1
+        else:
+            total_num_kv_heads = total_num_heads
+        num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
+        params_dict = dict(self.named_parameters())
+        for name, loaded_weight in hf_model_weights_iterator(
+                model_name_or_path, cache_dir, load_format, revision,
+                self.config):
+            # Skip loading extra bias for GPTQ models.
+            if name.endswith(".bias") and name not in params_dict:
+                continue
+            param = params_dict[name]
+            if "query_key_value" in name:
+                output_dim = getattr(param, "output_dim", None)
+                loaded_weight_shape = loaded_weight.shape
+                if output_dim is not None:
+                    loaded_weight = loaded_weight.view(
+                        loaded_weight_shape[:output_dim] +
+                        (total_num_kv_heads, num_query_heads_per_kv_head + 2,
+                         -1) + loaded_weight_shape[output_dim + 1:])
+                    wq = loaded_weight.narrow(
+                        output_dim + 1, 0,
+                        num_query_heads_per_kv_head).reshape(
+                            *loaded_weight_shape[:output_dim], -1,
+                            *loaded_weight_shape[output_dim + 1:])
+                    wk = loaded_weight.narrow(
+                        output_dim + 1, num_query_heads_per_kv_head,
+                        1).reshape(*loaded_weight_shape[:output_dim], -1,
+                                   *loaded_weight_shape[output_dim + 1:])
+                    wv = loaded_weight.narrow(
+                        output_dim + 1, num_query_heads_per_kv_head + 1,
+                        1).reshape(*loaded_weight_shape[:output_dim], -1,
+                                   *loaded_weight_shape[output_dim + 1:])
+                    loaded_weight = torch.cat([wq, wk, wv], dim=output_dim)
+
+            weight_loader = getattr(param, "weight_loader",
+                                    default_weight_loader)
+            weight_loader(param, loaded_weight)

+ 65 - 62
aphrodite/modeling/models/yi.py → aphrodite/modeling/models/gemma.py

@@ -1,14 +1,7 @@
 # coding=utf-8
-# Adapted from
-# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
 # Copyright 2023 The PygmalionAI team.
 # Copyright 2023 The vLLM team.
-# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
-#
-# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
-# and OPT implementations in this library. It has been modified from its
-# original forms to accommodate minor architectural differences compared
-# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+# Copyright (c) Google Inc.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -21,15 +14,15 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-"""Inference-only Yi model compatible with HuggingFace weights."""
-from typing import Any, Dict, List, Optional, Tuple
+"""Inference-only Gemma model compatible with HuggingFace weights."""
+from typing import List, Optional, Tuple
 
 import torch
 from torch import nn
-from aphrodite.transformers_utils.configs.yi import YiConfig
+from transformers import GemmaConfig
 
 from aphrodite.modeling.metadata import InputMetadata
-from aphrodite.modeling.layers.activation import SiluAndMul
+from aphrodite.modeling.layers.activation import GeluAndMul
 from aphrodite.modeling.layers.attention import PagedAttention
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (LinearMethodBase,
@@ -51,13 +44,12 @@ from aphrodite.common.sequence import SamplerOutput
 KVCache = Tuple[torch.Tensor, torch.Tensor]
 
 
-class YiMLP(nn.Module):
+class GemmaMLP(nn.Module):
 
     def __init__(
         self,
         hidden_size: int,
         intermediate_size: int,
-        hidden_act: str,
         linear_method: Optional[LinearMethodBase] = None,
     ) -> None:
         super().__init__()
@@ -82,10 +74,7 @@ class YiMLP(nn.Module):
                                            hidden_size,
                                            bias=False,
                                            linear_method=linear_method)
-        if hidden_act != "silu":
-            raise ValueError(f"Unsupported activation: {hidden_act}. "
-                             "Only silu is supported for now.")
-        self.act_fn = SiluAndMul()
+        self.act_fn = GeluAndMul()
 
     def forward(self, x):
         if self.merge_weight:
@@ -99,18 +88,16 @@ class YiMLP(nn.Module):
         return x
 
 
-class YiAttention(nn.Module):
+class GemmaAttention(nn.Module):
 
-    def __init__(
-        self,
-        hidden_size: int,
-        num_heads: int,
-        num_kv_heads: int,
-        rope_theta: float = 10000,
-        rope_scaling: Optional[Dict[str, Any]] = None,
-        max_position_embeddings: int = 8192,
-        linear_method: Optional[LinearMethodBase] = None,
-    ) -> None:
+    def __init__(self,
+                 hidden_size: int,
+                 num_heads: int,
+                 num_kv_heads: int,
+                 head_dim: int,
+                 max_position_embeddings: int = 8192,
+                 rope_theta: float = 10000,
+                 linear_method: Optional[LinearMethodBase] = None) -> None:
         super().__init__()
         self.hidden_size = hidden_size
         tp_size = get_tensor_model_parallel_world_size()
@@ -127,12 +114,11 @@ class YiAttention(nn.Module):
             # the KV heads across multiple tensor parallel GPUs.
             assert tp_size % self.total_num_kv_heads == 0
         self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
-        self.head_dim = hidden_size // self.total_num_heads
+        self.head_dim = head_dim
         self.q_size = self.num_heads * self.head_dim
         self.kv_size = self.num_kv_heads * self.head_dim
         self.scaling = self.head_dim**-0.5
         self.rope_theta = rope_theta
-        self.max_position_embeddings = max_position_embeddings
 
         if linear_method is not None and not linear_method.quant_config.merge_weight(
         ):
@@ -165,15 +151,12 @@ class YiAttention(nn.Module):
             bias=False,
             linear_method=linear_method,
         )
-        is_neox_style = True if linear_method is None or linear_method.quant_config.rope_style(
-        ) is None else linear_method.quant_config.rope_style()
         self.rotary_emb = get_rope(
             self.head_dim,
             rotary_dim=self.head_dim,
             max_position=max_position_embeddings,
             base=self.rope_theta,
-            rope_scaling=rope_scaling,
-            is_neox_style=is_neox_style,
+            is_neox_style=True,
         )
         self.attn = PagedAttention(self.num_heads,
                                    self.head_dim,
@@ -202,36 +185,33 @@ class YiAttention(nn.Module):
         return output
 
 
-class YiDecoderLayer(nn.Module):
+class GemmaDecoderLayer(nn.Module):
 
     def __init__(
         self,
-        config: YiConfig,
+        config: GemmaConfig,
         linear_method: Optional[LinearMethodBase] = None,
     ) -> None:
         super().__init__()
         self.hidden_size = config.hidden_size
-        rope_theta = getattr(config, "rope_theta", 10000)
-        rope_scaling = getattr(config, "rope_scaling", None)
-        max_position_embeddings = getattr(config, "max_position_embeddings",
-                                          8192)
-        self.self_attn = YiAttention(
+        self.self_attn = GemmaAttention(
             hidden_size=self.hidden_size,
             num_heads=config.num_attention_heads,
             num_kv_heads=config.num_key_value_heads,
-            rope_theta=rope_theta,
-            rope_scaling=rope_scaling,
-            max_position_embeddings=max_position_embeddings,
+            head_dim=config.head_dim,
+            max_position_embeddings=config.max_position_embeddings,
+            rope_theta=config.rope_theta,
             linear_method=linear_method,
         )
-        self.mlp = YiMLP(
+        self.mlp = GemmaMLP(
             hidden_size=self.hidden_size,
             intermediate_size=config.intermediate_size,
-            hidden_act=config.hidden_act,
             linear_method=linear_method,
         )
-        self.ln1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
-        self.ln2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+        self.input_layernorm = RMSNorm(config.hidden_size,
+                                       eps=config.rms_norm_eps)
+        self.post_attention_layernorm = RMSNorm(config.hidden_size,
+                                                eps=config.rms_norm_eps)
 
     def forward(
         self,
@@ -244,9 +224,10 @@ class YiDecoderLayer(nn.Module):
         # Self Attention
         if residual is None:
             residual = hidden_states
-            hidden_states = self.ln1(hidden_states)
+            hidden_states = self.input_layernorm(hidden_states)
         else:
-            hidden_states, residual = self.ln1(hidden_states, residual)
+            hidden_states, residual = self.input_layernorm(
+                hidden_states, residual)
         hidden_states = self.self_attn(
             positions=positions,
             hidden_states=hidden_states,
@@ -255,27 +236,27 @@ class YiDecoderLayer(nn.Module):
         )
 
         # Fully Connected
-        hidden_states, residual = self.ln2(hidden_states, residual)
+        hidden_states, residual = self.post_attention_layernorm(
+            hidden_states, residual)
         hidden_states = self.mlp(hidden_states)
         return hidden_states, residual
 
 
-class YiModel(nn.Module):
+class GemmaModel(nn.Module):
 
     def __init__(
         self,
-        config: YiConfig,
+        config: GemmaConfig,
         linear_method: Optional[LinearMethodBase] = None,
     ) -> None:
         super().__init__()
         self.config = config
-        self.padding_idx = config.pad_token_id
-        self.vocab_size = config.vocab_size
+
         self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                    config.hidden_size,
                                                    linear_method=linear_method)
         self.layers = nn.ModuleList([
-            YiDecoderLayer(config, linear_method)
+            GemmaDecoderLayer(config, linear_method)
             for _ in range(config.num_hidden_layers)
         ])
         self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -288,6 +269,9 @@ class YiModel(nn.Module):
         input_metadata: InputMetadata,
     ) -> torch.Tensor:
         hidden_states = self.embed_tokens(input_ids)
+        # Normalize the embedding by sqrt(hidden_size)
+        hidden_states *= self.config.hidden_size**0.5
+
         residual = None
         for i in range(len(self.layers)):
             layer = self.layers[i]
@@ -302,22 +286,23 @@ class YiModel(nn.Module):
         return hidden_states
 
 
-class YiForCausalLM(nn.Module):
+class GemmaForCausalLM(nn.Module):
 
     def __init__(
         self,
-        config: YiConfig,
+        config: GemmaConfig,
         linear_method: Optional[LinearMethodBase] = None,
     ) -> None:
         super().__init__()
         self.config = config
         self.linear_method = linear_method
-        self.model = YiModel(config, linear_method)
+        self.model = GemmaModel(config, linear_method)
         self.lm_head = ParallelLMHead(config.vocab_size,
                                       config.hidden_size,
                                       linear_method=linear_method)
         self.sampler = Sampler(config.vocab_size)
 
+    @torch.no_grad()
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -355,11 +340,19 @@ class YiForCausalLM(nn.Module):
         ):
             stacked_params_mapping = []
         params_dict = dict(self.named_parameters())
+        loaded_params = set()
         for name, loaded_weight in hf_model_weights_iterator(
                 model_name_or_path, cache_dir, load_format, revision,
                 self.config):
             if "rotary_emb.inv_freq" in name:
                 continue
+            if "embed_tokens.weight" in name:
+                # Copy word embedding to lm_head
+                loaded_params.add("lm_head.weight")
+                lm_head_param = params_dict["lm_head.weight"]
+                weight_loader = getattr(lm_head_param, "weight_loader",
+                                        default_weight_loader)
+                weight_loader(lm_head_param, loaded_weight)
             for (param_name, weight_name, shard_id) in stacked_params_mapping:
                 if weight_name not in name:
                     continue
@@ -372,10 +365,20 @@ class YiForCausalLM(nn.Module):
                 weight_loader(param, loaded_weight, shard_id)
                 break
             else:
-                # Skip loading extra bias for GPTQ models.
-                if name.endswith(".bias") and name not in params_dict:
+                # Skip loading extra layer for lora models.
+                if "lm_head" in name:
                     continue
+                # GemmaRMSNorm is different from Llama's in that it multiplies
+                # (1 + weight) to the output, instead of just weight.
+                if "norm.weight" in name:
+                    loaded_weight += 1.0
                 param = params_dict[name]
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        unloaded_params = params_dict.keys() - loaded_params
+        if unloaded_params:
+            raise RuntimeError(
+                "Some weights are not initialized from checkpoints: "
+                f"{unloaded_params}")

+ 286 - 0
aphrodite/modeling/models/gpt2.py

@@ -0,0 +1,286 @@
+# coding=utf-8
+# Adapted from
+# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
+# Copyright 2023 The PygmalionAI team.
+# Copyright 2023 The vLLM team.
+# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only GPT-2 model compatible with HuggingFace weights."""
+from typing import List, Optional, Tuple
+
+import torch
+from torch import nn
+from transformers import GPT2Config
+
+from aphrodite.modeling.metadata import InputMetadata
+from aphrodite.modeling.layers.activation import get_act_fn
+from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
+                                              LinearMethodBase,
+                                              QKVParallelLinear,
+                                              RowParallelLinear)
+from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    VocabParallelEmbedding, ParallelLMHead)
+from aphrodite.modeling.megatron.parallel_state import (
+    get_tensor_model_parallel_world_size)
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.modeling.hf_downloader import (default_weight_loader,
+                                              hf_model_weights_iterator)
+from aphrodite.common.sequence import SamplerOutput
+
+KVCache = Tuple[torch.Tensor, torch.Tensor]
+
+
+class GPT2Attention(nn.Module):
+
+    def __init__(
+        self,
+        config: GPT2Config,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.hidden_size = config.hidden_size
+        total_num_heads = config.num_attention_heads
+        tensor_model_parallel_world_size = (
+            get_tensor_model_parallel_world_size())
+        assert total_num_heads % tensor_model_parallel_world_size == 0
+        self.num_heads = total_num_heads // tensor_model_parallel_world_size
+        self.head_dim = self.hidden_size // total_num_heads
+        self.scale = self.head_dim**-0.5
+
+        self.c_attn = QKVParallelLinear(
+            self.hidden_size,
+            self.head_dim,
+            total_num_heads,
+            bias=True,
+            linear_method=linear_method,
+        )
+        self.c_proj = RowParallelLinear(
+            self.hidden_size,
+            self.hidden_size,
+            bias=True,
+            linear_method=linear_method,
+        )
+        self.attn = PagedAttention(self.num_heads,
+                                   self.head_dim,
+                                   scale=self.scale)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        qkv, _ = self.c_attn(hidden_states)
+        q, k, v = qkv.chunk(chunks=3, dim=-1)
+        key_cache, value_cache = kv_cache
+        attn_output = self.attn(q, k, v, key_cache, value_cache,
+                                input_metadata)
+        attn_output, _ = self.c_proj(attn_output)
+        return attn_output
+
+
+class GPT2MLP(nn.Module):
+
+    def __init__(
+        self,
+        intermediate_size: int,
+        config: GPT2Config,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        hidden_size = config.hidden_size
+        self.c_fc = ColumnParallelLinear(
+            hidden_size,
+            intermediate_size,
+            bias=True,
+            linear_method=linear_method,
+        )
+        self.c_proj = RowParallelLinear(
+            intermediate_size,
+            hidden_size,
+            bias=True,
+            linear_method=linear_method,
+        )
+        quant_config = getattr(linear_method, "quant_config", None)
+        self.act = get_act_fn(config.activation_function, quant_config,
+                              intermediate_size)
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        hidden_states, _ = self.c_fc(hidden_states)
+        hidden_states = self.act(hidden_states)
+        hidden_states, _ = self.c_proj(hidden_states)
+        return hidden_states
+
+
+class GPT2Block(nn.Module):
+
+    def __init__(
+        self,
+        config: GPT2Config,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        hidden_size = config.hidden_size
+        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
+                     hidden_size)
+
+        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+        self.attn = GPT2Attention(config, linear_method)
+        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+        self.mlp = GPT2MLP(inner_dim, config, linear_method)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        residual = hidden_states
+        hidden_states = self.ln_1(hidden_states)
+        attn_output = self.attn(
+            hidden_states=hidden_states,
+            kv_cache=kv_cache,
+            input_metadata=input_metadata,
+        )
+        # residual connection
+        hidden_states = attn_output + residual
+
+        residual = hidden_states
+        hidden_states = self.ln_2(hidden_states)
+        feed_forward_hidden_states = self.mlp(hidden_states)
+        # residual connection
+        hidden_states = residual + feed_forward_hidden_states
+        return hidden_states
+
+
+class GPT2Model(nn.Module):
+
+    def __init__(
+        self,
+        config: GPT2Config,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.config = config
+        assert not config.add_cross_attention
+        assert not config.scale_attn_by_inverse_layer_idx
+        assert not config.reorder_and_upcast_attn
+        self.embed_dim = config.hidden_size
+        self.wte = VocabParallelEmbedding(config.vocab_size,
+                                          self.embed_dim,
+                                          linear_method=linear_method)
+        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
+        self.h = nn.ModuleList([
+            GPT2Block(config, linear_method)
+            for _ in range(config.num_hidden_layers)
+        ])
+        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        position_ids: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        inputs_embeds = self.wte(input_ids)
+        position_embeds = self.wpe(position_ids)
+        hidden_states = inputs_embeds + position_embeds
+
+        for i in range(len(self.h)):
+            layer = self.h[i]
+            hidden_states = layer(hidden_states, kv_caches[i], input_metadata)
+
+        hidden_states = self.ln_f(hidden_states)
+        return hidden_states
+
+
+class GPT2LMHeadModel(nn.Module):
+
+    def __init__(
+        self,
+        config: GPT2Config,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.config = config
+        self.linear_method = linear_method
+        self.transformer = GPT2Model(config, linear_method)
+        # self.lm_head_weight = self.transformer.wte.weight
+        self.lm_head = ParallelLMHead(config.vocab_size,
+                                      config.hidden_size,
+                                      linear_method=linear_method)
+        self.sampler = Sampler(config.vocab_size)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.transformer(input_ids, positions, kv_caches,
+                                         input_metadata)
+        return hidden_states
+
+    def sample(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(self.lm_head(hidden_states),
+                                   sampling_metadata)
+        return next_tokens
+
+    def load_weights(self,
+                     model_name_or_path: str,
+                     cache_dir: Optional[str] = None,
+                     load_format: str = "auto",
+                     revision: Optional[str] = None):
+        params_dict = dict(self.named_parameters(remove_duplicate=False))
+        for name, loaded_weight in hf_model_weights_iterator(
+                model_name_or_path, cache_dir, load_format, revision,
+                self.config):
+            if "lm_head.weight" in name:
+                # GPT-2 ties the weights of the embedding layer and the final
+                # linear layer.
+                continue
+            if "wte.weight" in name:
+                # Copy word embedding to lm_head
+                lm_head_param = params_dict["lm_head.weight"]
+                weight_loader = getattr(lm_head_param, "weight_loader",
+                                        default_weight_loader)
+                weight_loader(lm_head_param, loaded_weight)
+            if ".attn.bias" in name or ".attn.masked_bias" in name:
+                # Skip attention mask.
+                # NOTE: "c_attn.bias" should not be skipped.
+                continue
+            if not name.startswith("transformer."):
+                name = "transformer." + name
+            param = params_dict[name]
+            # The HF's GPT-2 implementation uses Conv1D instead of Linear.
+            # Because of this, we need to transpose the weights.
+            # Note(zhuohan): the logic below might break quantized models.
+            for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
+                if conv1d_weight_name not in name:
+                    continue
+                if not name.endswith(".weight"):
+                    continue
+                loaded_weight = loaded_weight.t()
+            weight_loader = getattr(param, "weight_loader",
+                                    default_weight_loader)
+            weight_loader(param, loaded_weight)

+ 292 - 0
aphrodite/modeling/models/gpt_bigcode.py

@@ -0,0 +1,292 @@
+# coding=utf-8
+# Adapted from
+# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
+# Copyright 2023 The PygmalionAI team.
+# Copyright 2023 The vLLM team.
+# Copyright 2023 CTranslate2, and Michael Feil
+# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
+from typing import List, Optional, Tuple
+
+import torch
+from torch import nn
+from transformers import GPTBigCodeConfig
+
+from aphrodite.modeling.metadata import InputMetadata
+from aphrodite.modeling.layers.activation import get_act_fn
+from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
+                                              LinearMethodBase,
+                                              QKVParallelLinear,
+                                              RowParallelLinear)
+from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    VocabParallelEmbedding, ParallelLMHead)
+from aphrodite.modeling.megatron.parallel_state import (
+    get_tensor_model_parallel_world_size)
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.modeling.hf_downloader import (default_weight_loader,
+                                              hf_model_weights_iterator)
+from aphrodite.common.sequence import SamplerOutput
+
+KVCache = Tuple[torch.Tensor, torch.Tensor]
+
+
+class GPTBigCodeAttention(nn.Module):
+
+    def __init__(
+        self,
+        config: GPTBigCodeConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.hidden_size = config.hidden_size
+        total_num_heads = config.num_attention_heads
+        self.tensor_model_parallel_world_size = (
+            get_tensor_model_parallel_world_size())
+        assert total_num_heads % self.tensor_model_parallel_world_size == 0
+        self.num_heads = (total_num_heads //
+                          self.tensor_model_parallel_world_size)
+        self.head_dim = self.hidden_size // total_num_heads
+        self.scale = self.head_dim**-0.5
+
+        self.multi_query = config.multi_query
+        if self.multi_query:
+            total_num_kv_heads = 1
+            self.num_kv_heads = 1
+        else:
+            total_num_kv_heads = total_num_heads
+            self.num_kv_heads = self.num_heads
+        self.kv_dim = self.head_dim * self.num_kv_heads
+        self.c_attn = QKVParallelLinear(
+            self.hidden_size,
+            self.head_dim,
+            total_num_heads,
+            total_num_kv_heads,
+            bias=True,
+            linear_method=linear_method,
+        )
+
+        self.c_proj = RowParallelLinear(
+            self.hidden_size,
+            self.hidden_size,
+            bias=True,
+            linear_method=linear_method,
+        )
+        self.attn = PagedAttention(self.num_heads,
+                                   self.head_dim,
+                                   scale=self.scale,
+                                   num_kv_heads=self.num_kv_heads)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        qkv, _ = self.c_attn(hidden_states)
+        q, k, v = qkv.split(
+            [
+                self.hidden_size // self.tensor_model_parallel_world_size,
+                self.kv_dim, self.kv_dim
+            ],
+            dim=-1,
+        )
+        key_cache, value_cache = kv_cache
+        attn_output = self.attn(q, k, v, key_cache, value_cache,
+                                input_metadata)
+        attn_output, _ = self.c_proj(attn_output)
+        return attn_output
+
+
+class GPTBigMLP(nn.Module):
+
+    def __init__(
+        self,
+        intermediate_size: int,
+        config: GPTBigCodeConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        hidden_size = config.hidden_size
+        self.c_fc = ColumnParallelLinear(
+            hidden_size,
+            intermediate_size,
+            bias=True,
+            linear_method=linear_method,
+        )
+        self.c_proj = RowParallelLinear(
+            intermediate_size,
+            hidden_size,
+            bias=True,
+            linear_method=linear_method,
+        )
+        quant_config = getattr(linear_method, "quant_config", None)
+        self.act = get_act_fn(config.activation_function, quant_config,
+                              intermediate_size)
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        hidden_states, _ = self.c_fc(hidden_states)
+        hidden_states = self.act(hidden_states)
+        hidden_states, _ = self.c_proj(hidden_states)
+        return hidden_states
+
+
+class GPTBigCodeBlock(nn.Module):
+
+    def __init__(
+        self,
+        config: GPTBigCodeConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        hidden_size = config.hidden_size
+        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
+                     hidden_size)
+
+        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+        self.attn = GPTBigCodeAttention(config, linear_method)
+        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+        self.mlp = GPTBigMLP(inner_dim, config, linear_method)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        residual = hidden_states
+        hidden_states = self.ln_1(hidden_states)
+        attn_output = self.attn(
+            hidden_states=hidden_states,
+            kv_cache=kv_cache,
+            input_metadata=input_metadata,
+        )
+        # residual connection
+        hidden_states = attn_output + residual
+
+        residual = hidden_states
+        hidden_states = self.ln_2(hidden_states)
+        feed_forward_hidden_states = self.mlp(hidden_states)
+        # residual connection
+        hidden_states = residual + feed_forward_hidden_states
+        return hidden_states
+
+
+class GPTBigCodeModel(nn.Module):
+
+    def __init__(
+        self,
+        config: GPTBigCodeConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.config = config
+        assert not config.add_cross_attention
+
+        self.embed_dim = config.hidden_size
+
+        self.wte = VocabParallelEmbedding(config.vocab_size,
+                                          self.embed_dim,
+                                          linear_method=linear_method)
+        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
+        self.h = nn.ModuleList([
+            GPTBigCodeBlock(config, linear_method)
+            for _ in range(config.num_hidden_layers)
+        ])
+        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        position_ids: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        inputs_embeds = self.wte(input_ids)
+        position_embeds = self.wpe(position_ids)
+        hidden_states = inputs_embeds + position_embeds
+
+        for i in range(len(self.h)):
+            layer = self.h[i]
+            hidden_states = layer(hidden_states, kv_caches[i], input_metadata)
+
+        hidden_states = self.ln_f(hidden_states)
+        return hidden_states
+
+
+class GPTBigCodeForCausalLM(nn.Module):
+
+    def __init__(
+        self,
+        config: GPTBigCodeConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.config = config
+        self.linear_method = linear_method
+        self.transformer = GPTBigCodeModel(config, linear_method)
+        # self.lm_head_weight = self.transformer.wte.weight
+        self.lm_head = ParallelLMHead(config.vocab_size,
+                                      config.hidden_size,
+                                      linear_method=linear_method)
+        self.sampler = Sampler(config.vocab_size)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.transformer(input_ids, positions, kv_caches,
+                                         input_metadata)
+        return hidden_states
+
+    def sample(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(self.lm_head(hidden_states),
+                                   sampling_metadata)
+        return next_tokens
+
+    def load_weights(self,
+                     model_name_or_path: str,
+                     cache_dir: Optional[str] = None,
+                     load_format: str = "auto",
+                     revision: Optional[str] = None):
+        params_dict = dict(self.named_parameters(remove_duplicate=False))
+        for name, loaded_weight in hf_model_weights_iterator(
+                model_name_or_path, cache_dir, load_format, revision,
+                self.config):
+            if "lm_head.weight" in name:
+                continue
+            if "wte.weight" in name:
+                # Copy word embedding to lm_head
+                lm_head_param = params_dict["lm_head.weight"]
+                weight_loader = getattr(lm_head_param, "weight_loader",
+                                        default_weight_loader)
+                weight_loader(lm_head_param, loaded_weight)
+            if ".attn.bias" in name:
+                # Skip attention mask.
+                # NOTE: "c_attn.bias" should not be skipped.
+                continue
+            param = params_dict[name]
+            weight_loader = getattr(param, "weight_loader",
+                                    default_weight_loader)
+            weight_loader(param, loaded_weight)

+ 8 - 13
aphrodite/modeling/models/gpt_j.py

@@ -166,8 +166,7 @@ class GPTJBlock(nn.Module):
         linear_method: Optional[LinearMethodBase] = None,
     ):
         super().__init__()
-        inner_dim = (4 * config.n_embd
-                     if config.n_inner is None else config.n_inner)
+        inner_dim = 4 * config.n_embd if config.n_inner is None else config.n_inner
         self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
         self.attn = GPTJAttention(config, linear_method)
         self.mlp = GPTJMLP(inner_dim, config, linear_method)
@@ -202,11 +201,9 @@ class GPTJModel(nn.Module):
         super().__init__()
         self.config = config
         self.embed_dim = config.n_embd
-        self.wte = VocabParallelEmbedding(
-            config.vocab_size,
-            self.embed_dim,
-            linear_method=linear_method,
-        )
+        self.wte = VocabParallelEmbedding(config.vocab_size,
+                                          self.embed_dim,
+                                          linear_method=linear_method)
         self.h = nn.ModuleList(
             [GPTJBlock(config, linear_method) for _ in range(config.n_layer)])
         self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
@@ -243,12 +240,10 @@ class GPTJForCausalLM(nn.Module):
         self.linear_method = linear_method
         assert not config.tie_word_embeddings
         self.transformer = GPTJModel(config, linear_method)
-        self.lm_head = ParallelLMHead(
-            config.vocab_size,
-            config.n_embd,
-            bias=True,
-            linear_method=linear_method,
-        )
+        self.lm_head = ParallelLMHead(config.vocab_size,
+                                      config.n_embd,
+                                      bias=True,
+                                      linear_method=linear_method)
         self.sampler = Sampler(config.vocab_size)
 
     def forward(

+ 6 - 10
aphrodite/modeling/models/gpt_neox.py

@@ -196,11 +196,9 @@ class GPTNeoXModel(nn.Module):
         super().__init__()
         self.config = config
 
-        self.embed_in = VocabParallelEmbedding(
-            config.vocab_size,
-            config.hidden_size,
-            linear_method=linear_method,
-        )
+        self.embed_in = VocabParallelEmbedding(config.vocab_size,
+                                               config.hidden_size,
+                                               linear_method=linear_method)
         self.layers = nn.ModuleList([
             GPTNeoXLayer(config, linear_method)
             for _ in range(config.num_hidden_layers)
@@ -239,11 +237,9 @@ class GPTNeoXForCausalLM(nn.Module):
         self.config = config
         self.linear_method = linear_method
         self.gpt_neox = GPTNeoXModel(config, linear_method)
-        self.embed_out = ParallelLMHead(
-            config.vocab_size,
-            config.hidden_size,
-            linear_method=linear_method,
-        )
+        self.embed_out = ParallelLMHead(config.vocab_size,
+                                        config.hidden_size,
+                                        linear_method=linear_method)
         self.sampler = Sampler(config.vocab_size)
 
     def forward(

+ 352 - 0
aphrodite/modeling/models/internlm2.py

@@ -0,0 +1,352 @@
+# -*- coding: utf-8 -*-
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+from torch import nn
+from transformers import PretrainedConfig
+
+from aphrodite.modeling.metadata import InputMetadata
+from aphrodite.modeling.layers.activation import SiluAndMul
+from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.layernorm import RMSNorm
+from aphrodite.modeling.layers.linear import (LinearMethodBase,
+                                              ColumnParallelLinear,
+                                              MergedColumnParallelLinear,
+                                              QKVParallelLinear,
+                                              RowParallelLinear)
+from aphrodite.modeling.layers.rotary_embedding import get_rope
+from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    VocabParallelEmbedding, ParallelLMHead)
+from aphrodite.modeling.megatron.parallel_state import (
+    get_tensor_model_parallel_world_size)
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.modeling.hf_downloader import (default_weight_loader,
+                                              hf_model_weights_iterator)
+from aphrodite.common.sequence import SamplerOutput
+
+KVCache = Tuple[torch.Tensor, torch.Tensor]
+
+
+class InternLM2MLP(nn.Module):
+
+    def __init__(
+        self,
+        hidden_size: int,
+        intermediate_size: int,
+        hidden_act: str,
+        linear_method: Optional[LinearMethodBase] = None,
+    ) -> None:
+        super().__init__()
+        if linear_method is not None and not linear_method.quant_config.merge_weight(
+        ):
+            self.merge_weight = False
+            self.w1 = ColumnParallelLinear(hidden_size,
+                                           intermediate_size,
+                                           bias=False,
+                                           linear_method=linear_method)
+            self.w3 = ColumnParallelLinear(hidden_size,
+                                           intermediate_size,
+                                           bias=False,
+                                           linear_method=linear_method)
+        else:
+            self.merge_weight = True
+            self.gate_up_proj = MergedColumnParallelLinear(
+                hidden_size, [intermediate_size] * 2,
+                bias=False,
+                linear_method=linear_method)
+        self.w2 = RowParallelLinear(intermediate_size,
+                                    hidden_size,
+                                    bias=False,
+                                    linear_method=linear_method)
+        if hidden_act != "silu":
+            raise ValueError(f"Unsupported activation: {hidden_act}. "
+                             "Only silu is supported for now.")
+        self.act_fn = SiluAndMul()
+
+    def forward(self, x):
+        if self.merge_weight:
+            gate_up, _ = self.gate_up_proj(x)
+        else:
+            up, _ = self.up_proj(x)
+            gate, _ = self.gate_proj(x)
+            gate_up = torch.cat([gate, up], dim=-1)
+        x = self.act_fn(gate_up)
+        x, _ = self.w2(x)
+        return x
+
+
+class InternLM2Attention(nn.Module):
+
+    def __init__(
+        self,
+        hidden_size: int,
+        num_heads: int,
+        num_kv_heads: int,
+        rope_theta: float = 10000,
+        rope_scaling: Optional[Dict[str, Any]] = None,
+        max_position_embeddings: int = 8192,
+        linear_method: Optional[LinearMethodBase] = None,
+    ) -> None:
+        super().__init__()
+        self.hidden_size = hidden_size
+        tp_size = get_tensor_model_parallel_world_size()
+        self.total_num_heads = num_heads
+        assert self.total_num_heads % tp_size == 0
+        self.num_heads = self.total_num_heads // tp_size
+        self.total_num_kv_heads = num_kv_heads
+        if self.total_num_kv_heads >= tp_size:
+            # Number of KV heads is greater than TP size, so we partition
+            # the KV heads across multiple tensor parallel GPUs.
+            assert self.total_num_kv_heads % tp_size == 0
+        else:
+            # Number of KV heads is less than TP size, so we replicate
+            # the KV heads across multiple tensor parallel GPUs.
+            assert tp_size % self.total_num_kv_heads == 0
+        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
+        self.head_dim = hidden_size // self.total_num_heads
+        self.q_size = self.num_heads * self.head_dim
+        self.kv_size = self.num_kv_heads * self.head_dim
+        self.scaling = self.head_dim**-0.5
+        self.rope_theta = rope_theta
+        self.max_position_embeddings = max_position_embeddings
+
+        self.wqkv = QKVParallelLinear(
+            hidden_size,
+            self.head_dim,
+            self.total_num_heads,
+            self.total_num_kv_heads,
+            bias=False,
+            linear_method=linear_method,
+        )
+        self.wo = RowParallelLinear(
+            self.total_num_heads * self.head_dim,
+            hidden_size,
+            bias=False,
+            linear_method=linear_method,
+        )
+
+        self.rotary_emb = get_rope(
+            self.head_dim,
+            rotary_dim=self.head_dim,
+            max_position=max_position_embeddings,
+            base=rope_theta,
+            rope_scaling=rope_scaling,
+        )
+        self.attn = PagedAttention(self.num_heads,
+                                   self.head_dim,
+                                   self.scaling,
+                                   num_kv_heads=self.num_kv_heads)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        qkv, _ = self.wqkv(hidden_states)
+        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+        q, k = self.rotary_emb(positions, q, k)
+        k_cache, v_cache = kv_cache
+        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
+        output, _ = self.wo(attn_output)
+        return output
+
+
+class InternLMDecoderLayer(nn.Module):
+
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ) -> None:
+        super().__init__()
+        self.hidden_size = config.hidden_size
+        rope_theta = getattr(config, "rope_theta", 10000)
+        rope_scaling = getattr(config, "rope_scaling", None)
+        max_position_embeddings = getattr(config, "max_position_embeddings",
+                                          8192)
+        self.attention = InternLM2Attention(
+            hidden_size=self.hidden_size,
+            num_heads=config.num_attention_heads,
+            num_kv_heads=config.num_key_value_heads,
+            rope_theta=rope_theta,
+            rope_scaling=rope_scaling,
+            max_position_embeddings=max_position_embeddings,
+            linear_method=linear_method,
+        )
+        self.feed_forward = InternLM2MLP(
+            hidden_size=self.hidden_size,
+            intermediate_size=config.intermediate_size,
+            hidden_act=config.hidden_act,
+            linear_method=linear_method,
+        )
+        self.attention_norm = RMSNorm(config.hidden_size,
+                                      eps=config.rms_norm_eps)
+        self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+        residual: Optional[torch.Tensor],
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        # Self Attention
+        if residual is None:
+            residual = hidden_states
+            hidden_states = self.attention_norm(hidden_states)
+        else:
+            hidden_states, residual = self.attention_norm(
+                hidden_states, residual)
+        hidden_states = self.attention(
+            positions=positions,
+            hidden_states=hidden_states,
+            kv_cache=kv_cache,
+            input_metadata=input_metadata,
+        )
+
+        # Fully Connected
+        hidden_states, residual = self.ffn_norm(hidden_states, residual)
+        hidden_states = self.feed_forward(hidden_states)
+        return hidden_states, residual
+
+
+class InternLM2Model(nn.Module):
+
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ) -> None:
+        super().__init__()
+        self.config = config
+        self.padding_idx = config.pad_token_id
+        self.vocab_size = config.vocab_size
+        self.tok_embeddings = VocabParallelEmbedding(
+            config.vocab_size,
+            config.hidden_size,
+            linear_method=linear_method,
+        )
+        self.layers = nn.ModuleList([
+            InternLMDecoderLayer(config, linear_method)
+            for _ in range(config.num_hidden_layers)
+        ])
+        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.tok_embeddings(input_ids)
+        residual = None
+        for i in range(len(self.layers)):
+            layer = self.layers[i]
+            hidden_states, residual = layer(
+                positions,
+                hidden_states,
+                kv_caches[i],
+                input_metadata,
+                residual,
+            )
+        hidden_states, _ = self.norm(hidden_states, residual)
+        return hidden_states
+
+
+class InternLM2ForCausalLM(nn.Module):
+
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ) -> None:
+        super().__init__()
+        self.config = config
+        self.linear_method = linear_method
+        self.model = InternLM2Model(config, linear_method)
+        self.output = ParallelLMHead(
+            config.vocab_size,
+            config.hidden_size,
+            linear_method=linear_method,
+        )
+        self.sampler = Sampler(config.vocab_size)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.model(input_ids, positions, kv_caches,
+                                   input_metadata)
+        return hidden_states
+
+    def sample(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(self.output(hidden_states),
+                                   sampling_metadata)
+        return next_tokens
+
+    def load_weights(self,
+                     model_name_or_path: str,
+                     cache_dir: Optional[str] = None,
+                     load_format: str = "auto",
+                     revision: Optional[str] = None):
+        stacked_params_mapping = [
+            # (param_name, shard_name, shard_id)
+            ("gate_up_proj", "w1", 0),
+            ("gate_up_proj", "w3", 1),
+        ]
+        if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
+        ):
+            stacked_params_mapping = []
+        params_dict = dict(self.named_parameters())
+        for name, loaded_weight in hf_model_weights_iterator(
+                model_name_or_path, cache_dir, load_format, revision):
+            if "rotary_emb.inv_freq" in name:
+                continue
+            for (param_name, weight_name, shard_id) in stacked_params_mapping:
+                if weight_name not in name:
+                    continue
+                name = name.replace(weight_name, param_name)
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+                param = params_dict[name]
+                weight_loader = param.weight_loader
+                weight_loader(param, loaded_weight, shard_id)
+                break
+            else:
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+                param = params_dict[name]
+                if "wqkv" in name:
+                    config = self.config
+                    kv_groups = config.num_attention_heads // config.num_key_value_heads
+                    head_dim = config.hidden_size // config.num_attention_heads
+                    loaded_weight = loaded_weight.view(-1, 2 + kv_groups,
+                                                       head_dim,
+                                                       loaded_weight.shape[-1])
+                    wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1],
+                                             dim=1)
+                    wq = wq.reshape(-1, wq.shape[-1])
+                    wk = wk.reshape(-1, wk.shape[-1])
+                    wv = wv.reshape(-1, wv.shape[-1])
+                    weight_loader = param.weight_loader
+                    weight_loader(param, wq, 'q')
+                    weight_loader(param, wk, 'k')
+                    weight_loader(param, wv, 'v')
+                else:
+                    weight_loader = getattr(param, "weight_loader",
+                                            default_weight_loader)
+                    weight_loader(param, loaded_weight)

+ 45 - 14
aphrodite/modeling/models/llama.py

@@ -1,7 +1,6 @@
 # coding=utf-8
 # Adapted from
 # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
-# Copyright 2023 The PygmalionAI team.
 # Copyright 2023 The vLLM team.
 # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
 #
@@ -111,6 +110,8 @@ class LlamaAttention(nn.Module):
         rope_scaling: Optional[Dict[str, Any]] = None,
         max_position_embeddings: int = 8192,
         linear_method: Optional[LinearMethodBase] = None,
+        bias: bool = False,
+        sliding_window: Optional[int] = None,
     ) -> None:
         super().__init__()
         self.hidden_size = hidden_size
@@ -140,15 +141,15 @@ class LlamaAttention(nn.Module):
             self.merge_weight = False
             self.q_proj = ColumnParallelLinear(hidden_size,
                                                self.q_size,
-                                               bias=False,
+                                               bias=bias,
                                                linear_method=linear_method)
             self.k_proj = ColumnParallelLinear(hidden_size,
                                                self.kv_size,
-                                               bias=False,
+                                               bias=bias,
                                                linear_method=linear_method)
             self.v_proj = ColumnParallelLinear(hidden_size,
                                                self.kv_size,
-                                               bias=False,
+                                               bias=bias,
                                                linear_method=linear_method)
         else:
             self.merge_weight = True
@@ -157,13 +158,13 @@ class LlamaAttention(nn.Module):
                 self.head_dim,
                 self.total_num_heads,
                 self.total_num_kv_heads,
-                bias=False,
+                bias=bias,
                 linear_method=linear_method,
             )
         self.o_proj = RowParallelLinear(
             self.total_num_heads * self.head_dim,
             hidden_size,
-            bias=False,
+            bias=bias,
             linear_method=linear_method,
         )
 
@@ -180,7 +181,8 @@ class LlamaAttention(nn.Module):
         self.attn = PagedAttention(self.num_heads,
                                    self.head_dim,
                                    self.scaling,
-                                   num_kv_heads=self.num_kv_heads)
+                                   num_kv_heads=self.num_kv_heads,
+                                   sliding_window=sliding_window)
 
     def forward(
         self,
@@ -217,14 +219,18 @@ class LlamaDecoderLayer(nn.Module):
         rope_scaling = getattr(config, "rope_scaling", None)
         max_position_embeddings = getattr(config, "max_position_embeddings",
                                           8192)
+        sliding_window = getattr(config, "sliding_window", None)
         self.self_attn = LlamaAttention(
             hidden_size=self.hidden_size,
             num_heads=config.num_attention_heads,
-            num_kv_heads=config.num_key_value_heads,
+            num_kv_heads=getattr(config, "num_key_value_heads",
+                                 config.num_attention_heads),
             rope_theta=rope_theta,
             rope_scaling=rope_scaling,
             max_position_embeddings=max_position_embeddings,
             linear_method=linear_method,
+            bias=getattr(config, "bias", False),
+            sliding_window=sliding_window,
         )
         self.mlp = LlamaMLP(
             hidden_size=self.hidden_size,
@@ -316,7 +322,32 @@ class LlamaModel(nn.Module):
 
 
 class LlamaForCausalLM(nn.Module):
-    supports_lora = True
+    packed_modules_mapping = {
+        "qkv_proj": [
+            "q_proj",
+            "k_proj",
+            "v_proj",
+        ],
+        "gate_up_proj": [
+            "gate_proj",
+            "up_proj",
+        ],
+    }
+
+    # LoRA specific attributes
+    supported_lora_modules = [
+        "qkv_proj",
+        "o_proj",
+        "gate_up_proj",
+        "down_proj",
+        "embed_tokens",
+        "lm_head",
+    ]
+    embedding_modules = {
+        "embed_tokens": "input_embeddings",
+        "lm_head": "output_embeddings",
+    }
+    embedding_padding_modules = ["lm_head"]
 
     def __init__(
         self,
@@ -328,20 +359,20 @@ class LlamaForCausalLM(nn.Module):
         self.config = config
         self.linear_method = linear_method
         self.model = LlamaModel(config, linear_method, lora_config=lora_config)
-        unpadded_vocab_size = config.vocab_size
+        self.unpadded_vocab_size = config.vocab_size
         if lora_config:
-            unpadded_vocab_size += lora_config.lora_extra_vocab_size
+            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
         self.lm_head = ParallelLMHead(
-            unpadded_vocab_size,
+            self.unpadded_vocab_size,
             config.hidden_size,
-            linear_method=linear_method,
             org_num_embeddings=config.vocab_size,
+            linear_method=linear_method,
             padding_size=DEFAULT_VOCAB_PADDING_SIZE
             # We need bigger padding if using lora for kernel
             # compatibility
             if not lora_config else lora_config.lora_vocab_padding_size,
         )
-        self.sampler = Sampler(unpadded_vocab_size, config.vocab_size)
+        self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size)
 
     def forward(
         self,

+ 62 - 18
aphrodite/modeling/models/mixtral.py

@@ -25,23 +25,23 @@
 from typing import List, Optional, Tuple
 
 import torch
-
 from torch import nn
 from transformers import MixtralConfig
 
+from aphrodite.common.config import LoRAConfig
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.attention import PagedAttention
 from aphrodite.modeling.layers.triton_kernel.fused_moe import fused_moe
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (LinearMethodBase,
-                                              ReplicatedLinear,
                                               QKVParallelLinear,
+                                              ReplicatedLinear,
                                               RowParallelLinear,
                                               ColumnParallelLinear)
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    VocabParallelEmbedding, ParallelLMHead)
+    VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
 from aphrodite.modeling.megatron.communication_op import (
     tensor_model_parallel_all_reduce)
 from aphrodite.modeling.megatron.parallel_state import (
@@ -58,6 +58,7 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
 class MixtralMoE(nn.Module):
     """A tensor-parallel MoE implementation for Mixtral that shards each expert
     across all ranks.
+
     Each expert's weights are sharded across all ranks and a fused MoE
     kernel is used for the forward pass, and finally we reduce the outputs
     across ranks.
@@ -70,13 +71,14 @@ class MixtralMoE(nn.Module):
         hidden_size: int,
         intermediate_size: int,
         params_dtype: Optional[torch.dtype] = None,
+        tp_size: Optional[int] = None,
     ):
         super().__init__()
-        tp_size = get_tensor_model_parallel_world_size()
+        self.tp_size = tp_size or get_tensor_model_parallel_world_size()
         self.num_total_experts = num_experts
         self.top_k = top_k
         self.hidden_size = hidden_size
-        self.intermediate_size = intermediate_size // tp_size
+        self.intermediate_size = intermediate_size // self.tp_size
 
         if params_dtype is None:
             params_dtype = torch.get_default_dtype()
@@ -127,7 +129,6 @@ class MixtralMoE(nn.Module):
         hidden_states = hidden_states.view(-1, self.hidden_size)
         # router_logits: (batch * sequence_length, n_experts)
         router_logits, _ = self.gate(hidden_states)
-
         final_hidden_states = fused_moe(hidden_states,
                                         self.ws,
                                         self.w2s,
@@ -136,8 +137,9 @@ class MixtralMoE(nn.Module):
                                         renormalize=True,
                                         inplace=True)
 
-        final_hidden_states = tensor_model_parallel_all_reduce(
-            final_hidden_states)
+        if self.tp_size > 1:
+            final_hidden_states = tensor_model_parallel_all_reduce(
+                final_hidden_states)
 
         return final_hidden_states.view(batch_size, sequence_length,
                                         hidden_size)
@@ -310,15 +312,20 @@ class MixtralModel(nn.Module):
         self,
         config: MixtralConfig,
         linear_method: Optional[LinearMethodBase] = None,
+        lora_config: Optional[LoRAConfig] = None,
     ) -> None:
         super().__init__()
         self.padding_idx = config.pad_token_id
-        self.vocab_size = config.vocab_size
+        lora_vocab = (lora_config.lora_extra_vocab_size *
+                      (lora_config.max_loras or 1)) if lora_config else 0
+        self.vocab_size = config.vocab_size + lora_vocab
+        self.org_vocab_size = config.vocab_size
 
         self.embed_tokens = VocabParallelEmbedding(
-            config.vocab_size,
+            self.vocab_size,
             config.hidden_size,
             linear_method=linear_method,
+            org_num_embeddings=config.vocab_size,
         )
         self.layers = nn.ModuleList([
             MixtralDecoderLayer(config, linear_method=linear_method)
@@ -345,20 +352,53 @@ class MixtralModel(nn.Module):
 
 
 class MixtralForCausalLM(nn.Module):
+    packed_modules_mapping = {
+        "qkv_proj": [
+            "q_proj",
+            "k_proj",
+            "v_proj",
+        ],
+    }
+
+    # LoRA specific attributes
+    supported_lora_modules = [
+        "qkv_proj",
+        "o_proj",
+        "embed_tokens",
+        "lm_head",
+    ]
+    embedding_modules = {
+        "embed_tokens": "input_embeddings",
+        "lm_head": "output_embeddings",
+    }
+    embedding_padding_modules = ["lm_head"]
 
     def __init__(
         self,
         config: MixtralConfig,
         linear_method: Optional[LinearMethodBase] = None,
+        lora_config: Optional[LoRAConfig] = None,
     ) -> None:
         super().__init__()
         self.config = config
         self.linear_method = linear_method
-        self.model = MixtralModel(config, linear_method)
-        self.lm_head = ParallelLMHead(config.vocab_size,
-                                      config.hidden_size,
-                                      linear_method=linear_method)
-        self.sampler = Sampler(config.vocab_size)
+        self.model = MixtralModel(config,
+                                  linear_method,
+                                  lora_config=lora_config)
+        self.unpadded_vocab_size = config.vocab_size
+        if lora_config:
+            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
+        self.lm_head = ParallelLMHead(
+            self.unpadded_vocab_size,
+            config.hidden_size,
+            linear_method=linear_method,
+            org_num_embeddings=config.vocab_size,
+            padding_size=DEFAULT_VOCAB_PADDING_SIZE
+            # We need bigger padding if using lora for kernel
+            # compatibility
+            if not lora_config else lora_config.lora_vocab_padding_size,
+        )
+        self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size)
 
     def forward(
         self,
@@ -391,6 +431,10 @@ class MixtralForCausalLM(nn.Module):
             ("qkv_proj", "k_proj", "k"),
             ("qkv_proj", "v_proj", "v"),
         ]
+        if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
+        ):
+            stacked_params_mapping = []
+
         expert_params_mapping = [
             # (param_name, weight_name, expert_id)
             ("ws" if weight_name in ["w1", "w3"] else "w2s",
@@ -398,18 +442,18 @@ class MixtralForCausalLM(nn.Module):
             for expert_id in range(self.config.num_local_experts)
             for weight_name in ["w1", "w2", "w3"]
         ]
-        if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
-        ):
-            stacked_params_mapping = []
+
         params_dict = dict(self.named_parameters())
         for name, loaded_weight in hf_model_weights_iterator(
                 model_name_or_path,
                 cache_dir,
                 load_format,
                 revision,
+                self.config,
                 fall_back_to_pt=False):
             if "rotary_emb.inv_freq" in name:
                 continue
+
             for (param_name, weight_name, shard_id) in stacked_params_mapping:
                 if weight_name not in name:
                     continue

+ 7 - 6
aphrodite/modeling/models/mixtral_quant.py

@@ -226,14 +226,12 @@ class MixtralAttention(nn.Module):
             bias=False,
             linear_method=linear_method,
         )
-        is_neox_style = True if linear_method is None or linear_method.quant_config.rope_style(
-        ) is None else linear_method.quant_config.rope_style()
         self.rotary_emb = get_rope(
             self.head_dim,
             rotary_dim=self.head_dim,
             max_position=max_position,
             base=int(self.rope_theta),
-            is_neox_style=is_neox_style,
+            is_neox_style=True,
         )
         self.attn = PagedAttention(
             self.num_heads,
@@ -371,9 +369,11 @@ class MixtralForCausalLM(nn.Module):
         self.config = config
         self.linear_method = linear_method
         self.model = MixtralModel(config, linear_method)
-        self.lm_head = ParallelLMHead(config.vocab_size,
-                                      config.hidden_size,
-                                      linear_method=linear_method)
+        self.lm_head = ParallelLMHead(
+            config.vocab_size,
+            config.hidden_size,
+            linear_method=linear_method,
+        )
         self.sampler = Sampler(config.vocab_size)
 
     def forward(
@@ -410,6 +410,7 @@ class MixtralForCausalLM(nn.Module):
         if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
         ):
             stacked_params_mapping = []
+
         params_dict = dict(self.named_parameters())
         for name, loaded_weight in hf_model_weights_iterator(
                 model_name_or_path,

+ 307 - 0
aphrodite/modeling/models/mpt.py

@@ -0,0 +1,307 @@
+# coding=utf-8
+# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
+import math
+from typing import List, Optional, Tuple
+
+import torch
+import torch.nn as nn
+
+from aphrodite.modeling.metadata import InputMetadata
+from aphrodite.modeling.layers.activation import get_act_fn
+from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
+                                              LinearMethodBase,
+                                              QKVParallelLinear,
+                                              RowParallelLinear)
+from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    VocabParallelEmbedding, ParallelLMHead)
+from aphrodite.modeling.megatron.parallel_state import (
+    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.modeling.hf_downloader import (default_weight_loader,
+                                              hf_model_weights_iterator)
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.transformers_utils.configs.mpt import MPTConfig
+
+KVCache = Tuple[torch.Tensor, torch.Tensor]
+
+
+def _get_alibi_slopes(
+    total_num_heads: int,
+    alibi_bias_max: int,
+) -> torch.Tensor:
+    next_power_of_2 = 2**math.ceil(math.log2(total_num_heads))
+    m = torch.arange(1, next_power_of_2 + 1, dtype=torch.float32)
+    m = m.mul(alibi_bias_max / next_power_of_2)
+    slopes = 1.0 / torch.pow(2, m)
+    if next_power_of_2 != total_num_heads:
+        slopes = torch.concat([slopes[1::2], slopes[::2]])[:total_num_heads]
+    return slopes
+
+
+class MPTAttention(nn.Module):
+
+    def __init__(
+        self,
+        config: MPTConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.d_model = config.d_model
+        self.total_num_heads = config.n_heads
+        self.head_dim = self.d_model // self.total_num_heads
+        self.clip_qkv = config.attn_config["clip_qkv"]
+        self.qk_ln = config.attn_config["qk_ln"]
+        self.alibi_bias_max = config.attn_config["alibi_bias_max"]
+        if "kv_n_heads" in config.attn_config:
+            self.total_num_kv_heads = config.attn_config['kv_n_heads']
+        else:
+            self.total_num_kv_heads = self.total_num_heads
+        assert not config.attn_config["prefix_lm"]
+        assert config.attn_config["alibi"]
+
+        # pylint: disable=invalid-name
+        self.Wqkv = QKVParallelLinear(
+            self.d_model,
+            self.d_model // self.total_num_heads,
+            self.total_num_heads,
+            self.total_num_kv_heads,
+            bias=not config.no_bias,
+            linear_method=linear_method,
+        )
+        if self.qk_ln:
+            self.q_ln = nn.LayerNorm(self.d_model)
+            self.k_ln = nn.LayerNorm(self.d_model)
+        self.out_proj = RowParallelLinear(
+            self.d_model,
+            self.d_model,
+            bias=not config.no_bias,
+            linear_method=linear_method,
+        )
+
+        tp_world_size = get_tensor_model_parallel_world_size()
+        assert self.total_num_heads % tp_world_size == 0
+        self.num_heads = self.total_num_heads // tp_world_size
+
+        if self.total_num_kv_heads >= tp_world_size:
+            # Number of KV heads is greater than TP size, so we partition
+            # the KV heads across multiple tensor parallel GPUs.
+            assert self.total_num_kv_heads % tp_world_size == 0
+        else:
+            # Number of KV heads is less than TP size, so we replicate
+            # the KV heads across multiple tensor parallel GPUs.
+            assert tp_world_size % self.total_num_kv_heads == 0
+        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
+        self.q_size = self.num_heads * self.head_dim
+        self.kv_size = self.num_kv_heads * self.head_dim
+        # Create the alibi slopes and slice them.
+        tp_rank = get_tensor_model_parallel_rank()
+        head_start = tp_rank * self.num_heads
+        head_end = (tp_rank + 1) * self.num_heads
+        alibi_slopes = _get_alibi_slopes(self.total_num_heads,
+                                         self.alibi_bias_max)
+        alibi_slopes = alibi_slopes[head_start:head_end].tolist()
+
+        self.head_dim = self.d_model // self.total_num_heads
+        scaling = self.head_dim**-0.5
+        self.attn = PagedAttention(self.num_heads,
+                                   self.head_dim,
+                                   scaling,
+                                   alibi_slopes=alibi_slopes,
+                                   num_kv_heads=self.num_kv_heads)
+
+    def forward(
+        self,
+        position_ids: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        del position_ids  # unused.
+        qkv, _ = self.Wqkv(hidden_states)
+        if self.clip_qkv is not None:
+            qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
+        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+        if self.qk_ln:
+            q = self.q_ln(q)
+            k = self.k_ln(k)
+        k_cache, v_cache = kv_cache
+        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
+        output, _ = self.out_proj(attn_output)
+        return output
+
+
+class MPTMLP(nn.Module):
+
+    def __init__(
+        self,
+        config: MPTConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        hidden_size = config.d_model
+        expansion_ratio = config.expansion_ratio
+        intermediate_size = expansion_ratio * hidden_size
+        self.up_proj = ColumnParallelLinear(
+            hidden_size,
+            intermediate_size,
+            bias=not config.no_bias,
+            linear_method=linear_method,
+        )
+        quant_config = getattr(linear_method, "quant_config", None)
+        self.act = get_act_fn("gelu", quant_config, intermediate_size)
+        self.down_proj = RowParallelLinear(
+            intermediate_size,
+            hidden_size,
+            bias=not config.no_bias,
+            linear_method=linear_method,
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x, _ = self.up_proj(x)
+        x = self.act(x)
+        x, _ = self.down_proj(x)
+        return x
+
+
+class MPTBlock(nn.Module):
+
+    def __init__(
+        self,
+        config: MPTConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        hidden_size = config.d_model
+        self.norm_1 = nn.LayerNorm(hidden_size)
+        self.attn = MPTAttention(config, linear_method)
+        self.norm_2 = nn.LayerNorm(hidden_size)
+        self.ffn = MPTMLP(config, linear_method)
+
+    def forward(
+        self,
+        position_ids: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        x = self.norm_1(hidden_states)
+        x = self.attn(
+            position_ids=position_ids,
+            hidden_states=x,
+            kv_cache=kv_cache,
+            input_metadata=input_metadata,
+        )
+        hidden_states = hidden_states + x
+        x = self.norm_2(hidden_states)
+        x = self.ffn(x)
+        hidden_states = hidden_states + x
+        return hidden_states
+
+
+class MPTModel(nn.Module):
+
+    def __init__(
+        self,
+        config: MPTConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        assert config.embedding_fraction == 1.0
+        assert config.norm_type == "low_precision_layernorm"
+
+        self.wte = VocabParallelEmbedding(config.vocab_size,
+                                          config.d_model,
+                                          linear_method=linear_method)
+        self.blocks = nn.ModuleList(
+            [MPTBlock(config, linear_method) for _ in range(config.n_layers)])
+        self.norm_f = nn.LayerNorm(config.d_model)
+        if config.no_bias:
+            for module in self.modules():
+                if hasattr(module, "bias") and isinstance(
+                        module.bias, nn.Parameter):
+                    # Remove the bias term in Linear and LayerNorm.
+                    module.register_parameter("bias", None)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        position_ids: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.wte(input_ids)
+        for i in range(len(self.blocks)):
+            block = self.blocks[i]
+            hidden_states = block(
+                position_ids,
+                hidden_states,
+                kv_caches[i],
+                input_metadata,
+            )
+        hidden_states = self.norm_f(hidden_states)
+        return hidden_states
+
+
+class MPTForCausalLM(nn.Module):
+
+    def __init__(
+        self,
+        config: MPTConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.config = config
+        assert config.tie_word_embeddings
+        self.linear_method = linear_method
+
+        self.transformer = MPTModel(config, linear_method)
+        # self.lm_head_weight = self.transformer.wte.weight
+        self.lm_head = ParallelLMHead(config.vocab_size,
+                                      config.hidden_size,
+                                      linear_method=linear_method)
+        self.sampler = Sampler(config.vocab_size)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.transformer(input_ids, positions, kv_caches,
+                                         input_metadata)
+        return hidden_states
+
+    def sample(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(self.lm_head(hidden_states),
+                                   sampling_metadata)
+        return next_tokens
+
+    def load_weights(self,
+                     model_name_or_path: str,
+                     cache_dir: Optional[str] = None,
+                     load_format: str = "auto",
+                     revision: Optional[str] = None):
+        params_dict = dict(self.named_parameters(remove_duplicate=False))
+        for name, loaded_weight in hf_model_weights_iterator(
+                model_name_or_path, cache_dir, load_format, revision,
+                self.config):
+            # Skip loading extra bias for GPTQ models.
+            if name.endswith(".bias") and name not in params_dict:
+                continue
+            if "wte.weight" in name:
+                # Copy word embedding to lm_head
+                lm_head_param = params_dict["lm_head.weight"]
+                weight_loader = getattr(lm_head_param, "weight_loader",
+                                        default_weight_loader)
+                weight_loader(lm_head_param, loaded_weight)
+            param = params_dict[name]
+            weight_loader = getattr(param, "weight_loader",
+                                    default_weight_loader)
+            weight_loader(param, loaded_weight)

+ 377 - 0
aphrodite/modeling/models/olmo.py

@@ -0,0 +1,377 @@
+# coding=utf-8
+# Adapted from
+# https://github.com/allenai/OLMo/blob/v0.2.4/olmo/model.py and
+# https://github.com/allenai/OLMo/blob/v0.2.4/hf_olmo/modeling_olmo.py
+# Copyright 2023 The PygmalionAI team.
+# Copyright 2023 The vLLM team.
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+#
+# BSD 3-Clause License
+#
+# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+"""Inference-only OLMo model compatible with HuggingFace weights."""
+from typing import List, Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from aphrodite.modeling.metadata import InputMetadata
+from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.linear import (
+    ColumnParallelLinear,
+    LinearMethodBase,
+    QKVParallelLinear,
+    RowParallelLinear,
+)
+from aphrodite.modeling.layers.rotary_embedding import get_rope
+from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    VocabParallelEmbedding, ParallelLMHead)
+from aphrodite.modeling.megatron.parallel_state import (
+    get_tensor_model_parallel_world_size, )
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.modeling.hf_downloader import (
+    default_weight_loader,
+    hf_model_weights_iterator,
+)
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.transformers_utils.configs.olmo import OLMoConfig
+
+KVCache = Tuple[torch.Tensor, torch.Tensor]
+
+
+class SwiGLU(nn.Module):
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x, gate = x.chunk(2, dim=-1)
+        return F.silu(gate) * x
+
+    @property
+    def output_multiplier(self) -> float:
+        return 0.5
+
+
+class OlmoAttention(nn.Module):
+    """
+    This is the attention block where the output is computed as ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
+    (plus another skip connection).
+    """
+
+    def __init__(
+        self,
+        config: OLMoConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.config = config
+        self.hidden_size = config.d_model
+        assert config.d_model % config.n_heads == 0
+        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
+        )
+        self.total_num_heads = self.config.n_heads
+        assert self.total_num_heads % tensor_model_parallel_world_size == 0
+        self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
+        self.head_dim = self.hidden_size // self.total_num_heads
+
+        # Layer norms.
+        self.attn_norm = nn.LayerNorm(config.d_model,
+                                      elementwise_affine=False,
+                                      bias=False)
+        # Attention input projection. Projects x -> (q, k, v)
+        self.att_proj = QKVParallelLinear(
+            config.d_model,
+            self.head_dim,
+            self.total_num_heads,
+            bias=config.include_bias,
+            linear_method=linear_method,
+        )
+
+        # Rotary embeddings.
+        if self.config.rope:
+            rope_theta = getattr(config, "rope_theta", 10000)
+            max_position_embeddings = getattr(config,
+                                              "max_position_embeddings", 8192)
+            self.rotary_emb = get_rope(
+                self.head_dim,
+                rotary_dim=self.head_dim,
+                max_position=max_position_embeddings,
+                base=rope_theta,
+            )
+        self.scaling = self.head_dim**-0.5
+        self.attn = PagedAttention(self.num_heads,
+                                   self.head_dim,
+                                   scale=self.scaling)
+
+        # Attention output projection.
+        self.attn_out = RowParallelLinear(
+            config.d_model,
+            config.d_model,
+            bias=config.include_bias,
+            linear_method=linear_method,
+        )
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.attn_norm(hidden_states)
+        qkv, _ = self.att_proj(hidden_states)
+        q, k, v = qkv.chunk(chunks=3, dim=-1)
+        if self.config.rope:
+            q, k = self.rotary_emb(positions, q, k)
+        k_cache, v_cache = kv_cache
+        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
+        output, _ = self.attn_out(attn_output)
+        return output
+
+
+class OlmoMLP(nn.Module):
+    """
+    This is the MLP block where the output is computed as ``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
+    (plus another skip connection).
+    """
+
+    def __init__(
+        self,
+        config: OLMoConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.config = config
+        self.hidden_size = (config.mlp_hidden_size if config.mlp_hidden_size
+                            is not None else config.mlp_ratio * config.d_model)
+
+        # Layer norms.
+        self.ff_norm = nn.LayerNorm(config.d_model,
+                                    elementwise_affine=False,
+                                    bias=False)
+
+        # Feed-forward input projection.
+        self.ff_proj = ColumnParallelLinear(
+            config.d_model,
+            self.hidden_size,
+            bias=config.include_bias,
+            linear_method=linear_method,
+        )
+
+        # Activation function.
+        # self.act = SiluAndMul()
+        # self.act.output_multiplier = 0.5
+        self.act = SwiGLU()
+        assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
+
+        # Feed-forward output projection.
+        self.ff_out = RowParallelLinear(
+            int(self.act.output_multiplier * self.hidden_size),
+            config.d_model,
+            bias=config.include_bias,
+            linear_method=linear_method,
+        )
+
+    def forward(
+        self,
+        x: torch.Tensor,
+    ) -> torch.Tensor:
+        # Add feed-forward projection.
+        # shape: (batch_size, seq_len, d_model)
+        og_x = x
+        x = self.ff_norm(x)
+        x, _ = self.ff_proj(x)
+        x = self.act(x)
+        x, _ = self.ff_out(x)
+        x = og_x + x
+
+        return x
+
+
+class OlmoBlock(nn.Module):
+    """
+    This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
+    (plus another skip connection).
+    """
+
+    def __init__(self,
+                 config: OLMoConfig,
+                 linear_method: Optional[LinearMethodBase] = None):
+        super().__init__()
+        # Attention block.
+        self.attn = OlmoAttention(config, linear_method)
+
+        # MLP block.
+        self.mlp = OlmoMLP(config, linear_method)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
+        # Attention block.
+        og_x = hidden_states
+        x = self.attn(positions, hidden_states, kv_cache, input_metadata)
+        x = x + og_x
+
+        # MLP block.
+        hidden_states = self.mlp(x)
+        return hidden_states
+
+
+class OlmoModel(nn.Module):
+
+    def __init__(self,
+                 config: OLMoConfig,
+                 linear_method: Optional[LinearMethodBase] = None):
+        super().__init__()
+        self.config = config
+
+        self.transformer = nn.ModuleDict(
+            dict(wte=VocabParallelEmbedding(
+                config.embedding_size or config.vocab_size,
+                config.d_model,
+                linear_method=linear_method,
+            ),
+                 ln_f=nn.LayerNorm(config.d_model,
+                                   elementwise_affine=False,
+                                   bias=False),
+                 ff_out=ParallelLMHead(
+                     config.embedding_size or config.vocab_size,
+                     config.d_model,
+                     bias=config.include_bias,
+                     linear_method=linear_method,
+                 )))
+
+        blocks = [
+            OlmoBlock(config, linear_method) for i in range(config.n_layers)
+        ]
+        if self.config.block_group_size > 1:
+            raise NotImplementedError("Block group size > 1 not supported yet")
+        else:
+            self.transformer.update({"blocks": nn.ModuleList(blocks)})
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        """
+        :param input_ids: A tensor of shape `(batch_size, seq_len)`.
+        """
+        # Get embeddings of input.
+        # shape: (batch_size, seq_len, d_model)
+        x = self.transformer.wte(input_ids)  # type: ignore
+
+        # Apply blocks one-by-one.
+        for block_idx, block in enumerate(self.transformer.blocks):
+            # shape: (batch_size, seq_len, d_model)
+            x = block(
+                positions,
+                x,
+                kv_caches[block_idx],
+                input_metadata,
+            )
+
+        # Apply final layer norm.
+        # shape: (batch_size, seq_len or 1, d_model)
+        x = self.transformer.ln_f(x)  # type: ignore
+        return x
+
+
+class OLMoForCausalLM(nn.Module):
+    """
+    Extremely barebones HF model wrapper.
+    """
+
+    def __init__(self,
+                 config: OLMoConfig,
+                 linear_method: Optional[LinearMethodBase] = None):
+        super().__init__()
+        self.config = config
+        self.linear_method = linear_method
+        self.model = OlmoModel(config, linear_method)
+        self.sampler = Sampler(config.vocab_size)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.model(
+            input_ids=input_ids,
+            positions=positions,
+            kv_caches=kv_caches,
+            input_metadata=input_metadata,
+        )
+        return hidden_states
+
+    def sample(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(
+            self.model.transformer.ff_out(hidden_states), sampling_metadata)
+        return next_tokens
+
+    def load_weights(
+        self,
+        model_name_or_path: str,
+        cache_dir: Optional[str] = None,
+        load_format: str = "auto",
+        revision: Optional[str] = None,
+    ):
+        params_dict = dict(self.named_parameters(remove_duplicate=False))
+        for name, loaded_weight in hf_model_weights_iterator(
+                model_name_or_path, cache_dir, load_format, revision):
+            if "wte.weight" in name and self.config.weight_tying:
+                # Copy word embedding to lm_head
+                lm_head_param = params_dict["model.transformer.ff_out.weight"]
+                weight_loader = getattr(lm_head_param, "weight_loader",
+                                        default_weight_loader)
+                weight_loader(lm_head_param, loaded_weight)
+            # attention
+            if ".att" in name:
+                name = name.replace(".att", ".attn.att")
+            # mlp
+            if ".ff" in name and "transformer.ff_out" not in name:
+                name = name.replace(".ff", ".mlp.ff")
+            # there is no bias in olmo
+            param = params_dict[name]
+            weight_loader = getattr(param, "weight_loader",
+                                    default_weight_loader)
+            weight_loader(param, loaded_weight)

+ 0 - 1
aphrodite/modeling/models/phi.py

@@ -287,7 +287,6 @@ class PhiForCausalLM(nn.Module):
         hidden_states: torch.Tensor,
         sampling_metadata: SamplingMetadata,
     ) -> Optional[SamplerOutput]:
-        head = self.lm_head  # pylint: disable=unused-variable
         next_tokens = self.sampler(self.lm_head(hidden_states),
                                    sampling_metadata)
         return next_tokens

+ 315 - 0
aphrodite/modeling/models/qwen.py

@@ -0,0 +1,315 @@
+# coding=utf-8
+# Adapted from
+# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
+# Copyright (c) Alibaba Cloud.
+# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
+"""Inference-only QWen model compatible with HuggingFace weights."""
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+from torch import nn
+
+from aphrodite.modeling.metadata import InputMetadata
+from aphrodite.modeling.layers.activation import SiluAndMul
+from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.layernorm import RMSNorm
+from aphrodite.modeling.layers.linear import (LinearMethodBase,
+                                              MergedColumnParallelLinear,
+                                              QKVParallelLinear,
+                                              RowParallelLinear,
+                                              ColumnParallelLinear)
+from aphrodite.modeling.layers.rotary_embedding import get_rope
+from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    VocabParallelEmbedding, ParallelLMHead)
+from aphrodite.modeling.megatron.parallel_state import (
+    get_tensor_model_parallel_world_size)
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.modeling.hf_downloader import (default_weight_loader,
+                                              hf_model_weights_iterator)
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.transformers_utils.configs.qwen import QWenConfig
+
+KVCache = Tuple[torch.Tensor, torch.Tensor]
+
+
+class QWenMLP(nn.Module):
+
+    def __init__(
+        self,
+        hidden_size: int,
+        intermediate_size: int,
+        hidden_act: str = "silu",
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        if linear_method is not None and not linear_method.quant_config.merge_weight(
+        ):
+            self.merge_weight = False
+            self.w2 = ColumnParallelLinear(hidden_size,
+                                           intermediate_size,
+                                           bias=False,
+                                           linear_method=linear_method)
+            self.w1 = ColumnParallelLinear(hidden_size,
+                                           intermediate_size,
+                                           bias=False,
+                                           linear_method=linear_method)
+        else:
+            self.merge_weight = True
+            self.gate_up_proj = MergedColumnParallelLinear(
+                hidden_size, [intermediate_size] * 2,
+                bias=False,
+                linear_method=linear_method)
+        self.c_proj = RowParallelLinear(intermediate_size,
+                                        hidden_size,
+                                        bias=False,
+                                        linear_method=linear_method)
+        if hidden_act != "silu":
+            raise ValueError(f"Unsupported activation: {hidden_act}. "
+                             "Only silu is supported for now.")
+        self.act_fn = SiluAndMul()
+
+    def forward(self, x):
+        if self.merge_weight:
+            gate_up, _ = self.gate_up_proj(x)
+        else:
+            up, _ = self.w1(x)
+            gate, _ = self.w2(x)
+            gate_up = torch.cat([gate, up], dim=-1)
+        x = self.act_fn(gate_up)
+        x, _ = self.c_proj(x)
+        return x
+
+
+class QWenAttention(nn.Module):
+
+    def __init__(
+        self,
+        hidden_size: int,
+        num_heads: int,
+        max_position_embeddings: int,
+        rope_theta: float = 10000,
+        rope_scaling: Optional[Dict[str, Any]] = None,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.hidden_size = hidden_size
+        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
+        )
+        self.total_num_heads = num_heads
+        assert self.total_num_heads % tensor_model_parallel_world_size == 0
+        self.num_heads = (self.total_num_heads //
+                          tensor_model_parallel_world_size)
+        self.head_dim = hidden_size // self.total_num_heads
+        self.c_attn = QKVParallelLinear(
+            hidden_size,
+            self.head_dim,
+            self.total_num_heads,
+            bias=True,
+            linear_method=linear_method,
+        )
+        self.c_proj = RowParallelLinear(
+            self.total_num_heads * self.head_dim,
+            hidden_size,
+            bias=False,
+            linear_method=linear_method,
+        )
+        self.scaling = self.head_dim**-0.5
+
+        is_neox_style = True if linear_method is None or linear_method.quant_config.rope_style(
+        ) is None else linear_method.quant_config.rope_style()
+        self.rotary_emb = get_rope(
+            self.head_dim,
+            rotary_dim=self.head_dim,
+            max_position=max_position_embeddings,
+            base=rope_theta,
+            rope_scaling=rope_scaling,
+            is_neox_style=is_neox_style,
+        )
+        self.attn = PagedAttention(self.num_heads, self.head_dim, self.scaling)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        qkv, _ = self.c_attn(hidden_states)
+        q, k, v = qkv.chunk(chunks=3, dim=-1)
+        q, k = self.rotary_emb(positions, q, k)
+        k_cache, v_cache = kv_cache
+        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
+
+        output, _ = self.c_proj(attn_output)
+        return output
+
+
+class QWenBlock(nn.Module):
+
+    def __init__(
+        self,
+        config: QWenConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
+
+        rope_theta = getattr(config, "rope_theta", 10000)
+        rope_scaling = getattr(config, "rope_scaling", None)
+        self.attn = QWenAttention(config.hidden_size,
+                                  config.num_attention_heads,
+                                  config.max_position_embeddings,
+                                  rope_theta=rope_theta,
+                                  rope_scaling=rope_scaling,
+                                  linear_method=linear_method)
+
+        self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
+
+        self.mlp = QWenMLP(config.hidden_size,
+                           config.intermediate_size // 2,
+                           linear_method=linear_method)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+        residual: Optional[torch.Tensor],
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        # Self Attention
+        if residual is None:
+            residual = hidden_states
+            hidden_states = self.ln_1(hidden_states)
+        else:
+            hidden_states, residual = self.ln_1(hidden_states, residual)
+        hidden_states = self.attn(
+            positions=positions,
+            hidden_states=hidden_states,
+            kv_cache=kv_cache,
+            input_metadata=input_metadata,
+        )
+
+        # Fully Connected
+        hidden_states, residual = self.ln_2(hidden_states, residual)
+        hidden_states = self.mlp(hidden_states)
+        return hidden_states, residual
+
+
+class QWenModel(nn.Module):
+
+    def __init__(
+        self,
+        config: QWenConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.config = config
+        self.vocab_size = config.vocab_size
+
+        self.wte = VocabParallelEmbedding(config.vocab_size,
+                                          config.hidden_size,
+                                          linear_method=linear_method)
+        self.h = nn.ModuleList([
+            QWenBlock(config, linear_method)
+            for _ in range(config.num_hidden_layers)
+        ])
+        self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.wte(input_ids)
+        residual = None
+        for i in range(len(self.h)):
+            layer = self.h[i]
+            hidden_states, residual = layer(
+                positions,
+                hidden_states,
+                kv_caches[i],
+                input_metadata,
+                residual,
+            )
+        hidden_states, _ = self.ln_f(hidden_states, residual)
+        return hidden_states
+
+
+class QWenLMHeadModel(nn.Module):
+
+    def __init__(
+        self,
+        config: QWenConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.config = config
+        self.linear_method = linear_method
+        self.transformer = QWenModel(config, linear_method)
+        self.lm_head = ParallelLMHead(config.vocab_size,
+                                      config.hidden_size,
+                                      linear_method=linear_method)
+        self.sampler = Sampler(config.vocab_size)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.transformer(input_ids, positions, kv_caches,
+                                         input_metadata)
+        return hidden_states
+
+    def sample(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(self.lm_head(hidden_states),
+                                   sampling_metadata)
+        return next_tokens
+
+    def load_weights(self,
+                     model_name_or_path: str,
+                     cache_dir: Optional[str] = None,
+                     load_format: str = "auto",
+                     revision: Optional[str] = None):
+        stacked_params_mapping = [
+            # (param_name, shard_name, shard_id)
+            ("gate_up_proj", "w2", 0),
+            ("gate_up_proj", "w1", 1),
+        ]
+        if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
+        ):
+            stacked_params_mapping = []
+        params_dict = dict(self.named_parameters())
+        for name, loaded_weight in hf_model_weights_iterator(
+                model_name_or_path, cache_dir, load_format, revision,
+                self.config):
+            if "rotary_emb.inv_freq" in name:
+                continue
+            for (param_name, weight_name, shard_id) in stacked_params_mapping:
+                if weight_name not in name:
+                    continue
+                name = name.replace(weight_name, param_name)
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+                param = params_dict[name]
+                weight_loader = param.weight_loader
+                weight_loader(param, loaded_weight, shard_id)
+                break
+            else:
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+                param = params_dict[name]
+                weight_loader = getattr(param, "weight_loader",
+                                        default_weight_loader)
+                weight_loader(param, loaded_weight)

+ 35 - 52
aphrodite/modeling/models/mistral.py → aphrodite/modeling/models/qwen2.py

@@ -1,7 +1,8 @@
 # coding=utf-8
 # Adapted from
-# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
+# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
 # Copyright 2023 The PygmalionAI team.
+# Copyright 2024 The Qwen team.
 # Copyright 2023 The vLLM team.
 # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
 #
@@ -21,38 +22,37 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-"""Inference-only Mistral model compatible with HuggingFace weights."""
+"""Inference-only Qwen2 model compatible with HuggingFace weights."""
 from typing import List, Optional, Tuple
 
 import torch
 from torch import nn
-from transformers import MistralConfig
+from transformers import Qwen2Config
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.attention import PagedAttention
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (LinearMethodBase,
+                                              ColumnParallelLinear,
                                               MergedColumnParallelLinear,
                                               QKVParallelLinear,
-                                              RowParallelLinear,
-                                              ColumnParallelLinear)
+                                              RowParallelLinear)
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
+    VocabParallelEmbedding, ParallelLMHead)
 from aphrodite.modeling.megatron.parallel_state import (
     get_tensor_model_parallel_world_size)
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.hf_downloader import (default_weight_loader,
                                               hf_model_weights_iterator)
 from aphrodite.common.sequence import SamplerOutput
-from aphrodite.common.config import LoRAConfig
 
 KVCache = Tuple[torch.Tensor, torch.Tensor]
 
 
-class MistralMLP(nn.Module):
+class Qwen2MLP(nn.Module):
 
     def __init__(
         self,
@@ -100,7 +100,7 @@ class MistralMLP(nn.Module):
         return x
 
 
-class MistralAttention(nn.Module):
+class Qwen2Attention(nn.Module):
 
     def __init__(self,
                  hidden_size: int,
@@ -108,6 +108,7 @@ class MistralAttention(nn.Module):
                  num_kv_heads: int,
                  max_position: int = 4096 * 32,
                  rope_theta: float = 10000,
+                 use_sliding_window: bool = False,
                  linear_method: Optional[LinearMethodBase] = None,
                  sliding_window: Optional[int] = None) -> None:
         super().__init__()
@@ -131,22 +132,22 @@ class MistralAttention(nn.Module):
         self.kv_size = self.num_kv_heads * self.head_dim
         self.scaling = self.head_dim**-0.5
         self.rope_theta = rope_theta
-        self.sliding_window = sliding_window
+        self.sliding_window = sliding_window if use_sliding_window else None
 
         if linear_method is not None and not linear_method.quant_config.merge_weight(
         ):
             self.merge_weight = False
             self.q_proj = ColumnParallelLinear(hidden_size,
                                                self.q_size,
-                                               bias=False,
+                                               bias=True,
                                                linear_method=linear_method)
             self.k_proj = ColumnParallelLinear(hidden_size,
                                                self.kv_size,
-                                               bias=False,
+                                               bias=True,
                                                linear_method=linear_method)
             self.v_proj = ColumnParallelLinear(hidden_size,
                                                self.kv_size,
-                                               bias=False,
+                                               bias=True,
                                                linear_method=linear_method)
         else:
             self.merge_weight = True
@@ -155,7 +156,7 @@ class MistralAttention(nn.Module):
                 self.head_dim,
                 self.total_num_heads,
                 self.total_num_kv_heads,
-                bias=False,
+                bias=True,
                 linear_method=linear_method,
             )
         self.o_proj = RowParallelLinear(
@@ -165,14 +166,11 @@ class MistralAttention(nn.Module):
             linear_method=linear_method,
         )
 
-        is_neox_style = True if linear_method is None or linear_method.quant_config.rope_style(
-        ) is None else linear_method.quant_config.rope_style()
         self.rotary_emb = get_rope(
             self.head_dim,
             rotary_dim=self.head_dim,
             max_position=max_position,
             base=self.rope_theta,
-            is_neox_style=is_neox_style,
         )
         self.attn = PagedAttention(self.num_heads,
                                    self.head_dim,
@@ -202,26 +200,29 @@ class MistralAttention(nn.Module):
         return output
 
 
-class MistralDecoderLayer(nn.Module):
+class Qwen2DecoderLayer(nn.Module):
 
     def __init__(
         self,
-        config: MistralConfig,
+        config: Qwen2Config,
+        layer_idx: int,
         linear_method: Optional[LinearMethodBase] = None,
     ) -> None:
         super().__init__()
         self.hidden_size = config.hidden_size
         # Requires transformers > 4.32.0
-        rope_theta = getattr(config, "rope_theta", 10000)
-        self.self_attn = MistralAttention(
+        rope_theta = getattr(config, "rope_theta", 1000000)
+        use_sliding_window = config.use_sliding_window and layer_idx < config.max_window_layers
+        self.self_attn = Qwen2Attention(
             hidden_size=self.hidden_size,
             num_heads=config.num_attention_heads,
             max_position=config.max_position_embeddings,
             num_kv_heads=config.num_key_value_heads,
             rope_theta=rope_theta,
+            use_sliding_window=use_sliding_window,
             linear_method=linear_method,
             sliding_window=config.sliding_window)
-        self.mlp = MistralMLP(
+        self.mlp = Qwen2MLP(
             hidden_size=self.hidden_size,
             intermediate_size=config.intermediate_size,
             hidden_act=config.hidden_act,
@@ -261,31 +262,26 @@ class MistralDecoderLayer(nn.Module):
         return hidden_states, residual
 
 
-class MistralModel(nn.Module):
+class Qwen2Model(nn.Module):
 
     def __init__(
         self,
-        config: MistralConfig,
+        config: Qwen2Config,
         linear_method: Optional[LinearMethodBase] = None,
-        lora_config: Optional[LoRAConfig] = None,
     ) -> None:
         super().__init__()
         self.config = config
         self.padding_idx = config.pad_token_id
-        lora_vocab = (lora_config.lora_extra_vocab_size *
-                      (lora_config.max_loras or 1)) if lora_config else 0
-        self.vocab_size = config.vocab_size + lora_vocab
-        self.org_vocab_size = config.vocab_size
+        self.vocab_size = config.vocab_size
 
         self.embed_tokens = VocabParallelEmbedding(
-            self.vocab_size,
+            config.vocab_size,
             config.hidden_size,
             linear_method=linear_method,
-            org_num_embeddings=config.vocab_size,
         )
         self.layers = nn.ModuleList([
-            MistralDecoderLayer(config, linear_method)
-            for _ in range(config.num_hidden_layers)
+            Qwen2DecoderLayer(config, layer_idx, linear_method)
+            for layer_idx in range(config.num_hidden_layers)
         ])
         self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
@@ -311,35 +307,23 @@ class MistralModel(nn.Module):
         return hidden_states
 
 
-class MistralForCausalLM(nn.Module):
-    supports_lora = True
+class Qwen2ForCausalLM(nn.Module):
 
     def __init__(
         self,
-        config: MistralConfig,
+        config: Qwen2Config,
         linear_method: Optional[LinearMethodBase] = None,
-        lora_config: Optional[LoRAConfig] = None,
     ) -> None:
         super().__init__()
         self.config = config
         self.linear_method = linear_method
-        self.model = MistralModel(config,
-                                  linear_method,
-                                  lora_config=lora_config)
-        unpadded_vocab_size = config.vocab_size
-        if lora_config:
-            unpadded_vocab_size += lora_config.lora_extra_vocab_size
+        self.model = Qwen2Model(config, linear_method)
         self.lm_head = ParallelLMHead(
-            unpadded_vocab_size,
+            config.vocab_size,
             config.hidden_size,
             linear_method=linear_method,
-            org_num_embeddings=config.vocab_size,
-            padding_size=DEFAULT_VOCAB_PADDING_SIZE
-            # We need bigger padding if using lora for kernel
-            # compatibility
-            if not lora_config else lora_config.lora_vocab_padding_size,
         )
-        self.sampler = Sampler(unpadded_vocab_size, config.vocab_size)
+        self.sampler = Sampler(config.vocab_size)
 
     def forward(
         self,
@@ -379,8 +363,7 @@ class MistralForCausalLM(nn.Module):
             stacked_params_mapping = []
         params_dict = dict(self.named_parameters())
         for name, loaded_weight in hf_model_weights_iterator(
-                model_name_or_path, cache_dir, load_format, revision,
-                self.config):
+                model_name_or_path, cache_dir, load_format, revision):
             if "rotary_emb.inv_freq" in name:
                 continue
             for (param_name, weight_name, shard_id) in stacked_params_mapping:

+ 352 - 0
aphrodite/modeling/models/stablelm.py

@@ -0,0 +1,352 @@
+# coding=utf-8
+# Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# This code is based off the following work:
+# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/modeling_stablelm_epoch.py
+# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
+"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM) model compatible with HuggingFace weights."""
+from typing import List, Optional, Tuple
+
+import torch
+from torch import nn
+from transformers import PretrainedConfig
+
+from aphrodite.modeling.metadata import InputMetadata
+from aphrodite.modeling.layers.activation import SiluAndMul
+from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.linear import (LinearMethodBase,
+                                              MergedColumnParallelLinear,
+                                              QKVParallelLinear,
+                                              RowParallelLinear,
+                                              ColumnParallelLinear)
+from aphrodite.modeling.layers.rotary_embedding import get_rope
+from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    VocabParallelEmbedding, ParallelLMHead)
+from aphrodite.modeling.megatron.parallel_state import (
+    get_tensor_model_parallel_world_size)
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.modeling.hf_downloader import (default_weight_loader,
+                                              hf_model_weights_iterator)
+from aphrodite.common.sequence import SamplerOutput
+
+KVCache = Tuple[torch.Tensor, torch.Tensor]
+
+
+class StablelmMLP(nn.Module):
+
+    def __init__(self,
+                 config: PretrainedConfig,
+                 linear_method: Optional[LinearMethodBase] = None) -> None:
+        super().__init__()
+        self.config = config
+        self.hidden_size = config.hidden_size
+        self.intermediate_size = config.intermediate_size
+        if linear_method is not None and not linear_method.quant_config.merge_weight(
+        ):
+            self.merge_weight = False
+            self.gate_proj = ColumnParallelLinear(config.hidden_size,
+                                                  config.intermediate_size,
+                                                  bias=False,
+                                                  linear_method=linear_method)
+            self.up_proj = ColumnParallelLinear(config.hidden_size,
+                                                config.intermediate_size,
+                                                bias=False,
+                                                linear_method=linear_method)
+        else:
+            self.merge_weight = True
+            self.gate_up_proj = MergedColumnParallelLinear(
+                config.hidden_size, [config.intermediate_size] * 2,
+                bias=False,
+                linear_method=linear_method)
+        self.down_proj = RowParallelLinear(config.intermediate_size,
+                                           config.hidden_size,
+                                           bias=False)
+        self.act_fn = SiluAndMul()
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        if self.merge_weight:
+            gate_up, _ = self.gate_up_proj(x)
+        else:
+            up, _ = self.up_proj(x)
+            gate, _ = self.gate_proj(x)
+            gate_up = torch.cat([gate, up], dim=-1)
+        x = self.act_fn(gate_up)
+        x, _ = self.down_proj(x)
+        return x
+
+
+class StablelmAttention(nn.Module):
+
+    def __init__(self,
+                 config: PretrainedConfig,
+                 linear_method: Optional[LinearMethodBase] = None) -> None:
+        super().__init__()
+        self.config = config
+        self.hidden_size = config.hidden_size
+        tp_size = get_tensor_model_parallel_world_size()
+        self.total_num_heads = config.num_attention_heads
+        self.num_heads = self.total_num_heads // tp_size
+
+        self.total_num_key_value_heads = config.num_key_value_heads
+        if self.total_num_key_value_heads >= tp_size:
+            # Number of KV heads is greater than TP size, so we partition
+            # the KV heads across multiple tensor parallel GPUs.
+            assert self.total_num_key_value_heads % tp_size == 0
+        else:
+            # Number of KV heads is less than TP size, so we replicate
+            # the KV heads across multiple tensor parallel GPUs.
+            assert tp_size % self.total_num_key_value_heads == 0
+        self.num_key_value_heads = max(
+            1, self.total_num_key_value_heads // tp_size)
+        self.head_dim = self.hidden_size // self.total_num_heads
+        self.max_position_embeddings = config.max_position_embeddings
+        self.rotary_ndims = int(self.head_dim * self.config.rope_pct)
+        self.scaling = self.head_dim**-0.5
+        self.q_size = self.num_heads * self.head_dim
+        self.kv_size = self.num_key_value_heads * self.head_dim
+        self.qkv_bias = getattr(config, "use_qkv_bias", False)
+        if (self.head_dim * self.num_heads * tp_size) != self.hidden_size:
+            raise ValueError(
+                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+                f" and `num_heads`: {self.num_heads}).")
+
+        if linear_method is not None and not linear_method.quant_config.merge_weight(
+        ):
+            self.merge_weight = False
+            self.q_proj = ColumnParallelLinear(self.hidden_size,
+                                               self.q_size,
+                                               bias=self.qkv_bias,
+                                               linear_method=linear_method)
+            self.k_proj = ColumnParallelLinear(self.hidden_size,
+                                               self.kv_size,
+                                               bias=self.qkv_bias,
+                                               linear_method=linear_method)
+            self.v_proj = ColumnParallelLinear(self.hidden_size,
+                                               self.kv_size,
+                                               bias=self.qkv_bias,
+                                               linear_method=linear_method)
+        else:
+            self.merge_weight = True
+            self.qkv_proj = QKVParallelLinear(
+                self.hidden_size,
+                self.head_dim,
+                self.total_num_heads,
+                self.total_num_key_value_heads,
+                self.qkv_bias,
+                linear_method=linear_method,
+            )
+        self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
+                                        self.hidden_size,
+                                        bias=False,
+                                        linear_method=linear_method)
+        self.rotary_ndims = int(self.head_dim * self.config.rope_pct)
+        is_neox_style = True if linear_method is None or linear_method.quant_config.rope_style(
+        ) is None else linear_method.quant_config.rope_style()
+        self.rotary_emb = get_rope(
+            self.head_dim,
+            rotary_dim=self.rotary_ndims,
+            max_position=self.config.max_position_embeddings,
+            base=self.config.rope_theta,
+            is_neox_style=is_neox_style,
+        )
+        self.attn = PagedAttention(self.num_heads,
+                                   self.head_dim,
+                                   self.scaling,
+                                   num_kv_heads=self.num_key_value_heads)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        if self.merge_weight:
+            qkv, _ = self.qkv_proj(hidden_states)
+            q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
+                                dim=-1)
+        else:
+            q, _ = self.q_proj(hidden_states)
+            k, _ = self.k_proj(hidden_states)
+            v, _ = self.v_proj(hidden_states)
+        q, k = self.rotary_emb(positions, q, k)
+        k_cache, v_cache = kv_cache
+        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
+        output, _ = self.o_proj(attn_output)
+        return output
+
+
+class StablelmDecoderLayer(nn.Module):
+
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ) -> None:
+        super().__init__()
+        self.self_attn = StablelmAttention(config)
+        self.mlp = StablelmMLP(config, linear_method)
+        self.input_layernorm = nn.LayerNorm(config.hidden_size,
+                                            eps=config.norm_eps)
+        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
+                                                     eps=config.norm_eps)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        # Self Attention
+        residual = hidden_states
+        hidden_states = self.input_layernorm(hidden_states)
+        hidden_states = self.self_attn(
+            positions=positions,
+            hidden_states=hidden_states,
+            kv_cache=kv_cache,
+            input_metadata=input_metadata,
+        )
+        hidden_states = residual + hidden_states
+
+        # Fully Connected
+        residual = hidden_states
+        hidden_states = self.post_attention_layernorm(hidden_states)
+        hidden_states = self.mlp(hidden_states)
+        hidden_states = residual + hidden_states
+
+        return hidden_states, residual
+
+
+class StableLMEpochModel(nn.Module):
+
+    def __init__(self,
+                 config: PretrainedConfig,
+                 linear_method: Optional[LinearMethodBase] = None) -> None:
+        super().__init__()
+        # self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
+        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
+                                                   config.hidden_size,
+                                                   linear_method=linear_method)
+        self.layers = nn.ModuleList([
+            StablelmDecoderLayer(config, linear_method)
+            for _ in range(config.num_hidden_layers)
+        ])
+        self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.embed_tokens(input_ids)
+        for i in range(len(self.layers)):
+            layer = self.layers[i]
+            # pylint: disable=unused-variable
+            hidden_states, residual = layer(
+                positions,
+                hidden_states,
+                kv_caches[i],
+                input_metadata,
+            )
+        hidden_states = self.norm(hidden_states)
+        return hidden_states
+
+
+class StablelmForCausalLM(nn.Module):
+
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ) -> None:
+        super().__init__()
+        self.config = config
+        self.linear_method = linear_method
+        self.model = StableLMEpochModel(config, linear_method)
+        self.lm_head = ParallelLMHead(config.vocab_size,
+                                      config.hidden_size,
+                                      linear_method=linear_method)
+        self.sampler = Sampler(config.vocab_size)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.model(input_ids, positions, kv_caches,
+                                   input_metadata)
+        return hidden_states
+
+    def sample(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(self.lm_head(hidden_states),
+                                   sampling_metadata)
+        return next_tokens
+
+    def load_weights(self,
+                     model_name_or_path: str,
+                     cache_dir: Optional[str] = None,
+                     load_format: str = "auto",
+                     revision: Optional[str] = None):
+        stacked_params_mapping = [
+            # (param_name, shard_name, shard_id)
+            ("qkv_proj", "q_proj", "q"),
+            ("qkv_proj", "k_proj", "k"),
+            ("qkv_proj", "v_proj", "v"),
+            ("gate_up_proj", "gate_proj", 0),
+            ("gate_up_proj", "up_proj", 1),
+        ]
+        if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
+        ):
+            stacked_params_mapping = []
+        params_dict = dict(self.named_parameters())
+        for name, loaded_weight in hf_model_weights_iterator(
+                model_name_or_path, cache_dir, load_format, revision,
+                self.config):
+            if "rotary_emb.inv_freq" in name:
+                continue
+            if ("rotary_emb.cos_cached" in name
+                    or "rotary_emb.sin_cached" in name):
+                # Models trained using ColossalAI may include these tensors in
+                # the checkpoint. Skip them.
+                continue
+            for (param_name, weight_name, shard_id) in stacked_params_mapping:
+                if weight_name not in name:
+                    continue
+                name = name.replace(weight_name, param_name)
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+                param = params_dict[name]
+                weight_loader = param.weight_loader
+                weight_loader(param, loaded_weight, shard_id)
+                break
+            else:
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+                param = params_dict[name]
+                weight_loader = getattr(param, "weight_loader",
+                                        default_weight_loader)
+                weight_loader(param, loaded_weight)

+ 37 - 24
aphrodite/transformers_utils/config.py

@@ -3,37 +3,44 @@ from typing import Optional
 import gguf
 from transformers import AutoConfig, PretrainedConfig
 from transformers.models.auto.configuration_auto import CONFIG_MAPPING
-from aphrodite.transformers_utils.configs import YiConfig, QWenConfig
+
+from aphrodite.transformers_utils.configs import (BaiChuanConfig,
+                                                  ChatGLMConfig, MPTConfig,
+                                                  QWenConfig, RWConfig)
 
 _CONFIG_REGISTRY = {
+    "baichuan": BaiChuanConfig,
+    "chatglm": ChatGLMConfig,
+    "mpt": MPTConfig,
     "qwen": QWenConfig,
-    "yi": YiConfig,
+    "RefinedWeb": RWConfig,  # For tiiuae/falcon-40b(-instruct)
+    "RefinedWebModel": RWConfig,  # For tiiuae/falcon-7b(-instruct)
 }
 
 
 def extract_gguf_config(checkpoint):
     result = gguf.GGUFReader(checkpoint)
-    architecture = result.fields["general.architecture"]
+    architecture = result.fields['general.architecture']
     architecture = str(bytes(architecture.parts[architecture.data[0]]),
-                       encoding="utf-8")
+                       encoding='utf-8')
     # Only support llama so far
     if architecture != "llama":
         raise RuntimeError(f"Unsupported architecture {architecture}")
 
     # write config
-    vocab_size = len(result.fields["tokenizer.ggml.token_type"].data)
-    context_length = int(result.fields["llama.context_length"].parts[-1])
-    n_layer = int(result.fields["llama.block_count"].parts[-1])
-    n_head = int(result.fields["llama.attention.head_count"].parts[-1])
+    vocab_size = len(result.fields['tokenizer.ggml.token_type'].data)
+    context_length = int(result.fields['llama.context_length'].parts[-1])
+    n_layer = int(result.fields['llama.block_count'].parts[-1])
+    n_head = int(result.fields['llama.attention.head_count'].parts[-1])
     n_local_heads = int(
-        result.fields["llama.attention.head_count_kv"].parts[-1])
+        result.fields['llama.attention.head_count_kv'].parts[-1])
     intermediate_size = int(
-        result.fields["llama.feed_forward_length"].parts[-1])
+        result.fields['llama.feed_forward_length'].parts[-1])
     norm_eps = float(
-        result.fields["llama.attention.layer_norm_rms_epsilon"].parts[-1])
-    dim = int(result.fields["llama.embedding_length"].parts[-1])
+        result.fields['llama.attention.layer_norm_rms_epsilon'].parts[-1])
+    dim = int(result.fields['llama.embedding_length'].parts[-1])
     arch = "MixtralForCausalLM"
-    if "llama.expert_count" in result.fields:
+    if 'llama.expert_count' in result.fields:
         arch = "MixtralForCausalLM"
         name = "mixtral"
     else:
@@ -55,14 +62,14 @@ def extract_gguf_config(checkpoint):
         "torch_dtype": "float16",
         "vocab_size": vocab_size
     }
-    if "llama.rope.freq_base" in result.fields:
-        model_config["rope_theta"] = float(
-            result.fields["llama.rope.freq_base"].parts[-1])
-    if "llama.expert_count" in result.fields:
-        model_config["num_local_experts"] = int(
-            result.fields["llama.expert_count"].parts[-1])
-        model_config["num_experts_per_tok"] = int(
-            result.fields["llama.expert_used_count"].parts[-1])
+    if 'llama.rope.freq_base' in result.fields:
+        model_config['rope_theta'] = float(
+            result.fields['llama.rope.freq_base'].parts[-1])
+    if 'llama.expert_count' in result.fields:
+        model_config['num_local_experts'] = int(
+            result.fields['llama.expert_count'].parts[-1])
+        model_config['num_experts_per_tok'] = int(
+            result.fields['llama.expert_used_count'].parts[-1])
     if name in _CONFIG_REGISTRY:
         config_class = _CONFIG_REGISTRY[name]
     else:
@@ -73,12 +80,16 @@ def extract_gguf_config(checkpoint):
 
 def get_config(model: str,
                trust_remote_code: bool,
-               revision: Optional[str] = None) -> PretrainedConfig:
+               revision: Optional[str] = None,
+               code_revision: Optional[str] = None) -> PretrainedConfig:
     if model.endswith("gguf"):
         return extract_gguf_config(model)
     try:
         config = AutoConfig.from_pretrained(
-            model, trust_remote_code=trust_remote_code, revision=revision)
+            model,
+            trust_remote_code=trust_remote_code,
+            revision=revision,
+            code_revision=code_revision)
     except ValueError as e:
         if (not trust_remote_code and
                 "requires you to execute the configuration file" in str(e)):
@@ -92,5 +103,7 @@ def get_config(model: str,
             raise e
     if config.model_type in _CONFIG_REGISTRY:
         config_class = _CONFIG_REGISTRY[config.model_type]
-        config = config_class.from_pretrained(model, revision=revision)
+        config = config_class.from_pretrained(model,
+                                              revision=revision,
+                                              code_revision=code_revision)
     return config

+ 13 - 2
aphrodite/transformers_utils/configs/__init__.py

@@ -1,7 +1,18 @@
+from aphrodite.transformers_utils.configs.baichuan import BaiChuanConfig
+from aphrodite.transformers_utils.configs.chatglm import ChatGLMConfig
+from aphrodite.transformers_utils.configs.mpt import MPTConfig
+from aphrodite.transformers_utils.configs.olmo import OLMoConfig
 from aphrodite.transformers_utils.configs.qwen import QWenConfig
-from aphrodite.transformers_utils.configs.yi import YiConfig
+# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
+# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
+# `FalconConfig` class from the official HuggingFace transformers library.
+from aphrodite.transformers_utils.configs.falcon import RWConfig
 
 __all__ = [
+    "BaiChuanConfig",
+    "ChatGLMConfig",
+    "MPTConfig",
+    "OLMoConfig",
     "QWenConfig",
-    "YiConfig",
+    "RWConfig",
 ]

+ 62 - 0
aphrodite/transformers_utils/configs/baichuan.py

@@ -0,0 +1,62 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from transformers.configuration_utils import PretrainedConfig
+
+
+class BaiChuanConfig(PretrainedConfig):
+    model_type = "baichuan"
+    keys_to_ignore_at_inference = ["past_key_values"]
+
+    def __init__(
+        self,
+        vocab_size=64000,
+        hidden_size=4096,
+        intermediate_size=11008,
+        num_hidden_layers=32,
+        num_attention_heads=32,
+        hidden_act="silu",
+        max_position_embeddings=4096,
+        initializer_range=0.02,
+        rms_norm_eps=1e-6,
+        use_cache=True,
+        pad_token_id=0,
+        bos_token_id=1,
+        eos_token_id=2,
+        tie_word_embeddings=False,
+        **kwargs,
+    ):
+        self.vocab_size = vocab_size
+        self.max_position_embeddings = max_position_embeddings
+        self.hidden_size = hidden_size
+        self.intermediate_size = intermediate_size
+        self.num_hidden_layers = num_hidden_layers
+        self.num_attention_heads = num_attention_heads
+        self.hidden_act = hidden_act
+        self.initializer_range = initializer_range
+        self.rms_norm_eps = rms_norm_eps
+        self.use_cache = use_cache
+        super().__init__(
+            pad_token_id=pad_token_id,
+            bos_token_id=bos_token_id,
+            eos_token_id=eos_token_id,
+            tie_word_embeddings=tie_word_embeddings,
+            **kwargs,
+        )

+ 68 - 0
aphrodite/transformers_utils/configs/chatglm.py

@@ -0,0 +1,68 @@
+# coding=utf-8
+# Adapted from
+# https://github.com/THUDM/ChatGLM2-6B
+from transformers import PretrainedConfig
+
+
+class ChatGLMConfig(PretrainedConfig):
+    model_type = "chatglm"
+    attribute_map = {
+        "num_hidden_layers": "num_layers",
+        "n_head_kv": "multi_query_group_num",
+    }
+
+    def __init__(self,
+                 num_layers=28,
+                 padded_vocab_size=65024,
+                 hidden_size=4096,
+                 ffn_hidden_size=13696,
+                 kv_channels=128,
+                 num_attention_heads=32,
+                 seq_length=2048,
+                 hidden_dropout=0.0,
+                 attention_dropout=0.0,
+                 layernorm_epsilon=1e-5,
+                 rmsnorm=True,
+                 apply_residual_connection_post_layernorm=False,
+                 post_layer_norm=True,
+                 add_bias_linear=False,
+                 add_qkv_bias=False,
+                 interleaved_qkv=False,
+                 bias_dropout_fusion=True,
+                 multi_query_attention=False,
+                 multi_query_group_num=1,
+                 apply_query_key_layer_scaling=True,
+                 attention_softmax_in_fp32=True,
+                 fp32_residual_connection=False,
+                 quantization_bit=0,
+                 pre_seq_len=None,
+                 prefix_projection=False,
+                 **kwargs):
+        self.num_layers = num_layers
+        self.vocab_size = padded_vocab_size
+        self.padded_vocab_size = padded_vocab_size
+        self.hidden_size = hidden_size
+        self.ffn_hidden_size = ffn_hidden_size
+        self.kv_channels = kv_channels
+        self.num_attention_heads = num_attention_heads
+        self.seq_length = seq_length
+        self.hidden_dropout = hidden_dropout
+        self.attention_dropout = attention_dropout
+        self.layernorm_epsilon = layernorm_epsilon
+        self.rmsnorm = rmsnorm
+        self.apply_residual_connection_post_layernorm = (
+            apply_residual_connection_post_layernorm)
+        self.post_layer_norm = post_layer_norm
+        self.add_bias_linear = add_bias_linear
+        self.add_qkv_bias = add_qkv_bias
+        self.bias_dropout_fusion = bias_dropout_fusion
+        self.multi_query_attention = multi_query_attention
+        self.multi_query_group_num = multi_query_group_num
+        self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
+        self.attention_softmax_in_fp32 = attention_softmax_in_fp32
+        self.fp32_residual_connection = fp32_residual_connection
+        self.quantization_bit = quantization_bit
+        self.pre_seq_len = pre_seq_len
+        self.prefix_projection = prefix_projection
+        self.interleaved_qkv = interleaved_qkv
+        super().__init__(**kwargs)

+ 88 - 0
aphrodite/transformers_utils/configs/falcon.py

@@ -0,0 +1,88 @@
+# Adapted from
+# https://huggingface.co/tiiuae/falcon-7b/blob/main/configuration_RW.py
+# Copyright 2023 The PygmalionAI team.
+# Copyright 2023 The vLLM team.
+# Copyright 2022 the Big Science Workshop and HuggingFace Inc. team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Falcon configuration"""
+from transformers.configuration_utils import PretrainedConfig
+
+
+class RWConfig(PretrainedConfig):
+    model_type = "falcon"
+    keys_to_ignore_at_inference = ["past_key_values"]
+    attribute_map = {
+        "num_hidden_layers": "n_layer",
+        "num_attention_heads": "n_head",
+        "num_kv_heads": "n_head_kv",
+    }
+
+    def __init__(
+        self,
+        vocab_size=250880,
+        hidden_size=64,
+        n_layer=2,
+        n_head=8,
+        layer_norm_epsilon=1e-5,
+        initializer_range=0.02,
+        use_cache=True,
+        bos_token_id=1,
+        eos_token_id=2,
+        hidden_dropout=0.0,
+        attention_dropout=0.0,
+        multi_query=True,
+        n_head_kv=None,
+        alibi=False,
+        bias=False,
+        parallel_attn=False,
+        new_decoder_architecture=False,
+        **kwargs,
+    ) -> None:
+        self.vocab_size = vocab_size
+        # Backward compatibility with n_embed kwarg
+        n_embed = kwargs.pop("n_embed", None)
+        self.hidden_size = hidden_size if n_embed is None else n_embed
+        self.n_layer = n_layer
+        self.n_head = n_head
+        self.layer_norm_epsilon = layer_norm_epsilon
+        self.initializer_range = initializer_range
+        self.use_cache = use_cache
+        self.hidden_dropout = hidden_dropout
+        self.attention_dropout = attention_dropout
+
+        self.bos_token_id = bos_token_id
+        self.eos_token_id = eos_token_id
+        self.multi_query = multi_query
+        self.n_head_kv = 1 if n_head_kv is None else n_head_kv
+        self.alibi = alibi
+        self.bias = bias
+        self.parallel_attn = parallel_attn
+        self.new_decoder_architecture = new_decoder_architecture
+
+        if self.hidden_size == 8192:
+            # Hack for falcon-40b
+            self.new_decoder_architecture = True
+
+        super().__init__(bos_token_id=bos_token_id,
+                         eos_token_id=eos_token_id,
+                         **kwargs)
+
+    @property
+    def head_dim(self):
+        return self.hidden_size // self.n_head
+
+    @property
+    def rotary(self):
+        return not self.alibi

+ 233 - 0
aphrodite/transformers_utils/configs/mpt.py

@@ -0,0 +1,233 @@
+# coding=utf-8
+# Copied from
+# https://huggingface.co/mosaicml/mpt-7b/blob/main/configuration_mpt.py
+"""A HuggingFace-style model configuration."""
+import warnings
+from typing import Any, Dict, Optional, Union
+
+from transformers import PretrainedConfig
+
+attn_config_defaults: Dict = {
+    'attn_type': 'multihead_attention',
+    'attn_pdrop': 0.0,
+    'attn_impl': 'triton',
+    'qk_ln': False,
+    'clip_qkv': None,
+    'softmax_scale': None,
+    'prefix_lm': False,
+    'attn_uses_sequence_id': False,
+    'alibi': False,
+    'alibi_bias_max': 8
+}
+ffn_config_defaults: Dict = {'ffn_type': 'mptmlp'}
+init_config_defaults: Dict = {
+    'name': 'kaiming_normal_',
+    'fan_mode': 'fan_in',
+    'init_nonlinearity': 'relu',
+    'init_div_is_residual': True,
+    'emb_init_std': None,
+    'emb_init_uniform_lim': None,
+    'init_std': None,
+    'init_gain': 0.0
+}
+
+
+class MPTConfig(PretrainedConfig):
+    model_type = 'mpt'
+    attribute_map = {
+        'num_attention_heads': 'n_heads',
+        'hidden_size': 'd_model',
+        'num_hidden_layers': 'n_layers',
+    }
+
+    # pylint: disable=dangerous-default-value
+    def __init__(self,
+                 d_model: int = 2048,
+                 n_heads: int = 16,
+                 n_layers: int = 24,
+                 expansion_ratio: int = 4,
+                 max_seq_len: int = 2048,
+                 vocab_size: int = 50368,
+                 resid_pdrop: float = 0.0,
+                 emb_pdrop: float = 0.0,
+                 learned_pos_emb: bool = True,
+                 attn_config: Dict = attn_config_defaults,
+                 ffn_config: Dict = ffn_config_defaults,
+                 init_device: str = 'cpu',
+                 logit_scale: Optional[Union[float, str]] = None,
+                 no_bias: bool = False,
+                 embedding_fraction: float = 1.0,
+                 norm_type: str = 'low_precision_layernorm',
+                 use_cache: bool = False,
+                 init_config: Dict = init_config_defaults,
+                 fc_type: str = 'torch',
+                 verbose: Optional[int] = None,
+                 **kwargs: Any):
+        """The MPT configuration class.
+        Args:
+            d_model (int): The size of the embedding dimension of the model.
+            n_heads (int): The number of attention heads.
+            n_layers (int): The number of layers in the model.
+            expansion_ratio (int): The ratio of the up/down scale in the ffn.
+            max_seq_len (int): The maximum sequence length of the model.
+            vocab_size (int): The size of the vocabulary.
+            resid_pdrop (float): The dropout probability applied to the attention output before combining with residual.
+            emb_pdrop (float): The dropout probability for the embedding layer.
+            learned_pos_emb (bool): Whether to use learned positional embeddings
+            attn_config (Dict): A dictionary used to configure the model's attention module:
+                attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention, grouped_query_attention
+                attn_pdrop (float): The dropout probability for the attention layers.
+                attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'.
+                qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
+                clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to
+                    this value.
+                softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,
+                    use the default scale of ``1/sqrt(d_keys)``.
+                prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an
+                    extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix
+                    can attend to one another bi-directionally. Tokens outside the prefix use causal attention.
+                attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id.
+                    When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates
+                    which sub-sequence each token belongs to.
+                    Defaults to ``False`` meaning any provided `sequence_id` will be ignored.
+                alibi (bool): Whether to use the alibi bias instead of position embeddings.
+                alibi_bias_max (int): The maximum value of the alibi bias.
+                kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads.
+            ffn_config (Dict): A dictionary used to configure the model's ffn module:
+                ffn_type (str): type of ffn to use. Options: mptmlp, te_ln_mlp
+            init_device (str): The device to use for parameter initialization.
+            logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.
+            no_bias (bool): Whether to use bias in all layers.
+            verbose (int): The verbosity level. 0 is silent.
+            embedding_fraction (float): The fraction to scale the gradients of the embedding layer by.
+            norm_type (str): choose type of norm to use
+            use_cache (bool): Whether or not the model should return the last key/values attentions
+            init_config (Dict): A dictionary used to configure the model initialization:
+                init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_',
+                    'kaiming_uniform_', 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or
+                    'xavier_normal_'. These mimic the parameter initialization methods in PyTorch.
+                init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True.
+                emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer.
+                emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution
+                    used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``.
+                init_std (float): The standard deviation of the normal distribution used to initialize the model,
+                    if using the baseline_ parameter initialization scheme.
+                init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes.
+                fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes.
+                init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes.
+                ---
+                See llmfoundry.models.utils.param_init_fns.py for info on other param init config options
+            fc_type (str): choose fc layer implementation. Options: torch and te. te layers support fp8 when using H100 GPUs.
+        """
+        self.d_model = d_model
+        self.n_heads = n_heads
+        self.n_layers = n_layers
+        self.expansion_ratio = expansion_ratio
+        self.max_seq_len = max_seq_len
+        self.vocab_size = vocab_size
+        self.resid_pdrop = resid_pdrop
+        self.emb_pdrop = emb_pdrop
+        self.learned_pos_emb = learned_pos_emb
+        self.attn_config = attn_config
+        self.ffn_config = ffn_config
+        self.init_device = init_device
+        self.logit_scale = logit_scale
+        self.no_bias = no_bias
+        self.embedding_fraction = embedding_fraction
+        self.norm_type = norm_type
+        self.use_cache = use_cache
+        self.init_config = init_config
+        self.fc_type = fc_type
+        if verbose is not None:
+            warnings.warn(DeprecationWarning(
+                'verbose argument for MPTConfig is now ignored and will be removed. Use python_log_level instead.'
+            ),
+                          stacklevel=2)
+        if 'name' in kwargs:
+            del kwargs['name']
+        if 'loss_fn' in kwargs:
+            del kwargs['loss_fn']
+        if self.attn_config.get('alibi', False):
+            self.learned_pos_emb = False
+            warnings.warn(
+                f'alibi is turned on, setting `learned_pos_emb` to {self.learned_pos_emb}`',
+                stacklevel=2)
+        super().__init__(**kwargs)
+        self._validate_config()
+
+    def _set_config_defaults(
+            self, config: Dict[str, Any],
+            config_defaults: Dict[str, Any]) -> Dict[str, Any]:
+        for (k, v) in config_defaults.items():
+            if k not in config:
+                config[k] = v
+        return config
+
+    def _validate_config(self) -> None:
+        self.attn_config = self._set_config_defaults(self.attn_config,
+                                                     attn_config_defaults)
+        self.ffn_config = self._set_config_defaults(self.ffn_config,
+                                                    ffn_config_defaults)
+        self.init_config = self._set_config_defaults(self.init_config,
+                                                     init_config_defaults)
+        if self.d_model % self.n_heads != 0:
+            raise ValueError('d_model must be divisible by n_heads')
+        if any((
+                prob < 0 or prob > 1 for prob in
+            [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop]
+        )):
+            raise ValueError(
+                "self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1"  # pylint: disable=line-too-long
+            )
+        if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']:
+            raise ValueError(
+                f"Unknown attn_impl={self.attn_config['attn_impl']}")
+        if self.attn_config['prefix_lm'] and self.attn_config[
+                'attn_impl'] not in ['torch', 'triton']:
+            raise NotImplementedError(
+                'prefix_lm only implemented with torch and triton attention.')
+        if self.attn_config['alibi'] and self.attn_config['attn_impl'] not in [
+                'torch', 'triton'
+        ]:
+            raise NotImplementedError(
+                'alibi only implemented with torch and triton attention.')
+        if self.attn_config['attn_uses_sequence_id'] and self.attn_config[
+                'attn_impl'] not in ['torch', 'triton']:
+            raise NotImplementedError(
+                'attn_uses_sequence_id only implemented with torch and triton attention.'  # pylint: disable=line-too-long
+            )
+        if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
+            raise ValueError(
+                'model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!'  # pylint: disable=line-too-long
+            )
+        if isinstance(self.logit_scale,
+                      str) and self.logit_scale != 'inv_sqrt_d_model':
+            raise ValueError(
+                f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."  # pylint: disable=line-too-long
+            )
+        if self.init_config.get('name', None) is None:
+            raise ValueError(
+                f"self.init_config={self.init_config!r} 'name' needs to be set."
+            )
+        if not self.learned_pos_emb and (not self.attn_config['alibi']):
+            warnings.warn(
+                'Positional information not being provided to the model.',
+                stacklevel=2)
+        if self.fc_type == 'te' or self.ffn_config['ffn_type'] == 'te_ln_mlp':
+            try:
+                # pylint: disable=import-outside-toplevel
+                import transformer_engine.pytorch as te
+                del te
+            except Exception as exc:
+                raise ImportError(
+                    # pylint: disable=line-too-long
+                    'TransformerEngine import fail. `fc_type: te` requires TransformerEngine be installed. '
+                    +
+                    'The required version of transformer_engine also requires FlashAttention v1.0.6 is installed:\n'
+                    + 'pip install flash-attn==1.0.6 --no-build-isolation \n' +
+                    'pip install git+https://github.com/NVIDIA/TransformerEngine.git@144e4888b2cdd60bd52e706d5b7a79cb9c1a7156'
+                ) from exc
+        if self.ffn_config['ffn_type'] == 'mptmlp':
+            self.ffn_config['fc_type'] = self.fc_type
+        elif self.ffn_config['ffn_type'] == 'te_ln_mlp':
+            self.ffn_config['bias'] = not self.no_bias

+ 72 - 0
aphrodite/transformers_utils/configs/olmo.py

@@ -0,0 +1,72 @@
+# coding=utf-8
+# adapted from https://github.com/allenai/OLMo/blob/v0.2.4/hf_olmo/configuration_olmo.py
+"""OLMo configuration"""
+from transformers import PretrainedConfig
+
+
+class OLMoConfig(PretrainedConfig):
+    model_type = 'olmo'
+    attribute_map = {
+        'num_attention_heads': 'n_heads',
+        'hidden_size': 'd_model',
+        'num_hidden_layers': 'n_layers',
+    }
+
+    # Note that the defaults for these attributes are equivalent to the base GPT2 model.
+    def __init__(
+        self,
+        d_model=768,
+        n_heads=12,
+        n_layers=12,
+        mlp_ratio=4,
+        mlp_hidden_size=None,
+        activation_type="swiglu",
+        block_type="sequential",
+        block_group_size=1,
+        alibi=False,
+        alibi_bias_max=8.0,
+        rope=False,
+        rope_full_precision=True,
+        multi_query_attention=False,
+        attention_layer_norm=False,
+        layer_norm_type="default",
+        layer_norm_with_affine=True,
+        attention_layer_norm_with_affine=True,
+        max_sequence_length=1024,
+        include_bias=True,
+        bias_for_layer_norm=None,
+        scale_logits=False,
+        vocab_size=50257,
+        embedding_size=50304,
+        weight_tying=True,
+        eos_token_id=50256,
+        pad_token_id=50256,
+        **kwargs,
+    ):
+        self.d_model = d_model
+        self.n_heads = n_heads
+        self.n_layers = n_layers
+        self.mlp_ratio = mlp_ratio
+        self.mlp_hidden_size = mlp_hidden_size
+        self.activation_type = activation_type
+        self.block_type = block_type
+        self.block_group_size = block_group_size
+        self.alibi = alibi
+        self.alibi_bias_max = alibi_bias_max
+        self.rope = rope
+        self.rope_full_precision = rope_full_precision
+        self.multi_query_attention = multi_query_attention
+        self.attention_layer_norm = attention_layer_norm
+        self.layer_norm_type = layer_norm_type
+        self.layer_norm_with_affine = layer_norm_with_affine
+        self.attention_layer_norm_with_affine = attention_layer_norm_with_affine
+        self.max_sequence_length = max_sequence_length
+        self.include_bias = include_bias
+        self.bias_for_layer_norm = bias_for_layer_norm
+        self.scale_logits = scale_logits
+        self.vocab_size = vocab_size
+        self.embedding_size = embedding_size
+        self.weight_tying = weight_tying
+        self.eos_token_id = eos_token_id
+        self.pad_token_id = pad_token_id
+        super().__init__(**kwargs)

+ 49 - 52
aphrodite/transformers_utils/tokenizer.py

@@ -1,6 +1,6 @@
 import os
 import tempfile
-from typing import List, Tuple, Union, Optional
+from typing import List, Optional, Tuple, Union
 
 import gguf
 from transformers import (AutoTokenizer, PreTrainedTokenizer,
@@ -10,6 +10,7 @@ from transformers.convert_slow_tokenizer import import_protobuf
 from aphrodite.common.logger import init_logger
 from aphrodite.lora.request import LoRARequest
 from aphrodite.common.utils import make_async, LRUCache
+from aphrodite.transformers_utils.tokenizers import BaichuanTokenizer
 
 logger = init_logger(__name__)
 
@@ -19,90 +20,75 @@ def convert_gguf_to_tokenizer(checkpoint):
     # write vocab
     sentencepiece_model_pb2 = import_protobuf()
     vocab = sentencepiece_model_pb2.ModelProto()
-    vocab_size = len(result.fields["tokenizer.ggml.token_type"].data)
+    vocab_size = len(result.fields['tokenizer.ggml.token_type'].data)
     vocab.trainer_spec.model_type = 2  # BPE
     vocab.trainer_spec.vocab_size = vocab_size
     vocab.trainer_spec.byte_fallback = True
     vocab.normalizer_spec.remove_extra_whitespaces = False
-    tokens = result.fields["tokenizer.ggml.tokens"]
-    scores = result.fields["tokenizer.ggml.scores"]
-    types = result.fields["tokenizer.ggml.token_type"]
+    tokens = result.fields['tokenizer.ggml.tokens']
+    scores = result.fields['tokenizer.ggml.scores']
+    types = result.fields['tokenizer.ggml.token_type']
     for i in range(vocab_size):
         new_token = vocab.SentencePiece()
         new_token.piece = str(bytes(tokens.parts[tokens.data[i]]),
-                              encoding="utf-8")
+                              encoding='utf-8')
         new_token.score = scores.parts[scores.data[i]]
         # llama.cpp tokentype is the same with sentencepiece token type
         new_token.type = int(types.parts[types.data[i]])
         vocab.pieces.append(new_token)
-    with tempfile.NamedTemporaryFile(mode="wb", delete=False) as temp_file:
+    with tempfile.NamedTemporaryFile(mode='wb', delete=False) as temp_file:
         temp_file.write(vocab.SerializeToString())
         temp_file_filename = temp_file.name
     tokenizer_args = {"vocab_file": temp_file_filename}
 
-    if "tokenizer.ggml.bos_token_id" in result.fields:
+    if 'tokenizer.ggml.bos_token_id' in result.fields:
         tokenizer_args["bos_token"] = vocab.pieces[int(
-            result.fields["tokenizer.ggml.bos_token_id"].parts[-1])].piece
-    if "tokenizer.ggml.eos_token_id" in result.fields:
+            result.fields['tokenizer.ggml.bos_token_id'].parts[-1])].piece
+    if 'tokenizer.ggml.eos_token_id' in result.fields:
         tokenizer_args["eos_token"] = vocab.pieces[int(
-            result.fields["tokenizer.ggml.eos_token_id"].parts[-1])].piece
-    if "tokenizer.ggml.padding_token_id" in result.fields:
+            result.fields['tokenizer.ggml.eos_token_id'].parts[-1])].piece
+    if 'tokenizer.ggml.padding_token_id' in result.fields:
         tokenizer_args["pad_token"] = vocab.pieces[int(
-            result.fields["tokenizer.ggml.padding_token_id"].parts[-1])].piece
-    if "tokenizer.ggml.unknown_token_id" in result.fields:
+            result.fields['tokenizer.ggml.padding_token_id'].parts[-1])].piece
+    if 'tokenizer.ggml.unknown_token_id' in result.fields:
         tokenizer_args["unk_token"] = vocab.pieces[int(
-            result.fields["tokenizer.ggml.unknown_token_id"].parts[-1])].piece
-    if "tokenizer.ggml.add_bos_token" in result.fields:
+            result.fields['tokenizer.ggml.unknown_token_id'].parts[-1])].piece
+    if 'tokenizer.ggml.add_bos_token' in result.fields:
         tokenizer_args["add_bos_token"] = bool(
-            result.fields["tokenizer.ggml.add_bos_token"].parts[-1])
-    if "tokenizer.ggml.add_eos_token" in result.fields:
+            result.fields['tokenizer.ggml.add_bos_token'].parts[-1])
+    if 'tokenizer.ggml.add_eos_token' in result.fields:
         tokenizer_args["add_eos_token"] = bool(
-            result.fields["tokenizer.ggml.add_eos_token"].parts[-1])
-    tokenizer = LlamaTokenizer(**tokenizer_args, legacy=False)
+            result.fields['tokenizer.ggml.add_eos_token'].parts[-1])
+    tokenizer = LlamaTokenizer(**tokenizer_args)
     os.unlink(temp_file_filename)
     return tokenizer
 
 
-# A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file.
-_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
-
-
 def get_tokenizer(
     tokenizer_name: str,
     *args,
     tokenizer_mode: str = "auto",
     trust_remote_code: bool = False,
+    tokenizer_revision: Optional[str] = None,
     **kwargs,
 ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
     """Gets a tokenizer for the given model name via Huggingface."""
     if tokenizer_name.endswith("gguf"):
         return convert_gguf_to_tokenizer(tokenizer_name)
+
     if tokenizer_mode == "slow":
         if kwargs.get("use_fast", False):
             raise ValueError(
                 "Cannot use the fast tokenizer in slow tokenizer mode.")
         kwargs["use_fast"] = False
 
-    if ("llama" in tokenizer_name.lower() and kwargs.get("use_fast", True)
-            and tokenizer_name != _FAST_LLAMA_TOKENIZER):
-        logger.info(
-            "For some LLaMA V1 models, initializing the fast tokenizer may "
-            "take a long time. To reduce the initialization time, consider "
-            f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
-            "tokenizer.")
     try:
         tokenizer = AutoTokenizer.from_pretrained(
             tokenizer_name,
             *args,
             trust_remote_code=trust_remote_code,
+            tokenizer_revision=tokenizer_revision,
             **kwargs)
-    except TypeError as e:
-        # The LLaMA tokenizer causes a protobuf error in some environments.
-        err_msg = (
-            "Failed to load the tokenizer. If you are using a LLaMA V1 model "
-            f"consider using '{_FAST_LLAMA_TOKENIZER}' instead of the "
-            "original tokenizer.")
-        raise RuntimeError(err_msg) from e
     except ValueError as e:
         # If the error pertains to the tokenizer class not existing or not
         # currently being imported, suggest using the --trust-remote-code flag.
@@ -117,6 +103,18 @@ def get_tokenizer(
             raise RuntimeError(err_msg) from e
         else:
             raise e
+    except AttributeError as e:
+        if "BaichuanTokenizer" in str(e):
+            # This is for the error "'BaichuanTokenizer' object has no
+            # attribute 'sp_model'".
+            tokenizer = BaichuanTokenizer.from_pretrained(
+                tokenizer_name,
+                *args,
+                trust_remote_code=trust_remote_code,
+                tokenizer_revision=tokenizer_revision,
+                **kwargs)
+        else:
+            raise e
 
     if not isinstance(tokenizer, PreTrainedTokenizerFast):
         logger.warning(
@@ -161,21 +159,18 @@ class TokenizerGroup:
         else:
             self.lora_tokenizers = None
 
-    def encode(
-        self,
-        prompt: str,
-        request_id: Optional[str] = None,  # pylint: disable=unused-argument
-        lora_request: Optional[LoRARequest] = None
-    ) -> List[int]:
+    def encode(self,
+               prompt: str,
+               request_id: Optional[str] = None,
+               lora_request: Optional[LoRARequest] = None) -> List[int]:
         tokenizer = self.get_lora_tokenizer(lora_request)
         return tokenizer.encode(prompt)
 
     async def encode_async(
-        self,
-        prompt: str,
-        request_id: Optional[str] = None,  # pylint: disable=unused-argument
-        lora_request: Optional[LoRARequest] = None
-    ) -> List[int]:
+            self,
+            prompt: str,
+            request_id: Optional[str] = None,
+            lora_request: Optional[LoRARequest] = None) -> List[int]:
         tokenizer = await self.get_lora_tokenizer_async(lora_request)
         return tokenizer.encode(prompt)
 
@@ -262,7 +257,7 @@ def detokenize_incrementally(
         # tokenizers (bigger = more conservative).
         # Subtract 1 extra to account for the generated token.
         prefix_offset = max(len(output_tokens) - 6, 0)
-        # If the first new token is a special token we can't skip 1 extra token
+        # If the first new token is a special token, we can't skip 1 extra token
         if skip_special_tokens and new_token_id in tokenizer.all_special_ids:
             read_offset = max(len(output_tokens), 0)
         else:
@@ -286,12 +281,14 @@ def detokenize_incrementally(
             tokenizer,
             output_tokens[prefix_offset:read_offset],
             skip_special_tokens=skip_special_tokens,
-            spaces_between_special_tokens=spaces_between_special_tokens)
+            spaces_between_special_tokens=spaces_between_special_tokens,
+        )
         new_text = _convert_tokens_to_string_with_added_encoders(
             tokenizer,
             output_tokens[prefix_offset:],
             skip_special_tokens=skip_special_tokens,
-            spaces_between_special_tokens=spaces_between_special_tokens)
+            spaces_between_special_tokens=spaces_between_special_tokens,
+        )
 
     if len(new_text) > len(prefix_text) and not new_text.endswith("�"):
         # utf-8 char at the end means it's a potential unfinished byte sequence

+ 5 - 0
aphrodite/transformers_utils/tokenizers/__init__.py

@@ -0,0 +1,5 @@
+from aphrodite.transformers_utils.tokenizers.baichuan import BaichuanTokenizer
+
+__all__ = [
+    "BaichuanTokenizer",
+]

+ 261 - 0
aphrodite/transformers_utils/tokenizers/baichuan.py

@@ -0,0 +1,261 @@
+# yapf: disable
+# Adapted from
+# https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/8f6e343d545c503b91429582231d1d354dac2740/tokenization_baichuan.py
+# Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved.
+
+import os
+from shutil import copyfile
+from typing import Any, Dict, List, Optional, Tuple
+
+import sentencepiece as spm
+from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
+from transformers.utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
+
+PRETRAINED_VOCAB_FILES_MAP = {
+    "vocab_file": {},
+    "tokenizer_file": {},
+}
+PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {}
+
+
+class BaichuanTokenizer(PreTrainedTokenizer):
+    """
+    Construct a Baichuan tokenizer. Based on byte-level Byte-Pair-Encoding.
+
+    Args:
+        vocab_file (`str`):
+            Path to the vocabulary file.
+    """
+
+    vocab_files_names = VOCAB_FILES_NAMES
+    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
+    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
+    model_input_names = ["input_ids", "attention_mask"]
+
+    def __init__(
+        self,
+        vocab_file,
+        unk_token="<unk>",
+        bos_token="<s>",
+        eos_token="</s>",
+        pad_token=None,
+        sp_model_kwargs: Optional[Dict[str, Any]] = None,
+        add_bos_token=True,
+        add_eos_token=False,
+        clean_up_tokenization_spaces=False,
+        **kwargs,
+    ):
+        self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
+        bos_token = (
+            AddedToken(bos_token, lstrip=False, rstrip=False)
+            if isinstance(bos_token, str)
+            else bos_token
+        )
+        eos_token = (
+            AddedToken(eos_token, lstrip=False, rstrip=False)
+            if isinstance(eos_token, str)
+            else eos_token
+        )
+        unk_token = (
+            AddedToken(unk_token, lstrip=False, rstrip=False)
+            if isinstance(unk_token, str)
+            else unk_token
+        )
+        pad_token = (
+            AddedToken(pad_token, lstrip=False, rstrip=False)
+            if isinstance(pad_token, str)
+            else pad_token
+        )
+        self.vocab_file = vocab_file
+        self.add_bos_token = add_bos_token
+        self.add_eos_token = add_eos_token
+        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+        self.sp_model.Load(vocab_file)
+        super().__init__(
+            bos_token=bos_token,
+            eos_token=eos_token,
+            unk_token=unk_token,
+            pad_token=pad_token,
+            add_bos_token=add_bos_token,
+            add_eos_token=add_eos_token,
+            sp_model_kwargs=self.sp_model_kwargs,
+            clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+            **kwargs,
+        )
+
+    def __getstate__(self):
+        state = self.__dict__.copy()
+        state["sp_model"] = None
+        return state
+
+    def __setstate__(self, d):
+        self.__dict__ = d
+        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+        self.sp_model.Load(self.vocab_file)
+
+    @property
+    def vocab_size(self):
+        """Returns vocab size"""
+        return self.sp_model.get_piece_size()
+
+    def get_vocab(self):
+        """Returns vocab as a dict"""
+        vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
+        vocab.update(self.added_tokens_encoder)
+        return vocab
+
+    def _tokenize(self, text):
+        """Returns a tokenized string."""
+        return self.sp_model.encode(text, out_type=str)
+
+    def _convert_token_to_id(self, token):
+        """Converts a token (str) in an id using the vocab."""
+        return self.sp_model.piece_to_id(token)
+
+    def _convert_id_to_token(self, index):
+        """Converts an index (integer) in a token (str) using the vocab."""
+        token = self.sp_model.IdToPiece(index)
+        return token
+
+    def convert_tokens_to_string(self, tokens):
+        """Converts a sequence of tokens (string) in a single string."""
+        current_sub_tokens = []
+        out_string = ""
+        prev_is_special = False
+        for i, token in enumerate(tokens):
+            # make sure that special tokens are not decoded using sentencepiece model
+            if token in self.all_special_tokens:
+                if not prev_is_special and i != 0:
+                    out_string += " "
+                out_string += self.sp_model.decode(current_sub_tokens) + token
+                prev_is_special = True
+                current_sub_tokens = []
+            else:
+                current_sub_tokens.append(token)
+                prev_is_special = False
+        out_string += self.sp_model.decode(current_sub_tokens)
+        return out_string
+
+    def save_vocabulary(
+        self, save_directory, filename_prefix: Optional[str] = None
+    ) -> Tuple[str]:
+        """
+        Save the vocabulary and special tokens file to a directory.
+
+        Args:
+            save_directory (`str`):
+                The directory in which to save the vocabulary.
+
+        Returns:
+            `Tuple(str)`: Paths to the files saved.
+        """
+        if not os.path.isdir(save_directory):
+            logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+            return
+        out_vocab_file = os.path.join(
+            save_directory,
+            (filename_prefix + "-" if filename_prefix else "")
+            + VOCAB_FILES_NAMES["vocab_file"],
+        )
+
+        if os.path.abspath(self.vocab_file) != os.path.abspath(
+            out_vocab_file
+        ) and os.path.isfile(self.vocab_file):
+            copyfile(self.vocab_file, out_vocab_file)
+        elif not os.path.isfile(self.vocab_file):
+            with open(out_vocab_file, "wb") as fi:
+                content_spiece_model = self.sp_model.serialized_model_proto()
+                fi.write(content_spiece_model)
+
+        return (out_vocab_file,)
+
+    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
+        bos_token_id = [self.bos_token_id] if self.add_bos_token else []
+        eos_token_id = [self.eos_token_id] if self.add_eos_token else []
+
+        output = bos_token_id + token_ids_0 + eos_token_id
+
+        if token_ids_1 is not None:
+            output = output + bos_token_id + token_ids_1 + eos_token_id
+
+        return output
+
+    def get_special_tokens_mask(
+        self,
+        token_ids_0: List[int],
+        token_ids_1: Optional[List[int]] = None,
+        already_has_special_tokens: bool = False,
+    ) -> List[int]:
+        """
+        Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+        special tokens using the tokenizer `prepare_for_model` method.
+
+        Args:
+            token_ids_0 (`List[int]`):
+                List of IDs.
+            token_ids_1 (`List[int]`, *optional*):
+                Optional second list of IDs for sequence pairs.
+            already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+                Whether or not the token list is already formatted with special tokens for the model.
+
+        Returns:
+            `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+        """
+        if already_has_special_tokens:
+            return super().get_special_tokens_mask(
+                token_ids_0=token_ids_0,
+                token_ids_1=token_ids_1,
+                already_has_special_tokens=True,
+            )
+
+        bos_token_id = [1] if self.add_bos_token else []
+        eos_token_id = [1] if self.add_eos_token else []
+
+        if token_ids_1 is None:
+            return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
+        return (
+            bos_token_id
+            + ([0] * len(token_ids_0))
+            + eos_token_id
+            + bos_token_id
+            + ([0] * len(token_ids_1))
+            + eos_token_id
+        )
+
+    def create_token_type_ids_from_sequences(
+        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+    ) -> List[int]:
+        """
+        Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
+        sequence pair mask has the following format:
+
+        ```
+        0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
+        | first sequence    | second sequence |
+        ```
+
+        if token_ids_1 is None, only returns the first portion of the mask (0s).
+
+        Args:
+            token_ids_0 (`List[int]`):
+                List of ids.
+            token_ids_1 (`List[int]`, *optional*):
+                Optional second list of IDs for sequence pairs.
+
+        Returns:
+            `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
+        """
+        bos_token_id = [self.bos_token_id] if self.add_bos_token else []
+        eos_token_id = [self.eos_token_id] if self.add_eos_token else []
+
+        output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
+
+        if token_ids_1 is not None:
+            output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
+
+        return output

+ 48 - 25
kernels/activation_kernels.cu

@@ -2,19 +2,16 @@
 #include <torch/extension.h>
 #include <c10/cuda/CUDAGuard.h>
 
+#include <cmath>
+
 #include "cuda_compat.h"
 #include "dispatch_utils.h"
 
 namespace aphrodite {
 
-template<typename T>
-__device__ __forceinline__ T silu(const T& x) {
-  // x * sigmoid(x)
-  return (T) (((float) x) / (1.0f + expf((float) -x)));
-}
-
-template<typename scalar_t>
-__global__ void silu_and_mul_kernel(
+// Activation and gating kernel template.
+template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
+__global__ void act_and_mul_kernel(
   scalar_t* __restrict__ out,               // [..., d]
   const scalar_t* __restrict__ input,       // [..., 2, d]
   const int d) {
@@ -22,32 +19,58 @@ __global__ void silu_and_mul_kernel(
   for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
     const scalar_t x = APHRODITE_LDG(&input[token_idx * 2 * d + idx]);
     const scalar_t y = APHRODITE_LDG(&input[token_idx * 2 * d + d + idx]);
-    out[token_idx * d + idx] = silu(x) * y;
+    out[token_idx * d + idx] = ACT_FN(x) * y;
   }
 }
 
+template<typename T>
+__device__ __forceinline__ T silu_kernel(const T& x) {
+  // x * sigmoid(x)
+  return (T) (((float) x) / (1.0f + expf((float) -x)));
+}
+
+template<typename T>
+__device__ __forceinline__ T gelu_kernel(const T& x) {
+  // Equivalent to PyTorch GELU with 'none' approximation.
+  // Refer to:
+  // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L38
+  const float f = (float) x;
+  constexpr float ALPHA = M_SQRT1_2;
+  return (T) (f * 0.5f * (1.0f + ::erf(f * ALPHA)));
+}
+
 } // namespace aphrodite
 
+// Launch activation and gating kernel.
+#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL)                                                 \
+  int d = input.size(-1) / 2;                                                                 \
+  int64_t num_tokens = input.numel() / input.size(-1);                                        \
+  dim3 grid(num_tokens);                                                                      \
+  dim3 block(std::min(d, 1024));                                                              \
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));                           \
+  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();                               \
+  APHRODITE_DISPATCH_FLOATING_TYPES(                                                          \
+    input.scalar_type(),                                                                      \
+    "act_and_mul_kernel",                                                                     \
+    [&] {                                                                                     \
+      aphrodite::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>(  \
+        out.data_ptr<scalar_t>(),                                                             \
+        input.data_ptr<scalar_t>(),                                                           \
+        d);                                                                                   \
+    });
+
 void silu_and_mul(
   torch::Tensor& out,      // [..., d]
   torch::Tensor& input)    // [..., 2 * d]
 {
-  int64_t num_tokens = input.numel() / input.size(-1);
-  int d = input.size(-1) / 2;
-
-  dim3 grid(num_tokens);
-  dim3 block(std::min(d, 1024));
-  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
-  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
-  APHRODITE_DISPATCH_FLOATING_TYPES(
-    input.scalar_type(),
-    "silu_and_mul_kernel",
-    [&] {
-      aphrodite::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
-        out.data_ptr<scalar_t>(),
-        input.data_ptr<scalar_t>(),
-        d);
-    });
+  LAUNCH_ACTIVATION_GATE_KERNEL(aphrodite::silu_kernel);
+}
+
+void gelu_and_mul(
+  torch::Tensor& out,      // [..., d]
+  torch::Tensor& input)    // [..., 2 * d]
+{
+  LAUNCH_ACTIVATION_GATE_KERNEL(aphrodite::gelu_kernel);
 }
 
 namespace aphrodite {

+ 4 - 0
kernels/ops.h

@@ -57,6 +57,10 @@ void silu_and_mul(
   torch::Tensor& out,
   torch::Tensor& input);
 
+void gelu_and_mul(
+  torch::Tensor& out,
+  torch::Tensor& input);
+
 void gelu_new(
   torch::Tensor& out,
   torch::Tensor& input);

+ 4 - 0
kernels/pybind.cpp

@@ -22,6 +22,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
     "silu_and_mul",
     &silu_and_mul,
     "Activation function used in SwiGLU.");
+  ops.def(
+    "gelu_and_mul",
+    &gelu_and_mul,
+    "Activation function used in GeGLU.");
   ops.def(
     "gelu_new",
     &gelu_new,