Forráskód Böngészése

feat: add Solar model support (#701)

* feat: add support for SolarForCausalLM models

* progress bar
AlpinDale 6 hónapja
szülő
commit
0c6f03b7e4
2 módosított fájl, 559 hozzáadás és 0 törlés
  1. 1 0
      aphrodite/modeling/models/__init__.py
  2. 558 0
      aphrodite/modeling/models/solar.py

+ 1 - 0
aphrodite/modeling/models/__init__.py

@@ -76,6 +76,7 @@ _GENERATION_MODELS = {
     "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
     "ChameleonForConditionalGeneration": ("chameleon",
                                           "ChameleonForConditionalGeneration"),
+    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
 }
 
 _EMBEDDING_MODELS = {

+ 558 - 0
aphrodite/modeling/models/solar.py

@@ -0,0 +1,558 @@
+# coding=utf-8
+# Adapted from
+# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
+# Copyright 2023 The PygmalionAI team.
+# Copyright 2023 The vLLM team.
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only Solar model compatible with HuggingFace weights."""
+from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+
+from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.config import CacheConfig, LoRAConfig
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.utils import is_hip, progress_bar
+from aphrodite.distributed import (get_pp_group,
+                                   get_tensor_model_parallel_rank,
+                                   get_tensor_model_parallel_world_size)
+from aphrodite.modeling.layers.activation import SiluAndMul
+from aphrodite.modeling.layers.layernorm import RMSNorm
+from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
+                                              QKVParallelLinear,
+                                              RowParallelLinear)
+from aphrodite.modeling.layers.logits_processor import LogitsProcessor
+from aphrodite.modeling.layers.rotary_embedding import get_rope
+from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
+from aphrodite.modeling.model_loader.weight_utils import (
+    default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
+from aphrodite.modeling.models.interfaces import SupportsLoRA
+from aphrodite.modeling.models.utils import (PPMissingLayer,
+                                             is_pp_missing_parameter,
+                                             make_layers)
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.quantization.base_config import QuantizationConfig
+from aphrodite.quantization.compressed_tensors.utils import (
+    get_compressed_tensors_cache_scale)
+
+
+class SolarMLP(nn.Module):
+
+    def __init__(
+        self,
+        hidden_size: int,
+        intermediate_size: int,
+        hidden_act: str,
+        quant_config: Optional[QuantizationConfig] = None,
+        bias: bool = False,
+        prefix: str = "",
+    ) -> None:
+        super().__init__()
+        self.gate_up_proj = MergedColumnParallelLinear(
+            input_size=hidden_size,
+            output_sizes=[intermediate_size] * 2,
+            bias=bias,
+            quant_config=quant_config,
+            prefix=f"{prefix}.gate_up_proj")
+        self.down_proj = RowParallelLinear(input_size=intermediate_size,
+                                           output_size=hidden_size,
+                                           bias=bias,
+                                           quant_config=quant_config,
+                                           prefix=f"{prefix}.down_proj")
+        if hidden_act != "silu":
+            raise ValueError(f"Unsupported activation: {hidden_act}. "
+                             "Only silu is supported for now.")
+        self.act_fn = SiluAndMul()
+
+    def forward(self, x):
+        gate_up, _ = self.gate_up_proj(x)
+        x = self.act_fn(gate_up)
+        x, _ = self.down_proj(x)
+        return x
+
+
+class SolarAttention(nn.Module):
+
+    def __init__(
+        self,
+        config,
+        hidden_size: int,
+        num_heads: int,
+        num_kv_heads: int,
+        rope_theta: float = 10000,
+        rope_scaling: Optional[Dict[str, Any]] = None,
+        max_position_embeddings: int = 8192,
+        quant_config: Optional[QuantizationConfig] = None,
+        bias: bool = False,
+        cache_config: Optional[CacheConfig] = None,
+        prefix: str = "",
+    ) -> None:
+        super().__init__()
+        self.hidden_size = hidden_size
+        tp_size = get_tensor_model_parallel_world_size()
+        self.total_num_heads = num_heads
+        assert self.total_num_heads % tp_size == 0
+        self.num_heads = self.total_num_heads // tp_size
+        self.total_num_kv_heads = num_kv_heads
+        if self.total_num_kv_heads >= tp_size:
+            # Number of KV heads is greater than TP size, so we partition
+            # the KV heads across multiple tensor parallel GPUs.
+            assert self.total_num_kv_heads % tp_size == 0
+        else:
+            # Number of KV heads is less than TP size, so we replicate
+            # the KV heads across multiple tensor parallel GPUs.
+            assert tp_size % self.total_num_kv_heads == 0
+        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
+        # MistralConfig has an optional head_dim introduced by Mistral-Nemo
+        self.head_dim = getattr(config, "head_dim",
+                                self.hidden_size // self.total_num_heads)
+        self.q_size = self.num_heads * self.head_dim
+        self.kv_size = self.num_kv_heads * self.head_dim
+        self.scaling = self.head_dim**-0.5
+        self.rope_theta = rope_theta
+        self.max_position_embeddings = max_position_embeddings
+
+        self.qkv_proj = QKVParallelLinear(
+            hidden_size=hidden_size,
+            head_size=self.head_dim,
+            total_num_heads=self.total_num_heads,
+            total_num_kv_heads=self.total_num_kv_heads,
+            bias=bias,
+            quant_config=quant_config,
+            prefix=f"{prefix}.qkv_proj",
+        )
+        self.o_proj = RowParallelLinear(
+            input_size=self.total_num_heads * self.head_dim,
+            output_size=hidden_size,
+            bias=bias,
+            quant_config=quant_config,
+            prefix=f"{prefix}.o_proj",
+        )
+
+        self.rotary_emb = get_rope(
+            self.head_dim,
+            rotary_dim=self.head_dim,
+            max_position=max_position_embeddings,
+            base=rope_theta,
+            rope_scaling=rope_scaling,
+        )
+        self.attn = Attention(self.num_heads,
+                              self.head_dim,
+                              self.scaling,
+                              num_kv_heads=self.num_kv_heads,
+                              cache_config=cache_config,
+                              quant_config=quant_config)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        qkv, _ = self.qkv_proj(hidden_states)
+        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+        q, k = self.rotary_emb(positions, q, k)
+        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
+        output, _ = self.o_proj(attn_output)
+        return output
+
+
+class SolarDecoderLayer(nn.Module):
+
+    def __init__(
+        self,
+        config,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+        prefix: str = "",
+    ) -> None:
+        super().__init__()
+        self.hidden_size = config.hidden_size
+        rope_theta = getattr(config, "rope_theta", 10000)
+        rope_scaling = getattr(config, "rope_scaling", None)
+        if rope_scaling is not None and getattr(
+                config, "original_max_position_embeddings", None):
+            rope_scaling["original_max_position_embeddings"] = (
+                config.original_max_position_embeddings)
+        max_position_embeddings = getattr(config, "max_position_embeddings",
+                                          8192)
+        # Support abacusai/Smaug-72B-v0.1 with attention_bias
+        # Support internlm/internlm-7b with bias
+        attention_bias = getattr(config, "attention_bias", False) or getattr(
+            config, "bias", False)
+        self.self_attn = SolarAttention(
+            config=config,
+            hidden_size=self.hidden_size,
+            num_heads=config.num_attention_heads,
+            num_kv_heads=getattr(config, "num_key_value_heads",
+                                 config.num_attention_heads),
+            rope_theta=rope_theta,
+            rope_scaling=rope_scaling,
+            max_position_embeddings=max_position_embeddings,
+            quant_config=quant_config,
+            bias=attention_bias,
+            cache_config=cache_config,
+            prefix=f"{prefix}.self_attn",
+        )
+        self.mlp = SolarMLP(
+            hidden_size=self.hidden_size,
+            intermediate_size=config.intermediate_size,
+            hidden_act=config.hidden_act,
+            quant_config=quant_config,
+            bias=getattr(config, "mlp_bias", False),
+            prefix=f"{prefix}.mlp",
+        )
+        self.input_layernorm = RMSNorm(config.hidden_size,
+                                       eps=config.rms_norm_eps)
+        self.post_attention_layernorm = RMSNorm(config.hidden_size,
+                                                eps=config.rms_norm_eps)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+        residual: Optional[torch.Tensor],
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        # Self Attention
+        if residual is None:
+            residual = hidden_states
+            hidden_states = self.input_layernorm(hidden_states)
+        else:
+            hidden_states, residual = self.input_layernorm(
+                hidden_states, residual)
+        hidden_states = self.self_attn(
+            positions=positions,
+            hidden_states=hidden_states,
+            kv_cache=kv_cache,
+            attn_metadata=attn_metadata,
+        )
+
+        # Fully Connected
+        hidden_states, residual = self.post_attention_layernorm(
+            hidden_states, residual)
+        hidden_states = self.mlp(hidden_states)
+        return hidden_states, residual
+
+
+class SolarModel(nn.Module):
+
+    def __init__(
+        self,
+        config,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+        lora_config: Optional[LoRAConfig] = None,
+        prefix: str = "",
+    ) -> None:
+        super().__init__()
+        self.config = config
+        self.padding_idx = config.pad_token_id
+        lora_vocab = (lora_config.lora_extra_vocab_size *
+                      (lora_config.max_loras or 1)) if lora_config else 0
+        self.vocab_size = config.vocab_size + lora_vocab
+        self.org_vocab_size = config.vocab_size
+        if get_pp_group().is_first_rank or (config.tie_word_embeddings
+                                            and get_pp_group().is_last_rank):
+            self.embed_tokens = VocabParallelEmbedding(
+                self.vocab_size,
+                config.hidden_size,
+                org_num_embeddings=config.vocab_size,
+            )
+        else:
+            self.embed_tokens = PPMissingLayer()
+        self.start_layer, self.end_layer, self.layers = make_layers(
+            config.num_hidden_layers,
+            lambda prefix: SolarDecoderLayer(config=config,
+                                             cache_config=cache_config,
+                                             quant_config=quant_config,
+                                             prefix=prefix),
+            prefix=f"{prefix}.layers")
+        if get_pp_group().is_last_rank:
+            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+        else:
+            self.norm = PPMissingLayer()
+
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.embed_tokens(input_ids)
+
+    def forward(
+        self,
+        input_ids: Optional[torch.Tensor],
+        positions: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors],
+        inputs_embeds: Optional[torch.Tensor] = None,
+    ) -> Union[torch.Tensor, IntermediateTensors]:
+        if get_pp_group().is_first_rank:
+            if inputs_embeds is not None:
+                hidden_states = inputs_embeds
+            else:
+                hidden_states = self.get_input_embeddings(input_ids)
+            residual = None
+        else:
+            assert intermediate_tensors is not None
+            hidden_states = intermediate_tensors["hidden_states"]
+            residual = intermediate_tensors["residual"]
+
+        bskcn_h_1 = None
+        bskcn_h_2 = None
+        bskcn_r_1 = None
+        bskcn_r_2 = None
+        bskcn_tv = (self.config.bskcn_tv[0] \
+                    if self.training else self.config.bskcn_tv[1])
+
+        for i in range(self.start_layer, self.end_layer):
+            if i in self.config.bskcn_1:
+                bskcn_h_1 = hidden_states.clone()
+                bskcn_r_1 = residual.clone()
+            if i in self.config.bskcn_2:
+                bskcn_h_2 = hidden_states.clone()
+                bskcn_r_2 = residual.clone()
+            if i in self.config.bskcn_3:
+                hidden_states = bskcn_h_1*bskcn_tv + hidden_states*(1-bskcn_tv)
+                residual = bskcn_r_1*bskcn_tv + residual*(1-bskcn_tv)
+            if i in self.config.bskcn_4:
+                hidden_states = bskcn_h_2*bskcn_tv + hidden_states*(1-bskcn_tv)
+                residual = bskcn_r_2*bskcn_tv + residual*(1-bskcn_tv)
+            layer = self.layers[i]
+            hidden_states, residual = layer(
+                positions,
+                hidden_states,
+                kv_caches[i - self.start_layer],
+                attn_metadata,
+                residual,
+            )
+
+        if not get_pp_group().is_last_rank:
+            return IntermediateTensors({
+                "hidden_states": hidden_states,
+                "residual": residual
+            })
+
+        hidden_states, _ = self.norm(hidden_states, residual)
+        return hidden_states
+
+
+class SolarForCausalLM(nn.Module, SupportsLoRA):
+    packed_modules_mapping = {
+        "qkv_proj": [
+            "q_proj",
+            "k_proj",
+            "v_proj",
+        ],
+        "gate_up_proj": [
+            "gate_proj",
+            "up_proj",
+        ],
+    }
+
+    # LoRA specific attributes
+    supported_lora_modules = [
+        "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
+        "lm_head"
+    ]
+    embedding_modules = {
+        "embed_tokens": "input_embeddings",
+        "lm_head": "output_embeddings",
+    }
+    embedding_padding_modules = ["lm_head"]
+    bitsandbytes_stacked_params_mapping = {
+        # shard_name, weight_name, index
+        "q_proj": ("qkv_proj", 0),
+        "k_proj": ("qkv_proj", 1),
+        "v_proj": ("qkv_proj", 2),
+        "gate_proj": ("gate_up_proj", 0),
+        "up_proj": ("gate_up_proj", 1),
+    }
+
+    def __init__(
+        self,
+        config,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+        lora_config: Optional[LoRAConfig] = None,
+    ) -> None:
+        super().__init__()
+
+        self.config = config
+        self.lora_config = lora_config
+
+        self.model = SolarModel(config,
+                                cache_config,
+                                quant_config,
+                                lora_config=lora_config,
+                                prefix="model")
+        if get_pp_group().is_last_rank:
+            self.unpadded_vocab_size = config.vocab_size
+            if lora_config:
+                self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
+            self.lm_head = ParallelLMHead(
+                self.unpadded_vocab_size,
+                config.hidden_size,
+                org_num_embeddings=config.vocab_size,
+                padding_size=DEFAULT_VOCAB_PADDING_SIZE
+                # We need bigger padding if using lora for kernel
+                # compatibility
+                if not lora_config else lora_config.lora_vocab_padding_size,
+                quant_config=quant_config,
+            )
+            if config.tie_word_embeddings:
+                self.lm_head.weight = self.model.embed_tokens.weight
+
+            logit_scale = getattr(config, "logit_scale", 1.0)
+            self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
+                                                    config.vocab_size,
+                                                    logit_scale)
+            self.sampler = Sampler()
+        else:
+            self.lm_head = PPMissingLayer()
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
+    ) -> Union[torch.Tensor, IntermediateTensors]:
+        model_output = self.model(input_ids, positions, kv_caches,
+                                  attn_metadata, intermediate_tensors)
+        return model_output
+
+    def compute_logits(self, hidden_states: torch.Tensor,
+                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+        logits = self.logits_processor(self.lm_head, hidden_states,
+                                       sampling_metadata)
+        return logits
+
+    def sample(
+        self,
+        logits: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(logits, sampling_metadata)
+        return next_tokens
+
+    def make_empty_intermediate_tensors(
+            self, batch_size: int, dtype: torch.dtype,
+            device: torch.device) -> IntermediateTensors:
+        return IntermediateTensors({
+            "hidden_states":
+            torch.zeros((batch_size, self.config.hidden_size),
+                        dtype=dtype,
+                        device=device),
+            "residual":
+            torch.zeros((batch_size, self.config.hidden_size),
+                        dtype=dtype,
+                        device=device),
+        })
+
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+        stacked_params_mapping = [
+            # (param_name, shard_name, shard_id)
+            (".qkv_proj", ".q_proj", "q"),
+            (".qkv_proj", ".k_proj", "k"),
+            (".qkv_proj", ".v_proj", "v"),
+            (".gate_up_proj", ".gate_proj", 0),
+            (".gate_up_proj", ".up_proj", 1),
+        ]
+        params_dict = dict(self.named_parameters())
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
+            if "rotary_emb.inv_freq" in name:
+                continue
+            if ("rotary_emb.cos_cached" in name
+                    or "rotary_emb.sin_cached" in name):
+                # Models trained using ColossalAI may include these tensors in
+                # the checkpoint. Skip them.
+                continue
+            if scale_name := get_compressed_tensors_cache_scale(name):
+                # Loading kv cache scales for compressed-tensors quantization
+                param = params_dict[scale_name]
+                weight_loader = getattr(param, "weight_loader",
+                                        default_weight_loader)
+                loaded_weight = loaded_weight[0]
+                weight_loader(param, loaded_weight)
+                continue
+            for (param_name, weight_name, shard_id) in stacked_params_mapping:
+                if weight_name not in name:
+                    continue
+                name = name.replace(weight_name, param_name)
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+
+                if is_pp_missing_parameter(name, self):
+                    continue
+
+                param = params_dict[name]
+                weight_loader = param.weight_loader
+                weight_loader(param, loaded_weight, shard_id)
+
+                break
+            else:
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+                # Remapping the name of FP8 kv-scale.
+                name = maybe_remap_kv_scale_name(name, params_dict)
+                if name is None:
+                    continue
+
+                if is_pp_missing_parameter(name, self):
+                    continue
+
+                param = params_dict[name]
+                weight_loader = getattr(param, "weight_loader",
+                                        default_weight_loader)
+                weight_loader(param, loaded_weight)
+
+    # If this function is called, it should always initialize KV cache scale
+    # factors (or else raise an exception). Thus, handled exceptions should
+    # make sure to leave KV cache scale factors in a known good (dummy) state
+    def load_kv_cache_scales(self, quantization_param_path: str) -> None:
+        tp_size = get_tensor_model_parallel_world_size()
+        tp_rank = get_tensor_model_parallel_rank()
+        for layer_idx, scaling_factor in kv_cache_scales_loader(
+                quantization_param_path, tp_rank, tp_size,
+                self.config.num_hidden_layers,
+                self.config.__class__.model_type):
+            if not isinstance(self.model.layers[layer_idx], nn.Identity):
+                layer_self_attn = self.model.layers[layer_idx].self_attn
+
+            if is_hip():
+                # The scaling factor convention we are assuming is
+                # quantized_value * scaling_factor ~= true_value
+                # which is consistent with the practice of setting
+                # scaling_factor = tensor_amax / FPtype_max
+                scaling_factor *= 2
+            if hasattr(layer_self_attn, "kv_scale"):
+                layer_self_attn.attn._kv_scale = scaling_factor
+            else:
+                raise RuntimeError("Self attention has no KV cache scaling "
+                                   "factor attribute!")