from functools import lru_cache from typing import List, Optional, Tuple, TypeVar import torch from loguru import logger from PIL import Image from transformers import PreTrainedTokenizerBase from aphrodite.common.config import ModelConfig from aphrodite.inputs.registry import InputContext from aphrodite.transformers_utils.image_processor import get_image_processor from aphrodite.transformers_utils.tokenizer import get_tokenizer from .base import MultiModalInputs, MultiModalPlugin cached_get_image_processor = lru_cache(get_image_processor) cached_get_tokenizer = lru_cache(get_tokenizer) # Utilities for image input processors _T = TypeVar("_T", str, int) def repeat_and_pad_token( token: _T, *, repeat_count: int = 1, pad_token_left: Optional[_T] = None, pad_token_right: Optional[_T] = None, ) -> List[_T]: replacement = [token] * repeat_count if pad_token_left is not None: replacement = [pad_token_left] + replacement if pad_token_right is not None: replacement = replacement + [pad_token_right] return replacement def repeat_and_pad_image_tokens( tokenizer: PreTrainedTokenizerBase, prompt: Optional[str], prompt_token_ids: List[int], *, image_token_id: int, repeat_count: int = 1, pad_token_left: Optional[int] = None, pad_token_right: Optional[int] = None, ) -> Tuple[Optional[str], List[int]]: if prompt is None: new_prompt = None else: image_token_str = tokenizer.decode(image_token_id) pad_token_str_left = (None if pad_token_left is None else tokenizer.decode(pad_token_left)) pad_token_str_right = (None if pad_token_right is None else tokenizer.decode(pad_token_right)) replacement_str = "".join( repeat_and_pad_token( image_token_str, repeat_count=repeat_count, pad_token_left=pad_token_str_left, pad_token_right=pad_token_str_right, )) image_token_count = prompt.count(image_token_str) # This is an arbitrary number to distinguish between the two cases if image_token_count > 16: logger.warning("Please follow the prompt format that is " "documented on HuggingFace which does not involve " f"repeating {image_token_str} tokens.") elif image_token_count > 1: logger.warning("Multiple image input is not supported yet, " "so any extra image tokens will be treated " "as plain text.") # The image tokens are removed to be consistent with HuggingFace new_prompt = prompt.replace(image_token_str, replacement_str, 1) new_token_ids: List[int] = [] for i, token in enumerate(prompt_token_ids): if token == image_token_id: replacement_ids = repeat_and_pad_token( image_token_id, repeat_count=repeat_count, pad_token_left=pad_token_left, pad_token_right=pad_token_right, ) new_token_ids.extend(replacement_ids) # No need to further scan the list since we only replace once new_token_ids.extend(prompt_token_ids[i + 1:]) break else: new_token_ids.append(token) return new_prompt, new_token_ids class ImagePlugin(MultiModalPlugin): def get_data_key(self) -> str: return "image" def _get_hf_image_processor(self, model_config: ModelConfig): return cached_get_image_processor( model_config.model, trust_remote_code=model_config.trust_remote_code) def _default_input_mapper(self, ctx: InputContext, data: object) -> MultiModalInputs: model_config = ctx.model_config if isinstance(data, Image.Image): image_processor = self._get_hf_image_processor(model_config) if image_processor is None: raise RuntimeError("No HuggingFace processor is available" "to process the image object") try: batch_data = image_processor \ .preprocess(data, return_tensors="pt") \ .data except Exception: logger.error(f"Failed to process image ({data})") raise return MultiModalInputs(batch_data) elif isinstance(data, torch.Tensor): raise NotImplementedError("Embeddings input is not supported yet") raise TypeError(f"Invalid image type: {type(data)}") def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: return 3000