|
@@ -1,8 +1,7 @@
|
|
|
-from typing import Dict, Iterable, List, Literal, Optional, Tuple, TypedDict
|
|
|
+from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
-from loguru import logger
|
|
|
from PIL import Image
|
|
|
from transformers import CLIPVisionConfig, LlavaNextConfig
|
|
|
from transformers.models.llava_next.modeling_llava_next import (
|
|
@@ -11,22 +10,23 @@ from typing_extensions import NotRequired
|
|
|
|
|
|
from aphrodite.attention import AttentionMetadata
|
|
|
from aphrodite.common.config import CacheConfig, VisionLanguageConfig
|
|
|
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
|
|
|
-from aphrodite.inputs import INPUT_REGISTRY, InputContext
|
|
|
+from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
|
|
from aphrodite.modeling.layers.logits_processor import LogitsProcessor
|
|
|
+from aphrodite.quantization.base_config import (QuantizationConfig)
|
|
|
from aphrodite.modeling.layers.sampler import Sampler
|
|
|
from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
|
|
|
from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
|
|
|
from aphrodite.modeling.models.clip import CLIPVisionModel
|
|
|
from aphrodite.modeling.models.llama import LlamaModel
|
|
|
from aphrodite.modeling.sampling_metadata import SamplingMetadata
|
|
|
-from aphrodite.multimodal import MULTIMODAL_REGISTRY
|
|
|
-from aphrodite.quantization.base_config import QuantizationConfig
|
|
|
+from aphrodite.multimodal import MULTIMODAL_REGISTRY, BatchedTensors
|
|
|
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
|
|
|
|
|
|
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
|
|
|
- get_clip_patch_grid_length)
|
|
|
+ get_clip_patch_grid_length, input_processor_for_clip)
|
|
|
from .interfaces import SupportsVision
|
|
|
-from .llava import LlavaMultiModalProjector, merge_vision_embeddings
|
|
|
+from .llava import LlavaMultiModalProjector
|
|
|
+from .utils import merge_vision_embeddings
|
|
|
|
|
|
_KEYS_TO_MODIFY_MAPPING = {
|
|
|
"language_model.lm_head": "lm_head",
|
|
@@ -36,16 +36,27 @@ _KEYS_TO_MODIFY_MAPPING = {
|
|
|
|
|
|
class LlavaNextImagePixelInputs(TypedDict):
|
|
|
type: Literal["pixel_values"]
|
|
|
- data: torch.Tensor
|
|
|
- """Shape: (batch_size, 1 + num_patches, num_channels, height, width)"""
|
|
|
+ data: BatchedTensors
|
|
|
+ """
|
|
|
+ Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
|
|
|
+
|
|
|
+ Note that `num_patches` may be different for each batch.
|
|
|
+ """
|
|
|
|
|
|
image_sizes: NotRequired[torch.Tensor]
|
|
|
- """Shape: (batch_size, 2)"""
|
|
|
+ """
|
|
|
+ Shape: `(batch_size, 2)`
|
|
|
+
|
|
|
+ This should be in `(height, width)` format.
|
|
|
+ """
|
|
|
|
|
|
|
|
|
LlavaNextImageInputs = LlavaNextImagePixelInputs
|
|
|
|
|
|
|
|
|
+# Taken from: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L91
|
|
|
+# NOTE: new_height and new_width are further incremented to properly invert the
|
|
|
+# floordiv operation: https://github.com/huggingface/transformers/blob/v4.42.2/src/transformers/models/llava_next/modeling_llava_next.py#L133
|
|
|
def _get_llava_next_num_unpadded_features(
|
|
|
height: int,
|
|
|
width: int,
|
|
@@ -53,7 +64,6 @@ def _get_llava_next_num_unpadded_features(
|
|
|
num_patch_height: int,
|
|
|
num_patch_width: int,
|
|
|
) -> Tuple[int, int]:
|
|
|
- # Taken from: https://github.com/huggingface/text-generation-inference/blob/799a193b109662743bed1b18a09af1fdcd508c8b/server/text_generation_server/models/vlm_causal_lm.py#L111
|
|
|
current_height = npatches * num_patch_height
|
|
|
current_width = npatches * num_patch_width
|
|
|
|
|
@@ -61,9 +71,13 @@ def _get_llava_next_num_unpadded_features(
|
|
|
current_aspect_ratio: float = current_width / current_height
|
|
|
if aspect_ratio > current_aspect_ratio:
|
|
|
new_height = (height * current_width) // width
|
|
|
+ if new_height % 2 == 1:
|
|
|
+ new_height += 1
|
|
|
current_height = new_height
|
|
|
else:
|
|
|
new_width = (width * current_height) // height
|
|
|
+ if new_width % 2 == 1:
|
|
|
+ new_width += 1
|
|
|
current_width = new_width
|
|
|
|
|
|
unpadded_features = current_height * current_width
|
|
@@ -71,7 +85,8 @@ def _get_llava_next_num_unpadded_features(
|
|
|
return (unpadded_features, newline_features)
|
|
|
|
|
|
|
|
|
-def _get_llava_next_image_feature_size(
|
|
|
+# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L111
|
|
|
+def get_llava_next_image_feature_size(
|
|
|
hf_config: LlavaNextConfig,
|
|
|
*,
|
|
|
input_height: int,
|
|
@@ -86,7 +101,9 @@ def _get_llava_next_image_feature_size(
|
|
|
)
|
|
|
base_feature_size = num_patches * num_patches
|
|
|
|
|
|
- num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
|
|
+ # Note: We follow the "wrong" width/height order
|
|
|
+ # [ref: PR huggingface/transformers#31588]
|
|
|
+ num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
|
|
image_size=(input_height, input_width),
|
|
|
grid_pinpoints=hf_config.image_grid_pinpoints,
|
|
|
patch_size=vision_config.image_size,
|
|
@@ -107,14 +124,16 @@ def _get_llava_next_image_feature_size(
|
|
|
|
|
|
|
|
|
def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
|
|
|
- multimodal_config = ctx.get_multimodal_config()
|
|
|
hf_config = ctx.get_hf_config(LlavaNextConfig)
|
|
|
vision_config = hf_config.vision_config
|
|
|
|
|
|
- #TODO: change the logic for dummy data to support dynamic shape
|
|
|
- _, _, dummy_height, dummy_width = multimodal_config.image_input_shape
|
|
|
- image_feature_size = _get_llava_next_image_feature_size(
|
|
|
- hf_config, input_height=dummy_height, input_width=dummy_width)
|
|
|
+ # Result in the max possible feature size (2x2 grid of 336x336px tiles)
|
|
|
+ dummy_height = dummy_width = 448
|
|
|
+ image_feature_size = get_llava_next_image_feature_size(
|
|
|
+ hf_config,
|
|
|
+ input_height=dummy_height,
|
|
|
+ input_width=dummy_width,
|
|
|
+ )
|
|
|
|
|
|
if isinstance(vision_config, CLIPVisionConfig):
|
|
|
seq_data = dummy_seq_data_for_clip(
|
|
@@ -136,27 +155,47 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
|
|
|
raise NotImplementedError(msg)
|
|
|
|
|
|
|
|
|
-def _pixel_mapper(ctx: InputContext, image: object) -> Dict[str, torch.Tensor]:
|
|
|
+def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
|
|
|
+ multi_modal_data = llm_inputs.get("multi_modal_data")
|
|
|
+ if multi_modal_data is None or "image" not in multi_modal_data:
|
|
|
+ return llm_inputs
|
|
|
|
|
|
- if isinstance(image, Image.Image):
|
|
|
+ model_config = ctx.model_config
|
|
|
+ hf_config = ctx.get_hf_config(LlavaNextConfig)
|
|
|
+ vision_config = hf_config.vision_config
|
|
|
|
|
|
- # Temporary patch before dynamic number of image tokens is supported
|
|
|
- _, _, h, w = ctx.get_multimodal_config().image_input_shape
|
|
|
- if (w, h) != (image.width, image.height):
|
|
|
- logger.warning(
|
|
|
- "Dynamic image shape is currently not supported. "
|
|
|
- "Resizing input image to (%d, %d).", w, h)
|
|
|
+ image_data = multi_modal_data["image"]
|
|
|
+ if isinstance(image_data, Image.Image):
|
|
|
+ width, height = image_data.size
|
|
|
|
|
|
- image = image.resize((w, h))
|
|
|
+ image_feature_size = get_llava_next_image_feature_size(
|
|
|
+ hf_config,
|
|
|
+ input_height=height,
|
|
|
+ input_width=width,
|
|
|
+ )
|
|
|
+ elif isinstance(image_data, torch.Tensor):
|
|
|
+ raise NotImplementedError("Embeddings input is not supported yet")
|
|
|
+ else:
|
|
|
+ raise TypeError(f"Invalid image type: {type(image_data)}")
|
|
|
|
|
|
- return MULTIMODAL_REGISTRY._get_plugin("image") \
|
|
|
- ._default_input_mapper(ctx, image)
|
|
|
+ vision_config = hf_config.vision_config
|
|
|
+
|
|
|
+ if isinstance(vision_config, CLIPVisionConfig):
|
|
|
+ return input_processor_for_clip(
|
|
|
+ model_config,
|
|
|
+ vision_config,
|
|
|
+ llm_inputs,
|
|
|
+ image_token_id=hf_config.image_token_index,
|
|
|
+ image_feature_size_override=image_feature_size,
|
|
|
+ )
|
|
|
|
|
|
- raise TypeError(f"Invalid type for 'image': {type(image)}")
|
|
|
+ msg = f"Unsupported vision config: {type(vision_config)}"
|
|
|
+ raise NotImplementedError(msg)
|
|
|
|
|
|
|
|
|
-@MULTIMODAL_REGISTRY.register_image_input_mapper(_pixel_mapper)
|
|
|
+@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
|
|
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
|
|
|
+@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
|
|
|
class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
|
|
|
|
|
def __init__(self,
|
|
@@ -169,8 +208,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
|
|
self.config = config
|
|
|
self.vlm_config = vlm_config
|
|
|
|
|
|
+ # TODO: Optionally initializes this for supporting embeddings.
|
|
|
self.vision_tower = CLIPVisionModel(config=config.vision_config)
|
|
|
-
|
|
|
self.multi_modal_projector = LlavaMultiModalProjector(
|
|
|
vision_hidden_size=config.vision_config.hidden_size,
|
|
|
text_hidden_size=config.text_config.hidden_size,
|
|
@@ -193,24 +232,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
|
|
self.image_newline = nn.Parameter(
|
|
|
torch.empty(config.text_config.hidden_size))
|
|
|
|
|
|
- def _validate_image_pixels(self, data: torch.Tensor) -> torch.Tensor:
|
|
|
- _, num_channels, _, _ = self.vlm_config.image_input_shape
|
|
|
-
|
|
|
- # Note that this is different from that of vLLM vision_language_config
|
|
|
- # since the image is resized by the HuggingFace preprocessor
|
|
|
- height = width = self.config.vision_config.image_size
|
|
|
-
|
|
|
- if list(data.shape[2:]) != [num_channels, height, width]:
|
|
|
- raise ValueError(
|
|
|
- f"The expected image tensor shape is batch dimension plus "
|
|
|
- f"num_patches plus {[num_channels, height, width]}. "
|
|
|
- f"You supplied {data.shape}. "
|
|
|
- f"If you are using vLLM's entrypoint, make sure your "
|
|
|
- f"supplied image input is consistent with "
|
|
|
- f"image_input_shape in engine args.")
|
|
|
-
|
|
|
- return data
|
|
|
-
|
|
|
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
|
|
|
if list(data.shape[1:]) != [2]:
|
|
|
raise ValueError(
|
|
@@ -220,14 +241,14 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
|
|
return data
|
|
|
|
|
|
def _parse_and_validate_image_input(
|
|
|
- self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
|
|
|
+ self, **kwargs: object) -> Optional[LlavaNextImagePixelInputs]:
|
|
|
pixel_values = kwargs.pop("pixel_values", None)
|
|
|
image_sizes = kwargs.pop("image_sizes", None)
|
|
|
|
|
|
- if pixel_values is None or image_sizes is None:
|
|
|
+ if pixel_values is None:
|
|
|
return None
|
|
|
|
|
|
- if not isinstance(pixel_values, torch.Tensor):
|
|
|
+ if not isinstance(pixel_values, (torch.Tensor, list)):
|
|
|
raise ValueError("Incorrect type of pixel values. "
|
|
|
f"Got type: {type(pixel_values)}")
|
|
|
|
|
@@ -237,7 +258,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
|
|
|
|
|
return LlavaNextImagePixelInputs(
|
|
|
type="pixel_values",
|
|
|
- data=self._validate_image_pixels(pixel_values),
|
|
|
+ data=pixel_values,
|
|
|
image_sizes=self._validate_image_sizes(image_sizes),
|
|
|
)
|
|
|
|
|
@@ -264,15 +285,14 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
|
|
strategy=self.config.vision_feature_select_strategy,
|
|
|
)
|
|
|
|
|
|
+ # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
|
|
|
def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
|
|
|
patch_embeddings: torch.Tensor, *,
|
|
|
strategy: str) -> torch.Tensor:
|
|
|
- # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
|
|
|
if strategy == "flat":
|
|
|
return patch_embeddings.flatten(0, 1)
|
|
|
|
|
|
if strategy.startswith("spatial"):
|
|
|
- orig_width, orig_height = image_size
|
|
|
height = width = self.config.vision_config.image_size \
|
|
|
// self.config.vision_config.patch_size
|
|
|
|
|
@@ -286,13 +306,15 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
|
|
other_patch_embeds = patch_embeddings[1:]
|
|
|
|
|
|
# image_aspect_ratio == "anyres"
|
|
|
+ # Note: We follow the "wrong" width/height order
|
|
|
+ # [ref: PR huggingface/transformers#31588]
|
|
|
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
|
|
- (orig_width, orig_height),
|
|
|
+ image_size,
|
|
|
self.config.image_grid_pinpoints,
|
|
|
self.config.vision_config.image_size,
|
|
|
)
|
|
|
other_patch_embeds = other_patch_embeds \
|
|
|
- .view(num_patch_width, num_patch_height, height, width, -1)
|
|
|
+ .view(num_patch_height, num_patch_width, height, width, -1)
|
|
|
|
|
|
if "unpad" in strategy:
|
|
|
other_patch_embeds = other_patch_embeds \
|
|
@@ -330,44 +352,53 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
|
|
raise ValueError(f"Unexpected patch merge strategy: {strategy}")
|
|
|
|
|
|
def _process_image_pixels(
|
|
|
- self, inputs: LlavaNextImagePixelInputs) -> torch.Tensor:
|
|
|
+ self,
|
|
|
+ inputs: LlavaNextImagePixelInputs,
|
|
|
+ ) -> BatchedTensors:
|
|
|
assert self.vision_tower is not None
|
|
|
|
|
|
pixel_values = inputs["data"]
|
|
|
|
|
|
- b, num_patches, c, h, w = pixel_values.shape
|
|
|
- stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
|
|
|
+ if isinstance(pixel_values, torch.Tensor):
|
|
|
+ b, num_patches, c, h, w = pixel_values.shape
|
|
|
+ stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
|
|
|
+ stacked_image_features = self._image_pixels_to_features(
|
|
|
+ self.vision_tower, stacked_pixel_values)
|
|
|
+ stacked_patch_embeddings = self.multi_modal_projector(
|
|
|
+ stacked_image_features)
|
|
|
|
|
|
+ return stacked_patch_embeddings.view(
|
|
|
+ b, num_patches, *stacked_patch_embeddings.shape[1:])
|
|
|
+
|
|
|
+ num_patches_per_batch = [v.shape[0] for v in pixel_values]
|
|
|
+ stacked_pixel_values = torch.cat(pixel_values)
|
|
|
stacked_image_features = self._image_pixels_to_features(
|
|
|
self.vision_tower, stacked_pixel_values)
|
|
|
|
|
|
- return stacked_image_features.view(b, num_patches,
|
|
|
- *stacked_image_features.shape[-2:])
|
|
|
+ return [
|
|
|
+ self.multi_modal_projector(image_features) for image_features in
|
|
|
+ torch.split(stacked_image_features, num_patches_per_batch)
|
|
|
+ ]
|
|
|
|
|
|
def _process_image_input(
|
|
|
- self, image_input: LlavaNextImageInputs) -> torch.Tensor:
|
|
|
- assert self.vision_tower is not None
|
|
|
- image_features = self._process_image_pixels(image_input)
|
|
|
-
|
|
|
- patch_embeddings = self.multi_modal_projector(image_features)
|
|
|
+ self, image_input: LlavaNextImageInputs) -> BatchedTensors:
|
|
|
+ patch_embeddings = self._process_image_pixels(image_input)
|
|
|
|
|
|
image_sizes = image_input.get("image_sizes")
|
|
|
if image_sizes is None:
|
|
|
- batch_size = image_input["data"].shape[0]
|
|
|
+ batch_size = len(image_input["data"])
|
|
|
vision_config = self.config.vision_config
|
|
|
- default_width = default_height = vision_config.image_size
|
|
|
- image_sizes = torch.as_tensor([[default_width, default_height]
|
|
|
+ default_height = default_width = vision_config.image_size
|
|
|
+ image_sizes = torch.as_tensor([[default_height, default_width]
|
|
|
for _ in range(batch_size)])
|
|
|
|
|
|
- merged_patch_embeddings = [
|
|
|
+ return [
|
|
|
self._merge_image_patch_embeddings(image_sizes[i],
|
|
|
- patch_features,
|
|
|
+ patch_features_batch,
|
|
|
strategy="spatial_unpad")
|
|
|
- for i, patch_features in enumerate(patch_embeddings)
|
|
|
+ for i, patch_features_batch in enumerate(patch_embeddings)
|
|
|
]
|
|
|
|
|
|
- return torch.stack(merged_patch_embeddings, dim=0)
|
|
|
-
|
|
|
def forward(
|
|
|
self,
|
|
|
input_ids: torch.Tensor,
|
|
@@ -401,8 +432,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
|
|
input_ids: Flattened (concatenated) input_ids corresponding to a
|
|
|
batch.
|
|
|
pixel_values: The pixels in each grid patch for each input image.
|
|
|
- Expects a batch with shape `[1, num_patches, 3, 336, 336]`.
|
|
|
- image_sizes: The original `(width, height)` for each input image.
|
|
|
+ Expects a batch with shape `[1, num_patches, 3, h, w]`.
|
|
|
+ image_sizes: The original `(height, width)` for each input image.
|
|
|
Expects a batch with shape `[1, 2]`.
|
|
|
|
|
|
See also:
|