import itertools from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) import torch import torch.nn as nn from PIL import Image from transformers import CLIPVisionConfig, LlavaNextConfig, SiglipVisionConfig from transformers.models.llava_next.modeling_llava_next import ( get_anyres_image_grid_shape, unpad_image) from typing_extensions import NotRequired from aphrodite.attention import AttentionMetadata from aphrodite.common.config import CacheConfig, MultiModalConfig from aphrodite.common.sequence import IntermediateTensors from aphrodite.common.utils import is_list_of from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs from aphrodite.modeling.layers.sampler import SamplerOutput from aphrodite.modeling.model_loader.weight_utils import default_weight_loader from aphrodite.modeling.sampling_metadata import SamplingMetadata from aphrodite.multimodal import MULTIMODAL_REGISTRY from aphrodite.quantization.base_config import QuantizationConfig from .clip import (CLIPVisionModel, dummy_image_for_clip, dummy_seq_data_for_clip, get_clip_image_feature_size, get_clip_patch_grid_length, input_processor_for_clip) from .interfaces import SupportsMultiModal from .llava import LlavaMultiModalProjector from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip, get_siglip_image_feature_size, get_siglip_patch_grid_length, input_processor_for_siglip) from .utils import (filter_weights, flatten_bn, init_aphrodite_registered_model, merge_multimodal_embeddings) _KEYS_TO_MODIFY_MAPPING = { "language_model.lm_head": "lm_head", "language_model.model": "language_model", } # Result in the max possible feature size (2x2 grid of 336x336px tiles) MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448 class LlavaNextImagePixelInputs(TypedDict): type: Literal["pixel_values"] data: Union[torch.Tensor, List[torch.Tensor]] """ Shape: `(batch_size * num_images, 1 + num_patches, num_channels, height, width)` Note that `num_patches` may be different per batch and image, in which case the data is passed as a list instead of a batched tensor. """ image_sizes: NotRequired[torch.Tensor] """ Shape: `(batch_size * num_images, 2)` This should be in `(height, width)` format. """ class LlavaNextImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] data: torch.Tensor """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` `hidden_size` must match the hidden size of language model backbone. """ LlavaNextImageInputs = Union[LlavaNextImagePixelInputs, LlavaNextImageEmbeddingInputs] # Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79 def _get_llava_next_num_unpadded_features( original_height: int, original_width: int, npatches: int, num_patch_height: int, num_patch_width: int, ) -> Tuple[int, int]: current_height = npatches * num_patch_height current_width = npatches * num_patch_width original_aspect_ratio = original_width / original_height current_aspect_ratio = current_width / current_height if original_aspect_ratio > current_aspect_ratio: scale_factor = current_width / original_width new_height = int(original_height * scale_factor) padding = (current_height - new_height) // 2 current_height -= 2 * padding else: scale_factor = current_height / original_height new_width = int(original_width * scale_factor) padding = (current_width - new_width) // 2 current_width -= 2 * padding unpadded_features = current_height * current_width newline_features = current_height return (unpadded_features, newline_features) # Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106 def get_llava_next_image_feature_size( hf_config: LlavaNextConfig, *, input_height: int, input_width: int, ) -> int: vision_config = hf_config.vision_config if isinstance(vision_config, CLIPVisionConfig): num_patches = get_clip_patch_grid_length( image_size=vision_config.image_size, patch_size=vision_config.patch_size, ) base_feature_size = get_clip_image_feature_size(vision_config) elif isinstance(vision_config, SiglipVisionConfig): num_patches = get_siglip_patch_grid_length( image_size=vision_config.image_size, patch_size=vision_config.patch_size, ) base_feature_size = get_siglip_image_feature_size(vision_config) else: msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) strategy = hf_config.vision_feature_select_strategy if strategy == "default": base_feature_size -= 1 elif strategy == "full": pass else: raise ValueError(f"Unexpected select feature strategy: {strategy}") num_patch_height, num_patch_width = get_anyres_image_grid_shape( image_size=(input_height, input_width), grid_pinpoints=hf_config.image_grid_pinpoints, patch_size=vision_config.image_size, ) ( unpadded_feature_size, newline_feature_size, ) = _get_llava_next_num_unpadded_features(input_height, input_width, num_patches, num_patch_height, num_patch_width) return unpadded_feature_size + newline_feature_size + base_feature_size def get_max_llava_next_image_tokens(ctx: InputContext): return get_llava_next_image_feature_size( ctx.get_hf_config(LlavaNextConfig), input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH, ) def dummy_data_for_llava_next(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]): hf_config = ctx.get_hf_config(LlavaNextConfig) vision_config = hf_config.vision_config num_images = mm_counts["image"] image_feature_size = get_max_llava_next_image_tokens(ctx) if isinstance(vision_config, CLIPVisionConfig): seq_data = dummy_seq_data_for_clip( vision_config, seq_len, num_images, image_token_id=hf_config.image_token_index, image_feature_size_override=image_feature_size, ) mm_data = dummy_image_for_clip( vision_config, num_images, image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH, image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT, ) return seq_data, mm_data elif isinstance(vision_config, SiglipVisionConfig): seq_data = dummy_seq_data_for_siglip( vision_config, seq_len, num_images, image_token_id=hf_config.image_token_index, image_feature_size_override=image_feature_size, ) mm_data = dummy_image_for_siglip( vision_config, num_images, image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH, image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT, ) return seq_data, mm_data msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs): multi_modal_data = llm_inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: return llm_inputs model_config = ctx.model_config hf_config = ctx.get_hf_config(LlavaNextConfig) vision_config = hf_config.vision_config image_data = multi_modal_data["image"] if isinstance(image_data, Image.Image): width, height = image_data.size image_feature_size = get_llava_next_image_feature_size( hf_config, input_height=height, input_width=width, ) elif is_list_of(image_data, Image.Image): image_feature_size = [ get_llava_next_image_feature_size(hf_config, input_height=img.height, input_width=img.width) for img in image_data ] elif isinstance(image_data, torch.Tensor): num_images, image_feature_size, hidden_size = image_data.shape elif is_list_of(image_data, torch.Tensor): image_feature_size = [item.shape[1] for item in image_data] else: raise TypeError(f"Invalid image type: {type(image_data)}") vision_config = hf_config.vision_config if isinstance(vision_config, CLIPVisionConfig): return input_processor_for_clip( model_config, vision_config, llm_inputs, image_token_id=hf_config.image_token_index, image_feature_size_override=image_feature_size, ) elif isinstance(vision_config, SiglipVisionConfig): return input_processor_for_siglip( model_config, vision_config, llm_inputs, image_token_id=hf_config.image_token_index, image_feature_size_override=image_feature_size, ) msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) def _init_vision_tower(hf_config: LlavaNextConfig): vision_config = hf_config.vision_config # Initialize the vision tower only up to the required feature layer vision_feature_layer = hf_config.vision_feature_layer if vision_feature_layer < 0: num_hidden_layers = hf_config.vision_config.num_hidden_layers \ + vision_feature_layer + 1 else: num_hidden_layers = vision_feature_layer + 1 if isinstance(vision_config, CLIPVisionConfig): return CLIPVisionModel( vision_config, num_hidden_layers_override=num_hidden_layers, ) elif isinstance(vision_config, SiglipVisionConfig): return SiglipVisionModel( vision_config, num_hidden_layers_override=num_hidden_layers, ) msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) @MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next) @INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next) class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal): def __init__(self, config: LlavaNextConfig, multimodal_config: MultiModalConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None) -> None: super().__init__() self.config = config self.multimodal_config = multimodal_config # TODO: Optionally initializes this for supporting embeddings. self.vision_tower = _init_vision_tower(config) self.multi_modal_projector = LlavaMultiModalProjector( vision_hidden_size=config.vision_config.hidden_size, text_hidden_size=config.text_config.hidden_size, projector_hidden_act=config.projector_hidden_act) self.language_model = init_aphrodite_registered_model( config.text_config, cache_config, quant_config) self.image_newline = nn.Parameter( torch.empty(config.text_config.hidden_size)) def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: expected_dims = (2, ) def _validate_shape(d: torch.Tensor): actual_dims = tuple(d.shape) if actual_dims != expected_dims: expected_expr = str(expected_dims) raise ValueError( f"The expected shape of image sizes per image per batch " f"is {expected_expr}. You supplied {tuple(d.shape)}.") for d in data: _validate_shape(d) return data def _validate_pixel_values( self, data: Union[torch.Tensor, List[torch.Tensor]] ) -> Union[torch.Tensor, List[torch.Tensor]]: h = w = self.config.vision_config.image_size expected_dims = (3, h, w) def _validate_shape(d: torch.Tensor): actual_dims = tuple(d.shape[1:]) if actual_dims != expected_dims: expected_expr = ("num_patches", *map(str, expected_dims)) raise ValueError( "The expected shape of pixel values per image per batch " f"is {expected_expr}. You supplied {tuple(d.shape)}.") for d in data: _validate_shape(d) return data def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[LlavaNextImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_sizes = kwargs.pop("image_sizes", None) image_embeds = kwargs.pop("image_embeds", None) if pixel_values is None and image_embeds is None: return None if pixel_values is not None: if not isinstance(pixel_values, (torch.Tensor, list)): raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") if not isinstance(image_sizes, (torch.Tensor, list)): raise ValueError("Incorrect type of image sizes. " f"Got type: {type(image_sizes)}") return LlavaNextImagePixelInputs( type="pixel_values", data=self._validate_pixel_values(flatten_bn(pixel_values)), image_sizes=self._validate_image_sizes( flatten_bn(image_sizes, concat=True)), ) if image_embeds is not None: if not isinstance(image_embeds, torch.Tensor): raise ValueError("Incorrect type of image embeds. " f"Got type: {type(image_embeds)}") return LlavaNextImageEmbeddingInputs( type="image_embeds", data=flatten_bn(image_embeds), ) raise AssertionError("This line should be unreachable.") def _select_image_features(self, image_features: torch.Tensor, *, strategy: str) -> torch.Tensor: # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa if strategy == "default": return image_features[:, 1:] elif strategy == "full": return image_features raise ValueError(f"Unexpected select feature strategy: {strategy}") def _image_pixels_to_features( self, vision_tower: Union[CLIPVisionModel, SiglipVisionModel], pixel_values: torch.Tensor, ) -> torch.Tensor: # NOTE: we skip the step to select the vision feature layer since # this is already done inside the vision tower image_features = vision_tower(pixel_values) return self._select_image_features( image_features, strategy=self.config.vision_feature_select_strategy, ) # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py def _merge_image_patch_embeddings(self, image_size: torch.Tensor, patch_embeddings: torch.Tensor, *, strategy: str) -> torch.Tensor: if strategy == "flat": return patch_embeddings.flatten(0, 1) if strategy.startswith("spatial"): height = width = self.config.vision_config.image_size \ // self.config.vision_config.patch_size base_patch_embeds = patch_embeddings[0] if height * width != base_patch_embeds.shape[0]: raise ValueError( "The number of patches is not consistent with the " "image size.") if patch_embeddings.shape[0] > 1: other_patch_embeds = patch_embeddings[1:] # Move to CPU to avoid floating-point errors orig_height, orig_width = image_size.tolist() # image_aspect_ratio == "anyres" num_patch_height, num_patch_width = get_anyres_image_grid_shape( (orig_height, orig_width), self.config.image_grid_pinpoints, self.config.vision_config.image_size, ) num_patches = num_patch_height * num_patch_width # Image patches might be padded for batch processing other_patch_embeds = other_patch_embeds[:num_patches] \ .view(num_patch_height, num_patch_width, height, width, -1) if "unpad" in strategy: other_patch_embeds = other_patch_embeds \ .permute(4, 0, 2, 1, 3).contiguous() \ .flatten(1, 2).flatten(2, 3) other_patch_embeds = unpad_image(other_patch_embeds, (orig_height, orig_width)) other_patch_embeds = torch.cat(( other_patch_embeds, self.image_newline[:, None, None] \ .expand(*other_patch_embeds.shape[:-1], 1) \ .to(other_patch_embeds.device), ), dim=-1) other_patch_embeds = other_patch_embeds \ .flatten(1, 2).transpose(0, 1) else: other_patch_embeds = other_patch_embeds \ .permute(0, 2, 1, 3, 4).contiguous() \ .flatten(0, 3) merged_patch_embeddings = torch.cat( (base_patch_embeds, other_patch_embeds), dim=0) else: if "unpad" in strategy: merged_patch_embeddings = torch.cat( (base_patch_embeds, self.image_newline[None] \ .to(base_patch_embeds.device) ), dim=0) else: merged_patch_embeddings = base_patch_embeds return merged_patch_embeddings raise ValueError(f"Unexpected patch merge strategy: {strategy}") def _process_image_pixels( self, inputs: LlavaNextImagePixelInputs, ) -> Union[torch.Tensor, List[torch.Tensor]]: assert self.vision_tower is not None pixel_values = inputs["data"] if isinstance(pixel_values, torch.Tensor): b, num_patches, c, h, w = pixel_values.shape stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w) stacked_image_features = self._image_pixels_to_features( self.vision_tower, stacked_pixel_values) stacked_patch_embeddings = self.multi_modal_projector( stacked_image_features) return stacked_patch_embeddings.view( b, num_patches, *stacked_patch_embeddings.shape[1:]) num_patches_per_batch = [v.shape[0] for v in pixel_values] stacked_pixel_values = torch.cat(pixel_values) stacked_image_features = self._image_pixels_to_features( self.vision_tower, stacked_pixel_values) return [ self.multi_modal_projector(image_features) for image_features in torch.split(stacked_image_features, num_patches_per_batch) ] def _process_image_input( self, image_input: LlavaNextImageInputs, ) -> Union[torch.Tensor, List[torch.Tensor]]: if image_input["type"] == "image_embeds": return [image_input["data"]] patch_embeddings = self._process_image_pixels(image_input) image_sizes = image_input.get("image_sizes") if image_sizes is None: batch_size = len(image_input["data"]) vision_config = self.config.vision_config default_height = default_width = vision_config.image_size image_sizes = torch.as_tensor([[default_height, default_width] for _ in range(batch_size)]) return [ self._merge_image_patch_embeddings(image_sizes[i], patch_features_batch, strategy="spatial_unpad") for i, patch_features_batch in enumerate(patch_embeddings) ] def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, ) -> SamplerOutput: """Run forward pass for LlaVA-NeXT. One key thing to understand is the `input_ids` already accounts for the positions of the to-be-inserted image embeddings. Concretely, consider a text prompt: `"A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: \\nWhat is shown in this image? ASSISTANT:"`. Tokenizer outputs: `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255, 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568, 6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901, 29871, 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799, 9047, 13566, 29901]`. To reserve space in KV cache, we have to insert placeholder tokens before they are inputted to the model, so the input processor prepends additional image tokens (denoted as `32000`), resulting in: `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255, 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568, 6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901, 29871, 32000, ..., 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799, 9047, 13566, 29901]`. Unlike in LLaVA-1.5, the number of image tokens inputted to the language model depends on the original size of the input image. Including the original image token in the input, the required number of image tokens is given by :func:`get_llava_next_image_feature_size`. This way, the `positions` and `attn_metadata` are consistent with the `input_ids`. Args: input_ids: Flattened (concatenated) input_ids corresponding to a batch. pixel_values: The pixels in each grid patch for each input image. image_sizes: The original `(height, width)` for each input image. See also: :class:`LlavaNextImageInputs` """ image_input = self._parse_and_validate_image_input(**kwargs) if image_input is not None: vision_embeddings = self._process_image_input(image_input) inputs_embeds = self.language_model.model.get_input_embeddings( input_ids) inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, vision_embeddings, self.config.image_token_index) input_ids = None else: inputs_embeds = None hidden_states = self.language_model.model(input_ids, positions, kv_caches, attn_metadata, None, inputs_embeds=inputs_embeds) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states, sampling_metadata) def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: return self.language_model.sample(logits, sampling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # prepare weight iterators for components vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee( weights, 4) # load vision encoder vit_weights = filter_weights(vit_weights, "vision_tower") self.vision_tower.load_weights(vit_weights) # load mlp projector mlp_weights = filter_weights(mlp_weights, "multi_modal_projector") mlp_params_dict = dict(self.multi_modal_projector.named_parameters()) for name, loaded_weight in mlp_weights: param = mlp_params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) # load newline newline_weights = filter_weights(newline_weights, "image_newline") for name, loaded_weight in newline_weights: assert name == "" param = self.image_newline weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) # load llm backbone llm_weights = filter_weights(llm_weights, "language_model") self.language_model.load_weights(llm_weights)