123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412 |
- import itertools
- from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
- TypedDict, Union)
- import torch
- import torch.nn as nn
- from PIL import Image
- from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig
- from aphrodite.attention import AttentionMetadata
- from aphrodite.common.config import CacheConfig, MultiModalConfig
- from aphrodite.common.sequence import IntermediateTensors
- from aphrodite.common.utils import is_list_of
- from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
- from aphrodite.modeling.layers.activation import get_act_fn
- from aphrodite.modeling.layers.sampler import SamplerOutput
- from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
- from aphrodite.modeling.sampling_metadata import SamplingMetadata
- from aphrodite.multimodal import MULTIMODAL_REGISTRY
- from aphrodite.quantization.base_config import QuantizationConfig
- from .clip import (CLIPVisionModel, dummy_image_for_clip,
- dummy_seq_data_for_clip, get_max_clip_image_tokens,
- input_processor_for_clip)
- from .interfaces import SupportsMultiModal
- from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
- dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
- input_processor_for_siglip)
- from .utils import (filter_weights, flatten_bn,
- init_aphrodite_registered_model,
- merge_multimodal_embeddings)
- class LlavaImagePixelInputs(TypedDict):
- type: Literal["pixel_values"]
- data: torch.Tensor
- """Shape: `(batch_size * num_images, num_channels, height, width)`"""
- class LlavaImageEmbeddingInputs(TypedDict):
- type: Literal["image_embeds"]
- data: torch.Tensor
- """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
- `hidden_size` must match the hidden size of language model backbone.
- """
- LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]
- # TODO: Run benchmark and decide if TP.
- class LlavaMultiModalProjector(nn.Module):
- def __init__(self, vision_hidden_size: int, text_hidden_size: int,
- projector_hidden_act: str):
- super().__init__()
- self.linear_1 = nn.Linear(vision_hidden_size,
- text_hidden_size,
- bias=True)
- self.act = get_act_fn(projector_hidden_act)
- self.linear_2 = nn.Linear(text_hidden_size,
- text_hidden_size,
- bias=True)
- def forward(self, image_features: torch.Tensor) -> torch.Tensor:
- hidden_states = self.linear_1(image_features)
- hidden_states = self.act(hidden_states)
- hidden_states = self.linear_2(hidden_states)
- return hidden_states
- def get_max_llava_image_tokens(ctx: InputContext):
- hf_config = ctx.get_hf_config(LlavaConfig)
- vision_config = hf_config.vision_config
- if isinstance(vision_config, CLIPVisionConfig):
- num_image_tokens = get_max_clip_image_tokens(vision_config)
- elif isinstance(vision_config, SiglipVisionConfig):
- num_image_tokens = get_max_siglip_image_tokens(vision_config)
- else:
- msg = f"Unsupported vision config: {type(vision_config)}"
- raise NotImplementedError(msg)
- strategy = hf_config.vision_feature_select_strategy
- if strategy == "default":
- return num_image_tokens - 1
- elif strategy == "full":
- return num_image_tokens
- else:
- raise ValueError(f"Unexpected select feature strategy: {strategy}")
- def dummy_data_for_llava(ctx: InputContext, seq_len: int,
- mm_counts: Mapping[str, int]):
- hf_config = ctx.get_hf_config(LlavaConfig)
- vision_config = hf_config.vision_config
- num_images = mm_counts["image"]
- image_feature_size = get_max_llava_image_tokens(ctx)
- if isinstance(vision_config, CLIPVisionConfig):
- seq_data = dummy_seq_data_for_clip(
- vision_config,
- seq_len,
- num_images,
- image_token_id=hf_config.image_token_index,
- image_feature_size_override=image_feature_size,
- )
- mm_data = dummy_image_for_clip(vision_config, num_images)
- return seq_data, mm_data
- elif isinstance(vision_config, SiglipVisionConfig):
- seq_data = dummy_seq_data_for_siglip(
- vision_config,
- seq_len,
- num_images,
- image_token_id=hf_config.image_token_index,
- image_feature_size_override=image_feature_size,
- )
- mm_data = dummy_image_for_siglip(vision_config, num_images)
- return seq_data, mm_data
- msg = f"Unsupported vision config: {type(vision_config)}"
- raise NotImplementedError(msg)
- def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
- multi_modal_data = llm_inputs.get("multi_modal_data")
- if multi_modal_data is None or "image" not in multi_modal_data:
- return llm_inputs
- model_config = ctx.model_config
- hf_config = ctx.get_hf_config(LlavaConfig)
- vision_config = hf_config.vision_config
- image_data = multi_modal_data["image"]
- if isinstance(image_data, Image.Image):
- image_feature_size = get_max_llava_image_tokens(ctx)
- elif is_list_of(image_data, Image.Image):
- image_feature_size = [get_max_llava_image_tokens(ctx)
- ] * len(image_data)
- elif isinstance(image_data, torch.Tensor):
- num_images, image_feature_size, hidden_size = image_data.shape
- elif is_list_of(image_data, torch.Tensor):
- image_feature_size = [item.shape[1] for item in image_data]
- else:
- raise TypeError(f"Invalid image type: {type(image_data)}")
- if isinstance(vision_config, CLIPVisionConfig):
- return input_processor_for_clip(
- model_config,
- vision_config,
- llm_inputs,
- image_token_id=hf_config.image_token_index,
- image_feature_size_override=image_feature_size,
- )
- elif isinstance(vision_config, SiglipVisionConfig):
- return input_processor_for_siglip(
- model_config,
- vision_config,
- llm_inputs,
- image_token_id=hf_config.image_token_index,
- image_feature_size_override=image_feature_size,
- )
- msg = f"Unsupported vision config: {type(vision_config)}"
- raise NotImplementedError(msg)
- def _init_vision_tower(hf_config: LlavaConfig):
- vision_config = hf_config.vision_config
- # Initialize the vision tower only up to the required feature layer
- vision_feature_layer = hf_config.vision_feature_layer
- if vision_feature_layer < 0:
- num_hidden_layers = hf_config.vision_config.num_hidden_layers \
- + vision_feature_layer + 1
- else:
- num_hidden_layers = vision_feature_layer + 1
- if isinstance(vision_config, CLIPVisionConfig):
- return CLIPVisionModel(
- vision_config,
- num_hidden_layers_override=num_hidden_layers,
- )
- elif isinstance(vision_config, SiglipVisionConfig):
- return SiglipVisionModel(
- vision_config,
- num_hidden_layers_override=num_hidden_layers,
- )
- msg = f"Unsupported vision config: {type(vision_config)}"
- raise NotImplementedError(msg)
- @MULTIMODAL_REGISTRY.register_image_input_mapper()
- @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
- @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
- @INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
- class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
- def __init__(self,
- config: LlavaConfig,
- multimodal_config: MultiModalConfig,
- cache_config: Optional[CacheConfig] = None,
- quant_config: Optional[QuantizationConfig] = None) -> None:
- super().__init__()
- self.config = config
- self.multimodal_config = multimodal_config
- # TODO: Optionally initializes this for supporting embeddings.
- self.vision_tower = _init_vision_tower(config)
- self.multi_modal_projector = LlavaMultiModalProjector(
- vision_hidden_size=config.vision_config.hidden_size,
- text_hidden_size=config.text_config.hidden_size,
- projector_hidden_act=config.projector_hidden_act)
- self.language_model = init_aphrodite_registered_model(
- config.text_config, cache_config, quant_config)
- def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
- h = w = self.config.vision_config.image_size
- expected_dims = (3, h, w)
- actual_dims = tuple(data.shape[1:])
- if actual_dims != expected_dims:
- expected_expr = ("batch_size", *map(str, expected_dims))
- raise ValueError(
- f"The expected shape of pixel values is {expected_expr}. "
- f"You supplied {tuple(data.shape)}.")
- return data
- def _parse_and_validate_image_input(
- self, **kwargs: object) -> Optional[LlavaImageInputs]:
- pixel_values = kwargs.pop("pixel_values", None)
- image_embeds = kwargs.pop("image_embeds", None)
- if pixel_values is None and image_embeds is None:
- return None
- if pixel_values is not None:
- if not isinstance(pixel_values, (torch.Tensor, list)):
- raise ValueError("Incorrect type of pixel values. "
- f"Got type: {type(pixel_values)}")
- return LlavaImagePixelInputs(
- type="pixel_values",
- data=self._validate_pixel_values(
- flatten_bn(pixel_values, concat=True)),
- )
- if image_embeds is not None:
- if not isinstance(image_embeds, (torch.Tensor, list)):
- raise ValueError("Incorrect type of image embeddings. "
- f"Got type: {type(image_embeds)}")
- return LlavaImageEmbeddingInputs(
- type="image_embeds",
- data=flatten_bn(image_embeds, concat=True),
- )
- raise AssertionError("This line should be unreachable.")
- def _select_image_features(self, image_features: torch.Tensor, *,
- strategy: str) -> torch.Tensor:
- # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
- if strategy == "default":
- return image_features[:, 1:]
- elif strategy == "full":
- return image_features
- raise ValueError(f"Unexpected select feature strategy: {strategy}")
- def _image_pixels_to_features(
- self,
- vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
- pixel_values: torch.Tensor,
- ) -> torch.Tensor:
- # NOTE: we skip the step to select the vision feature layer since
- # this is already done inside the vision tower
- image_features = vision_tower(pixel_values)
- return self._select_image_features(
- image_features,
- strategy=self.config.vision_feature_select_strategy,
- )
- def _process_image_pixels(self,
- inputs: LlavaImagePixelInputs) -> torch.Tensor:
- assert self.vision_tower is not None
- pixel_values = inputs["data"]
- return self._image_pixels_to_features(self.vision_tower, pixel_values)
- def _process_image_input(self,
- image_input: LlavaImageInputs) -> torch.Tensor:
- if image_input["type"] == "image_embeds":
- return image_input["data"]
- assert self.vision_tower is not None
- image_features = self._process_image_pixels(image_input)
- return self.multi_modal_projector(image_features)
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- kv_caches: List[torch.Tensor],
- attn_metadata: AttentionMetadata,
- intermediate_tensors: Optional[IntermediateTensors] = None,
- **kwargs: object,
- ) -> SamplerOutput:
- """Run forward pass for LLaVA-1.5.
- One key thing to understand is the `input_ids` already accounts for the
- positions of the to-be-inserted image embeddings.
- Concretely, consider a text prompt:
- `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.
- Tokenizer outputs:
- `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
- 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.
- To reserve space in KV cache, we have to insert placeholder tokens
- before they are inputted to the model, so the input processor prepends
- additional image tokens (denoted as `32000`), resulting in:
- `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
- 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
- 29901]`.
- We insert 575 tokens so that including the original image token in the
- input, there are a total of 576 (24 * 24) image tokens, which
- corresponds to the number of image tokens inputted to the language
- model, i.e. the number of image tokens outputted by the visual encoder.
- This way, the `positions` and `attn_metadata` are consistent
- with the `input_ids`.
- Args:
- input_ids: Flattened (concatenated) input_ids corresponding to a
- batch.
- pixel_values: The pixels in each input image.
-
- See also:
- :class:`LlavaImageInputs`
- """
- image_input = self._parse_and_validate_image_input(**kwargs)
- if image_input is not None:
- vision_embeddings = self._process_image_input(image_input)
- inputs_embeds = self.language_model.model.get_input_embeddings(
- input_ids)
- inputs_embeds = merge_multimodal_embeddings(
- input_ids, inputs_embeds, vision_embeddings,
- self.config.image_token_index)
- input_ids = None
- else:
- inputs_embeds = None
- hidden_states = self.language_model.model(input_ids,
- positions,
- kv_caches,
- attn_metadata,
- None,
- inputs_embeds=inputs_embeds)
- return hidden_states
- def compute_logits(
- self,
- hidden_states: torch.Tensor,
- sampling_metadata: SamplingMetadata,
- ) -> Optional[torch.Tensor]:
- return self.language_model.compute_logits(hidden_states,
- sampling_metadata)
- def sample(
- self,
- logits: torch.Tensor,
- sampling_metadata: SamplingMetadata,
- ) -> Optional[SamplerOutput]:
- return self.language_model.sample(logits, sampling_metadata)
- def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
- # prepare weight iterators for components
- vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
- # load vision encoder
- vit_weights = filter_weights(vit_weights, "vision_tower")
- self.vision_tower.load_weights(vit_weights)
- # load mlp projector
- mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
- mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
- for name, loaded_weight in mlp_weights:
- param = mlp_params_dict[name]
- weight_loader = getattr(param, "weight_loader",
- default_weight_loader)
- weight_loader(param, loaded_weight)
- # load llm backbone
- llm_weights = filter_weights(llm_weights, "language_model")
- self.language_model.load_weights(llm_weights)
|