|
@@ -1,9 +1,9 @@
|
|
|
-from typing import (Dict, Iterable, List, Literal, Optional, Tuple, TypedDict,
|
|
|
- Union)
|
|
|
+from typing import Dict, Iterable, List, Literal, Optional, Tuple, TypedDict
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from loguru import logger
|
|
|
+from PIL import Image
|
|
|
from transformers import CLIPVisionConfig, LlavaNextConfig
|
|
|
from transformers.models.llava_next.modeling_llava_next import (
|
|
|
get_anyres_image_grid_shape, unpad_image)
|
|
@@ -17,20 +17,17 @@ from aphrodite.modeling.layers.logits_processor import LogitsProcessor
|
|
|
from aphrodite.modeling.layers.sampler import Sampler
|
|
|
from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
|
|
|
from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
|
|
|
-from aphrodite.modeling.models.clip import (CLIPVisionModel,
|
|
|
- dummy_feature_data_for_clip,
|
|
|
- dummy_pixel_data_for_clip,
|
|
|
- dummy_seq_data_for_clip,
|
|
|
- get_clip_patch_grid_length)
|
|
|
-from aphrodite.modeling.models.interfaces import SupportsVision
|
|
|
+from aphrodite.modeling.models.clip import CLIPVisionModel
|
|
|
from aphrodite.modeling.models.llama import LlamaModel
|
|
|
-from aphrodite.modeling.models.llava import (LlavaMultiModalProjector,
|
|
|
- merge_vision_embeddings)
|
|
|
from aphrodite.modeling.sampling_metadata import SamplingMetadata
|
|
|
-from aphrodite.multimodal import MULTIMODAL_REGISTRY, MultiModalData
|
|
|
-from aphrodite.multimodal.image import ImagePixelData
|
|
|
+from aphrodite.multimodal import MULTIMODAL_REGISTRY
|
|
|
from aphrodite.quantization.base_config import QuantizationConfig
|
|
|
|
|
|
+from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
|
|
|
+ get_clip_patch_grid_length)
|
|
|
+from .interfaces import SupportsVision
|
|
|
+from .llava import LlavaMultiModalProjector, merge_vision_embeddings
|
|
|
+
|
|
|
_KEYS_TO_MODIFY_MAPPING = {
|
|
|
"language_model.lm_head": "lm_head",
|
|
|
"language_model.model": "language_model",
|
|
@@ -46,17 +43,7 @@ class LlavaNextImagePixelInputs(TypedDict):
|
|
|
"""Shape: (batch_size, 2)"""
|
|
|
|
|
|
|
|
|
-class LlavaNextImageFeatureInputs(TypedDict):
|
|
|
- type: Literal["image_features"]
|
|
|
- data: torch.Tensor
|
|
|
- """Shape: (batch_size, 1 + num_patches, image_feature_size, hidden_size)"""
|
|
|
-
|
|
|
- image_sizes: NotRequired[torch.Tensor]
|
|
|
- """Shape: (batch_size, 2)"""
|
|
|
-
|
|
|
-
|
|
|
-LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
|
|
|
- LlavaNextImageFeatureInputs]
|
|
|
+LlavaNextImageInputs = LlavaNextImagePixelInputs
|
|
|
|
|
|
|
|
|
def _get_llava_next_num_unpadded_features(
|
|
@@ -137,20 +124,11 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
|
|
|
image_feature_size_override=image_feature_size,
|
|
|
)
|
|
|
|
|
|
- image_input_type = multimodal_config.image_input_type
|
|
|
- ImageInputType = VisionLanguageConfig.ImageInputType
|
|
|
- mm_data: MultiModalData
|
|
|
- if image_input_type == ImageInputType.PIXEL_VALUES:
|
|
|
- mm_data = dummy_pixel_data_for_clip(
|
|
|
- vision_config,
|
|
|
- image_width_override=dummy_width,
|
|
|
- image_height_override=dummy_height,
|
|
|
- )
|
|
|
- elif image_input_type == ImageInputType.IMAGE_FEATURES:
|
|
|
- mm_data = dummy_feature_data_for_clip(
|
|
|
- vision_config,
|
|
|
- image_feature_size_override=image_feature_size,
|
|
|
- )
|
|
|
+ mm_data = dummy_image_for_clip(
|
|
|
+ vision_config,
|
|
|
+ image_width_override=dummy_width,
|
|
|
+ image_height_override=dummy_height,
|
|
|
+ )
|
|
|
|
|
|
return seq_data, mm_data
|
|
|
|
|
@@ -158,31 +136,26 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
|
|
|
raise NotImplementedError(msg)
|
|
|
|
|
|
|
|
|
-def _pixel_mapper(ctx: InputContext,
|
|
|
- data: ImagePixelData) -> Dict[str, torch.Tensor]:
|
|
|
- image = data.image
|
|
|
+def _pixel_mapper(ctx: InputContext, image: object) -> Dict[str, torch.Tensor]:
|
|
|
|
|
|
- if isinstance(image, torch.Tensor):
|
|
|
- pixel_values = image.to(ctx.model_config.dtype)
|
|
|
- batch_size, _, _, h, w = pixel_values.shape
|
|
|
- image_sizes = torch.tensor([(w, h) for _ in range(batch_size)])
|
|
|
+ if isinstance(image, Image.Image):
|
|
|
|
|
|
- return {"pixel_values": pixel_values, "image_sizes": image_sizes}
|
|
|
+ # Temporary patch before dynamic number of image tokens is supported
|
|
|
+ _, _, h, w = ctx.get_multimodal_config().image_input_shape
|
|
|
+ if (w, h) != (image.width, image.height):
|
|
|
+ logger.warning(
|
|
|
+ "Dynamic image shape is currently not supported. "
|
|
|
+ "Resizing input image to (%d, %d).", w, h)
|
|
|
|
|
|
- # Temporary patch before dynamic number of image tokens is supported
|
|
|
- _, _, h, w = ctx.get_multimodal_config().image_input_shape
|
|
|
- if (w, h) != (image.width, image.height):
|
|
|
- logger.warning("Dynamic image shape is currently not supported. "
|
|
|
- f"Resizing input image to ({w}, {h}).")
|
|
|
+ image = image.resize((w, h))
|
|
|
|
|
|
- data.image = image.resize((w, h))
|
|
|
+ return MULTIMODAL_REGISTRY._get_plugin("image") \
|
|
|
+ ._default_input_mapper(ctx, image)
|
|
|
|
|
|
- return MULTIMODAL_REGISTRY._get_plugin_for_data_type(ImagePixelData) \
|
|
|
- ._default_input_mapper(ctx, data)
|
|
|
+ raise TypeError(f"Invalid type for 'image': {type(image)}")
|
|
|
|
|
|
|
|
|
-@MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
|
|
|
-@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper(_pixel_mapper)
|
|
|
+@MULTIMODAL_REGISTRY.register_image_input_mapper(_pixel_mapper)
|
|
|
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
|
|
|
class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
|
|
|
|
@@ -196,11 +169,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
|
|
self.config = config
|
|
|
self.vlm_config = vlm_config
|
|
|
|
|
|
- if self.vlm_config.image_input_type == (
|
|
|
- VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
|
|
|
- self.vision_tower = CLIPVisionModel(config=config.vision_config)
|
|
|
- else:
|
|
|
- raise TypeError("Image features are not supported by LLaVA-NeXT")
|
|
|
+ self.vision_tower = CLIPVisionModel(config=config.vision_config)
|
|
|
|
|
|
self.multi_modal_projector = LlavaMultiModalProjector(
|
|
|
vision_hidden_size=config.vision_config.hidden_size,
|
|
@@ -226,9 +195,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
|
|
def _validate_image_pixels(self, data: torch.Tensor) -> torch.Tensor:
|
|
|
_, num_channels, _, _ = self.vlm_config.image_input_shape
|
|
|
|
|
|
- # Note that this is different from that of Aphrodite
|
|
|
- # vision_language_config since the image is resized by the HuggingFace
|
|
|
- # preprocessor
|
|
|
+ # Note that this is different from that of vLLM vision_language_config
|
|
|
+ # since the image is resized by the HuggingFace preprocessor
|
|
|
height = width = self.config.vision_config.image_size
|
|
|
|
|
|
if list(data.shape[2:]) != [num_channels, height, width]:
|
|
@@ -236,7 +204,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
|
|
f"The expected image tensor shape is batch dimension plus "
|
|
|
f"num_patches plus {[num_channels, height, width]}. "
|
|
|
f"You supplied {data.shape}. "
|
|
|
- f"If you are using Aphrodite's entrypoint, make sure your "
|
|
|
+ f"If you are using vLLM's entrypoint, make sure your "
|
|
|
f"supplied image input is consistent with "
|
|
|
f"image_input_shape in engine args.")
|
|
|
|
|
@@ -254,36 +222,23 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
|
|
self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
|
|
|
pixel_values = kwargs.pop("pixel_values", None)
|
|
|
image_sizes = kwargs.pop("image_sizes", None)
|
|
|
- image_features = kwargs.pop("image_features", None)
|
|
|
-
|
|
|
- expected_input_type = self.vlm_config.image_input_type
|
|
|
- ImageInputType = VisionLanguageConfig.ImageInputType
|
|
|
-
|
|
|
- if expected_input_type == ImageInputType.PIXEL_VALUES:
|
|
|
- if image_features is not None:
|
|
|
- raise ValueError(
|
|
|
- "Expected pixel values but got image features")
|
|
|
- if pixel_values is None:
|
|
|
- return None
|
|
|
|
|
|
- if not isinstance(pixel_values, torch.Tensor):
|
|
|
- raise ValueError("Incorrect type of pixel values. "
|
|
|
- f"Got type: {type(pixel_values)}")
|
|
|
+ if pixel_values is None or image_sizes is None:
|
|
|
+ return None
|
|
|
|
|
|
- if not isinstance(image_sizes, torch.Tensor):
|
|
|
- raise ValueError("Incorrect type of image sizes. "
|
|
|
- f"Got type: {type(image_sizes)}")
|
|
|
+ if not isinstance(pixel_values, torch.Tensor):
|
|
|
+ raise ValueError("Incorrect type of pixel values. "
|
|
|
+ f"Got type: {type(pixel_values)}")
|
|
|
|
|
|
- return LlavaNextImagePixelInputs(
|
|
|
- type="pixel_values",
|
|
|
- data=self._validate_image_pixels(pixel_values),
|
|
|
- image_sizes=self._validate_image_sizes(image_sizes),
|
|
|
- )
|
|
|
+ if not isinstance(image_sizes, torch.Tensor):
|
|
|
+ raise ValueError("Incorrect type of image sizes. "
|
|
|
+ f"Got type: {type(image_sizes)}")
|
|
|
|
|
|
- assert expected_input_type != ImageInputType.IMAGE_FEATURES, (
|
|
|
- "Failed to validate this at initialization time")
|
|
|
-
|
|
|
- return None
|
|
|
+ return LlavaNextImagePixelInputs(
|
|
|
+ type="pixel_values",
|
|
|
+ data=self._validate_image_pixels(pixel_values),
|
|
|
+ image_sizes=self._validate_image_sizes(image_sizes),
|
|
|
+ )
|
|
|
|
|
|
def _select_image_features(self, image_features: torch.Tensor, *,
|
|
|
strategy: str) -> torch.Tensor:
|
|
@@ -390,11 +345,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
|
|
|
|
|
def _process_image_input(
|
|
|
self, image_input: LlavaNextImageInputs) -> torch.Tensor:
|
|
|
- if image_input["type"] == "pixel_values":
|
|
|
- assert self.vision_tower is not None
|
|
|
- image_features = self._process_image_pixels(image_input)
|
|
|
- else:
|
|
|
- image_features = image_input["data"]
|
|
|
+ assert self.vision_tower is not None
|
|
|
+ image_features = self._process_image_pixels(image_input)
|
|
|
|
|
|
patch_embeddings = self.multi_modal_projector(image_features)
|
|
|
|