from typing import Iterable, List, Literal, Optional, Tuple, TypedDict import torch import torch.nn as nn from PIL import Image from transformers import CLIPVisionConfig, LlavaNextConfig from transformers.models.llava_next.modeling_llava_next import ( get_anyres_image_grid_shape, unpad_image) from typing_extensions import NotRequired from aphrodite.attention import AttentionMetadata from aphrodite.common.config import CacheConfig, VisionLanguageConfig from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs from aphrodite.modeling.layers.logits_processor import LogitsProcessor from aphrodite.quantization.base_config import (QuantizationConfig) from aphrodite.modeling.layers.sampler import Sampler from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead from aphrodite.modeling.model_loader.weight_utils import default_weight_loader from aphrodite.modeling.models.clip import CLIPVisionModel from aphrodite.modeling.models.llama import LlamaModel from aphrodite.modeling.sampling_metadata import SamplingMetadata from aphrodite.multimodal import MULTIMODAL_REGISTRY, BatchedTensors from aphrodite.common.sequence import IntermediateTensors, SamplerOutput from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, get_clip_patch_grid_length, input_processor_for_clip) from .interfaces import SupportsVision from .llava import LlavaMultiModalProjector from .utils import merge_vision_embeddings _KEYS_TO_MODIFY_MAPPING = { "language_model.lm_head": "lm_head", "language_model.model": "language_model", } class LlavaNextImagePixelInputs(TypedDict): type: Literal["pixel_values"] data: BatchedTensors """ Shape: `(batch_size, 1 + num_patches, num_channels, height, width)` Note that `num_patches` may be different for each batch. """ image_sizes: NotRequired[torch.Tensor] """ Shape: `(batch_size, 2)` This should be in `(height, width)` format. """ LlavaNextImageInputs = LlavaNextImagePixelInputs # Taken from: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L91 # NOTE: new_height and new_width are further incremented to properly invert the # floordiv operation: https://github.com/huggingface/transformers/blob/v4.42.2/src/transformers/models/llava_next/modeling_llava_next.py#L133 def _get_llava_next_num_unpadded_features( height: int, width: int, npatches: int, num_patch_height: int, num_patch_width: int, ) -> Tuple[int, int]: current_height = npatches * num_patch_height current_width = npatches * num_patch_width aspect_ratio: float = width / height current_aspect_ratio: float = current_width / current_height if aspect_ratio > current_aspect_ratio: new_height = (height * current_width) // width if new_height % 2 == 1: new_height += 1 current_height = new_height else: new_width = (width * current_height) // height if new_width % 2 == 1: new_width += 1 current_width = new_width unpadded_features = current_height * current_width newline_features = current_height return (unpadded_features, newline_features) # Based on: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L111 def get_llava_next_image_feature_size( hf_config: LlavaNextConfig, *, input_height: int, input_width: int, ) -> int: vision_config = hf_config.vision_config if isinstance(vision_config, CLIPVisionConfig): num_patches = get_clip_patch_grid_length( image_size=vision_config.image_size, patch_size=vision_config.patch_size, ) base_feature_size = num_patches * num_patches # Note: We follow the "wrong" width/height order # [ref: PR huggingface/transformers#31588] num_patch_width, num_patch_height = get_anyres_image_grid_shape( image_size=(input_height, input_width), grid_pinpoints=hf_config.image_grid_pinpoints, patch_size=vision_config.image_size, ) ( unpadded_feature_size, newline_feature_size, ) = _get_llava_next_num_unpadded_features(input_height, input_width, num_patches, num_patch_height, num_patch_width) return unpadded_feature_size + newline_feature_size + base_feature_size msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) def dummy_data_for_llava_next(ctx: InputContext, seq_len: int): hf_config = ctx.get_hf_config(LlavaNextConfig) vision_config = hf_config.vision_config # Result in the max possible feature size (2x2 grid of 336x336px tiles) dummy_height = dummy_width = 448 image_feature_size = get_llava_next_image_feature_size( hf_config, input_height=dummy_height, input_width=dummy_width, ) if isinstance(vision_config, CLIPVisionConfig): seq_data = dummy_seq_data_for_clip( vision_config, seq_len, image_token_id=hf_config.image_token_index, image_feature_size_override=image_feature_size, ) mm_data = dummy_image_for_clip( vision_config, image_width_override=dummy_width, image_height_override=dummy_height, ) return seq_data, mm_data msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs): multi_modal_data = llm_inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: return llm_inputs model_config = ctx.model_config hf_config = ctx.get_hf_config(LlavaNextConfig) vision_config = hf_config.vision_config image_data = multi_modal_data["image"] if isinstance(image_data, Image.Image): width, height = image_data.size image_feature_size = get_llava_next_image_feature_size( hf_config, input_height=height, input_width=width, ) elif isinstance(image_data, torch.Tensor): raise NotImplementedError("Embeddings input is not supported yet") else: raise TypeError(f"Invalid image type: {type(image_data)}") vision_config = hf_config.vision_config if isinstance(vision_config, CLIPVisionConfig): return input_processor_for_clip( model_config, vision_config, llm_inputs, image_token_id=hf_config.image_token_index, image_feature_size_override=image_feature_size, ) msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) @MULTIMODAL_REGISTRY.register_image_input_mapper() @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next) @INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next) class LlavaNextForConditionalGeneration(nn.Module, SupportsVision): def __init__(self, config: LlavaNextConfig, vlm_config: VisionLanguageConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None) -> None: super().__init__() self.config = config self.vlm_config = vlm_config # TODO: Optionally initializes this for supporting embeddings. self.vision_tower = CLIPVisionModel(config=config.vision_config) self.multi_modal_projector = LlavaMultiModalProjector( vision_hidden_size=config.vision_config.hidden_size, text_hidden_size=config.text_config.hidden_size, projector_hidden_act=config.projector_hidden_act) self.quant_config = quant_config self.language_model = LlamaModel(config.text_config, cache_config, quant_config) self.unpadded_vocab_size = config.text_config.vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.text_config.hidden_size, org_num_embeddings=self.language_model.org_vocab_size, quant_config=quant_config) logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) self.sampler = Sampler() self.image_newline = nn.Parameter( torch.empty(config.text_config.hidden_size)) def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: if list(data.shape[1:]) != [2]: raise ValueError( f"The expected image sizes shape is batch dimension plus " f"{[2]}. You supplied {data.shape}.") return data def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[LlavaNextImagePixelInputs]: pixel_values = kwargs.pop("pixel_values", None) image_sizes = kwargs.pop("image_sizes", None) if pixel_values is None: return None if not isinstance(pixel_values, (torch.Tensor, list)): raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") if not isinstance(image_sizes, torch.Tensor): raise ValueError("Incorrect type of image sizes. " f"Got type: {type(image_sizes)}") return LlavaNextImagePixelInputs( type="pixel_values", data=pixel_values, image_sizes=self._validate_image_sizes(image_sizes), ) def _select_image_features(self, image_features: torch.Tensor, *, strategy: str) -> torch.Tensor: # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa if strategy == "default": return image_features[:, 1:] elif strategy == "full": return image_features raise ValueError(f"Unexpected select feature strategy: {strategy}") def _image_pixels_to_features(self, vision_tower: CLIPVisionModel, pixel_values: torch.Tensor) -> torch.Tensor: # NOTE: we skip the step to select the vision feature layer since # this is already done inside the vision tower image_features = vision_tower(pixel_values, self.config.vision_feature_layer) return self._select_image_features( image_features, strategy=self.config.vision_feature_select_strategy, ) # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py def _merge_image_patch_embeddings(self, image_size: torch.Tensor, patch_embeddings: torch.Tensor, *, strategy: str) -> torch.Tensor: if strategy == "flat": return patch_embeddings.flatten(0, 1) if strategy.startswith("spatial"): height = width = self.config.vision_config.image_size \ // self.config.vision_config.patch_size base_patch_embeds = patch_embeddings[0] if height * width != base_patch_embeds.shape[0]: raise ValueError( "The number of patches is not consistent with the " "image size.") if patch_embeddings.shape[0] > 1: other_patch_embeds = patch_embeddings[1:] # image_aspect_ratio == "anyres" # Note: We follow the "wrong" width/height order # [ref: PR huggingface/transformers#31588] num_patch_width, num_patch_height = get_anyres_image_grid_shape( image_size, self.config.image_grid_pinpoints, self.config.vision_config.image_size, ) other_patch_embeds = other_patch_embeds \ .view(num_patch_height, num_patch_width, height, width, -1) if "unpad" in strategy: other_patch_embeds = other_patch_embeds \ .permute(4, 0, 2, 1, 3).contiguous() \ .flatten(1, 2).flatten(2, 3) other_patch_embeds = unpad_image(other_patch_embeds, image_size) other_patch_embeds = torch.cat(( other_patch_embeds, self.image_newline[:, None, None] \ .expand(*other_patch_embeds.shape[:-1], 1) \ .to(other_patch_embeds.device), ), dim=-1) other_patch_embeds = other_patch_embeds \ .flatten(1, 2).transpose(0, 1) else: other_patch_embeds = other_patch_embeds \ .permute(0, 2, 1, 3, 4).contiguous() \ .flatten(0, 3) merged_patch_embeddings = torch.cat( (base_patch_embeds, other_patch_embeds), dim=0) else: if "unpad" in strategy: merged_patch_embeddings = torch.cat( (base_patch_embeds, self.image_newline[None] \ .to(base_patch_embeds.device) ), dim=0) else: merged_patch_embeddings = base_patch_embeds return merged_patch_embeddings raise ValueError(f"Unexpected patch merge strategy: {strategy}") def _process_image_pixels( self, inputs: LlavaNextImagePixelInputs, ) -> BatchedTensors: assert self.vision_tower is not None pixel_values = inputs["data"] if isinstance(pixel_values, torch.Tensor): b, num_patches, c, h, w = pixel_values.shape stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w) stacked_image_features = self._image_pixels_to_features( self.vision_tower, stacked_pixel_values) stacked_patch_embeddings = self.multi_modal_projector( stacked_image_features) return stacked_patch_embeddings.view( b, num_patches, *stacked_patch_embeddings.shape[1:]) num_patches_per_batch = [v.shape[0] for v in pixel_values] stacked_pixel_values = torch.cat(pixel_values) stacked_image_features = self._image_pixels_to_features( self.vision_tower, stacked_pixel_values) return [ self.multi_modal_projector(image_features) for image_features in torch.split(stacked_image_features, num_patches_per_batch) ] def _process_image_input( self, image_input: LlavaNextImageInputs) -> BatchedTensors: patch_embeddings = self._process_image_pixels(image_input) image_sizes = image_input.get("image_sizes") if image_sizes is None: batch_size = len(image_input["data"]) vision_config = self.config.vision_config default_height = default_width = vision_config.image_size image_sizes = torch.as_tensor([[default_height, default_width] for _ in range(batch_size)]) return [ self._merge_image_patch_embeddings(image_sizes[i], patch_features_batch, strategy="spatial_unpad") for i, patch_features_batch in enumerate(patch_embeddings) ] def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, ) -> SamplerOutput: """Run forward pass for LlaVA-NeXT. One key thing to understand is the `input_ids` already accounts for the positions of the to-be-inserted image embeddings. Concretely, consider a text prompt: "\nUSER: What's the content of the image?\nASSISTANT:". Tokenizer outputs: [1, 32000, 29871, 13, 11889, 29901, 1724, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]. The to-be-inserted image has a size of 576 (24 * 24) along the context length dimension. `input_ids` is thus [1, 32000, ..., 32000, 29871, 13, 11889, 29901, 1724, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]. There will be 576 `32000` in the `input_ids`. (32000 is the token id for ``.) This way, the `positions` and `attn_metadata` are consistent with the `input_ids`. Args: input_ids: Flattened (concatenated) input_ids corresponding to a batch. pixel_values: The pixels in each grid patch for each input image. Expects a batch with shape `[1, num_patches, 3, h, w]`. image_sizes: The original `(height, width)` for each input image. Expects a batch with shape `[1, 2]`. See also: Each input maps to huggingface implementation, as follows: - `pixel_values`: https://github.com/huggingface/transformers/blob/v4.41.1/src/transformers/models/llava_next/modeling_llava_next.py#L690 - `image_sizes`: https://github.com/huggingface/transformers/blob/v4.41.1/src/transformers/models/llava_next/modeling_llava_next.py#L691 """ image_input = self._parse_and_validate_image_input(**kwargs) if image_input is not None: vision_embeddings = self._process_image_input(image_input) inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = merge_vision_embeddings( input_ids, inputs_embeds, vision_embeddings, self.vlm_config.image_token_id) input_ids = None else: inputs_embeds = None hidden_states = self.language_model(input_ids, positions, kv_caches, attn_metadata, None, inputs_embeds=inputs_embeds) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # only doing this for language model part for now. stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue # post_layernorm is not needed in CLIPVisionModel if "vision_model.post_layernorm" in name: continue for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): if key_to_modify in name: name = name.replace(key_to_modify, new_key) use_default_weight_loading = False if "vision" in name: if self.vision_tower is not None: # We only do sharding for language model and # not vision model for now. use_default_weight_loading = True else: for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue param = params_dict[name.replace(weight_name, param_name)] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: use_default_weight_loading = True if use_default_weight_loading: param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight)