import asyncio import codecs import json from abc import ABC, abstractmethod from collections import defaultdict from functools import lru_cache, partial from pathlib import Path from typing import (Any, Awaitable, Dict, Generic, Iterable, List, Literal, Mapping, Optional, Tuple, TypeVar, Union, cast) from loguru import logger # yapf conflicts with isort for this block # yapf: disable from openai.types.chat import (ChatCompletionAssistantMessageParam, ChatCompletionContentPartImageParam) from openai.types.chat import ( ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam) from openai.types.chat import (ChatCompletionContentPartRefusalParam, ChatCompletionContentPartTextParam) from openai.types.chat import ( ChatCompletionMessageParam as OpenAIChatCompletionMessageParam) from openai.types.chat import (ChatCompletionMessageToolCallParam, ChatCompletionToolMessageParam) # yapf: enable # pydantic needs the TypedDict from typing_extensions from pydantic import ConfigDict from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from typing_extensions import Required, TypeAlias, TypedDict from aphrodite.common.config import ModelConfig from aphrodite.multimodal import MultiModalDataDict from aphrodite.multimodal.utils import (async_get_and_parse_audio, async_get_and_parse_image, get_and_parse_audio, get_and_parse_image) from aphrodite.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer) class AudioURL(TypedDict, total=False): url: Required[str] """ Either a URL of the audio or a data URL with base64 encoded audio data. """ class ChatCompletionContentPartAudioParam(TypedDict, total=False): audio_url: Required[AudioURL] type: Required[Literal["audio_url"]] """The type of the content part.""" class CustomChatCompletionContentPartParam(TypedDict, total=False): __pydantic_config__ = ConfigDict(extra="allow") # type: ignore type: Required[str] """The type of the content part.""" ChatCompletionContentPartParam: TypeAlias = Union[ OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam, ChatCompletionContentPartRefusalParam, CustomChatCompletionContentPartParam] class CustomChatCompletionMessageParam(TypedDict, total=False): """Enables custom roles in the Chat Completion API.""" role: Required[str] """The role of the message's author.""" content: Union[str, List[ChatCompletionContentPartParam]] """The contents of the message.""" name: str """An optional name for the participant. Provides the model information to differentiate between participants of the same role. """ tool_call_id: Optional[str] """Tool call that this message is responding to.""" tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]] """The tool calls generated by the model, such as function calls.""" ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam, CustomChatCompletionMessageParam] # TODO: Make fields ReadOnly once mypy supports it class ConversationMessage(TypedDict, total=False): role: Required[str] """The role of the message's author.""" content: Optional[str] """The contents of the message""" tool_call_id: Optional[str] """Tool call that this message is responding to.""" name: Optional[str] """The name of the function to call""" tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]] """The tool calls generated by the model, such as function calls.""" ModalityStr = Literal["image", "audio", "video"] _T = TypeVar("_T") class BaseMultiModalItemTracker(ABC, Generic[_T]): """ Tracks multi-modal items in a given request and ensures that the number of multi-modal items in a given request does not exceed the configured maximum per prompt. """ def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer): super().__init__() self._model_config = model_config self._tokenizer = tokenizer self._allowed_items = (model_config.multimodal_config.limit_per_prompt if model_config.multimodal_config else {}) self._consumed_items = {k: 0 for k in self._allowed_items} self._items: List[_T] = [] @staticmethod @lru_cache(maxsize=None) def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str: return tokenizer.decode(token_index) def _placeholder_str(self, modality: ModalityStr, current_count: int) -> Optional[str]: # TODO: Let user specify how to insert image tokens into prompt # (similar to chat template) hf_config = self._model_config.hf_config model_type = hf_config.model_type if modality == "image": if model_type == "phi3_v": # Workaround since this token is not defined in the tokenizer return f"<|image_{current_count}|>" if model_type == "minicpmv": return "(./)" if model_type in ("blip-2", "chatglm", "fuyu", "paligemma", "pixtral"): # These models do not use image tokens in the prompt return None if model_type == "qwen": return f"Picture {current_count}: " if model_type.startswith("llava"): return self._cached_token_str(self._tokenizer, hf_config.image_token_index) if model_type in ("chameleon", "internvl_chat"): return "" if model_type == "qwen2_vl": return "<|vision_start|><|image_pad|><|vision_end|>" raise TypeError(f"Unknown model type: {model_type}") elif modality == "audio": if model_type == "ultravox": return "<|reserved_special_token_0|>" raise TypeError(f"Unknown model type: {model_type}") elif modality == "video": if model_type == "qwen2_vl": return "<|vision_start|><|video_pad|><|vision_end|>" raise TypeError(f"Unknown model type: {model_type}") else: raise TypeError(f"Unknown modality: {modality}") @staticmethod def _combine(items: List[MultiModalDataDict]) -> MultiModalDataDict: mm_lists: Mapping[str, List[object]] = defaultdict(list) # Merge all the multi-modal items for single_mm_data in items: for mm_key, mm_item in single_mm_data.items(): if isinstance(mm_item, list): mm_lists[mm_key].extend(mm_item) else: mm_lists[mm_key].append(mm_item) # Unpack any single item lists for models that don't expect multiple. return { mm_key: mm_list[0] if len(mm_list) == 1 else mm_list for mm_key, mm_list in mm_lists.items() } def add(self, modality: ModalityStr, item: _T) -> Optional[str]: """ Add a multi-modal item to the current prompt and returns the placeholder string to use, if any. """ allowed_count = self._allowed_items.get(modality, 1) current_count = self._consumed_items.get(modality, 0) + 1 if current_count > allowed_count: raise ValueError( f"At most {allowed_count} {modality}(s) may be provided in " "one request.") self._consumed_items[modality] = current_count self._items.append(item) return self._placeholder_str(modality, current_count) @abstractmethod def create_parser(self) -> "BaseMultiModalContentParser": raise NotImplementedError class MultiModalItemTracker(BaseMultiModalItemTracker[MultiModalDataDict]): def all_mm_data(self) -> Optional[MultiModalDataDict]: return self._combine(self._items) if self._items else None def create_parser(self) -> "BaseMultiModalContentParser": return MultiModalContentParser(self) class AsyncMultiModalItemTracker( BaseMultiModalItemTracker[Awaitable[MultiModalDataDict]]): async def all_mm_data(self) -> Optional[MultiModalDataDict]: if self._items: items = await asyncio.gather(*self._items) return self._combine(items) return None def create_parser(self) -> "BaseMultiModalContentParser": return AsyncMultiModalContentParser(self) class BaseMultiModalContentParser(ABC): def __init__(self) -> None: super().__init__() # multimodal placeholder_string : count self._placeholder_counts: Dict[str, int] = defaultdict(lambda: 0) def _add_placeholder(self, placeholder: Optional[str]): if placeholder: self._placeholder_counts[placeholder] += 1 def mm_placeholder_counts(self) -> Dict[str, int]: return dict(self._placeholder_counts) @abstractmethod def parse_image(self, image_url: str) -> None: raise NotImplementedError @abstractmethod def parse_audio(self, audio_url: str) -> None: raise NotImplementedError class MultiModalContentParser(BaseMultiModalContentParser): def __init__(self, tracker: MultiModalItemTracker) -> None: super().__init__() self._tracker = tracker def parse_image(self, image_url: str) -> None: image = get_and_parse_image(image_url) placeholder = self._tracker.add("image", image) self._add_placeholder(placeholder) def parse_audio(self, audio_url: str) -> None: audio = get_and_parse_audio(audio_url) placeholder = self._tracker.add("audio", audio) self._add_placeholder(placeholder) class AsyncMultiModalContentParser(BaseMultiModalContentParser): def __init__(self, tracker: AsyncMultiModalItemTracker) -> None: super().__init__() self._tracker = tracker def parse_image(self, image_url: str) -> None: image_coro = async_get_and_parse_image(image_url) placeholder = self._tracker.add("image", image_coro) self._add_placeholder(placeholder) def parse_audio(self, audio_url: str) -> None: audio_coro = async_get_and_parse_audio(audio_url) placeholder = self._tracker.add("audio", audio_coro) self._add_placeholder(placeholder) def load_chat_template( chat_template: Optional[Union[Path, str]]) -> Optional[str]: if chat_template is None: return None try: with open(chat_template, "r") as f: resolved_chat_template = f.read() except OSError as e: if isinstance(chat_template, Path): raise JINJA_CHARS = "{}\n" if not any(c in chat_template for c in JINJA_CHARS): msg = (f"The supplied chat template ({chat_template}) " f"looks like a file path, but it failed to be " f"opened. Reason: {e}") raise ValueError(msg) from e # If opening a file fails, set chat template to be args to # ensure we decode so our escape are interpreted correctly resolved_chat_template = codecs.decode(chat_template, "unicode_escape") logger.info("Using supplied chat template:\n%s", resolved_chat_template) return resolved_chat_template # TODO: Let user specify how to insert multimodal tokens into prompt # (similar to chat template) def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int], text_prompt: str) -> str: """Combine multimodal prompts for a multimodal language model.""" # Look through the text prompt to check for missing placeholders missing_placeholders: List[str] = [] for placeholder in placeholder_counts: # For any existing placeholder in the text prompt, we leave it as is placeholder_counts[placeholder] -= text_prompt.count(placeholder) if placeholder_counts[placeholder] < 0: raise ValueError( f"Found more '{placeholder}' placeholders in input prompt than " "actual multimodal data items.") missing_placeholders.extend([placeholder] * placeholder_counts[placeholder]) # NOTE: For now we always add missing placeholders at the front of # the prompt. This may change to be customizable in the future. return "\n".join(missing_placeholders + [text_prompt]) # No need to validate using Pydantic again _TextParser = partial(cast, ChatCompletionContentPartTextParam) _ImageParser = partial(cast, ChatCompletionContentPartImageParam) _AudioParser = partial(cast, ChatCompletionContentPartAudioParam) _RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) def _parse_chat_message_content_parts( role: str, parts: Iterable[ChatCompletionContentPartParam], mm_tracker: BaseMultiModalItemTracker, ) -> List[ConversationMessage]: texts: List[str] = [] mm_parser = mm_tracker.create_parser() for part in parts: part_type = part["type"] if part_type == "text": text = _TextParser(part)["text"] texts.append(text) elif part_type == "image_url": image_url = _ImageParser(part)["image_url"] if image_url.get("detail", "auto") != "auto": logger.warning( "'image_url.detail' is currently not supported and " "will be ignored.") mm_parser.parse_image(image_url["url"]) elif part_type == "audio_url": audio_url = _AudioParser(part)["audio_url"] mm_parser.parse_audio(audio_url["url"]) elif part_type == "refusal": text = _RefusalParser(part)["refusal"] texts.append(text) else: raise NotImplementedError(f"Unknown part type: {part_type}") text_prompt = "\n".join(texts) mm_placeholder_counts = mm_parser.mm_placeholder_counts() if mm_placeholder_counts: text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts, text_prompt) return [ConversationMessage(role=role, content=text_prompt)] # No need to validate using Pydantic again _AssistantParser = partial(cast, ChatCompletionAssistantMessageParam) _ToolParser = partial(cast, ChatCompletionToolMessageParam) def _parse_chat_message_content( message: ChatCompletionMessageParam, mm_tracker: BaseMultiModalItemTracker, ) -> List[ConversationMessage]: role = message["role"] content = message.get("content") if content is None: content = [] elif isinstance(content, str): content = [ ChatCompletionContentPartTextParam(type="text", text=content) ] result = _parse_chat_message_content_parts( role, content, # type: ignore mm_tracker, ) for result_msg in result: if role == 'assistant': parsed_msg = _AssistantParser(message) if "tool_calls" in parsed_msg: result_msg["tool_calls"] = list(parsed_msg["tool_calls"]) elif role == "tool": parsed_msg = _ToolParser(message) if "tool_call_id" in parsed_msg: result_msg["tool_call_id"] = parsed_msg["tool_call_id"] if "name" in message and isinstance(message["name"], str): result_msg["name"] = message["name"] return result def _postprocess_messages(messages: List[ConversationMessage]) -> None: # per the Transformers docs & maintainers, tool call arguments in # assistant-role messages with tool_calls need to be dicts not JSON str - # this is how tool-use chat templates will expect them moving forwards # so, for messages that have tool_calls, parse the string (which we get # from openAI format) to dict for message in messages: if (message["role"] == "assistant" and "tool_calls" in message and isinstance(message["tool_calls"], list)): for item in message["tool_calls"]: item["function"]["arguments"] = json.loads( item["function"]["arguments"]) def parse_chat_messages( messages: List[ChatCompletionMessageParam], model_config: ModelConfig, tokenizer: AnyTokenizer, ) -> Tuple[List[ConversationMessage], Optional[MultiModalDataDict]]: conversation: List[ConversationMessage] = [] mm_tracker = MultiModalItemTracker(model_config, tokenizer) for msg in messages: sub_messages = _parse_chat_message_content(msg, mm_tracker) conversation.extend(sub_messages) _postprocess_messages(conversation) return conversation, mm_tracker.all_mm_data() def parse_chat_messages_futures( messages: List[ChatCompletionMessageParam], model_config: ModelConfig, tokenizer: AnyTokenizer, ) -> Tuple[List[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]: conversation: List[ConversationMessage] = [] mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer) for msg in messages: sub_messages = _parse_chat_message_content(msg, mm_tracker) conversation.extend(sub_messages) _postprocess_messages(conversation) return conversation, mm_tracker.all_mm_data() def apply_hf_chat_template( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], conversation: List[ConversationMessage], chat_template: Optional[str], *, tokenize: bool = False, # Different from HF's default **kwargs: Any, ) -> str: if chat_template is None and tokenizer.chat_template is None: raise ValueError( "As of transformers v4.44, default chat template is no longer " "allowed, so you must provide a chat template if the tokenizer " "does not define one.") return tokenizer.apply_chat_template( conversation=conversation, # type: ignore[arg-type] chat_template=chat_template, tokenize=tokenize, **kwargs, ) def apply_mistral_chat_template( tokenizer: MistralTokenizer, messages: List[ChatCompletionMessageParam], chat_template: Optional[str] = None, **kwargs: Any, ) -> List[int]: if chat_template is not None: logger.warning( "'chat_template' cannot be overridden for mistral tokenizer.") return tokenizer.apply_chat_template( messages=messages, **kwargs, )