image.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. from functools import lru_cache
  2. from typing import List, Optional, Tuple, TypeVar
  3. import torch
  4. from loguru import logger
  5. from PIL import Image
  6. from transformers import PreTrainedTokenizerBase
  7. from aphrodite.common.config import ModelConfig
  8. from aphrodite.inputs.registry import InputContext
  9. from aphrodite.transformers_utils.image_processor import get_image_processor
  10. from aphrodite.transformers_utils.tokenizer import get_tokenizer
  11. from .base import MultiModalInputs, MultiModalPlugin
  12. cached_get_image_processor = lru_cache(get_image_processor)
  13. cached_get_tokenizer = lru_cache(get_tokenizer)
  14. # Utilities for image input processors
  15. _T = TypeVar("_T", str, int)
  16. def repeat_and_pad_token(
  17. token: _T,
  18. *,
  19. repeat_count: int = 1,
  20. pad_token_left: Optional[_T] = None,
  21. pad_token_right: Optional[_T] = None,
  22. ) -> List[_T]:
  23. replacement = [token] * repeat_count
  24. if pad_token_left is not None:
  25. replacement = [pad_token_left] + replacement
  26. if pad_token_right is not None:
  27. replacement = replacement + [pad_token_right]
  28. return replacement
  29. def repeat_and_pad_image_tokens(
  30. tokenizer: PreTrainedTokenizerBase,
  31. prompt: Optional[str],
  32. prompt_token_ids: List[int],
  33. *,
  34. image_token_id: int,
  35. repeat_count: int = 1,
  36. pad_token_left: Optional[int] = None,
  37. pad_token_right: Optional[int] = None,
  38. ) -> Tuple[Optional[str], List[int]]:
  39. if prompt is None:
  40. new_prompt = None
  41. else:
  42. image_token_str = tokenizer.decode(image_token_id)
  43. pad_token_str_left = (None if pad_token_left is None else
  44. tokenizer.decode(pad_token_left))
  45. pad_token_str_right = (None if pad_token_right is None else
  46. tokenizer.decode(pad_token_right))
  47. replacement_str = "".join(
  48. repeat_and_pad_token(
  49. image_token_str,
  50. repeat_count=repeat_count,
  51. pad_token_left=pad_token_str_left,
  52. pad_token_right=pad_token_str_right,
  53. ))
  54. image_token_count = prompt.count(image_token_str)
  55. # This is an arbitrary number to distinguish between the two cases
  56. if image_token_count > 16:
  57. logger.warning("Please follow the prompt format that is "
  58. "documented on HuggingFace which does not involve "
  59. f"repeating {image_token_str} tokens.")
  60. elif image_token_count > 1:
  61. logger.warning("Multiple image input is not supported yet, "
  62. "so any extra image tokens will be treated "
  63. "as plain text.")
  64. # The image tokens are removed to be consistent with HuggingFace
  65. new_prompt = prompt.replace(image_token_str, replacement_str, 1)
  66. new_token_ids: List[int] = []
  67. for i, token in enumerate(prompt_token_ids):
  68. if token == image_token_id:
  69. replacement_ids = repeat_and_pad_token(
  70. image_token_id,
  71. repeat_count=repeat_count,
  72. pad_token_left=pad_token_left,
  73. pad_token_right=pad_token_right,
  74. )
  75. new_token_ids.extend(replacement_ids)
  76. # No need to further scan the list since we only replace once
  77. new_token_ids.extend(prompt_token_ids[i + 1:])
  78. break
  79. else:
  80. new_token_ids.append(token)
  81. return new_prompt, new_token_ids
  82. class ImagePlugin(MultiModalPlugin):
  83. def get_data_key(self) -> str:
  84. return "image"
  85. def _get_hf_image_processor(self, model_config: ModelConfig):
  86. return cached_get_image_processor(
  87. model_config.model,
  88. trust_remote_code=model_config.trust_remote_code)
  89. def _default_input_mapper(self, ctx: InputContext,
  90. data: object) -> MultiModalInputs:
  91. model_config = ctx.model_config
  92. if isinstance(data, Image.Image):
  93. image_processor = self._get_hf_image_processor(model_config)
  94. if image_processor is None:
  95. raise RuntimeError("No HuggingFace processor is available"
  96. "to process the image object")
  97. try:
  98. batch_data = image_processor \
  99. .preprocess(data, return_tensors="pt") \
  100. .data
  101. except Exception:
  102. logger.error(f"Failed to process image ({data})")
  103. raise
  104. return MultiModalInputs(batch_data)
  105. elif isinstance(data, torch.Tensor):
  106. raise NotImplementedError("Embeddings input is not supported yet")
  107. raise TypeError(f"Invalid image type: {type(data)}")
  108. def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
  109. return 3000