Просмотр исходного кода

models: add support for IBM Granite (PowerLM) models (#978)

AlpinDale 2 месяцев назад
Родитель
Сommit
adb6982090

+ 1 - 0
aphrodite/modeling/models/__init__.py

@@ -66,6 +66,7 @@ _GENERATION_MODELS = {
     "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
     "SolarForCausalLM": ("solar", "SolarForCausalLM"),
     "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
+    "GraniteForCausalLM": ("granite", "GraniteForCausalLM")
 }
 
 _EMBEDDING_MODELS = {

+ 565 - 0
aphrodite/modeling/models/granite.py

@@ -0,0 +1,565 @@
+# coding=utf-8
+# Adapted from
+# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
+# Copyright 2023 The PygmalionAI team.
+# Copyright 2023 The vLLM team.
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only IBM Granite model compatible with HuggingFace weights."""
+from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+
+from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.config import CacheConfig, LoRAConfig
+from aphrodite.common.sequence import IntermediateTensors
+from aphrodite.common.utils import is_hip
+from aphrodite.distributed import (get_pp_group,
+                                   get_tensor_model_parallel_rank,
+                                   get_tensor_model_parallel_world_size)
+from aphrodite.modeling.layers.activation import SiluAndMul
+from aphrodite.modeling.layers.layernorm import RMSNorm
+from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
+                                              QKVParallelLinear,
+                                              RowParallelLinear)
+from aphrodite.modeling.layers.logits_processor import LogitsProcessor
+from aphrodite.modeling.layers.rotary_embedding import get_rope
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
+from aphrodite.modeling.model_loader.weight_utils import (
+    default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.quantization.base_config import QuantizationConfig
+from aphrodite.quantization.compressed_tensors.utils import (
+    get_compressed_tensors_cache_scale)
+from aphrodite.transformers_utils.configs.granite import GraniteConfig
+
+from .interfaces import SupportsLoRA
+from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
+
+
+class GraniteMLP(nn.Module):
+    def __init__(
+        self,
+        hidden_size: int,
+        intermediate_size: int,
+        hidden_act: str,
+        quant_config: Optional[QuantizationConfig] = None,
+        bias: bool = False,
+        prefix: str = "",
+    ) -> None:
+        super().__init__()
+        self.gate_up_proj = MergedColumnParallelLinear(
+            input_size=hidden_size,
+            output_sizes=[intermediate_size] * 2,
+            bias=bias,
+            quant_config=quant_config,
+            prefix=f"{prefix}.gate_up_proj",
+        )
+        self.down_proj = RowParallelLinear(
+            input_size=intermediate_size,
+            output_size=hidden_size,
+            bias=bias,
+            quant_config=quant_config,
+            prefix=f"{prefix}.down_proj",
+        )
+        if hidden_act != "silu":
+            raise ValueError(
+                f"Unsupported activation: {hidden_act}. "
+                "Only silu is supported for now."
+            )
+        self.act_fn = SiluAndMul()
+
+    def forward(self, x):
+        gate_up, _ = self.gate_up_proj(x)
+        x = self.act_fn(gate_up)
+        x, _ = self.down_proj(x)
+        return x
+
+
+class GraniteAttention(nn.Module):
+    def __init__(
+        self,
+        config: GraniteConfig,
+        hidden_size: int,
+        num_heads: int,
+        num_kv_heads: int,
+        rope_theta: float = 10000,
+        rope_scaling: Optional[Dict[str, Any]] = None,
+        max_position_embeddings: int = 8192,
+        quant_config: Optional[QuantizationConfig] = None,
+        bias: bool = False,
+        cache_config: Optional[CacheConfig] = None,
+        prefix: str = "",
+    ) -> None:
+        super().__init__()
+        self.hidden_size = hidden_size
+        tp_size = get_tensor_model_parallel_world_size()
+        self.total_num_heads = num_heads
+        assert self.total_num_heads % tp_size == 0
+        self.num_heads = self.total_num_heads // tp_size
+        self.total_num_kv_heads = num_kv_heads
+        if self.total_num_kv_heads >= tp_size:
+            # Number of KV heads is greater than TP size, so we partition
+            # the KV heads across multiple tensor parallel GPUs.
+            assert self.total_num_kv_heads % tp_size == 0
+        else:
+            # Number of KV heads is less than TP size, so we replicate
+            # the KV heads across multiple tensor parallel GPUs.
+            assert tp_size % self.total_num_kv_heads == 0
+        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
+        # MistralConfig has an optional head_dim introduced by Mistral-Nemo
+        self.head_dim = getattr(
+            config, "head_dim", self.hidden_size // self.total_num_heads
+        )
+        self.q_size = self.num_heads * self.head_dim
+        self.kv_size = self.num_kv_heads * self.head_dim
+        self.scaling = config.attention_multiplier
+        self.rope_theta = rope_theta
+        self.max_position_embeddings = max_position_embeddings
+        self.qkv_proj = QKVParallelLinear(
+            hidden_size=hidden_size,
+            head_size=self.head_dim,
+            total_num_heads=self.total_num_heads,
+            total_num_kv_heads=self.total_num_kv_heads,
+            bias=bias,
+            quant_config=quant_config,
+            prefix=f"{prefix}.qkv_proj",
+        )
+        self.o_proj = RowParallelLinear(
+            input_size=self.total_num_heads * self.head_dim,
+            output_size=hidden_size,
+            bias=bias,
+            quant_config=quant_config,
+            prefix=f"{prefix}.o_proj",
+        )
+        self.rotary_emb = get_rope(
+            self.head_dim,
+            rotary_dim=self.head_dim,
+            max_position=max_position_embeddings,
+            base=rope_theta,
+            rope_scaling=rope_scaling,
+        )
+        self.attn = Attention(
+            self.num_heads,
+            self.head_dim,
+            self.scaling,
+            num_kv_heads=self.num_kv_heads,
+            cache_config=cache_config,
+            quant_config=quant_config,
+        )
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        qkv, _ = self.qkv_proj(hidden_states)
+        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+        q, k = self.rotary_emb(positions, q, k)
+        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
+        output, _ = self.o_proj(attn_output)
+        return output
+
+
+class GraniteDecoderLayer(nn.Module):
+    def __init__(
+        self,
+        config: GraniteConfig,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+        prefix: str = "",
+    ) -> None:
+        super().__init__()
+        self.hidden_size = config.hidden_size
+        self.residual_multiplier = config.residual_multiplier
+        rope_theta = getattr(config, "rope_theta", 10000)
+        rope_scaling = getattr(config, "rope_scaling", None)
+        if rope_scaling is not None and getattr(
+            config, "original_max_position_embeddings", None
+        ):
+            rope_scaling[
+                "original_max_position_embeddings"
+            ] = config.original_max_position_embeddings
+        max_position_embeddings = getattr(
+            config, "max_position_embeddings", 8192
+        )
+        # Support abacusai/Smaug-72B-v0.1 with attention_bias
+        # Support internlm/internlm-7b with bias
+        attention_bias = getattr(config, "attention_bias", False) or getattr(
+            config, "bias", False
+        )
+        self.self_attn = GraniteAttention(
+            config=config,
+            hidden_size=self.hidden_size,
+            num_heads=config.num_attention_heads,
+            num_kv_heads=getattr(
+                config, "num_key_value_heads", config.num_attention_heads
+            ),
+            rope_theta=rope_theta,
+            rope_scaling=rope_scaling,
+            max_position_embeddings=max_position_embeddings,
+            quant_config=quant_config,
+            bias=attention_bias,
+            cache_config=cache_config,
+            prefix=f"{prefix}.self_attn",
+        )
+        self.mlp = GraniteMLP(
+            hidden_size=self.hidden_size,
+            intermediate_size=config.intermediate_size,
+            hidden_act=config.hidden_act,
+            quant_config=quant_config,
+            bias=getattr(config, "mlp_bias", False),
+            prefix=f"{prefix}.mlp",
+        )
+        self.input_layernorm = RMSNorm(
+            config.hidden_size, eps=config.rms_norm_eps
+        )
+        self.post_attention_layernorm = RMSNorm(
+            config.hidden_size, eps=config.rms_norm_eps
+        )
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        # Self Attention
+        residual = hidden_states
+        hidden_states = self.input_layernorm(hidden_states)
+        hidden_states = self.self_attn(
+            positions=positions,
+            hidden_states=hidden_states,
+            kv_cache=kv_cache,
+            attn_metadata=attn_metadata,
+        )
+        hidden_states = residual + hidden_states * self.residual_multiplier
+        # Fully Connected
+        residual = hidden_states
+        hidden_states = self.post_attention_layernorm(hidden_states)
+        hidden_states = self.mlp(hidden_states)
+        hidden_states = residual + hidden_states * self.residual_multiplier
+        return hidden_states
+
+
+class GraniteModel(nn.Module):
+    def __init__(
+        self,
+        config: GraniteConfig,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+        lora_config: Optional[LoRAConfig] = None,
+        prefix: str = "",
+    ) -> None:
+        super().__init__()
+        self.config = config
+        self.padding_idx = config.pad_token_id
+        lora_vocab = (
+            (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
+            if lora_config
+            else 0
+        )
+        self.vocab_size = config.vocab_size + lora_vocab
+        self.org_vocab_size = config.vocab_size
+        if get_pp_group().is_first_rank or (
+            config.tie_word_embeddings and get_pp_group().is_last_rank
+        ):
+            self.embed_tokens = VocabParallelEmbedding(
+                self.vocab_size,
+                config.hidden_size,
+                org_num_embeddings=config.vocab_size,
+                quant_config=quant_config,
+            )
+        else:
+            self.embed_tokens = PPMissingLayer()
+        self.start_layer, self.end_layer, self.layers = make_layers(
+            config.num_hidden_layers,
+            lambda prefix: GraniteDecoderLayer(
+                config=config,
+                cache_config=cache_config,
+                quant_config=quant_config,
+                prefix=prefix,
+            ),
+            prefix=f"{prefix}.layers",
+        )
+        if get_pp_group().is_last_rank:
+            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+        else:
+            self.norm = PPMissingLayer()
+
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.embed_tokens(input_ids)
+
+    def forward(
+        self,
+        input_ids: Optional[torch.Tensor],
+        positions: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors],
+        inputs_embeds: Optional[torch.Tensor] = None,
+    ) -> Union[torch.Tensor, IntermediateTensors]:
+        if get_pp_group().is_first_rank:
+            if inputs_embeds is not None:
+                hidden_states = inputs_embeds
+            else:
+                hidden_states = self.get_input_embeddings(input_ids)
+            residual = None
+        else:
+            assert intermediate_tensors is not None
+            hidden_states = intermediate_tensors["hidden_states"]
+            residual = intermediate_tensors["residual"]
+        hidden_states *= self.config.embedding_multiplier
+        for i in range(self.start_layer, self.end_layer):
+            layer = self.layers[i]
+            hidden_states = layer(
+                positions,
+                hidden_states,
+                kv_caches[i - self.start_layer],
+                attn_metadata,
+            )
+        if not get_pp_group().is_last_rank:
+            return IntermediateTensors(
+                {"hidden_states": hidden_states, "residual": residual}
+            )
+        hidden_states = self.norm(hidden_states)
+        return hidden_states
+
+
+class GraniteForCausalLM(nn.Module, SupportsLoRA):
+    packed_modules_mapping = {
+        "qkv_proj": [
+            "q_proj",
+            "k_proj",
+            "v_proj",
+        ],
+        "gate_up_proj": [
+            "gate_proj",
+            "up_proj",
+        ],
+    }
+    # LoRA specific attributes
+    supported_lora_modules = [
+        "qkv_proj",
+        "o_proj",
+        "gate_up_proj",
+        "down_proj",
+        "embed_tokens",
+        "lm_head",
+    ]
+    embedding_modules = {
+        "embed_tokens": "input_embeddings",
+        "lm_head": "output_embeddings",
+    }
+    embedding_padding_modules = ["lm_head"]
+    bitsandbytes_stacked_params_mapping = {
+        # shard_name, weight_name, index
+        "q_proj": ("qkv_proj", 0),
+        "k_proj": ("qkv_proj", 1),
+        "v_proj": ("qkv_proj", 2),
+        "gate_proj": ("gate_up_proj", 0),
+        "up_proj": ("gate_up_proj", 1),
+    }
+
+    def __init__(
+        self,
+        config: GraniteConfig,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+        lora_config: Optional[LoRAConfig] = None,
+    ) -> None:
+        super().__init__()
+        self.config = config
+        self.lora_config = lora_config
+        self.model = GraniteModel(
+            config,
+            cache_config,
+            quant_config,
+            lora_config=lora_config,
+            prefix="model",
+        )
+        if get_pp_group().is_last_rank:
+            self.unpadded_vocab_size = config.vocab_size
+            if lora_config:
+                self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
+            self.lm_head = ParallelLMHead(
+                self.unpadded_vocab_size,
+                config.hidden_size,
+                org_num_embeddings=config.vocab_size,
+                padding_size=DEFAULT_VOCAB_PADDING_SIZE
+                # We need bigger padding if using lora for kernel
+                # compatibility
+                if not lora_config
+                else lora_config.lora_vocab_padding_size,
+                quant_config=quant_config,
+            )
+            if config.tie_word_embeddings:
+                self.lm_head.weight = self.model.embed_tokens.weight
+            logit_scale = getattr(config, "logit_scale", 1.0)
+            self.logits_processor = LogitsProcessor(
+                self.unpadded_vocab_size, config.vocab_size, logit_scale
+            )
+            self.sampler = Sampler()
+        else:
+            self.lm_head = PPMissingLayer()
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
+    ) -> Union[torch.Tensor, IntermediateTensors]:
+        model_output = self.model(
+            input_ids, positions, kv_caches, attn_metadata, intermediate_tensors
+        )
+        return model_output
+
+    def compute_logits(
+        self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata
+    ) -> Optional[torch.Tensor]:
+        logits = self.logits_processor(
+            self.lm_head, hidden_states, sampling_metadata
+        )
+        return logits
+
+    def sample(
+        self,
+        logits: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(logits, sampling_metadata)
+        return next_tokens
+
+    def make_empty_intermediate_tensors(
+        self, batch_size: int, dtype: torch.dtype, device: torch.device
+    ) -> IntermediateTensors:
+        return IntermediateTensors(
+            {
+                "hidden_states": torch.zeros(
+                    (batch_size, self.config.hidden_size),
+                    dtype=dtype,
+                    device=device,
+                ),
+                "residual": torch.zeros(
+                    (batch_size, self.config.hidden_size),
+                    dtype=dtype,
+                    device=device,
+                ),
+            }
+        )
+
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+        stacked_params_mapping = [
+            # (param_name, shard_name, shard_id)
+            (".qkv_proj", ".q_proj", "q"),
+            (".qkv_proj", ".k_proj", "k"),
+            (".qkv_proj", ".v_proj", "v"),
+            (".gate_up_proj", ".gate_proj", 0),
+            (".gate_up_proj", ".up_proj", 1),
+        ]
+        params_dict = dict(self.named_parameters())
+        for name, loaded_weight in weights:
+            if "rotary_emb.inv_freq" in name:
+                continue
+            if (
+                "rotary_emb.cos_cached" in name
+                or "rotary_emb.sin_cached" in name
+            ):
+                # Models trained using ColossalAI may include these tensors in
+                # the checkpoint. Skip them.
+                continue
+            # With tie_word_embeddings, we can skip lm_head.weight
+            # The weight might appear unnecessarily in the files if the model is
+            # processed with quantization, LoRA, fine-tuning, etc.
+            if self.config.tie_word_embeddings and "lm_head.weight" in name:
+                continue
+            if scale_name := get_compressed_tensors_cache_scale(name):
+                # Loading kv cache scales for compressed-tensors quantization
+                param = params_dict[scale_name]
+                weight_loader = getattr(
+                    param, "weight_loader", default_weight_loader
+                )
+                loaded_weight = loaded_weight[0]
+                weight_loader(param, loaded_weight)
+                continue
+            for param_name, weight_name, shard_id in stacked_params_mapping:
+                if weight_name not in name:
+                    continue
+                name = name.replace(weight_name, param_name)
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+                if is_pp_missing_parameter(name, self):
+                    continue
+                param = params_dict[name]
+                weight_loader = param.weight_loader
+                weight_loader(param, loaded_weight, shard_id)
+                break
+            else:
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+                # Remapping the name of FP8 kv-scale.
+                name = maybe_remap_kv_scale_name(name, params_dict)
+                if name is None:
+                    continue
+                if is_pp_missing_parameter(name, self):
+                    continue
+                param = params_dict[name]
+                weight_loader = getattr(
+                    param, "weight_loader", default_weight_loader
+                )
+                weight_loader(param, loaded_weight)
+
+    # If this function is called, it should always initialize KV cache scale
+    # factors (or else raise an exception). Thus, handled exceptions should
+    # make sure to leave KV cache scale factors in a known good (dummy) state
+    def load_kv_cache_scales(self, quantization_param_path: str) -> None:
+        tp_size = get_tensor_model_parallel_world_size()
+        tp_rank = get_tensor_model_parallel_rank()
+        for layer_idx, scaling_factor in kv_cache_scales_loader(
+            quantization_param_path,
+            tp_rank,
+            tp_size,
+            self.config.num_hidden_layers,
+            self.config.__class__.model_type,
+        ):
+            if not isinstance(self.model.layers[layer_idx], nn.Identity):
+                layer_self_attn = self.model.layers[layer_idx].self_attn
+            if is_hip():
+                # The scaling factor convention we are assuming is
+                # quantized_value * scaling_factor ~= true_value
+                # which is consistent with the practice of setting
+                # scaling_factor = tensor_amax / FPtype_max
+                scaling_factor *= 2
+            if hasattr(layer_self_attn, "kv_scale"):
+                layer_self_attn.attn._kv_scale = scaling_factor
+            else:
+                raise RuntimeError(
+                    "Self attention has no KV cache scaling "
+                    "factor attribute!"
+                )

+ 2 - 1
aphrodite/transformers_utils/config.py

@@ -17,7 +17,7 @@ from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME
 
 import aphrodite.common.envs as envs
 from aphrodite.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
-                                                  EAGLEConfig,
+                                                  EAGLEConfig, GraniteConfig,
                                                   InternVLChatConfig,
                                                   JAISConfig, MedusaConfig,
                                                   MLPSpeculatorConfig,
@@ -46,6 +46,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
     "internvl_chat": InternVLChatConfig,
     "ultravox": UltravoxConfig,
     "eagle": EAGLEConfig,
+    "granite": GraniteConfig,
 }
 
 for name, cls in _CONFIG_REGISTRY.items():

+ 2 - 0
aphrodite/transformers_utils/configs/__init__.py

@@ -5,6 +5,7 @@ from aphrodite.transformers_utils.configs.eagle import EAGLEConfig
 # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
 # `FalconConfig` class from the official HuggingFace transformers library.
 from aphrodite.transformers_utils.configs.falcon import RWConfig
+from aphrodite.transformers_utils.configs.granite import GraniteConfig
 from aphrodite.transformers_utils.configs.internvl import InternVLChatConfig
 from aphrodite.transformers_utils.configs.jais import JAISConfig
 from aphrodite.transformers_utils.configs.medusa import MedusaConfig
@@ -24,4 +25,5 @@ __all__ = [
     "MedusaConfig",
     "UltravoxConfig",
     "EAGLEConfig",
+    "GraniteConfig",
 ]

+ 186 - 0
aphrodite/transformers_utils/configs/granite.py

@@ -0,0 +1,186 @@
+# coding=utf-8
+# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Granite model configuration"""
+from transformers.configuration_utils import PretrainedConfig
+from transformers.modeling_rope_utils import rope_config_validation
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+
+class GraniteConfig(PretrainedConfig):
+    r"""
+    This is the configuration class to store the configuration of
+    a [`GraniteModel`]. It is used to instantiate an Granite
+    model according to the specified arguments, defining the model architecture.
+    Instantiating a configuration with the defaults will yield a similar
+    configuration to that of the Granite-3B.
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to
+    control the model outputs. Read the documentation from [`PretrainedConfig`]
+    for more information.
+    Args:
+        vocab_size (`int`, *optional*, defaults to 32000):
+            Vocabulary size of the Granite model. Defines the number of
+            different tokens that can be represented by the `inputs_ids`
+            passed when calling [`GraniteModel`]
+        hidden_size (`int`, *optional*, defaults to 4096):
+            Dimension of the hidden representations.
+        intermediate_size (`int`, *optional*, defaults to 11008):
+            Dimension of the MLP representations.
+        num_hidden_layers (`int`, *optional*, defaults to 32):
+            Number of hidden layers in the Transformer decoder.
+        num_attention_heads (`int`, *optional*, defaults to 32):
+            Number of attention heads for each attention layer in the
+            Transformer decoder.
+        num_key_value_heads (`int`, *optional*):
+            This is the number of key_value heads that should be used to
+            implement Grouped Query Attention. If
+            `num_key_value_heads=num_attention_heads`, the model will use Multi
+            Head Attention (MHA), if `num_key_value_heads=1` the model will use
+            Multi Query Attention (MQA) otherwise GQA is used. When converting
+            a multi-head checkpoint to a GQA checkpoint, each group key and
+            value head should be constructed by meanpooling all the original
+            heads within that group. For more details checkout
+            [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not
+            specified, will default to `num_attention_heads`.
+        hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+            The non-linear activation function (function or string) in the
+            decoder.
+        max_position_embeddings (`int`, *optional*, defaults to 2048):
+            The maximum sequence length that this model might ever be used with.
+        initializer_range (`float`, *optional*, defaults to 0.02):
+            The standard deviation of the truncated_normal_initializer for
+            initializing all weight matrices.
+        rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+            The epsilon used by the rms normalization layers.
+        use_cache (`bool`, *optional*, defaults to `True`):
+            Whether or not the model should return the last key/values
+            attentions (not used by all models). Only relevant if
+            `config.is_decoder=True`.
+        pad_token_id (`int`, *optional*):
+            Padding token id.
+        bos_token_id (`int`, *optional*, defaults to 1):
+            Beginning of stream token id.
+        eos_token_id (`int`, *optional*, defaults to 2):
+            End of stream token id.
+        tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+            Whether to tie weight embeddings
+        rope_theta (`float`, *optional*, defaults to 10000.0):
+            The base period of the RoPE embeddings.
+        rope_scaling (`Dict`, *optional*):
+            Dictionary containing the scaling configuration for the RoPE
+            embeddings. Currently supports two scaling strategies: linear and
+            dynamic. Their scaling factor must be a float greater than 1. The
+            expected format is
+            `{"type": strategy name, "factor": scaling factor}`.
+            When using this flag, don't update `max_position_embeddings` to
+            the expected new maximum. See the following thread for more
+            information on how these scaling strategies behave:
+            https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/.
+            This is an experimental feature, subject to breaking API changes
+            in future versions.
+        attention_bias (`bool`, *optional*, defaults to `False`):
+            Whether to use a bias in the query, key, value and output
+            projection layers during self-attention.
+        attention_dropout (`float`, *optional*, defaults to 0.0):
+            The dropout ratio for the attention probabilities.
+        mlp_bias (`bool`, *optional*, defaults to `False`):
+            Whether to use a bias in up_proj, down_proj and gate_proj layers
+            in the MLP layers.
+        embedding_multiplier (`float`, *optional*, defaults to 1.0):
+            embedding multiplier
+        logits_scaling (`float`, *optional*, defaults to 1.0):
+            divisor for output logits
+        residual_multiplier (`float`, *optional*, defaults to 1.0):
+            residual multiplier
+        attention_multiplier (`float`, *optional*, defaults to 1.0):
+            attention multiplier
+    ```python
+    >>> from transformers import GraniteModel, GraniteConfig
+    >>> # Initializing a Granite granite-3b style configuration
+    >>> configuration = GraniteConfig()
+    >>> # Initializing a model from the granite-7b style configuration
+    >>> model = GraniteModel(configuration)
+    >>> # Accessing the model configuration
+    >>> configuration = model.config
+    ```"""
+
+    model_type = "granite"
+    keys_to_ignore_at_inference = ["past_key_values"]
+
+    def __init__(
+        self,
+        vocab_size=32000,
+        hidden_size=4096,
+        intermediate_size=11008,
+        num_hidden_layers=32,
+        num_attention_heads=32,
+        num_key_value_heads=None,
+        hidden_act="silu",
+        max_position_embeddings=2048,
+        initializer_range=0.02,
+        rms_norm_eps=1e-6,
+        use_cache=True,
+        pad_token_id=None,
+        bos_token_id=1,
+        eos_token_id=2,
+        tie_word_embeddings=False,
+        rope_theta=10000.0,
+        rope_scaling=None,
+        attention_bias=False,
+        attention_dropout=0.0,
+        mlp_bias=False,
+        embedding_multiplier=1.0,
+        logits_scaling=1.0,
+        residual_multiplier=1.0,
+        attention_multiplier=1.0,
+        **kwargs,
+    ):
+        self.vocab_size = vocab_size
+        self.max_position_embeddings = max_position_embeddings
+        self.hidden_size = hidden_size
+        self.intermediate_size = intermediate_size
+        self.num_hidden_layers = num_hidden_layers
+        self.num_attention_heads = num_attention_heads
+        # for backward compatibility
+        if num_key_value_heads is None:
+            num_key_value_heads = num_attention_heads
+        self.num_key_value_heads = num_key_value_heads
+        self.hidden_act = hidden_act
+        self.initializer_range = initializer_range
+        self.rms_norm_eps = rms_norm_eps
+        self.use_cache = use_cache
+        self.rope_theta = rope_theta
+        self.rope_scaling = rope_scaling
+        self.attention_bias = attention_bias
+        self.attention_dropout = attention_dropout
+        self.mlp_bias = mlp_bias
+        self.embedding_multiplier = embedding_multiplier
+        self.logits_scaling = logits_scaling
+        self.residual_multiplier = residual_multiplier
+        self.attention_multiplier = attention_multiplier
+        super().__init__(
+            pad_token_id=pad_token_id,
+            bos_token_id=bos_token_id,
+            eos_token_id=eos_token_id,
+            tie_word_embeddings=tie_word_embeddings,
+            **kwargs,
+        )
+        rope_config_validation(self)