|
@@ -1,9 +1,10 @@
|
|
|
+import itertools
|
|
|
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from PIL import Image
|
|
|
-from transformers import CLIPVisionConfig, LlavaNextConfig
|
|
|
+from transformers import CLIPVisionConfig, LlavaNextConfig, SiglipVisionConfig
|
|
|
from transformers.models.llava_next.modeling_llava_next import (
|
|
|
get_anyres_image_grid_shape, unpad_image)
|
|
|
from typing_extensions import NotRequired
|
|
@@ -12,21 +13,21 @@ from aphrodite.attention import AttentionMetadata
|
|
|
from aphrodite.common.config import CacheConfig, MultiModalConfig
|
|
|
from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
|
|
|
from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
|
|
-from aphrodite.modeling.layers.logits_processor import LogitsProcessor
|
|
|
-from aphrodite.modeling.layers.sampler import Sampler
|
|
|
-from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
|
|
|
from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
|
|
|
-from aphrodite.modeling.models.clip import CLIPVisionModel
|
|
|
-from aphrodite.modeling.models.llama import LlamaModel
|
|
|
from aphrodite.modeling.sampling_metadata import SamplingMetadata
|
|
|
from aphrodite.multimodal import MULTIMODAL_REGISTRY
|
|
|
from aphrodite.quantization.base_config import QuantizationConfig
|
|
|
|
|
|
-from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
|
|
|
+from .clip import (CLIPVisionModel, dummy_image_for_clip,
|
|
|
+ dummy_seq_data_for_clip, get_clip_image_feature_size,
|
|
|
get_clip_patch_grid_length, input_processor_for_clip)
|
|
|
from .interfaces import SupportsVision
|
|
|
from .llava import LlavaMultiModalProjector
|
|
|
-from .utils import merge_vision_embeddings
|
|
|
+from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
|
|
+ dummy_seq_data_for_siglip, get_siglip_image_feature_size,
|
|
|
+ get_siglip_patch_grid_length, input_processor_for_siglip)
|
|
|
+from .utils import (filter_weights, init_aphrodite_registered_model,
|
|
|
+ merge_vision_embeddings)
|
|
|
|
|
|
_KEYS_TO_MODIFY_MAPPING = {
|
|
|
"language_model.lm_head": "lm_head",
|
|
@@ -100,30 +101,42 @@ def get_llava_next_image_feature_size(
|
|
|
image_size=vision_config.image_size,
|
|
|
patch_size=vision_config.patch_size,
|
|
|
)
|
|
|
- base_feature_size = num_patches * num_patches
|
|
|
-
|
|
|
- num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
|
|
- image_size=(input_height, input_width),
|
|
|
- grid_pinpoints=hf_config.image_grid_pinpoints,
|
|
|
- patch_size=vision_config.image_size,
|
|
|
+ base_feature_size = get_clip_image_feature_size(vision_config)
|
|
|
+ elif isinstance(vision_config, SiglipVisionConfig):
|
|
|
+ num_patches = get_siglip_patch_grid_length(
|
|
|
+ image_size=vision_config.image_size,
|
|
|
+ patch_size=vision_config.patch_size,
|
|
|
)
|
|
|
+ base_feature_size = get_siglip_image_feature_size(vision_config)
|
|
|
+ else:
|
|
|
+ msg = f"Unsupported vision config: {type(vision_config)}"
|
|
|
+ raise NotImplementedError(msg)
|
|
|
+
|
|
|
+ strategy = hf_config.vision_feature_select_strategy
|
|
|
+ if strategy == "default":
|
|
|
+ base_feature_size -= 1
|
|
|
+ elif strategy == "full":
|
|
|
+ pass
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Unexpected select feature strategy: {strategy}")
|
|
|
|
|
|
- (
|
|
|
- unpadded_feature_size,
|
|
|
- newline_feature_size,
|
|
|
- ) = _get_llava_next_num_unpadded_features(input_height, input_width,
|
|
|
- num_patches,
|
|
|
- num_patch_height,
|
|
|
- num_patch_width)
|
|
|
+ num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
|
|
+ image_size=(input_height, input_width),
|
|
|
+ grid_pinpoints=hf_config.image_grid_pinpoints,
|
|
|
+ patch_size=vision_config.image_size,
|
|
|
+ )
|
|
|
|
|
|
- return unpadded_feature_size + newline_feature_size + base_feature_size
|
|
|
+ (
|
|
|
+ unpadded_feature_size,
|
|
|
+ newline_feature_size,
|
|
|
+ ) = _get_llava_next_num_unpadded_features(input_height, input_width,
|
|
|
+ num_patches, num_patch_height,
|
|
|
+ num_patch_width)
|
|
|
|
|
|
- msg = f"Unsupported vision config: {type(vision_config)}"
|
|
|
- raise NotImplementedError(msg)
|
|
|
+ return unpadded_feature_size + newline_feature_size + base_feature_size
|
|
|
|
|
|
|
|
|
def get_max_llava_next_image_tokens(ctx: InputContext):
|
|
|
-
|
|
|
return get_llava_next_image_feature_size(
|
|
|
ctx.get_hf_config(LlavaNextConfig),
|
|
|
input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
|
@@ -151,6 +164,21 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
|
|
|
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
|
|
)
|
|
|
|
|
|
+ return seq_data, mm_data
|
|
|
+ elif isinstance(vision_config, SiglipVisionConfig):
|
|
|
+ seq_data = dummy_seq_data_for_siglip(
|
|
|
+ vision_config,
|
|
|
+ seq_len,
|
|
|
+ image_token_id=hf_config.image_token_index,
|
|
|
+ image_feature_size_override=image_feature_size,
|
|
|
+ )
|
|
|
+
|
|
|
+ mm_data = dummy_image_for_siglip(
|
|
|
+ vision_config,
|
|
|
+ image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
|
|
+ image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
|
|
+ )
|
|
|
+
|
|
|
return seq_data, mm_data
|
|
|
|
|
|
msg = f"Unsupported vision config: {type(vision_config)}"
|
|
@@ -190,6 +218,40 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
|
|
|
image_token_id=hf_config.image_token_index,
|
|
|
image_feature_size_override=image_feature_size,
|
|
|
)
|
|
|
+ elif isinstance(vision_config, SiglipVisionConfig):
|
|
|
+ return input_processor_for_siglip(
|
|
|
+ model_config,
|
|
|
+ vision_config,
|
|
|
+ llm_inputs,
|
|
|
+ image_token_id=hf_config.image_token_index,
|
|
|
+ image_feature_size_override=image_feature_size,
|
|
|
+ )
|
|
|
+
|
|
|
+ msg = f"Unsupported vision config: {type(vision_config)}"
|
|
|
+ raise NotImplementedError(msg)
|
|
|
+
|
|
|
+
|
|
|
+def _init_vision_tower(hf_config: LlavaNextConfig):
|
|
|
+ vision_config = hf_config.vision_config
|
|
|
+
|
|
|
+ # Initialize the vision tower only up to the required feature layer
|
|
|
+ vision_feature_layer = hf_config.vision_feature_layer
|
|
|
+ if vision_feature_layer < 0:
|
|
|
+ num_hidden_layers = hf_config.vision_config.num_hidden_layers \
|
|
|
+ + vision_feature_layer + 1
|
|
|
+ else:
|
|
|
+ num_hidden_layers = vision_feature_layer + 1
|
|
|
+
|
|
|
+ if isinstance(vision_config, CLIPVisionConfig):
|
|
|
+ return CLIPVisionModel(
|
|
|
+ vision_config,
|
|
|
+ num_hidden_layers_override=num_hidden_layers,
|
|
|
+ )
|
|
|
+ elif isinstance(vision_config, SiglipVisionConfig):
|
|
|
+ return SiglipVisionModel(
|
|
|
+ vision_config,
|
|
|
+ num_hidden_layers_override=num_hidden_layers,
|
|
|
+ )
|
|
|
|
|
|
msg = f"Unsupported vision config: {type(vision_config)}"
|
|
|
raise NotImplementedError(msg)
|
|
@@ -211,36 +273,15 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
|
|
self.config = config
|
|
|
self.multimodal_config = multimodal_config
|
|
|
|
|
|
- # Initialize the vision tower only up to the required feature layer
|
|
|
- vision_feature_layer = config.vision_feature_layer
|
|
|
- if vision_feature_layer < 0:
|
|
|
- num_hidden_layers = config.vision_config.num_hidden_layers \
|
|
|
- + vision_feature_layer + 1
|
|
|
- else:
|
|
|
- num_hidden_layers = vision_feature_layer + 1
|
|
|
-
|
|
|
# TODO: Optionally initializes this for supporting embeddings.
|
|
|
- self.vision_tower = CLIPVisionModel(
|
|
|
- config.vision_config, num_hidden_layers_override=num_hidden_layers)
|
|
|
+ self.vision_tower = _init_vision_tower(config)
|
|
|
self.multi_modal_projector = LlavaMultiModalProjector(
|
|
|
vision_hidden_size=config.vision_config.hidden_size,
|
|
|
text_hidden_size=config.text_config.hidden_size,
|
|
|
projector_hidden_act=config.projector_hidden_act)
|
|
|
|
|
|
- self.quant_config = quant_config
|
|
|
- self.language_model = LlamaModel(config.text_config, cache_config,
|
|
|
- quant_config)
|
|
|
- self.unpadded_vocab_size = config.text_config.vocab_size
|
|
|
- self.lm_head = ParallelLMHead(
|
|
|
- self.unpadded_vocab_size,
|
|
|
- config.text_config.hidden_size,
|
|
|
- org_num_embeddings=self.language_model.org_vocab_size,
|
|
|
- quant_config=quant_config)
|
|
|
- logit_scale = getattr(config, "logit_scale", 1.0)
|
|
|
- self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
|
|
- config.text_config.vocab_size,
|
|
|
- logit_scale)
|
|
|
- self.sampler = Sampler()
|
|
|
+ self.language_model = init_aphrodite_registered_model(
|
|
|
+ config.text_config, cache_config, quant_config)
|
|
|
|
|
|
self.image_newline = nn.Parameter(
|
|
|
torch.empty(config.text_config.hidden_size))
|
|
@@ -306,8 +347,11 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
|
|
|
|
|
raise ValueError(f"Unexpected select feature strategy: {strategy}")
|
|
|
|
|
|
- def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
|
|
|
- pixel_values: torch.Tensor) -> torch.Tensor:
|
|
|
+ def _image_pixels_to_features(
|
|
|
+ self,
|
|
|
+ vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
|
|
|
+ pixel_values: torch.Tensor,
|
|
|
+ ) -> torch.Tensor:
|
|
|
|
|
|
# NOTE: we skip the step to select the vision feature layer since
|
|
|
# this is already done inside the vision tower
|
|
@@ -445,19 +489,23 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
|
|
**kwargs: object,
|
|
|
) -> SamplerOutput:
|
|
|
"""Run forward pass for LlaVA-NeXT.
|
|
|
+
|
|
|
One key thing to understand is the `input_ids` already accounts for the
|
|
|
positions of the to-be-inserted image embeddings.
|
|
|
+
|
|
|
Concretely, consider a text prompt:
|
|
|
`"A chat between a curious human and an artificial intelligence
|
|
|
assistant. The assistant gives helpful, detailed, and polite answers to
|
|
|
the human's questions.
|
|
|
USER: <image>\\nWhat is shown in this image? ASSISTANT:"`.
|
|
|
+
|
|
|
Tokenizer outputs:
|
|
|
`[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
|
|
|
29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
|
|
|
6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
|
|
|
29871, 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799,
|
|
|
9047, 13566, 29901]`.
|
|
|
+
|
|
|
To reserve space in KV cache, we have to insert placeholder tokens
|
|
|
before they are inputted to the model, so the input processor prepends
|
|
|
additional image tokens (denoted as `32000`), resulting in:
|
|
@@ -466,6 +514,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
|
|
6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
|
|
|
29871, 32000, ..., 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973,
|
|
|
319, 1799, 9047, 13566, 29901]`.
|
|
|
+
|
|
|
Unlike in LLaVA-1.5, the number of image tokens inputted to the language
|
|
|
model depends on the original size of the input image. Including the
|
|
|
original image token in the input, the required number of image tokens
|
|
@@ -487,7 +536,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
|
|
|
|
|
if image_input is not None:
|
|
|
vision_embeddings = self._process_image_input(image_input)
|
|
|
- inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
|
|
+ inputs_embeds = self.language_model.model.get_input_embeddings(
|
|
|
+ input_ids)
|
|
|
|
|
|
inputs_embeds = merge_vision_embeddings(
|
|
|
input_ids, inputs_embeds, vision_embeddings,
|
|
@@ -497,68 +547,55 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
|
|
else:
|
|
|
inputs_embeds = None
|
|
|
|
|
|
- hidden_states = self.language_model(input_ids,
|
|
|
- positions,
|
|
|
- kv_caches,
|
|
|
- attn_metadata,
|
|
|
- None,
|
|
|
- inputs_embeds=inputs_embeds)
|
|
|
+ hidden_states = self.language_model.model(input_ids,
|
|
|
+ positions,
|
|
|
+ kv_caches,
|
|
|
+ attn_metadata,
|
|
|
+ None,
|
|
|
+ inputs_embeds=inputs_embeds)
|
|
|
|
|
|
return hidden_states
|
|
|
|
|
|
def compute_logits(self, hidden_states: torch.Tensor,
|
|
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
|
|
- logits = self.logits_processor(self.lm_head, hidden_states,
|
|
|
- sampling_metadata)
|
|
|
- return logits
|
|
|
+ return self.language_model.compute_logits(hidden_states,
|
|
|
+ sampling_metadata)
|
|
|
|
|
|
def sample(
|
|
|
self,
|
|
|
logits: torch.Tensor,
|
|
|
sampling_metadata: SamplingMetadata,
|
|
|
) -> Optional[SamplerOutput]:
|
|
|
- next_tokens = self.sampler(logits, sampling_metadata)
|
|
|
- return next_tokens
|
|
|
+ return self.language_model.sample(logits, sampling_metadata)
|
|
|
|
|
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
|
- # only doing this for language model part for now.
|
|
|
- stacked_params_mapping = [
|
|
|
- # (param_name, shard_name, shard_id)
|
|
|
- ("qkv_proj", "q_proj", "q"),
|
|
|
- ("qkv_proj", "k_proj", "k"),
|
|
|
- ("qkv_proj", "v_proj", "v"),
|
|
|
- ("gate_up_proj", "gate_proj", 0),
|
|
|
- ("gate_up_proj", "up_proj", 1),
|
|
|
- ]
|
|
|
- params_dict = dict(self.named_parameters())
|
|
|
- for name, loaded_weight in weights:
|
|
|
- if "rotary_emb.inv_freq" in name:
|
|
|
- continue
|
|
|
- # post_layernorm is not needed in CLIPVisionModel
|
|
|
- if "vision_model.post_layernorm" in name:
|
|
|
- continue
|
|
|
- for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
|
|
|
- if key_to_modify in name:
|
|
|
- name = name.replace(key_to_modify, new_key)
|
|
|
- use_default_weight_loading = False
|
|
|
- if "vision" in name:
|
|
|
- if self.vision_tower is not None:
|
|
|
- # We only do sharding for language model and
|
|
|
- # not vision model for now.
|
|
|
- use_default_weight_loading = True
|
|
|
- else:
|
|
|
- for (param_name, weight_name,
|
|
|
- shard_id) in stacked_params_mapping:
|
|
|
- if weight_name not in name:
|
|
|
- continue
|
|
|
- param = params_dict[name.replace(weight_name, param_name)]
|
|
|
- weight_loader = param.weight_loader
|
|
|
- weight_loader(param, loaded_weight, shard_id)
|
|
|
- break
|
|
|
- else:
|
|
|
- use_default_weight_loading = True
|
|
|
- if use_default_weight_loading and name in params_dict:
|
|
|
- param = params_dict[name]
|
|
|
- weight_loader = getattr(param, "weight_loader",
|
|
|
- default_weight_loader)
|
|
|
- weight_loader(param, loaded_weight)
|
|
|
+ # prepare weight iterators for components
|
|
|
+ vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee(
|
|
|
+ weights, 4)
|
|
|
+
|
|
|
+ # load vision encoder
|
|
|
+ vit_weights = filter_weights(vit_weights, "vision_tower")
|
|
|
+ self.vision_tower.load_weights(vit_weights)
|
|
|
+
|
|
|
+ # load mlp projector
|
|
|
+ mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
|
|
|
+ mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
|
|
|
+ for name, loaded_weight in mlp_weights:
|
|
|
+ param = mlp_params_dict[name]
|
|
|
+ weight_loader = getattr(param, "weight_loader",
|
|
|
+ default_weight_loader)
|
|
|
+ weight_loader(param, loaded_weight)
|
|
|
+
|
|
|
+ # load newline
|
|
|
+ newline_weights = filter_weights(newline_weights, "image_newline")
|
|
|
+ for name, loaded_weight in newline_weights:
|
|
|
+ assert name == ""
|
|
|
+ param = self.image_newline
|
|
|
+ weight_loader = getattr(param, "weight_loader",
|
|
|
+ default_weight_loader)
|
|
|
+ weight_loader(param, loaded_weight)
|
|
|
+
|
|
|
+ # load llm backbone
|
|
|
+ llm_weights = filter_weights(llm_weights, "language_model")
|
|
|
+ self.language_model.load_weights(llm_weights)
|
|
|
+
|