1
0

image.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. from typing import Dict, Tuple, Type, Union
  2. import torch
  3. from loguru import logger
  4. from PIL import Image
  5. from aphrodite.common.config import ModelConfig, VisionLanguageConfig
  6. from aphrodite.common.sequence import SequenceData
  7. from aphrodite.transformers_utils.image_processor import \
  8. cached_get_image_processor
  9. from .base import MultiModalData, MultiModalPlugin
  10. def _get_dummy_seq_data(seq_len: int,
  11. vlm_config: VisionLanguageConfig) -> SequenceData:
  12. # NOTE: We assume that <image> token is repeated `image_feature_size` times
  13. # and then concatenated with the text prompt
  14. # TODO: Enable other ways of inserting the image into the prompt
  15. token_ids = [vlm_config.image_token_id] * vlm_config.image_feature_size
  16. token_ids += [0] * (seq_len - vlm_config.image_feature_size)
  17. return SequenceData(token_ids)
  18. def _get_dummy_values(vlm_config: VisionLanguageConfig) -> torch.Tensor:
  19. if vlm_config.image_processor is None:
  20. values_dtype = torch.float16
  21. else:
  22. values_dtype = torch.uint8
  23. return torch.zeros(vlm_config.image_input_shape, dtype=values_dtype)
  24. def get_dummy_image_data(
  25. seq_len: int,
  26. model_config: ModelConfig,
  27. vlm_config: VisionLanguageConfig,
  28. ) -> Tuple[SequenceData, MultiModalData]:
  29. """Standard dummy data factory for image data (to be used in
  30. :meth:`vlm.multimodal.MultiModalRegistry.register_dummy_data`)."""
  31. seq_data = _get_dummy_seq_data(seq_len, vlm_config)
  32. values = _get_dummy_values(vlm_config)
  33. config_input_type = vlm_config.image_input_type
  34. ImageInputType = VisionLanguageConfig.ImageInputType
  35. fake_mm_data: MultiModalData
  36. if config_input_type == ImageInputType.PIXEL_VALUES:
  37. fake_mm_data = ImagePixelData(values)
  38. elif config_input_type == ImageInputType.IMAGE_FEATURES:
  39. fake_mm_data = ImageFeatureData(values)
  40. else:
  41. raise NotImplementedError
  42. return seq_data, fake_mm_data
  43. class ImagePixelData(MultiModalData):
  44. """
  45. The pixel data of an image. Can be one of:
  46. - :class:``PIL.Image``: An image object. Requires that a HuggingFace
  47. processor is available to the model.
  48. - :class:``torch.Tensor``: The raw pixel data which is passed to the model
  49. without additional pre-processing.
  50. """
  51. def __init__(self, image: Union[Image.Image, torch.Tensor]) -> None:
  52. if isinstance(image, Image.Image):
  53. # So that this class can be created inside the Image context manager
  54. image.load()
  55. self.image = image
  56. class ImagePixelPlugin(MultiModalPlugin[ImagePixelData]):
  57. def get_data_type(self) -> Type[ImagePixelData]:
  58. return ImagePixelData
  59. def _get_hf_image_processor(self, model_config: ModelConfig,
  60. vlm_config: VisionLanguageConfig):
  61. if vlm_config is None or vlm_config.image_processor is None:
  62. return None
  63. return cached_get_image_processor(
  64. vlm_config.image_processor,
  65. trust_remote_code=model_config.trust_remote_code,
  66. revision=vlm_config.image_processor_revision,
  67. )
  68. def _default_input_processor(
  69. self, data: ImagePixelData, model_config: ModelConfig,
  70. vlm_config: VisionLanguageConfig) -> Dict[str, torch.Tensor]:
  71. image = data.image
  72. image_processor = self._get_hf_image_processor(model_config,
  73. vlm_config)
  74. if isinstance(image, Image.Image):
  75. if image_processor is None:
  76. raise RuntimeError("No HuggingFace processor is available"
  77. "to process the image object")
  78. try:
  79. return image_processor.preprocess(image, return_tensors="pt") \
  80. .to(model_config.dtype).data
  81. except Exception:
  82. logger.error("Failed to process image (%s)", image)
  83. raise
  84. elif isinstance(image, torch.Tensor):
  85. pixel_values = image.to(model_config.dtype)
  86. return {"pixel_values": pixel_values}
  87. raise TypeError(f"Invalid image type: {type(image)}")
  88. class ImageFeatureData(MultiModalData):
  89. """
  90. The feature vector of an image, passed directly to the model.
  91. This should be the output of the vision tower.
  92. """
  93. def __init__(self, image_features: torch.Tensor) -> None:
  94. self.image_features = image_features
  95. class ImageFeaturePlugin(MultiModalPlugin[ImageFeatureData]):
  96. def get_data_type(self) -> Type[ImageFeatureData]:
  97. return ImageFeatureData
  98. def _default_input_processor(
  99. self, data: ImageFeatureData, model_config: ModelConfig,
  100. vlm_config: VisionLanguageConfig) -> Dict[str, torch.Tensor]:
  101. image_features = data.image_features.to(model_config.dtype)
  102. return {"image_features": image_features}