Jelajahi Sumber

feat: add siglip encoder for llava family (#626)

* wip

* wip
AlpinDale 6 bulan lalu
induk
melakukan
30d7effc7c

+ 28 - 10
aphrodite/modeling/model_loader/loader.py

@@ -17,7 +17,7 @@ import torch
 from huggingface_hub import HfApi, hf_hub_download
 from loguru import logger
 from torch import nn
-from transformers import AutoModelForCausalLM
+from transformers import AutoModelForCausalLM, PretrainedConfig
 
 from aphrodite.common.config import (APHRODITE_USE_MODELSCOPE, CacheConfig,
                                      DeviceConfig, LoadConfig, LoadFormat,
@@ -139,6 +139,22 @@ def _get_model_initialization_kwargs(
     return extra_kwargs
 
 
+def build_model(model_class: Type[nn.Module], hf_config: PretrainedConfig,
+                cache_config: Optional[CacheConfig],
+                quant_config: Optional[QuantizationConfig], *,
+                lora_config: Optional[LoRAConfig],
+                multimodal_config: Optional[MultiModalConfig],
+                scheduler_config: Optional[SchedulerConfig]) -> nn.Module:
+    extra_kwargs = _get_model_initialization_kwargs(model_class, lora_config,
+                                                    multimodal_config,
+                                                    scheduler_config)
+
+    return model_class(config=hf_config,
+                       cache_config=cache_config,
+                       quant_config=quant_config,
+                       **extra_kwargs)
+
+
 def _initialize_model(
         model_config: ModelConfig,
         load_config: LoadConfig,
@@ -147,15 +163,17 @@ def _initialize_model(
         cache_config: CacheConfig,
         scheduler_config: Optional[SchedulerConfig] = None) -> nn.Module:
     """Initialize a model with the given configurations."""
-    model_class = get_model_architecture(model_config)[0]
-    quant_config = _get_quantization_config(model_config, load_config)
-
-    return model_class(config=model_config.hf_config,
-                       cache_config=cache_config,
-                       quant_config=quant_config,
-                       **_get_model_initialization_kwargs(
-                           model_class, lora_config, multimodal_config,
-                           scheduler_config))
+    model_class, _ = get_model_architecture(model_config)
+
+    return build_model(
+        model_class,
+        model_config.hf_config,
+        quant_config=_get_quantization_config(model_config, load_config),
+        lora_config=lora_config,
+        multimodal_config=multimodal_config,
+        cache_config=cache_config,
+        scheduler_config=scheduler_config,
+    )
 
 
 class BaseModelLoader(ABC):

+ 1 - 7
aphrodite/modeling/model_loader/utils.py

@@ -28,13 +28,7 @@ def get_model_architecture(
             and "MixtralForCausalLM" in architectures):
         architectures = ["QuantMixtralForCausalLM"]
 
-    for arch in architectures:
-        model_cls = ModelRegistry.load_model_cls(arch)
-        if model_cls is not None:
-            return (model_cls, arch)
-    raise ValueError(
-        f"Model architectures {architectures} are not supported for now. "
-        f"Supported architectures: {ModelRegistry.get_supported_archs()}")
+    return ModelRegistry.resolve_model_cls(architectures)
 
 
 def get_architecture_class_name(model_config: ModelConfig) -> str:

+ 14 - 2
aphrodite/modeling/models/__init__.py

@@ -1,6 +1,6 @@
 import functools
 import importlib
-from typing import Dict, List, Optional, Type
+from typing import Dict, List, Optional, Tuple, Type
 
 import torch.nn as nn
 from loguru import logger
@@ -124,7 +124,7 @@ class ModelRegistry:
         return getattr(module, model_cls_name, None)
 
     @staticmethod
-    def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
+    def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
         if model_arch in _OOT_MODELS:
             return _OOT_MODELS[model_arch]
         if model_arch not in _MODELS:
@@ -142,6 +142,18 @@ class ModelRegistry:
 
         return ModelRegistry._get_model(model_arch)
 
+    @staticmethod
+    def resolve_model_cls(
+            architectures: List[str]) -> Tuple[Type[nn.Module], str]:
+        for arch in architectures:
+            model_cls = ModelRegistry._try_load_model_cls(arch)
+            if model_cls is not None:
+                return (model_cls, arch)
+
+        raise ValueError(
+            f"Model architectures {architectures} are not supported for now. "
+            f"Supported architectures: {ModelRegistry.get_supported_archs()}")
+
     @staticmethod
     def get_supported_archs() -> List[str]:
         return list(_MODELS.keys())

+ 10 - 1
aphrodite/modeling/models/intern_vit.py

@@ -4,7 +4,7 @@
 # Copyright (c) 2023 OpenGVLab
 # Licensed under The MIT License [see LICENSE for details]
 # --------------------------------------------------------
-from typing import Optional
+from typing import Iterable, Optional, Tuple
 
 import torch
 import torch.nn as nn
@@ -15,6 +15,7 @@ from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.quantization import QuantizationConfig
 
 NORM2FN = {
@@ -268,3 +269,11 @@ class InternVisionModel(nn.Module):
         encoder_outputs = self.encoder(inputs_embeds=hidden_states)
 
         return encoder_outputs
+
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+        params_dict = dict(self.named_parameters())
+        for name, loaded_weight in weights:
+            param = params_dict[name]
+            weight_loader = getattr(param, "weight_loader",
+                                    default_weight_loader)
+            weight_loader(param, loaded_weight)

+ 25 - 60
aphrodite/modeling/models/internvl.py

@@ -4,6 +4,7 @@
 # Copyright (c) 2023 OpenGVLab
 # Licensed under The MIT License [see LICENSE for details]
 # --------------------------------------------------------
+import itertools
 from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
 
 import torch
@@ -17,7 +18,6 @@ from aphrodite.common.config import CacheConfig, MultiModalConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
-from aphrodite.modeling.models import ModelRegistry
 from aphrodite.modeling.models.intern_vit import InternVisionModel
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.multimodal import MULTIMODAL_REGISTRY
@@ -28,7 +28,8 @@ from aphrodite.quantization import QuantizationConfig
 from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
                    get_clip_num_patches)
 from .interfaces import SupportsVision
-from .utils import merge_vision_embeddings
+from .utils import (filter_weights, init_aphrodite_registered_model,
+                    merge_vision_embeddings)
 
 IMG_START = '<img>'
 IMG_END = '</img>'
@@ -46,6 +47,7 @@ class InternVLImagePixelInputs(TypedDict):
     data: Union[torch.Tensor, List[torch.Tensor]]
     """
     Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
+
     Note that `num_patches` may be different for each batch, in which case
     the data is passed as a list instead of a batched tensor.
     """
@@ -281,10 +283,8 @@ class InternVLChatModel(nn.Module, SupportsVision):
         self.vision_model = InternVisionModel(
             config.vision_config, num_hidden_layers_override=num_hidden_layers)
 
-        llm_class = ModelRegistry.load_model_cls(
-            config.text_config.architectures[0])
-        self.language_model = llm_class(config.text_config, cache_config,
-                                        quant_config)
+        self.language_model = init_aphrodite_registered_model(
+            config.text_config, cache_config, quant_config)
 
         vit_hidden_size = config.vision_config.hidden_size
         llm_hidden_size = config.text_config.hidden_size
@@ -414,57 +414,22 @@ class InternVLChatModel(nn.Module, SupportsVision):
         return self.language_model.sample(logits, sampling_metadata)
 
     def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
-        stacked_params_mapping = [
-            # (param_name, shard_name, shard_id)
-            (".qkv_proj", ".q_proj", "q"),
-            (".qkv_proj", ".k_proj", "k"),
-            (".qkv_proj", ".v_proj", "v"),
-            (".gate_up_proj", ".gate_proj", 0),
-            (".gate_up_proj", ".up_proj", 1),
-            (".gate_up_proj", ".w1", 0),
-            (".gate_up_proj", ".w3", 1),
-        ]
-        params_dict = dict(self.named_parameters())
-        for name, loaded_weight in weights:
-            if "rotary_emb.inv_freq" in name:
-                continue
-            if self.config.text_config.tie_word_embeddings \
-                and "lm_head.weight" in name:
-                continue
-            for (param_name, weight_name, shard_id) in stacked_params_mapping:
-                # We only do sharding for language model
-                # and not vision model for now.
-                if "vision_embed_tokens" in name and self.vision_embed_tokens:
-                    continue
-                if weight_name not in name:
-                    continue
-                param = params_dict[name.replace(weight_name, param_name)]
-                weight_loader = param.weight_loader
-                weight_loader(param, loaded_weight, shard_id)
-                break
-            else:
-                # Skip loading extra bias for GPTQ models.
-                if name.endswith(".bias") and name not in params_dict:
-                    continue
-                param = params_dict[name]
-                if "wqkv" in name:
-                    config = self.config.text_config
-                    kv_groups = (config.num_attention_heads //
-                                 config.num_key_value_heads)
-                    head_dim = config.hidden_size // config.num_attention_heads
-                    loaded_weight = loaded_weight.view(-1, 2 + kv_groups,
-                                                       head_dim,
-                                                       loaded_weight.shape[-1])
-                    wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1],
-                                             dim=1)
-                    wq = wq.reshape(-1, wq.shape[-1])
-                    wk = wk.reshape(-1, wk.shape[-1])
-                    wv = wv.reshape(-1, wv.shape[-1])
-                    weight_loader = param.weight_loader
-                    weight_loader(param, wq, 'q')
-                    weight_loader(param, wk, 'k')
-                    weight_loader(param, wv, 'v')
-                    continue
-                weight_loader = getattr(param, "weight_loader",
-                                        default_weight_loader)
-                weight_loader(param, loaded_weight)
+        # prepare weight iterators for components
+        vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
+
+        # load vision encoder
+        vit_weights = filter_weights(vit_weights, "vision_model")
+        self.vision_model.load_weights(vit_weights)
+
+        # load mlp projector
+        mlp_weights = filter_weights(mlp_weights, "mlp1")
+        mlp_params_dict = dict(self.mlp1.named_parameters())
+        for name, loaded_weight in mlp_weights:
+            param = mlp_params_dict[name]
+            weight_loader = getattr(param, "weight_loader",
+                                    default_weight_loader)
+            weight_loader(param, loaded_weight)
+
+        # load llm backbone
+        llm_weights = filter_weights(llm_weights, "language_model")
+        self.language_model.load_weights(llm_weights)

+ 119 - 99
aphrodite/modeling/models/llava.py

@@ -1,33 +1,29 @@
-from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
+import itertools
+from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
 
 import torch
 import torch.nn as nn
-from transformers import CLIPVisionConfig, LlavaConfig
+from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig
 
 from aphrodite.attention import AttentionMetadata
 from aphrodite.common.config import CacheConfig, MultiModalConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
 from aphrodite.modeling.layers.activation import get_act_fn
-from aphrodite.modeling.layers.logits_processor import LogitsProcessor
-from aphrodite.modeling.layers.sampler import Sampler
-from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
-from aphrodite.modeling.models.clip import CLIPVisionModel
-from aphrodite.modeling.models.llama import LlamaModel
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.multimodal import MULTIMODAL_REGISTRY
 from aphrodite.quantization.base_config import QuantizationConfig
 
-from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
-                   get_max_clip_image_tokens, input_processor_for_clip)
+from .clip import (CLIPVisionModel, dummy_image_for_clip,
+                   dummy_seq_data_for_clip, get_max_clip_image_tokens,
+                   input_processor_for_clip)
 from .interfaces import SupportsVision
-from .utils import merge_vision_embeddings
-
-_KEYS_TO_MODIFY_MAPPING = {
-    "language_model.lm_head": "lm_head",
-    "language_model.model": "language_model",
-}
+from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
+                     dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
+                     input_processor_for_siglip)
+from .utils import (filter_weights, init_aphrodite_registered_model,
+                    merge_vision_embeddings)
 
 
 # TODO: Run benchmark and decide if TP.
@@ -66,25 +62,48 @@ def get_max_llava_image_tokens(ctx: InputContext):
     vision_config = hf_config.vision_config
 
     if isinstance(vision_config, CLIPVisionConfig):
-        return get_max_clip_image_tokens(vision_config)
-
-    msg = f"Unsupported vision config: {type(vision_config)}"
-    raise NotImplementedError(msg)
+        num_image_tokens = get_max_clip_image_tokens(vision_config)
+    elif isinstance(vision_config, SiglipVisionConfig):
+        num_image_tokens = get_max_siglip_image_tokens(vision_config)
+    else:
+        msg = f"Unsupported vision config: {type(vision_config)}"
+        raise NotImplementedError(msg)
+
+    strategy = hf_config.vision_feature_select_strategy
+    if strategy == "default":
+        return num_image_tokens - 1
+    elif strategy == "full":
+        return num_image_tokens
+    else:
+        raise ValueError(f"Unexpected select feature strategy: {strategy}")
 
 
 def dummy_data_for_llava(ctx: InputContext, seq_len: int):
     hf_config = ctx.get_hf_config(LlavaConfig)
     vision_config = hf_config.vision_config
 
+    image_feature_size = get_max_llava_image_tokens(ctx)
+
     if isinstance(vision_config, CLIPVisionConfig):
         seq_data = dummy_seq_data_for_clip(
             vision_config,
             seq_len,
             image_token_id=hf_config.image_token_index,
+            image_feature_size_override=image_feature_size,
         )
 
         mm_data = dummy_image_for_clip(vision_config)
         return seq_data, mm_data
+    elif isinstance(vision_config, SiglipVisionConfig):
+        seq_data = dummy_seq_data_for_siglip(
+            vision_config,
+            seq_len,
+            image_token_id=hf_config.image_token_index,
+            image_feature_size_override=image_feature_size,
+        )
+
+        mm_data = dummy_image_for_siglip(vision_config)
+        return seq_data, mm_data
 
     msg = f"Unsupported vision config: {type(vision_config)}"
     raise NotImplementedError(msg)
@@ -99,12 +118,49 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
     hf_config = ctx.get_hf_config(LlavaConfig)
     vision_config = hf_config.vision_config
 
+    image_feature_size = get_max_llava_image_tokens(ctx)
+
     if isinstance(vision_config, CLIPVisionConfig):
         return input_processor_for_clip(
             model_config,
             vision_config,
             llm_inputs,
             image_token_id=hf_config.image_token_index,
+            image_feature_size_override=image_feature_size,
+        )
+    elif isinstance(vision_config, SiglipVisionConfig):
+        return input_processor_for_siglip(
+            model_config,
+            vision_config,
+            llm_inputs,
+            image_token_id=hf_config.image_token_index,
+            image_feature_size_override=image_feature_size,
+        )
+
+    msg = f"Unsupported vision config: {type(vision_config)}"
+    raise NotImplementedError(msg)
+
+
+def _init_vision_tower(hf_config: LlavaConfig):
+    vision_config = hf_config.vision_config
+
+    # Initialize the vision tower only up to the required feature layer
+    vision_feature_layer = hf_config.vision_feature_layer
+    if vision_feature_layer < 0:
+        num_hidden_layers = hf_config.vision_config.num_hidden_layers \
+            + vision_feature_layer + 1
+    else:
+        num_hidden_layers = vision_feature_layer + 1
+
+    if isinstance(vision_config, CLIPVisionConfig):
+        return CLIPVisionModel(
+            vision_config,
+            num_hidden_layers_override=num_hidden_layers,
+        )
+    elif isinstance(vision_config, SiglipVisionConfig):
+        return SiglipVisionModel(
+            vision_config,
+            num_hidden_layers_override=num_hidden_layers,
         )
 
     msg = f"Unsupported vision config: {type(vision_config)}"
@@ -127,36 +183,15 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
         self.config = config
         self.multimodal_config = multimodal_config
 
-        # Initialize the vision tower only up to the required feature layer
-        vision_feature_layer = config.vision_feature_layer
-        if vision_feature_layer < 0:
-            num_hidden_layers = config.vision_config.num_hidden_layers \
-                + vision_feature_layer + 1
-        else:
-            num_hidden_layers = vision_feature_layer + 1
-
         # TODO: Optionally initializes this for supporting embeddings.
-        self.vision_tower = CLIPVisionModel(
-            config.vision_config, num_hidden_layers_override=num_hidden_layers)
+        self.vision_tower = _init_vision_tower(config)
         self.multi_modal_projector = LlavaMultiModalProjector(
             vision_hidden_size=config.vision_config.hidden_size,
             text_hidden_size=config.text_config.hidden_size,
             projector_hidden_act=config.projector_hidden_act)
 
-        self.quant_config = quant_config
-        self.language_model = LlamaModel(config.text_config, cache_config,
-                                         quant_config)
-        self.unpadded_vocab_size = config.text_config.vocab_size
-        self.lm_head = ParallelLMHead(
-            self.unpadded_vocab_size,
-            config.text_config.hidden_size,
-            org_num_embeddings=self.language_model.org_vocab_size,
-            quant_config=quant_config)
-        logit_scale = getattr(config, "logit_scale", 1.0)
-        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
-                                                config.text_config.vocab_size,
-                                                logit_scale)
-        self.sampler = Sampler()
+        self.language_model = init_aphrodite_registered_model(
+            config.text_config, cache_config, quant_config)
 
     def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
         h = w = self.config.vision_config.image_size
@@ -197,8 +232,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
 
         raise ValueError(f"Unexpected select feature strategy: {strategy}")
 
-    def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
-                                  pixel_values: torch.Tensor) -> torch.Tensor:
+    def _image_pixels_to_features(
+        self,
+        vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
+        pixel_values: torch.Tensor,
+    ) -> torch.Tensor:
 
         # NOTE: we skip the step to select the vision feature layer since
         # this is already done inside the vision tower
@@ -233,19 +271,24 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
         **kwargs: object,
     ) -> SamplerOutput:
         """Run forward pass for LLaVA-1.5.
+
         One key thing to understand is the `input_ids` already accounts for the
         positions of the to-be-inserted image embeddings.
+
         Concretely, consider a text prompt:
         `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.
+
         Tokenizer outputs:
         `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
         278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.
+
         To reserve space in KV cache, we have to insert placeholder tokens
         before they are inputted to the model, so the input processor prepends 
         additional image tokens (denoted as `32000`), resulting in:
         `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
         29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
         29901]`.
+
         We insert 575 tokens so that including the original image token in the
         input, there are a total of 576 (24 * 24) image tokens, which
         corresponds to the number of image tokens inputted to the language
@@ -258,7 +301,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
             input_ids: Flattened (concatenated) input_ids corresponding to a
                 batch.
             pixel_values: The pixels in each input image.
-
+        
         See also:
             :class:`LlavaImageInputs`
         """
@@ -266,7 +309,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
 
         if image_input is not None:
             vision_embeddings = self._process_image_input(image_input)
-            inputs_embeds = self.language_model.get_input_embeddings(input_ids)
+            inputs_embeds = self.language_model.model.get_input_embeddings(
+                input_ids)
 
             inputs_embeds = merge_vision_embeddings(
                 input_ids, inputs_embeds, vision_embeddings,
@@ -276,68 +320,44 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
         else:
             inputs_embeds = None
 
-        hidden_states = self.language_model(input_ids,
-                                            positions,
-                                            kv_caches,
-                                            attn_metadata,
-                                            None,
-                                            inputs_embeds=inputs_embeds)
+        hidden_states = self.language_model.model(input_ids,
+                                                  positions,
+                                                  kv_caches,
+                                                  attn_metadata,
+                                                  None,
+                                                  inputs_embeds=inputs_embeds)
 
         return hidden_states
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head, hidden_states,
-                                       sampling_metadata)
-        return logits
+        return self.language_model.compute_logits(hidden_states,
+                                                  sampling_metadata)
 
     def sample(
         self,
         logits: torch.Tensor,
         sampling_metadata: SamplingMetadata,
     ) -> Optional[SamplerOutput]:
-        next_tokens = self.sampler(logits, sampling_metadata)
-        return next_tokens
+        return self.language_model.sample(logits, sampling_metadata)
 
     def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
-        # only doing this for language model part for now.
-        stacked_params_mapping = [
-            # (param_name, shard_name, shard_id)
-            ("qkv_proj", "q_proj", "q"),
-            ("qkv_proj", "k_proj", "k"),
-            ("qkv_proj", "v_proj", "v"),
-            ("gate_up_proj", "gate_proj", 0),
-            ("gate_up_proj", "up_proj", 1),
-        ]
-        params_dict = dict(self.named_parameters())
-        for name, loaded_weight in weights:
-            if "rotary_emb.inv_freq" in name:
-                continue
-            # post_layernorm is not needed in CLIPVisionModel
-            if "vision_model.post_layernorm" in name:
-                continue
-            for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
-                if key_to_modify in name:
-                    name = name.replace(key_to_modify, new_key)
-            use_default_weight_loading = False
-            if "vision" in name:
-                if self.vision_tower is not None:
-                    # We only do sharding for language model and
-                    # not vision model for now.
-                    use_default_weight_loading = True
-            else:
-                for (param_name, weight_name,
-                     shard_id) in stacked_params_mapping:
-                    if weight_name not in name:
-                        continue
-                    param = params_dict[name.replace(weight_name, param_name)]
-                    weight_loader = param.weight_loader
-                    weight_loader(param, loaded_weight, shard_id)
-                    break
-                else:
-                    use_default_weight_loading = True
-            if use_default_weight_loading and name in params_dict:
-                param = params_dict[name]
-                weight_loader = getattr(param, "weight_loader",
-                                        default_weight_loader)
-                weight_loader(param, loaded_weight)
+        # prepare weight iterators for components
+        vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
+
+        # load vision encoder
+        vit_weights = filter_weights(vit_weights, "vision_tower")
+        self.vision_tower.load_weights(vit_weights)
+
+        # load mlp projector
+        mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
+        mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
+        for name, loaded_weight in mlp_weights:
+            param = mlp_params_dict[name]
+            weight_loader = getattr(param, "weight_loader",
+                                    default_weight_loader)
+            weight_loader(param, loaded_weight)
+
+        # load llm backbone
+        llm_weights = filter_weights(llm_weights, "language_model")
+        self.language_model.load_weights(llm_weights)

+ 141 - 104
aphrodite/modeling/models/llava_next.py

@@ -1,9 +1,10 @@
+import itertools
 from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
 
 import torch
 import torch.nn as nn
 from PIL import Image
-from transformers import CLIPVisionConfig, LlavaNextConfig
+from transformers import CLIPVisionConfig, LlavaNextConfig, SiglipVisionConfig
 from transformers.models.llava_next.modeling_llava_next import (
     get_anyres_image_grid_shape, unpad_image)
 from typing_extensions import NotRequired
@@ -12,21 +13,21 @@ from aphrodite.attention import AttentionMetadata
 from aphrodite.common.config import CacheConfig, MultiModalConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
-from aphrodite.modeling.layers.logits_processor import LogitsProcessor
-from aphrodite.modeling.layers.sampler import Sampler
-from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
-from aphrodite.modeling.models.clip import CLIPVisionModel
-from aphrodite.modeling.models.llama import LlamaModel
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.multimodal import MULTIMODAL_REGISTRY
 from aphrodite.quantization.base_config import QuantizationConfig
 
-from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
+from .clip import (CLIPVisionModel, dummy_image_for_clip,
+                   dummy_seq_data_for_clip, get_clip_image_feature_size,
                    get_clip_patch_grid_length, input_processor_for_clip)
 from .interfaces import SupportsVision
 from .llava import LlavaMultiModalProjector
-from .utils import merge_vision_embeddings
+from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
+                     dummy_seq_data_for_siglip, get_siglip_image_feature_size,
+                     get_siglip_patch_grid_length, input_processor_for_siglip)
+from .utils import (filter_weights, init_aphrodite_registered_model,
+                    merge_vision_embeddings)
 
 _KEYS_TO_MODIFY_MAPPING = {
     "language_model.lm_head": "lm_head",
@@ -100,30 +101,42 @@ def get_llava_next_image_feature_size(
             image_size=vision_config.image_size,
             patch_size=vision_config.patch_size,
         )
-        base_feature_size = num_patches * num_patches
-
-        num_patch_height, num_patch_width = get_anyres_image_grid_shape(
-            image_size=(input_height, input_width),
-            grid_pinpoints=hf_config.image_grid_pinpoints,
-            patch_size=vision_config.image_size,
+        base_feature_size = get_clip_image_feature_size(vision_config)
+    elif isinstance(vision_config, SiglipVisionConfig):
+        num_patches = get_siglip_patch_grid_length(
+            image_size=vision_config.image_size,
+            patch_size=vision_config.patch_size,
         )
+        base_feature_size = get_siglip_image_feature_size(vision_config)
+    else:
+        msg = f"Unsupported vision config: {type(vision_config)}"
+        raise NotImplementedError(msg)
+
+    strategy = hf_config.vision_feature_select_strategy
+    if strategy == "default":
+        base_feature_size -= 1
+    elif strategy == "full":
+        pass
+    else:
+        raise ValueError(f"Unexpected select feature strategy: {strategy}")
 
-        (
-            unpadded_feature_size,
-            newline_feature_size,
-        ) = _get_llava_next_num_unpadded_features(input_height, input_width,
-                                                  num_patches,
-                                                  num_patch_height,
-                                                  num_patch_width)
+    num_patch_height, num_patch_width = get_anyres_image_grid_shape(
+        image_size=(input_height, input_width),
+        grid_pinpoints=hf_config.image_grid_pinpoints,
+        patch_size=vision_config.image_size,
+    )
 
-        return unpadded_feature_size + newline_feature_size + base_feature_size
+    (
+        unpadded_feature_size,
+        newline_feature_size,
+    ) = _get_llava_next_num_unpadded_features(input_height, input_width,
+                                              num_patches, num_patch_height,
+                                              num_patch_width)
 
-    msg = f"Unsupported vision config: {type(vision_config)}"
-    raise NotImplementedError(msg)
+    return unpadded_feature_size + newline_feature_size + base_feature_size
 
 
 def get_max_llava_next_image_tokens(ctx: InputContext):
-
     return get_llava_next_image_feature_size(
         ctx.get_hf_config(LlavaNextConfig),
         input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
@@ -151,6 +164,21 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
             image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
         )
 
+        return seq_data, mm_data
+    elif isinstance(vision_config, SiglipVisionConfig):
+        seq_data = dummy_seq_data_for_siglip(
+            vision_config,
+            seq_len,
+            image_token_id=hf_config.image_token_index,
+            image_feature_size_override=image_feature_size,
+        )
+
+        mm_data = dummy_image_for_siglip(
+            vision_config,
+            image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
+            image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
+        )
+
         return seq_data, mm_data
 
     msg = f"Unsupported vision config: {type(vision_config)}"
@@ -190,6 +218,40 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
             image_token_id=hf_config.image_token_index,
             image_feature_size_override=image_feature_size,
         )
+    elif isinstance(vision_config, SiglipVisionConfig):
+        return input_processor_for_siglip(
+            model_config,
+            vision_config,
+            llm_inputs,
+            image_token_id=hf_config.image_token_index,
+            image_feature_size_override=image_feature_size,
+        )
+
+    msg = f"Unsupported vision config: {type(vision_config)}"
+    raise NotImplementedError(msg)
+
+
+def _init_vision_tower(hf_config: LlavaNextConfig):
+    vision_config = hf_config.vision_config
+
+    # Initialize the vision tower only up to the required feature layer
+    vision_feature_layer = hf_config.vision_feature_layer
+    if vision_feature_layer < 0:
+        num_hidden_layers = hf_config.vision_config.num_hidden_layers \
+            + vision_feature_layer + 1
+    else:
+        num_hidden_layers = vision_feature_layer + 1
+
+    if isinstance(vision_config, CLIPVisionConfig):
+        return CLIPVisionModel(
+            vision_config,
+            num_hidden_layers_override=num_hidden_layers,
+        )
+    elif isinstance(vision_config, SiglipVisionConfig):
+        return SiglipVisionModel(
+            vision_config,
+            num_hidden_layers_override=num_hidden_layers,
+        )
 
     msg = f"Unsupported vision config: {type(vision_config)}"
     raise NotImplementedError(msg)
@@ -211,36 +273,15 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
         self.config = config
         self.multimodal_config = multimodal_config
 
-        # Initialize the vision tower only up to the required feature layer
-        vision_feature_layer = config.vision_feature_layer
-        if vision_feature_layer < 0:
-            num_hidden_layers = config.vision_config.num_hidden_layers \
-                + vision_feature_layer + 1
-        else:
-            num_hidden_layers = vision_feature_layer + 1
-
         # TODO: Optionally initializes this for supporting embeddings.
-        self.vision_tower = CLIPVisionModel(
-            config.vision_config, num_hidden_layers_override=num_hidden_layers)
+        self.vision_tower = _init_vision_tower(config)
         self.multi_modal_projector = LlavaMultiModalProjector(
             vision_hidden_size=config.vision_config.hidden_size,
             text_hidden_size=config.text_config.hidden_size,
             projector_hidden_act=config.projector_hidden_act)
 
-        self.quant_config = quant_config
-        self.language_model = LlamaModel(config.text_config, cache_config,
-                                         quant_config)
-        self.unpadded_vocab_size = config.text_config.vocab_size
-        self.lm_head = ParallelLMHead(
-            self.unpadded_vocab_size,
-            config.text_config.hidden_size,
-            org_num_embeddings=self.language_model.org_vocab_size,
-            quant_config=quant_config)
-        logit_scale = getattr(config, "logit_scale", 1.0)
-        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
-                                                config.text_config.vocab_size,
-                                                logit_scale)
-        self.sampler = Sampler()
+        self.language_model = init_aphrodite_registered_model(
+            config.text_config, cache_config, quant_config)
 
         self.image_newline = nn.Parameter(
             torch.empty(config.text_config.hidden_size))
@@ -306,8 +347,11 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
 
         raise ValueError(f"Unexpected select feature strategy: {strategy}")
 
-    def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
-                                  pixel_values: torch.Tensor) -> torch.Tensor:
+    def _image_pixels_to_features(
+        self,
+        vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
+        pixel_values: torch.Tensor,
+    ) -> torch.Tensor:
 
         # NOTE: we skip the step to select the vision feature layer since
         # this is already done inside the vision tower
@@ -445,19 +489,23 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
         **kwargs: object,
     ) -> SamplerOutput:
         """Run forward pass for LlaVA-NeXT.
+
         One key thing to understand is the `input_ids` already accounts for the
         positions of the to-be-inserted image embeddings.
+
         Concretely, consider a text prompt:
         `"A chat between a curious human and an artificial intelligence
         assistant. The assistant gives helpful, detailed, and polite answers to
         the human's questions.
         USER: <image>\\nWhat is shown in this image? ASSISTANT:"`.
+
         Tokenizer outputs:
         `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
         29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
         6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
         29871, 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799,
         9047, 13566, 29901]`.
+
         To reserve space in KV cache, we have to insert placeholder tokens
         before they are inputted to the model, so the input processor prepends 
         additional image tokens (denoted as `32000`), resulting in:
@@ -466,6 +514,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
         6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
         29871, 32000, ..., 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973,
         319, 1799, 9047, 13566, 29901]`.
+
         Unlike in LLaVA-1.5, the number of image tokens inputted to the language
         model depends on the original size of the input image. Including the
         original image token in the input, the required number of image tokens
@@ -487,7 +536,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
 
         if image_input is not None:
             vision_embeddings = self._process_image_input(image_input)
-            inputs_embeds = self.language_model.get_input_embeddings(input_ids)
+            inputs_embeds = self.language_model.model.get_input_embeddings(
+                input_ids)
 
             inputs_embeds = merge_vision_embeddings(
                 input_ids, inputs_embeds, vision_embeddings,
@@ -497,68 +547,55 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
         else:
             inputs_embeds = None
 
-        hidden_states = self.language_model(input_ids,
-                                            positions,
-                                            kv_caches,
-                                            attn_metadata,
-                                            None,
-                                            inputs_embeds=inputs_embeds)
+        hidden_states = self.language_model.model(input_ids,
+                                                  positions,
+                                                  kv_caches,
+                                                  attn_metadata,
+                                                  None,
+                                                  inputs_embeds=inputs_embeds)
 
         return hidden_states
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head, hidden_states,
-                                       sampling_metadata)
-        return logits
+        return self.language_model.compute_logits(hidden_states,
+                                                  sampling_metadata)
 
     def sample(
         self,
         logits: torch.Tensor,
         sampling_metadata: SamplingMetadata,
     ) -> Optional[SamplerOutput]:
-        next_tokens = self.sampler(logits, sampling_metadata)
-        return next_tokens
+        return self.language_model.sample(logits, sampling_metadata)
 
     def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
-        # only doing this for language model part for now.
-        stacked_params_mapping = [
-            # (param_name, shard_name, shard_id)
-            ("qkv_proj", "q_proj", "q"),
-            ("qkv_proj", "k_proj", "k"),
-            ("qkv_proj", "v_proj", "v"),
-            ("gate_up_proj", "gate_proj", 0),
-            ("gate_up_proj", "up_proj", 1),
-        ]
-        params_dict = dict(self.named_parameters())
-        for name, loaded_weight in weights:
-            if "rotary_emb.inv_freq" in name:
-                continue
-            # post_layernorm is not needed in CLIPVisionModel
-            if "vision_model.post_layernorm" in name:
-                continue
-            for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
-                if key_to_modify in name:
-                    name = name.replace(key_to_modify, new_key)
-            use_default_weight_loading = False
-            if "vision" in name:
-                if self.vision_tower is not None:
-                    # We only do sharding for language model and
-                    # not vision model for now.
-                    use_default_weight_loading = True
-            else:
-                for (param_name, weight_name,
-                     shard_id) in stacked_params_mapping:
-                    if weight_name not in name:
-                        continue
-                    param = params_dict[name.replace(weight_name, param_name)]
-                    weight_loader = param.weight_loader
-                    weight_loader(param, loaded_weight, shard_id)
-                    break
-                else:
-                    use_default_weight_loading = True
-            if use_default_weight_loading and name in params_dict:
-                param = params_dict[name]
-                weight_loader = getattr(param, "weight_loader",
-                                        default_weight_loader)
-                weight_loader(param, loaded_weight)
+        # prepare weight iterators for components
+        vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee(
+            weights, 4)
+
+        # load vision encoder
+        vit_weights = filter_weights(vit_weights, "vision_tower")
+        self.vision_tower.load_weights(vit_weights)
+
+        # load mlp projector
+        mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
+        mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
+        for name, loaded_weight in mlp_weights:
+            param = mlp_params_dict[name]
+            weight_loader = getattr(param, "weight_loader",
+                                    default_weight_loader)
+            weight_loader(param, loaded_weight)
+
+        # load newline
+        newline_weights = filter_weights(newline_weights, "image_newline")
+        for name, loaded_weight in newline_weights:
+            assert name == ""
+            param = self.image_newline
+            weight_loader = getattr(param, "weight_loader",
+                                    default_weight_loader)
+            weight_loader(param, loaded_weight)
+
+        # load llm backbone
+        llm_weights = filter_weights(llm_weights, "language_model")
+        self.language_model.load_weights(llm_weights)
+    

+ 44 - 16
aphrodite/modeling/models/siglip.py

@@ -2,13 +2,13 @@
 within a vision language model."""
 
 import math
-from typing import Optional, Tuple
+from typing import Iterable, Optional, Tuple
 
 import torch
 from aphrodite_flash_attn import flash_attn_func
 from PIL import Image
 from torch import nn
-from transformers import SiglipConfig, SiglipVisionConfig
+from transformers import SiglipVisionConfig
 from transformers.models.siglip.modeling_siglip import SiglipAttention
 from xformers.ops import memory_efficient_attention
 
@@ -22,13 +22,15 @@ from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     VocabParallelEmbedding)
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.multimodal.image import (cached_get_tokenizer,
                                         repeat_and_pad_image_tokens)
 from aphrodite.quantization import QuantizationConfig
 
 
 def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
-    assert image_size % patch_size == 0
+    # Since interpolation is applied, the image size need not be divisible
+    # assert image_size % patch_size == 0
     return image_size // patch_size
 
 
@@ -151,6 +153,7 @@ class SiglipVisionEmbeddings(nn.Module):
         class embedding unlike other ViTs) that allows the model to interpolate
         the pre-trained position encodings such that it can be usable on higher
         resolution images.
+
         Source:
         https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
         """
@@ -208,7 +211,7 @@ class SiglipVisionEmbeddings(nn.Module):
 
 
 # NOTE: Not used - kept for later when we TP the ViT
-# TODO: Implement TP version of Attention
+# TODO(ChristopherCho): Implement TP version of Attention
 class SiglipTPAttention(nn.Module):
 
     def __init__(
@@ -323,8 +326,8 @@ class SiglipTPAttention(nn.Module):
 
 
 # NOTE: Not used - kept for later when we TP the ViT
-# TODO: flash_attn_func is not working properly.
-#       It constantly throws a CUDA error.
+# TODO(ChristopherCho): flash_attn_func is not working properly.
+#                       It constantly throws a CUDA error.
 class SiglipFlashAttention2(SiglipTPAttention):
 
     def __init__(self, *args, **kwargs):
@@ -453,13 +456,13 @@ class SiglipEncoderLayer(nn.Module):
 
     def __init__(
         self,
-        config: SiglipConfig,
+        config: SiglipVisionConfig,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.embed_dim = config.hidden_size
 
-        # TODO: use TP'ed Attention block
+        # TODO(ChristopherCho): use TP'ed Attention block
         self.self_attn = SiglipAttention(config)
         self.layer_norm1 = nn.LayerNorm(self.embed_dim,
                                         eps=config.layer_norm_eps)
@@ -473,7 +476,7 @@ class SiglipEncoderLayer(nn.Module):
     def forward(
         self,
         hidden_states: torch.Tensor,
-    ) -> Tuple[torch.Tensor]:
+    ) -> Tuple[torch.Tensor, None]:
         residual = hidden_states
 
         hidden_states = self.layer_norm1(hidden_states)
@@ -492,22 +495,27 @@ class SiglipEncoder(nn.Module):
 
     def __init__(
         self,
-        config: SiglipConfig,
+        config: SiglipVisionConfig,
         quant_config: Optional[QuantizationConfig] = None,
+        num_hidden_layers_override: Optional[int] = None,
     ):
         super().__init__()
         self.config = config
+
+        if num_hidden_layers_override is None:
+            num_hidden_layers = config.num_hidden_layers
+        else:
+            num_hidden_layers = num_hidden_layers_override
+
         self.layers = nn.ModuleList([
-            SiglipEncoderLayer(
-                config,
-                quant_config=quant_config,
-            ) for _ in range(config.num_hidden_layers)
+            SiglipEncoderLayer(config, quant_config=quant_config)
+            for _ in range(num_hidden_layers)
         ])
 
     def forward(
         self,
         inputs_embeds: torch.Tensor,
-    ) -> Tuple:
+    ) -> torch.Tensor:
         hidden_states = inputs_embeds
         for encoder_layer in self.layers:
             hidden_states, _ = encoder_layer(hidden_states)
@@ -526,7 +534,7 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
         super().__init__()
 
         self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
-        # TODO(ChristopherCho): Implement vLLM version of MultiheadAttention
+        # TODO(ChristopherCho): Implement aphrodite version of MHA
         self.attention = torch.nn.MultiheadAttention(
             config.hidden_size, config.num_attention_heads, batch_first=True)
         self.layernorm = nn.LayerNorm(config.hidden_size,
@@ -552,6 +560,7 @@ class SiglipVisionTransformer(nn.Module):
         self,
         config: SiglipVisionConfig,
         quant_config: Optional[QuantizationConfig] = None,
+        num_hidden_layers_override: Optional[int] = None,
     ):
         super().__init__()
         self.config = config
@@ -561,6 +570,7 @@ class SiglipVisionTransformer(nn.Module):
         self.encoder = SiglipEncoder(
             config,
             quant_config=quant_config,
+            num_hidden_layers_override=num_hidden_layers_override,
         )
         self.post_layernorm = nn.LayerNorm(embed_dim,
                                            eps=config.layer_norm_eps)
@@ -599,11 +609,13 @@ class SiglipVisionModel(nn.Module):
         self,
         config: SiglipVisionConfig,
         quant_config: Optional[QuantizationConfig] = None,
+        num_hidden_layers_override: Optional[int] = None,
     ):
         super().__init__()
         self.vision_model = SiglipVisionTransformer(
             config,
             quant_config,
+            num_hidden_layers_override=num_hidden_layers_override,
         )
 
     def get_input_embeddings(self) -> nn.Module:
@@ -618,3 +630,19 @@ class SiglipVisionModel(nn.Module):
             pixel_values=pixel_values,
             interpolate_pos_encoding=interpolate_pos_encoding,
         )
+
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+        params_dict = dict(self.named_parameters())
+        layer_count = len(self.vision_model.encoder.layers)
+
+        for name, loaded_weight in weights:
+            # omit layers when num_hidden_layers_override is set
+            if "vision_model.encoder.layers." in name:
+                layer_idx = int(name.split(".")[3])
+                if layer_idx >= layer_count:
+                    continue
+
+            param = params_dict[name]
+            weight_loader = getattr(param, "weight_loader",
+                                    default_weight_loader)
+            weight_loader(param, loaded_weight)

+ 55 - 4
aphrodite/modeling/models/utils.py

@@ -1,10 +1,57 @@
-from typing import Dict, List, Protocol, Tuple
+from typing import Dict, Iterable, List, Optional, Protocol, Tuple
 
 import torch
+import torch.nn as nn
 from torch.func import functional_call
+from transformers import PretrainedConfig
 
+from aphrodite.common.config import (CacheConfig, LoRAConfig, MultiModalConfig,
+                                     SchedulerConfig)
 from aphrodite.common.utils import is_pin_memory_available
+from aphrodite.modeling.model_loader.loader import build_model
+from aphrodite.modeling.models import ModelRegistry
 from aphrodite.multimodal import BatchedTensors
+from aphrodite.quantization import QuantizationConfig
+
+
+def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str):
+    """
+    Helper function to load weights for inner aphrodite models.
+
+    See also:
+        :ref:`init_aphrodite_registered_model`
+    """
+    for name, loaded_weight in weights:
+        name = name.split(".")
+        if prefix == name.pop(0):
+            name = ".".join(name)
+            yield name, loaded_weight
+
+
+def init_aphrodite_registered_model(
+    hf_config: PretrainedConfig,
+    cache_config: Optional[CacheConfig],
+    quant_config: Optional[QuantizationConfig],
+    *,
+    lora_config: Optional[LoRAConfig] = None,
+    multimodal_config: Optional[MultiModalConfig] = None,
+    scheduler_config: Optional[SchedulerConfig] = None,
+) -> nn.Module:
+    """
+    Helper function to initialize an inner model registered to aphrodite,
+    based on the arguments passed to the outer aphrodite model.
+    """
+    model_class, _ = ModelRegistry.resolve_model_cls(hf_config.architectures)
+
+    return build_model(
+        model_class,
+        hf_config,
+        cache_config,
+        quant_config,
+        lora_config=lora_config,
+        multimodal_config=multimodal_config,
+        scheduler_config=scheduler_config,
+    )
 
 
 def merge_vision_embeddings(input_ids: torch.Tensor,
@@ -12,11 +59,12 @@ def merge_vision_embeddings(input_ids: torch.Tensor,
                             vision_embeddings: BatchedTensors,
                             image_token_id: int) -> torch.Tensor:
     """
-    Merge `vision_embeddings` into `inputs_embeds` by overwriting the positions
-    in `inputs_embeds` corresponding to placeholder image tokens in `input_ids`.
+    Merge ``vision_embeddings`` into ``inputs_embeds`` by overwriting the
+    positions in ``inputs_embeds`` corresponding to placeholder image tokens in
+    ``input_ids``.
 
     Note:
-        This updates `inputs_embeds` in place.
+        This updates ``inputs_embeds`` in place.
     """
     mask = (input_ids == image_token_id)
     num_expected_tokens = mask.sum()
@@ -75,11 +123,14 @@ def set_cpu_offload_max_bytes(max_bytes: int) -> None:
 
 def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
     device = next(module.parameters()).device
+
     if device == torch.device("cpu"):
         return module
+
     global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
     if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
         return module
+
     pin_memory = is_pin_memory_available()
 
     # offload parameters to CPU