123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284 |
- import codecs
- import tempfile
- from dataclasses import dataclass
- from functools import lru_cache
- from pathlib import Path
- from typing import (Any, Awaitable, Iterable, List, Literal, Optional, Tuple,
- Union, cast)
- import requests
- from loguru import logger
- # yapf conflicts with isort for this block
- # yapf: disable
- from openai.types.chat import ChatCompletionContentPartImageParam
- from openai.types.chat import (
- ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam)
- from openai.types.chat import ChatCompletionContentPartTextParam
- from openai.types.chat import (
- ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
- # yapf: enable
- # pydantic needs the TypedDict from typing_extensions
- from pydantic import ConfigDict
- from transformers import PreTrainedTokenizer
- from typing_extensions import Required, TypedDict
- from aphrodite.common.config import ModelConfig
- from aphrodite.multimodal import MultiModalDataDict
- from aphrodite.multimodal.utils import (async_get_and_parse_audio,
- async_get_and_parse_image)
- from aphrodite.transformers_utils.tokenizer import AnyTokenizer
- class AudioURL(TypedDict, total=False):
- url: Required[str]
- """
- Either a URL of the audio or a data URL with base64 encoded audio data.
- """
- class ChatCompletionContentPartAudioParam(TypedDict, total=False):
- audio_url: Required[AudioURL]
- type: Required[Literal["audio_url"]]
- """The type of the content part."""
- class CustomChatCompletionContentPartParam(TypedDict, total=False):
- __pydantic_config__ = ConfigDict(extra="allow") # type: ignore
- type: Required[str]
- """The type of the content part."""
- ChatCompletionContentPartParam = Union[OpenAIChatCompletionContentPartParam,
- ChatCompletionContentPartAudioParam,
- CustomChatCompletionContentPartParam]
- class CustomChatCompletionMessageParam(TypedDict, total=False):
- """Enables custom roles in the Chat Completion API."""
- role: Required[str]
- """The role of the message's author."""
- content: Union[str, List[ChatCompletionContentPartParam]]
- """The contents of the message."""
- name: str
- """An optional name for the participant.
- Provides the model information to differentiate between participants of the
- same role.
- """
- ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam,
- CustomChatCompletionMessageParam]
- # TODO: Make fields ReadOnly once mypy supports it
- class ConversationMessage(TypedDict):
- role: str
- content: str
- @dataclass(frozen=True)
- class ChatMessageParseResult:
- messages: List[ConversationMessage]
- mm_futures: List[Awaitable[MultiModalDataDict]]
- def load_chat_template(
- chat_template: Optional[Union[Path, str]]) -> Optional[str]:
- if chat_template is None:
- return None
- try:
- chat_template_str = str(chat_template)
- if chat_template_str.startswith(('http')):
- response = requests.get(chat_template_str)
- temp = tempfile.NamedTemporaryFile(delete=False)
- temp.write(response.content)
- temp.close()
- chat_template = temp.name
- with open(chat_template, "r") as f:
- resolved_chat_template = f.read()
- except OSError as e:
- if isinstance(chat_template, Path):
- raise
- JINJA_CHARS = "{}\n"
- if not any(c in chat_template for c in JINJA_CHARS):
- msg = (f"The supplied chat template ({chat_template}) "
- "looks like a file path, but it failed to be "
- f"opened. Reason: {e}")
- raise ValueError(msg) from e
- # If opening a file fails, set chat template to be args to
- # ensure we decode so our escape are interpreted correctly
- resolved_chat_template = codecs.decode(chat_template, "unicode_escape")
- logger.info(f"Using supplied chat template:\n{resolved_chat_template}")
- return resolved_chat_template
- @lru_cache(maxsize=None)
- def _mm_token_str(model_config: ModelConfig, tokenizer: PreTrainedTokenizer,
- modality: Literal["image", "audio"]) -> Optional[str]:
- # TODO: Let user specify how to insert image tokens into prompt
- # (similar to chat template)
- model_type = model_config.hf_config.model_type
- if modality == "image":
- if model_type == "phi3_v":
- # Workaround since this token is not defined in the tokenizer
- return "<|image_1|>"
- if model_type == "minicpmv":
- return "(<image>./</image>)"
- if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"):
- # These models do not use image tokens in the prompt
- return None
- if model_type.startswith("llava"):
- return tokenizer.decode(model_config.hf_config.image_token_index)
- if model_type in ("chameleon", "internvl_chat"):
- return "<image>"
- raise TypeError(f"Unknown model type: {model_type}")
- elif modality == "audio":
- if model_type == "ultravox":
- return "<|reserved_special_token_0|>"
- raise TypeError(f"Unknown model type: {model_type}")
- else:
- raise TypeError(f"Unknown modality: {modality}")
- # TODO: Let user specify how to insert multimodal tokens into prompt
- # (similar to chat template)
- def _get_full_multimodal_text_prompt(placeholder_token_str: str,
- text_prompt: str) -> str:
- """Combine multimodal prompts for a multimodal language model"""
- # NOTE: For now we assume all model architectures use the same
- # placeholder + text prompt format. This may change in the future.
- return f"{placeholder_token_str}\n{text_prompt}"
- def _parse_chat_message_content_parts(
- role: str,
- parts: Iterable[ChatCompletionContentPartParam],
- model_config: ModelConfig,
- tokenizer: PreTrainedTokenizer,
- ) -> ChatMessageParseResult:
- texts: List[str] = []
- mm_futures: List[Awaitable[MultiModalDataDict]] = []
- modality: Literal["image", "audio"] = "image"
- for part in parts:
- part_type = part["type"]
- if part_type == "text":
- text = cast(ChatCompletionContentPartTextParam, part)["text"]
- texts.append(text)
- elif part_type == "image_url":
- modality = "image"
- if len(mm_futures) > 0:
- raise NotImplementedError(
- "Multiple multimodal inputs is currently not supported.")
- image_url = cast(ChatCompletionContentPartImageParam,
- part)["image_url"]
- if image_url.get("detail", "auto") != "auto":
- logger.warning(
- "'image_url.detail' is currently not supported and "
- "will be ignored.")
- image_future = async_get_and_parse_image(image_url["url"])
- mm_futures.append(image_future)
- elif part_type == "audio_url":
- modality = "audio"
- if len(mm_futures) > 0:
- raise NotImplementedError(
- "Multiple multimodal inputs is currently not supported.")
- audio_url = cast(ChatCompletionContentPartAudioParam,
- part)["audio_url"]
- audio_future = async_get_and_parse_audio(audio_url["url"])
- mm_futures.append(audio_future)
- else:
- raise NotImplementedError(f"Unknown part type: {part_type}")
- text_prompt = "\n".join(texts)
- if mm_futures:
- placeholder_token_str = _mm_token_str(model_config, tokenizer,
- modality)
- if placeholder_token_str is not None:
- if placeholder_token_str in text_prompt:
- logger.warning(
- "Detected multi-modal token string in the text prompt. "
- "Skipping prompt formatting.")
- else:
- text_prompt = _get_full_multimodal_text_prompt(
- placeholder_token_str=placeholder_token_str,
- text_prompt=text_prompt,
- )
- messages = [ConversationMessage(role=role, content=text_prompt)]
- return ChatMessageParseResult(messages=messages, mm_futures=mm_futures)
- def _parse_chat_message_content(
- message: ChatCompletionMessageParam,
- model_config: ModelConfig,
- tokenizer: PreTrainedTokenizer,
- ) -> ChatMessageParseResult:
- role = message["role"]
- content = message.get("content")
- if content is None:
- return ChatMessageParseResult(messages=[], mm_futures=[])
- if isinstance(content, str):
- messages = [ConversationMessage(role=role, content=content)]
- return ChatMessageParseResult(messages=messages, mm_futures=[])
- return _parse_chat_message_content_parts(role, content, model_config,
- tokenizer)
- def parse_chat_messages(
- messages: List[ChatCompletionMessageParam],
- model_config: ModelConfig,
- tokenizer: PreTrainedTokenizer,
- ) -> Tuple[List[ConversationMessage], List[Awaitable[MultiModalDataDict]]]:
- conversation: List[ConversationMessage] = []
- mm_futures: List[Awaitable[MultiModalDataDict]] = []
- for msg in messages:
- parse_result = _parse_chat_message_content(msg, model_config,
- tokenizer)
- conversation.extend(parse_result.messages)
- mm_futures.extend(parse_result.mm_futures)
- return conversation, mm_futures
- def apply_chat_template(
- tokenizer: AnyTokenizer,
- conversation: List[ConversationMessage],
- chat_template: Optional[str],
- *,
- tokenize: bool = False, # Different from HF's default
- **kwargs: Any,
- ) -> Union[str, List[int]]:
- if chat_template is None and tokenizer.chat_template is None:
- raise ValueError(
- "As of transformers v4.44, default chat template is no longer "
- "allowed, so you must provide a chat template if the tokenizer "
- "does not define one.")
- prompt = tokenizer.apply_chat_template(
- conversation=conversation,
- chat_template=chat_template,
- tokenize=tokenize,
- **kwargs,
- )
- return prompt
|