image.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. from functools import lru_cache
  2. from typing import List, Optional, Tuple, TypeVar
  3. import torch
  4. from loguru import logger
  5. from PIL import Image
  6. from transformers import PreTrainedTokenizerBase
  7. from aphrodite.common.config import ModelConfig
  8. from aphrodite.common.utils import is_list_of
  9. from aphrodite.inputs.registry import InputContext
  10. from aphrodite.transformers_utils.image_processor import get_image_processor
  11. from aphrodite.transformers_utils.tokenizer import get_tokenizer
  12. from .base import MultiModalData, MultiModalInputs, MultiModalPlugin
  13. cached_get_image_processor = lru_cache(get_image_processor)
  14. cached_get_tokenizer = lru_cache(get_tokenizer)
  15. # Utilities for image input processors
  16. _T = TypeVar("_T", str, int)
  17. def repeat_and_pad_token(
  18. token: _T,
  19. *,
  20. repeat_count: int = 1,
  21. pad_token_left: Optional[_T] = None,
  22. pad_token_right: Optional[_T] = None,
  23. ) -> List[_T]:
  24. replacement = [token] * repeat_count
  25. if pad_token_left is not None:
  26. replacement = [pad_token_left] + replacement
  27. if pad_token_right is not None:
  28. replacement = replacement + [pad_token_right]
  29. return replacement
  30. def repeat_and_pad_image_tokens(
  31. tokenizer: PreTrainedTokenizerBase,
  32. prompt: Optional[str],
  33. prompt_token_ids: List[int],
  34. *,
  35. image_token_id: int,
  36. repeat_count: int = 1,
  37. pad_token_left: Optional[int] = None,
  38. pad_token_right: Optional[int] = None,
  39. ) -> Tuple[Optional[str], List[int]]:
  40. if prompt is None:
  41. new_prompt = None
  42. else:
  43. image_token_str = tokenizer.decode(image_token_id)
  44. pad_token_str_left = (None if pad_token_left is None else
  45. tokenizer.decode(pad_token_left))
  46. pad_token_str_right = (None if pad_token_right is None else
  47. tokenizer.decode(pad_token_right))
  48. replacement_str = "".join(
  49. repeat_and_pad_token(
  50. image_token_str,
  51. repeat_count=repeat_count,
  52. pad_token_left=pad_token_str_left,
  53. pad_token_right=pad_token_str_right,
  54. ))
  55. image_token_count = prompt.count(image_token_str)
  56. # This is an arbitrary number to distinguish between the two cases
  57. if image_token_count > 16:
  58. logger.warning("Please follow the prompt format that is "
  59. "documented on HuggingFace which does not involve "
  60. f"repeating {image_token_str} tokens.")
  61. elif image_token_count > 1:
  62. logger.warning("Multiple image input is not supported yet, "
  63. "so any extra image tokens will be treated "
  64. "as plain text.")
  65. # The image tokens are removed to be consistent with HuggingFace
  66. new_prompt = prompt.replace(image_token_str, replacement_str, 1)
  67. new_token_ids: List[int] = []
  68. for i, token in enumerate(prompt_token_ids):
  69. if token == image_token_id:
  70. replacement_ids = repeat_and_pad_token(
  71. image_token_id,
  72. repeat_count=repeat_count,
  73. pad_token_left=pad_token_left,
  74. pad_token_right=pad_token_right,
  75. )
  76. new_token_ids.extend(replacement_ids)
  77. # No need to further scan the list since we only replace once
  78. new_token_ids.extend(prompt_token_ids[i + 1:])
  79. break
  80. else:
  81. new_token_ids.append(token)
  82. return new_prompt, new_token_ids
  83. class ImagePlugin(MultiModalPlugin):
  84. def get_data_key(self) -> str:
  85. return "image"
  86. def _get_hf_image_processor(self, model_config: ModelConfig):
  87. return cached_get_image_processor(
  88. model_config.model,
  89. trust_remote_code=model_config.trust_remote_code)
  90. def _default_input_mapper(
  91. self,
  92. ctx: InputContext,
  93. data: MultiModalData[object],
  94. ) -> MultiModalInputs:
  95. model_config = ctx.model_config
  96. if isinstance(data, Image.Image) or is_list_of(data, Image.Image):
  97. image_processor = self._get_hf_image_processor(model_config)
  98. if image_processor is None:
  99. raise RuntimeError("No HuggingFace processor is available"
  100. "to process the image object")
  101. try:
  102. batch_data = image_processor \
  103. .preprocess(data, return_tensors="pt") \
  104. .data
  105. except Exception:
  106. logger.error(f"Failed to process image ({data})")
  107. raise
  108. return MultiModalInputs(batch_data)
  109. elif isinstance(data, torch.Tensor) or is_list_of(data, torch.Tensor):
  110. return MultiModalInputs({"image_embeds": data})
  111. raise TypeError(f"Invalid image type: {type(data)}")
  112. def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
  113. return 3000