import codecs
import json
import tempfile
from dataclasses import dataclass
from functools import lru_cache, partial
from pathlib import Path
from typing import (Any, Awaitable, Dict, Iterable, List, Literal, Optional,
Tuple, Union, cast)
import requests
from loguru import logger
# yapf conflicts with isort for this block
# yapf: disable
from openai.types.chat import (ChatCompletionAssistantMessageParam,
ChatCompletionContentPartImageParam)
from openai.types.chat import (
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam)
from openai.types.chat import (ChatCompletionContentPartRefusalParam,
ChatCompletionContentPartTextParam)
from openai.types.chat import (
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
from openai.types.chat import (ChatCompletionMessageToolCallParam,
ChatCompletionToolMessageParam)
# yapf: enable
# pydantic needs the TypedDict from typing_extensions
from pydantic import ConfigDict
from typing_extensions import Required, TypeAlias, TypedDict
from aphrodite.common.config import ModelConfig
from aphrodite.multimodal import MultiModalDataDict
from aphrodite.multimodal.utils import (async_get_and_parse_audio,
async_get_and_parse_image)
from aphrodite.transformers_utils.tokenizer import AnyTokenizer
class AudioURL(TypedDict, total=False):
url: Required[str]
"""
Either a URL of the audio or a data URL with base64 encoded audio data.
"""
class ChatCompletionContentPartAudioParam(TypedDict, total=False):
audio_url: Required[AudioURL]
type: Required[Literal["audio_url"]]
"""The type of the content part."""
class CustomChatCompletionContentPartParam(TypedDict, total=False):
__pydantic_config__ = ConfigDict(extra="allow") # type: ignore
type: Required[str]
"""The type of the content part."""
ChatCompletionContentPartParam: TypeAlias = Union[
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
ChatCompletionContentPartRefusalParam,
CustomChatCompletionContentPartParam]
class CustomChatCompletionMessageParam(TypedDict, total=False):
"""Enables custom roles in the Chat Completion API."""
role: Required[str]
"""The role of the message's author."""
content: Union[str, List[ChatCompletionContentPartParam]]
"""The contents of the message."""
name: str
"""An optional name for the participant.
Provides the model information to differentiate between participants of the
same role.
"""
tool_call_id: Optional[str]
"""Tool call that this message is responding to."""
tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
"""The tool calls generated by the model, such as function calls."""
ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam,
CustomChatCompletionMessageParam]
# TODO: Make fields ReadOnly once mypy supports it
class ConversationMessage(TypedDict, total=False):
role: Required[str]
"""The role of the message's author."""
content: Optional[str]
"""The contents of the message"""
tool_call_id: Optional[str]
"""Tool call that this message is responding to."""
name: Optional[str]
"""The name of the function to call"""
tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
"""The tool calls generated by the model, such as function calls."""
@dataclass(frozen=True)
class ChatMessageParseResult:
messages: List[ConversationMessage]
mm_futures: List[Awaitable[MultiModalDataDict]]
def load_chat_template(
chat_template: Optional[Union[Path, str]]) -> Optional[str]:
if chat_template is None:
return None
try:
chat_template_str = str(chat_template)
if chat_template_str.startswith(('http')):
response = requests.get(chat_template_str)
temp = tempfile.NamedTemporaryFile(delete=False)
temp.write(response.content)
temp.close()
chat_template = temp.name
with open(chat_template, "r") as f:
resolved_chat_template = f.read()
except OSError as e:
if isinstance(chat_template, Path):
raise
JINJA_CHARS = "{}\n"
if not any(c in chat_template for c in JINJA_CHARS):
msg = (f"The supplied chat template ({chat_template}) "
"looks like a file path, but it failed to be "
f"opened. Reason: {e}")
raise ValueError(msg) from e
# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
resolved_chat_template = codecs.decode(chat_template, "unicode_escape")
logger.info(f"Using supplied chat template:\n{resolved_chat_template}")
return resolved_chat_template
@lru_cache(maxsize=None)
def _mm_token_str(model_config: ModelConfig, tokenizer: AnyTokenizer,
modality: Literal["image", "audio"]) -> Optional[str]:
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
if modality == "image":
model_type = model_config.hf_config.model_type
if model_type == "phi3_v":
# Workaround since this token is not defined in the tokenizer
return "<|image_1|>"
if model_type == "minicpmv":
return "(./)"
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"):
# These models do not use image tokens in the prompt
return None
if model_type.startswith("llava"):
return tokenizer.decode(model_config.hf_config.image_token_index)
if model_type in ("chameleon", "internvl_chat"):
return ""
raise TypeError(f"Unknown model type: {model_type}")
elif modality == "audio":
raise TypeError("No audio models are supported yet.")
else:
raise TypeError(f"Unknown modality: {modality}")
# TODO: Let user specify how to insert multimodal tokens into prompt
# (similar to chat template)
def _get_full_multimodal_text_prompt(placeholder_token_str: str,
text_prompt: str) -> str:
"""Combine multimodal prompts for a multimodal language model"""
# NOTE: For now we assume all model architectures use the same
# placeholder + text prompt format. This may change in the future.
return f"{placeholder_token_str}\n{text_prompt}"
def _parse_chat_message_content_parts(
role: str,
parts: Iterable[ChatCompletionContentPartParam],
model_config: ModelConfig,
tokenizer: AnyTokenizer,
) -> ChatMessageParseResult:
texts: List[str] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []
modality: Literal["image", "audio"] = "image"
for part in parts:
part_type = part["type"]
if part_type == "text":
text = partial(cast(ChatCompletionContentPartTextParam, part)
)["text"]
texts.append(text)
elif part_type == "image_url":
modality = "image"
if len(mm_futures) > 0:
raise NotImplementedError(
"Multiple multimodal inputs is currently not supported.")
image_url = cast(ChatCompletionContentPartImageParam,
part)["image_url"]
if image_url.get("detail", "auto") != "auto":
logger.warning(
"'image_url.detail' is currently not supported and "
"will be ignored.")
image_future = async_get_and_parse_image(image_url["url"])
mm_futures.append(image_future)
elif part_type == "audio_url":
modality = "audio"
if len(mm_futures) > 0:
raise NotImplementedError(
"Multiple multimodal inputs is currently not supported.")
audio_url = cast(ChatCompletionContentPartAudioParam,
part)["audio_url"]
audio_future = async_get_and_parse_audio(audio_url["url"])
mm_futures.append(audio_future)
else:
raise NotImplementedError(f"Unknown part type: {part_type}")
text_prompt = "\n".join(texts)
if mm_futures:
placeholder_token_str = _mm_token_str(model_config, tokenizer,
modality)
if placeholder_token_str is not None:
if placeholder_token_str in text_prompt:
logger.warning(
"Detected multi-modal token string in the text prompt. "
"Skipping prompt formatting.")
else:
text_prompt = _get_full_multimodal_text_prompt(
placeholder_token_str=placeholder_token_str,
text_prompt=text_prompt,
)
messages = [ConversationMessage(role=role, content=text_prompt)]
return ChatMessageParseResult(messages=messages, mm_futures=mm_futures)
# No need to validate using Pydantic again
_AssistantParser = partial(cast, ChatCompletionAssistantMessageParam)
_ToolParser = partial(cast, ChatCompletionToolMessageParam)
def _parse_chat_message_content(
message: ChatCompletionMessageParam,
model_config: ModelConfig,
tokenizer: AnyTokenizer,
) -> ChatMessageParseResult:
role = message["role"]
content = message.get("content")
if content is None:
content = []
elif isinstance(content, str):
content = [
ChatCompletionContentPartTextParam(type="text", text=content)
]
result = _parse_chat_message_content_parts(role, content,
model_config, tokenizer)
for result_msg in result:
if role == 'assistant':
parsed_msg = _AssistantParser(message)
if "tool_calls" in parsed_msg:
result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
elif role == "tool":
parsed_msg = _ToolParser(message)
if "tool_call_id" in parsed_msg:
result_msg["tool_call_id"] = parsed_msg["tool_call_id"]
if "name" in message and isinstance(message["name"], str):
result_msg["name"] = message["name"]
return result
def parse_chat_messages(
messages: List[ChatCompletionMessageParam],
model_config: ModelConfig,
tokenizer: AnyTokenizer,
) -> Tuple[List[ConversationMessage], List[Awaitable[MultiModalDataDict]]]:
conversation: List[ConversationMessage] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []
for msg in messages:
parse_result = _parse_chat_message_content(msg, model_config,
tokenizer)
conversation.extend(parse_result.messages)
mm_futures.extend(parse_result.mm_futures)
return conversation, mm_futures
def apply_chat_template(
tokenizer: AnyTokenizer,
conversation: List[ConversationMessage],
chat_template: Optional[str],
*,
tokenize: bool = False, # Different from HF's default
**kwargs: Any,
) -> Union[str, List[int]]:
if chat_template is None and tokenizer.chat_template is None:
raise ValueError(
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one.")
# per the Transformers docs & maintainers, tool call arguments in
# assistant-role messages with tool_calls need to be dicts not JSON str -
# this is how tool-use chat templates will expect them moving forwards
# so, for messages that have tool_calls, parse the string (which we get
# from openAI format) to dict
for message in conversation:
if (message["role"] == "assistant" and "tool_calls" in message
and isinstance(message["tool_calls"], list)):
for i in range(len(message["tool_calls"])):
args: str = message["tool_calls"][i]["function"]["arguments"]
parsed_args: Dict = json.loads(args)
message["tool_calls"][i]["function"]["arguments"] = parsed_args
prompt = tokenizer.apply_chat_template(
conversation=conversation,
chat_template=chat_template,
tokenize=tokenize,
**kwargs,
)
return prompt