Selaa lähdekoodia

feat: ministral support (#776)

AlpinDale 4 kuukautta sitten
vanhempi
commit
7222b84582

+ 35 - 15
aphrodite/common/config.py

@@ -17,7 +17,8 @@ from aphrodite.distributed import get_current_tp_rank_partition_size
 from aphrodite.modeling.models import ModelRegistry
 from aphrodite.platforms import current_platform
 from aphrodite.quantization import QUANTIZATION_METHODS
-from aphrodite.transformers_utils.config import get_config, get_hf_text_config
+from aphrodite.transformers_utils.config import (ConfigFormat, get_config,
+                                                 get_hf_text_config)
 
 if TYPE_CHECKING:
     from ray.util.placement_group import PlacementGroup
@@ -133,6 +134,8 @@ class ModelConfig:
             the model name will be the same as `model`.
         limit_mm_per_prompt: Maximum number of data instances per modality
             per prompt. Only applicable for multimodal models.
+        config_format: The config format which will be loaded. Defaults to
+            'auto' which defaults to 'hf'.
     """
 
     def __init__(
@@ -162,6 +165,7 @@ class ModelConfig:
         skip_tokenizer_init: bool = False,
         served_model_name: Optional[Union[str, List[str]]] = None,
         limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
+        config_format: ConfigFormat = ConfigFormat.AUTO,
     ) -> None:
         self.model = model
         self.tokenizer = tokenizer
@@ -194,7 +198,8 @@ class ModelConfig:
         self.skip_tokenizer_init = skip_tokenizer_init
 
         self.hf_config = get_config(self.model, trust_remote_code, revision,
-                                    code_revision, rope_scaling, rope_theta)
+                                    code_revision, rope_scaling, rope_theta,
+                                    config_format)
         self.hf_text_config = get_hf_text_config(self.hf_config)
         self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
 
@@ -226,14 +231,18 @@ class ModelConfig:
             # so no logging message needed.
             self.enforce_eager = False
 
-        if (not self.disable_sliding_window
-                and self.hf_text_config.model_type == "gemma2"
-                and self.hf_text_config.sliding_window is not None):
+        sliding_window = getattr(self.hf_text_config, "sliding_window", None)
+        has_interleaved_attention = (sliding_window is not None) and (
+            isinstance(sliding_window, list) or
+            (self.hf_text_config.model_type in ["gemma2"]))
+        if (not self.disable_sliding_window and has_interleaved_attention):
+            sliding_window_len_min = get_min_sliding_window(
+                self.hf_text_config.sliding_window)
             print_warning_once(
-                "Gemma 2 uses sliding window attention for every odd layer, "
-                "which is currently not supported by Aphrodite. Disabling "
-                "sliding window and capping the max length to the sliding "
-                f"window size ({self.hf_text_config.sliding_window}).")
+                f"{self.hf_text_config.model_type} has interleaved attention, "
+                "which is currently not supported by vLLM. Disabling sliding "
+                "window and capping the max length to the sliding window size "
+                f"({sliding_window_len_min}).")
             self.disable_sliding_window = True
 
         self.max_model_len = _get_and_verify_max_len(
@@ -469,7 +478,8 @@ class ModelConfig:
             return True
         return False
 
-    def get_hf_config_sliding_window(self) -> Optional[int]:
+    def get_hf_config_sliding_window(
+            self) -> Union[Optional[int], List[Optional[int]]]:
         """Get the sliding window size, or None if disabled.
         """
 
@@ -481,7 +491,7 @@ class ModelConfig:
             return None
         return getattr(self.hf_text_config, "sliding_window", None)
 
-    def get_sliding_window(self) -> Optional[int]:
+    def get_sliding_window(self) -> Optional[Union[int, List[Optional[int]]]]:
         """Get the sliding window size, or None if disabled.
         """
         # If user disables sliding window, return None.
@@ -811,6 +821,7 @@ class LoadFormat(str, enum.Enum):
     SHARDED_STATE = "sharded_state"
     GGUF = "gguf"
     BITSANDBYTES = "bitsandbytes"
+    MISTRAL = "mistral"
 
 
 @dataclass
@@ -853,7 +864,7 @@ class LoadConfig:
                 "Ignoring the following patterns when downloading weights: "
                 f"{self.ignore_patterns}")
         else:
-            self.ignore_patterns = ["original/**/*", "consolidated*"]
+            self.ignore_patterns = ["original/**/*"]
 
     def _verify_load_format(self) -> None:
         if not isinstance(self.load_format, str):
@@ -1704,7 +1715,7 @@ def _get_and_verify_max_len(
     hf_config: PretrainedConfig,
     max_model_len: Optional[int],
     disable_sliding_window: bool,
-    sliding_window_len: Optional[int],
+    sliding_window_len: Optional[Union[int, List[Optional[int]]]],
     rope_scaling_arg: Optional[Dict[str, Any]],
 ) -> int:
     """Get and verify the model's maximum length."""
@@ -1739,9 +1750,11 @@ def _get_and_verify_max_len(
     # If sliding window is manually disabled, max_length should be less
     # than the sliding window length in the model config.
     if disable_sliding_window and sliding_window_len is not None:
+        sliding_window_len_min = get_min_sliding_window(sliding_window_len)
         max_len_key = "sliding_window" \
-            if sliding_window_len < derived_max_model_len else max_len_key
-        derived_max_model_len = min(derived_max_model_len, sliding_window_len)
+            if sliding_window_len_min < derived_max_model_len else max_len_key
+        derived_max_model_len = min(derived_max_model_len,
+                                    sliding_window_len_min)
 
     # If none of the keys were found in the config, use a default and
     # log a warning.
@@ -1795,6 +1808,13 @@ def _get_and_verify_max_len(
     return int(max_model_len)
 
 
+def get_min_sliding_window(
+        sliding_window: Union[int, List[Optional[int]]]) -> int:
+    if isinstance(sliding_window, list):
+        return min(s for s in sliding_window if s is not None)
+    return sliding_window
+
+
 def get_served_model_name(model: str,
                           served_model_name: Optional[Union[str, List[str]]]):
     """

+ 18 - 15
aphrodite/engine/args_tools.py

@@ -7,11 +7,12 @@ from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Type,
 
 from loguru import logger
 
-from aphrodite.common.config import (CacheConfig, DecodingConfig, DeviceConfig,
-                                     EngineConfig, LoadConfig, LoRAConfig,
-                                     ModelConfig, ParallelConfig,
-                                     PromptAdapterConfig, SchedulerConfig,
-                                     SpeculativeConfig, TokenizerPoolConfig)
+from aphrodite.common.config import (CacheConfig, ConfigFormat, DecodingConfig,
+                                     DeviceConfig, EngineConfig, LoadConfig,
+                                     LoadFormat, LoRAConfig, ModelConfig,
+                                     ParallelConfig, PromptAdapterConfig,
+                                     SchedulerConfig, SpeculativeConfig,
+                                     TokenizerPoolConfig)
 from aphrodite.common.utils import FlexibleArgumentParser, is_cpu
 from aphrodite.executor.executor_base import ExecutorBase
 from aphrodite.quantization import QUANTIZATION_METHODS
@@ -75,6 +76,7 @@ class EngineArgs:
     device: str = "auto"
     # Load Options
     load_format: str = "auto"
+    config_format: str = "auto"
     dtype: str = "auto"
     ignore_patterns: Optional[Union[str, List[str]]] = None
     # Parallel Options
@@ -359,16 +361,7 @@ class EngineArgs:
             '--load-format',
             type=str,
             default=EngineArgs.load_format,
-            choices=[
-                'auto',
-                'pt',
-                'safetensors',
-                'npcache',
-                'dummy',
-                'tensorizer',
-                'sharded_state',
-                'bitsandbytes',
-            ],
+            choices=[f.value for f in LoadFormat],
             help='Category: Model Options\n'
             'The format of the model weights to load.\n\n'
             '* "auto" will try to load the weights in the safetensors format '
@@ -385,6 +378,15 @@ class EngineArgs:
             'Examples section for more information.\n'
             '* "bitsandbytes" will load the weights using bitsandbytes '
             'quantization.\n')
+        parser.add_argument(
+            '--config-format',
+            default=EngineArgs.config_format,
+            choices=[f.value for f in ConfigFormat],
+            help='The format of the model config to load.\n\n'
+            '* "auto" will try to load the config in hf format '
+            'if available else it will try to load in mistral format. '
+            'Mistral format is specific to mistral models and is not '
+            'compatible with other models.')
         parser.add_argument(
             '--dtype',
             type=str,
@@ -911,6 +913,7 @@ class EngineArgs:
             skip_tokenizer_init=self.skip_tokenizer_init,
             served_model_name=self.served_model_name,
             limit_mm_per_prompt=self.limit_mm_per_prompt,
+            config_format=self.config_format,
         )
 
         cache_config = CacheConfig(

+ 9 - 3
aphrodite/modeling/model_loader/loader.py

@@ -18,6 +18,7 @@ from huggingface_hub import HfApi, hf_hub_download
 from loguru import logger
 from torch import nn
 from transformers import AutoModelForCausalLM, PretrainedConfig
+from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
 
 from aphrodite.common.config import (APHRODITE_USE_MODELSCOPE, CacheConfig,
                                      DeviceConfig, LoadConfig, LoadFormat,
@@ -237,12 +238,17 @@ class DefaultModelLoader(BaseModelLoader):
         is_local = os.path.isdir(model_name_or_path)
         load_format = self.load_config.load_format
         use_safetensors = False
+        index_file = SAFE_WEIGHTS_INDEX_NAME
         # Some quantized models use .pt files for storing the weights.
         if load_format == LoadFormat.AUTO:
             allow_patterns = ["*.safetensors", "*.bin"]
         elif load_format == LoadFormat.SAFETENSORS:
             use_safetensors = True
             allow_patterns = ["*.safetensors"]
+        elif load_format == LoadFormat.MISTRAL:
+            use_safetensors = True
+            allow_patterns = ["consolidated*.safetensors"]
+            index_file = "consolidated.safetensors.index.json"
         elif load_format == LoadFormat.PT:
             allow_patterns = ["*.pt"]
         elif load_format == LoadFormat.NPCACHE:
@@ -280,10 +286,10 @@ class DefaultModelLoader(BaseModelLoader):
             # any files not found in the index.
             if not is_local:
                 download_safetensors_index_file_from_hf(
-                    model_name_or_path, self.load_config.download_dir,
-                    revision)
+                    model_name_or_path, index_file,
+                    self.load_config.download_dir, revision)
             hf_weights_files = filter_duplicate_safetensors_files(
-                hf_weights_files, hf_folder)
+                hf_weights_files, hf_folder, index_file)
         else:
             hf_weights_files = filter_files_not_needed_for_inference(
                 hf_weights_files)

+ 11 - 10
aphrodite/modeling/model_loader/weight_utils.py

@@ -17,7 +17,6 @@ from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
 from loguru import logger
 from safetensors.torch import load_file, safe_open, save_file
 from tqdm.auto import tqdm
-from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
 
 from aphrodite.common.config import LoadConfig, ModelConfig
 from aphrodite.common.utils import print_warning_once
@@ -243,6 +242,7 @@ def download_weights_from_hf(
 
 def download_safetensors_index_file_from_hf(
     model_name_or_path: str,
+    index_file: str,
     cache_dir: Optional[str],
     revision: Optional[str] = None,
 ) -> None:
@@ -260,36 +260,37 @@ def download_safetensors_index_file_from_hf(
             # Download the safetensors index file.
             hf_hub_download(
                 repo_id=model_name_or_path,
-                filename=SAFE_WEIGHTS_INDEX_NAME,
+                filename=index_file,
                 cache_dir=cache_dir,
                 revision=revision,
                 local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
             )
         # If file not found on remote or locally, we should not fail since
-        # only some models will have SAFE_WEIGHTS_INDEX_NAME.
+        # only some models will have index_file.
         except huggingface_hub.utils.EntryNotFoundError:
-            logger.info(f"No {SAFE_WEIGHTS_INDEX_NAME} found in remote.")
+            logger.info(f"No {index_file} found in remote.")
         except huggingface_hub.utils.LocalEntryNotFoundError:
-            logger.info(f"No {SAFE_WEIGHTS_INDEX_NAME} found in local cache.")
+            logger.info(f"No {index_file} found in local cache.")
 
 
 # For models like Mistral-7B-v0.3, there are both sharded
 # safetensors files and a consolidated safetensors file.
 # Passing both of these to the weight loader functionality breaks.
-# So, we use the SAFE_WEIGHTS_INDEX_NAME to
+# So, we use the index_file to
 # look up which safetensors files should be used.
 def filter_duplicate_safetensors_files(hf_weights_files: List[str],
-                                       hf_folder: str) -> List[str]:
+                                       hf_folder: str,
+                                       index_file: str) -> List[str]:
     # model.safetensors.index.json is a mapping from keys in the
     # torch state_dict to safetensors file holding that weight.
-    index_file_name = os.path.join(hf_folder, SAFE_WEIGHTS_INDEX_NAME)
+    index_file_name = os.path.join(hf_folder, index_file)
     if not os.path.isfile(index_file_name):
         return hf_weights_files
 
     # Iterate through the weight_map (weight_name: safetensors files)
     # to identify weights that we should use.
-    with open(index_file_name) as index_file:
-        weight_map = json.load(index_file)["weight_map"]
+    with open(index_file_name, "r") as f:
+        weight_map = json.load(f)["weight_map"]
     weight_files_in_index = set()
     for weight_name in weight_map:
         weight_files_in_index.add(

+ 45 - 0
aphrodite/modeling/models/llama.py

@@ -372,6 +372,27 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
         "gate_proj": ("gate_up_proj", 0),
         "up_proj": ("gate_up_proj", 1),
     }
+    # Mistral/Llama models can also be loaded with --load-format mistral
+    # from consolidated.safetensors checkpoints
+    mistral_mapping = {
+        "layers": "model.layers",
+        "attention": "self_attn",
+        "wq": "q_proj",
+        "wk": "k_proj",
+        "wv": "v_proj",
+        "wo": "o_proj",
+        "attention_norm": "input_layernorm",
+        "feed_forward": "mlp",
+        "w1": "gate_proj",
+        "w2": "down_proj",
+        "w3": "up_proj",
+        "ffn_norm": "post_attention_layernorm",
+        "tok_embeddings": "model.embed_tokens",
+        "output": "lm_head",
+        "norm": "model.norm"
+    }
+
+
 
     def __init__(
         self,
@@ -469,6 +490,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
         weights_list = list(weights)
         for name, loaded_weight in progress_bar(weights_list,
                                                 desc="Loading modules..."):
+            name, loaded_weight = self.maybe_remap_mistral(name, loaded_weight)
             if "rotary_emb.inv_freq" in name:
                 continue
             if ("rotary_emb.cos_cached" in name
@@ -545,3 +567,26 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
             else:
                 raise RuntimeError("Self attention has no KV cache scaling "
                                    "factor attribute!")
+    # This function is used to remap the mistral format as
+    # used by Mistral and Llama <=2
+    def maybe_remap_mistral(
+            self, name: str,
+            loaded_weight: torch.Tensor) -> Tuple[str, torch.Tensor]:
+        def permute(w, n_heads):
+            attn_in = self.config.head_dim * n_heads
+            attn_out = self.config.hidden_size
+            return w.view(n_heads, attn_in // n_heads // 2, 2,
+                          attn_out).transpose(1, 2).reshape(attn_in, attn_out)
+        mapping = self.mistral_mapping
+        modules = name.split(".")
+        # rotary embeds should be sliced
+        if "wk" in modules:
+            loaded_weight = permute(loaded_weight,
+                                    self.config.num_key_value_heads)
+        elif "wq" in modules:
+            loaded_weight = permute(loaded_weight,
+                                    self.config.num_attention_heads)
+        for item in modules:
+            if item in mapping and mapping[item] not in name:
+                name = name.replace(item, mapping[item])
+        return name, loaded_weight

+ 173 - 28
aphrodite/transformers_utils/config.py

@@ -1,12 +1,18 @@
 import contextlib
+import enum
+import json
 import os
 from pathlib import Path
-from typing import Dict, Optional, Type, Union
+from typing import Any, Dict, Optional, Type, Union
 
+import huggingface_hub
+from huggingface_hub import (file_exists, hf_hub_download,
+                             try_to_load_from_cache)
 from loguru import logger
 from transformers import GenerationConfig, PretrainedConfig
 from transformers.models.auto.modeling_auto import (
     MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
+from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME
 
 from aphrodite.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
                                                   InternVLChatConfig,
@@ -22,6 +28,8 @@ if APHRODITE_USE_MODELSCOPE:
 else:
     from transformers import AutoConfig
 
+MISTRAL_CONFIG_NAME = "params.json"
+
 _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
     "chatglm": ChatGLMConfig,
     "dbrx": DbrxConfig,
@@ -39,6 +47,35 @@ for name, cls in _CONFIG_REGISTRY.items():
         AutoConfig.register(name, cls)
 
 
+class ConfigFormat(str, enum.Enum):
+    AUTO = "auto"
+    HF = "hf"
+    MISTRAL = "mistral"
+
+
+def file_or_path_exists(model: Union[str, Path], config_name, revision,
+                        token) -> bool:
+    if Path(model).exists():
+        return (Path(model) / config_name).is_file()
+
+    # Offline mode support: Check if config file is cached already
+    cached_filepath = try_to_load_from_cache(repo_id=model,
+                                             filename=config_name,
+                                             revision=revision)
+    if isinstance(cached_filepath, str):
+        # The config file exists in cache- we can continue trying to load
+        return True
+
+    # NB: file_exists will only check for the existence of the config file on
+    # hf_hub. This will fail in offline mode.
+    try:
+        return file_exists(model, config_name, revision=revision, token=token)
+    except huggingface_hub.errors.OfflineModeIsEnabled:
+        # Don't raise in offline mode, all we know is that we don't have this
+        # file cached.
+        return False
+
+
 def get_config(
     model: Union[str, Path],
     trust_remote_code: bool,
@@ -46,38 +83,77 @@ def get_config(
     code_revision: Optional[str] = None,
     rope_scaling: Optional[dict] = None,
     rope_theta: Optional[float] = None,
+    config_format: ConfigFormat = ConfigFormat.AUTO,
     **kwargs,
 ) -> PretrainedConfig:
-
     # Separate model folder from file path for GGUF models
+
     is_gguf = check_gguf_file(model)
     if is_gguf:
         kwargs["gguf_file"] = Path(model).name
         model = Path(model).parent
 
-    try:
-        config = AutoConfig.from_pretrained(
-            model,
-            trust_remote_code=trust_remote_code,
-            revision=revision,
-            code_revision=code_revision,
-            **kwargs)
-    except ValueError as e:
-        if (not trust_remote_code and
-                "requires you to execute the configuration file" in str(e)):
-            err_msg = (
-                "Failed to load the model config. If the model is a custom "
-                "model not yet available in the HuggingFace transformers "
-                "library, consider setting `trust_remote_code=True` in LLM "
-                "or using the `--trust-remote-code` flag in the CLI.")
-            raise RuntimeError(err_msg) from e
+    if config_format == ConfigFormat.AUTO:
+        if is_gguf or file_or_path_exists(model,
+                                          HF_CONFIG_NAME,
+                                          revision=revision,
+                                          token=kwargs.get("token")):
+            config_format = ConfigFormat.HF
+        elif file_or_path_exists(model,
+                                 MISTRAL_CONFIG_NAME,
+                                 revision=revision,
+                                 token=kwargs.get("token")):
+            config_format = ConfigFormat.MISTRAL
         else:
-            raise e
-    if config.model_type in _CONFIG_REGISTRY:
-        config_class = _CONFIG_REGISTRY[config.model_type]
-        config = config_class.from_pretrained(model,
-                                              revision=revision,
-                                              code_revision=code_revision)
+            # If we're in offline mode and found no valid config format, then
+            # raise an offline mode error to indicate to the user that they
+            # don't have files cached and may need to go online.
+            # This is conveniently triggered by calling file_exists().
+            file_exists(model,
+                        HF_CONFIG_NAME,
+                        revision=revision,
+                        token=kwargs.get("token"))
+
+            raise ValueError(f"No supported config format found in {model}")
+
+    if config_format == ConfigFormat.HF:
+        config_dict, _ = PretrainedConfig.get_config_dict(
+            model, revision=revision, code_revision=code_revision, **kwargs)
+
+        # Use custom model class if it's in our registry
+        model_type = config_dict.get("model_type")
+        if model_type in _CONFIG_REGISTRY:
+            config_class = _CONFIG_REGISTRY[model_type]
+            config = config_class.from_pretrained(model,
+                                                  revision=revision,
+                                                  code_revision=code_revision)
+        else:
+            try:
+                config = AutoConfig.from_pretrained(
+                    model,
+                    trust_remote_code=trust_remote_code,
+                    revision=revision,
+                    code_revision=code_revision,
+                    **kwargs,
+                )
+            except ValueError as e:
+                if (not trust_remote_code
+                        and "requires you to execute the configuration file"
+                        in str(e)):
+                    err_msg = (
+                        "Failed to load the model config. If the model "
+                        "is a custom model not yet available in the "
+                        "HuggingFace transformers library, consider setting "
+                        "`trust_remote_code=True` in LLM or using the "
+                        "`--trust-remote-code` flag in the CLI.")
+                    raise RuntimeError(err_msg) from e
+                else:
+                    raise e
+
+    elif config_format == ConfigFormat.MISTRAL:
+        config = load_params_config(model, revision)
+    else:
+        raise ValueError(f"Unsupported config format: {config_format}")
 
     # Special architecture mapping check for GGUF models
     if is_gguf:
@@ -86,13 +162,82 @@ def get_config(
                 f"Can't get gguf config for {config.model_type}.")
         model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
         config.update({"architectures": [model_type]})
-    for key, value in [("rope_scaling", rope_scaling),
-                       ("rope_theta", rope_theta)]:
+
+    for key, value in [
+        ("rope_scaling", rope_scaling),
+        ("rope_theta", rope_theta),
+    ]:
         if value is not None:
-            logger.info(f"Updating {key} from "
-                        f"{getattr(config, key, None)} to {value}")
+            logger.info(
+                "Updating %s from %r to %r",
+                key,
+                getattr(config, key, None),
+                value,
+            )
             config.update({key: value})
 
+
+    return config
+
+
+def load_params_config(model, revision) -> PretrainedConfig:
+    # This function loads a params.json config which
+    # should be used when loading models in mistral format
+
+    config_file_name = "params.json"
+
+    config_path = Path(model) / config_file_name
+
+    if not config_path.is_file():
+        config_path = Path(
+            hf_hub_download(model, config_file_name, revision=revision))
+
+    with open(config_path, "r") as file:
+        config_dict = json.load(file)
+
+    config_mapping = {
+        "dim": "hidden_size",
+        "norm_eps": "rms_norm_eps",
+        "n_kv_heads": "num_key_value_heads",
+        "n_layers": "num_hidden_layers",
+        "n_heads": "num_attention_heads",
+        "hidden_dim": "intermediate_size",
+    }
+
+    def recurse_elems(elem: Any):
+        if isinstance(elem, dict):
+            config_dict = {}
+            for key, value in elem.items():
+                key = config_mapping.get(key, key)
+                config_dict[key] = recurse_elems(value)
+            return PretrainedConfig(**config_dict)
+        else:
+            return elem
+
+    config_dict["model_type"] = config_dict.get("model_type", "transformer")
+    config_dict["hidden_act"] = config_dict.get("activation", "silu")
+    config_dict["tie_word_embeddings"] = config_dict.get(
+        "tie_embeddings", False)
+    config_dict["max_seq_len"] = config_dict.get("max_seq_len", 128_000)
+    config_dict["max_position_embeddings"] = config_dict.get(
+        "max_position_embeddings", 128_000)
+
+    if config_dict.get("moe") is not None:
+        config_dict["architectures"] = ["MixtralForCausalLM"]
+    else:
+        config_dict["architectures"] = ["MistralForCausalLM"]
+
+    if config_dict.get("vision_encoder") is not None:
+        multimodal_config = config_dict.pop("vision_encoder")
+
+        config_dict = {
+            "text_config": config_dict,
+            "vision_config": multimodal_config
+        }
+        config_dict["architectures"] = ["PixtralForConditionalGeneration"]
+        config_dict["model_type"] = "pixtral"
+
+    config = recurse_elems(config_dict)
     return config