from typing import List, Optional import torch from torch import nn # TODO: We should port CLIPVisionModel's code over to not depend on # transformers' impl. from transformers import CLIPVisionModel, LlavaConfig from aphrodite.attention import AttentionMetadata from aphrodite.common.config import VisionLanguageConfig from aphrodite.modeling.layers.activation import get_act_fn from aphrodite.modeling.layers.linear import LinearMethodBase from aphrodite.modeling.layers.logits_processor import LogitsProcessor from aphrodite.modeling.layers.sampler import Sampler from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead from aphrodite.modeling.models.llama import LlamaModel from aphrodite.modeling.sampling_metadata import SamplingMetadata from aphrodite.modeling.hf_downloader import (default_weight_loader, hf_model_weights_iterator) from aphrodite.common.sequence import SamplerOutput _KEYS_TO_MODIFY_MAPPING = { "language_model.lm_head": "lm_head", "language_model.model": "language_model", } # TODO: Run benchmark and decide if TP. class LlavaMultiModalProjector(nn.Module): def __init__(self, vision_hidden_size: int, text_hidden_size: int, projector_hidden_act: str): super().__init__() self.linear_1 = nn.Linear(vision_hidden_size, text_hidden_size, bias=True) self.act = get_act_fn(projector_hidden_act) self.linear_2 = nn.Linear(text_hidden_size, text_hidden_size, bias=True) def forward(self, image_features): hidden_states = self.linear_1(image_features) hidden_states = self.act(hidden_states) hidden_states = self.linear_2(hidden_states) return hidden_states def _merge_vision_embeddings(input_ids: torch.Tensor, inputs_embeds: torch.Tensor, vision_embeddings: torch.Tensor, image_token_id: int): """In place merges in vision_embeddings with inputs_embeds.""" mask = (input_ids == image_token_id) inputs_embeds[mask] = vision_embeddings.view(-1, vision_embeddings.shape[-1]) class LlavaForConditionalGeneration(nn.Module): def __init__(self, config: "LlavaConfig", vision_language_config: VisionLanguageConfig, linear_method: Optional["LinearMethodBase"] = None) -> None: super().__init__() self.config = config self.vision_language_config = vision_language_config assert self.vision_language_config, ( "Provide `image_input_type` and other vision " "related configurations through LLM entrypoint " "or engine arguments.") if self.vision_language_config.image_input_type == ( VisionLanguageConfig.ImageInputType.PIXEL_VALUES): self.vision_tower = CLIPVisionModel(config.vision_config) else: self.vision_tower = None self.multi_modal_projector = LlavaMultiModalProjector( vision_hidden_size=config.vision_config.hidden_size, text_hidden_size=config.text_config.hidden_size, projector_hidden_act=config.projector_hidden_act) self.linear_method = linear_method self.language_model = LlamaModel(config.text_config, linear_method) self.unpadded_vocab_size = config.text_config.vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.text_config.hidden_size, org_num_embeddings=self.language_model.org_vocab_size) logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) self.sampler = Sampler() def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, image_input: Optional[torch.Tensor] = None ) -> SamplerOutput: # noqa: E501 """Run forward pass for Llava 1.5. One key thing to understand is the `input_ids` already accounts for the positions of the to-be-inserted image embeddings. Concretely, consider a text prompt: "\nUSER: What's the content of the image?\nASSISTANT:". Tokenizer outputs: [1, 32000, 29871, 13, 11889, 29901, 1724, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]. The to-be-inserted image has a size of 576 (24 * 24) along the context length dimension. `input_ids` is thus [1, 32000, ..., 32000, 29871, 13, 11889, 29901, 1724, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]. There will be 576 `32000` in the `input_ids`. (32000 is the token id for ``.) This way, the `positions` and `attn_metadata` are consistent with the `input_ids`. The model takes two types of image inputs: PIXEL_VALUES and IMAGE_FEATURES. The following shows how each maps to huggingface implementation. PIXEL_VALUES: - https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L353 IMAGE_FEATURES: - https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L430 before going through the multi modal projector. Args: input_ids: Flattened (concatenated) input_ids corresponding to a batch. image_input: A batch of image inputs. For PIXEL_VALUES, expecting [1, 3, 336, 336]. For IMAGE_FEATURES, expecting [1, 576, 1024]. """ if image_input is not None: if list(image_input.shape[1:]) != list( self.vision_language_config.image_input_shape[1:]): raise ValueError( f"The expected image tensor shape is batch dimension " f"plus " f"{self.vision_language_config.image_input_shape[1:]}." f" You supplied {image_input.shape}. " f"If you are using Aphrodite's endpoint, make sure your " f"supplied image input is consistent with " f"image_input_shape in engine args.") if self.vision_tower is not None: # TODO: Maybe port minimal CLIPVisionModel over. image_outputs = self.vision_tower(image_input, output_hidden_states=True) image_features = image_outputs.hidden_states[ self.config.vision_feature_layer] # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa if self.config.vision_feature_select_strategy == "default": image_features = image_features[:, 1:] elif self.config.vision_feature_select_strategy == "full": image_features = image_features else: raise ValueError( f"Unexpected select feature strategy: " f"{self.config.vision_feature_select_strategy}") else: image_features = image_input vision_embeddings = self.multi_modal_projector(image_features) inputs_embeds = self.language_model.get_input_embeddings(input_ids) _merge_vision_embeddings( input_ids, inputs_embeds, vision_embeddings, self.vision_language_config.image_token_id) input_ids = None else: inputs_embeds = None hidden_states = self.language_model(input_ids, positions, kv_caches, attn_metadata, inputs_embeds=inputs_embeds) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): # only doing this for language model part for now. stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): if "rotary_emb.inv_freq" in name: continue for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): if key_to_modify in name: name = name.replace(key_to_modify, new_key) use_default_weight_loading = False if "vision" in name: if self.vision_tower is not None: # We only do sharding for language model and # not vision model for now. use_default_weight_loading = True else: for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue param = params_dict[name.replace(weight_name, param_name)] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: use_default_weight_loading = True if use_default_weight_loading: param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight)