瀏覽代碼

move input parsing to utils

AlpinDale 6 月之前
父節點
當前提交
e81b8d1c52
共有 2 個文件被更改,包括 88 次插入85 次删除
  1. 0 84
      aphrodite/multimodal/image.py
  2. 88 1
      aphrodite/multimodal/utils.py

+ 0 - 84
aphrodite/multimodal/image.py

@@ -1,101 +1,17 @@
 from functools import lru_cache
-from typing import List, Optional, Tuple, TypeVar
 
 import torch
 from loguru import logger
 from PIL import Image
-from transformers import PreTrainedTokenizerBase
 
 from aphrodite.common.config import ModelConfig
 from aphrodite.common.utils import is_list_of
 from aphrodite.inputs.registry import InputContext
 from aphrodite.transformers_utils.image_processor import get_image_processor
-from aphrodite.transformers_utils.tokenizer import get_tokenizer
 
 from .base import MultiModalInputs, MultiModalPlugin
 
 cached_get_image_processor = lru_cache(get_image_processor)
-cached_get_tokenizer = lru_cache(get_tokenizer)
-
-# Utilities for image input processors
-_T = TypeVar("_T", str, int)
-
-
-def repeat_and_pad_token(
-    token: _T,
-    *,
-    repeat_count: int = 1,
-    pad_token_left: Optional[_T] = None,
-    pad_token_right: Optional[_T] = None,
-) -> List[_T]:
-    replacement = [token] * repeat_count
-    if pad_token_left is not None:
-        replacement = [pad_token_left] + replacement
-    if pad_token_right is not None:
-        replacement = replacement + [pad_token_right]
-
-    return replacement
-
-
-def repeat_and_pad_image_tokens(
-    tokenizer: PreTrainedTokenizerBase,
-    prompt: Optional[str],
-    prompt_token_ids: List[int],
-    *,
-    image_token_id: int,
-    repeat_count: int = 1,
-    pad_token_left: Optional[int] = None,
-    pad_token_right: Optional[int] = None,
-) -> Tuple[Optional[str], List[int]]:
-    if prompt is None:
-        new_prompt = None
-    else:
-        image_token_str = tokenizer.decode(image_token_id)
-        pad_token_str_left = (None if pad_token_left is None else
-                              tokenizer.decode(pad_token_left))
-        pad_token_str_right = (None if pad_token_right is None else
-                               tokenizer.decode(pad_token_right))
-        replacement_str = "".join(
-            repeat_and_pad_token(
-                image_token_str,
-                repeat_count=repeat_count,
-                pad_token_left=pad_token_str_left,
-                pad_token_right=pad_token_str_right,
-            ))
-
-        image_token_count = prompt.count(image_token_str)
-        # This is an arbitrary number to distinguish between the two cases
-        if image_token_count > 16:
-            logger.warning("Please follow the prompt format that is "
-                           "documented on HuggingFace which does not involve "
-                           f"repeating {image_token_str} tokens.")
-        elif image_token_count > 1:
-            logger.warning("Multiple image input is not supported yet, "
-                           "so any extra image tokens will be treated "
-                           "as plain text.")
-
-        # The image tokens are removed to be consistent with HuggingFace
-        new_prompt = prompt.replace(image_token_str, replacement_str, 1)
-
-    new_token_ids: List[int] = []
-    for i, token in enumerate(prompt_token_ids):
-        if token == image_token_id:
-            replacement_ids = repeat_and_pad_token(
-                image_token_id,
-                repeat_count=repeat_count,
-                pad_token_left=pad_token_left,
-                pad_token_right=pad_token_right,
-            )
-            new_token_ids.extend(replacement_ids)
-
-            # No need to further scan the list since we only replace once
-            new_token_ids.extend(prompt_token_ids[i + 1:])
-            break
-        else:
-            new_token_ids.append(token)
-
-    return new_prompt, new_token_ids
-
 
 class ImagePlugin(MultiModalPlugin):
 

+ 88 - 1
aphrodite/multimodal/utils.py

@@ -1,15 +1,18 @@
 import base64
 import os
+from functools import lru_cache
 from io import BytesIO
-from typing import Tuple, Union
+from typing import List, Optional, Tuple, TypeVar, Union
 
 import librosa
 import numpy as np
 import soundfile
+from loguru import logger
 from PIL import Image
 
 from aphrodite.common.connections import global_http_connection
 from aphrodite.multimodal.base import MultiModalDataDict
+from aphrodite.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
 
 APHRODITE_IMAGE_FETCH_TIMEOUT = int(
     os.getenv("APHRODITE_IMAGE_FETCH_TIMEOUT", 10))
@@ -18,6 +21,9 @@ APHRODITE_AUDIO_FETCH_TIMEOUT = int(
     os.getenv("APHRODITE_AUDIO_FETCH_TIMEOUT", 10))
 
 
+cahed_get_tokenizer = lru_cache(get_tokenizer)
+
+
 def _load_image_from_bytes(b: bytes):
     image = Image.open(BytesIO(b))
     image.load()
@@ -157,3 +163,84 @@ def rescale_image_size(image: Image.Image,
     if transpose >= 0:
         image = image.transpose(Image.Transpose(transpose))
     return image
+
+
+# Utilities for input processors
+_T = TypeVar("_T", str, int)
+
+
+def repeat_and_pad_token(
+    token: _T,
+    *,
+    repeat_count: int = 1,
+    pad_token_left: Optional[_T] = None,
+    pad_token_right: Optional[_T] = None,
+) -> List[_T]:
+    replacement = [token] * repeat_count
+    if pad_token_left is not None:
+        replacement = [pad_token_left] + replacement
+    if pad_token_right is not None:
+        replacement = replacement + [pad_token_right]
+
+    return replacement
+
+
+def repeat_and_pad_placeholder_tokens(
+    tokenizer: AnyTokenizer,
+    prompt: Optional[str],
+    prompt_token_ids: List[int],
+    *,
+    placeholder_token_id: int,
+    repeat_count: int = 1,
+    pad_token_left: Optional[int] = None,
+    pad_token_right: Optional[int] = None,
+) -> Tuple[Optional[str], List[int]]:
+    if prompt is None:
+        new_prompt = None
+    else:
+        placeholder_token_str = tokenizer.decode(placeholder_token_id)
+        pad_token_str_left = (None if pad_token_left is None else
+                              tokenizer.decode(pad_token_left))
+        pad_token_str_right = (None if pad_token_right is None else
+                               tokenizer.decode(pad_token_right))
+        replacement_str = "".join(
+            repeat_and_pad_token(
+                placeholder_token_str,
+                repeat_count=repeat_count,
+                pad_token_left=pad_token_str_left,
+                pad_token_right=pad_token_str_right,
+            ))
+
+        placeholder_token_count = prompt.count(placeholder_token_str)
+        # This is an arbitrary number to distinguish between the two cases
+        if placeholder_token_count > 16:
+            logger.warning(
+                "Please follow the prompt format that is "
+                "documented on HuggingFace which does not involve "
+                "repeating %s tokens.", placeholder_token_str)
+        elif placeholder_token_count > 1:
+            logger.warning("Multiple multi-modal input is not supported yet, "
+                           "so any extra placeholder tokens will be treated "
+                           "as plain text.")
+
+        # The image tokens are removed to be consistent with HuggingFace
+        new_prompt = prompt.replace(placeholder_token_str, replacement_str, 1)
+
+    new_token_ids: List[int] = []
+    for i, token in enumerate(prompt_token_ids):
+        if token == placeholder_token_id:
+            replacement_ids = repeat_and_pad_token(
+                placeholder_token_id,
+                repeat_count=repeat_count,
+                pad_token_left=pad_token_left,
+                pad_token_right=pad_token_right,
+            )
+            new_token_ids.extend(replacement_ids)
+
+            # No need to further scan the list since we only replace once
+            new_token_ids.extend(prompt_token_ids[i + 1:])
+            break
+        else:
+            new_token_ids.append(token)
+
+    return new_prompt, new_token_ids