123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535 |
- import asyncio
- import codecs
- import json
- from abc import ABC, abstractmethod
- from collections import defaultdict
- from functools import lru_cache, partial
- from pathlib import Path
- from typing import (Any, Awaitable, Dict, Generic, Iterable, List, Literal,
- Mapping, Optional, Tuple, TypeVar, Union, cast)
- from loguru import logger
- # yapf conflicts with isort for this block
- # yapf: disable
- from openai.types.chat import (ChatCompletionAssistantMessageParam,
- ChatCompletionContentPartImageParam)
- from openai.types.chat import (
- ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam)
- from openai.types.chat import (ChatCompletionContentPartRefusalParam,
- ChatCompletionContentPartTextParam)
- from openai.types.chat import (
- ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
- from openai.types.chat import (ChatCompletionMessageToolCallParam,
- ChatCompletionToolMessageParam)
- # yapf: enable
- # pydantic needs the TypedDict from typing_extensions
- from pydantic import ConfigDict
- from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
- from typing_extensions import Required, TypeAlias, TypedDict
- from aphrodite.common.config import ModelConfig
- from aphrodite.multimodal import MultiModalDataDict
- from aphrodite.multimodal.utils import (async_get_and_parse_audio,
- async_get_and_parse_image,
- get_and_parse_audio,
- get_and_parse_image)
- from aphrodite.transformers_utils.tokenizer import (AnyTokenizer,
- MistralTokenizer)
- class AudioURL(TypedDict, total=False):
- url: Required[str]
- """
- Either a URL of the audio or a data URL with base64 encoded audio data.
- """
- class ChatCompletionContentPartAudioParam(TypedDict, total=False):
- audio_url: Required[AudioURL]
- type: Required[Literal["audio_url"]]
- """The type of the content part."""
- class CustomChatCompletionContentPartParam(TypedDict, total=False):
- __pydantic_config__ = ConfigDict(extra="allow") # type: ignore
- type: Required[str]
- """The type of the content part."""
- ChatCompletionContentPartParam: TypeAlias = Union[
- OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
- ChatCompletionContentPartRefusalParam,
- CustomChatCompletionContentPartParam]
- class CustomChatCompletionMessageParam(TypedDict, total=False):
- """Enables custom roles in the Chat Completion API."""
- role: Required[str]
- """The role of the message's author."""
- content: Union[str, List[ChatCompletionContentPartParam]]
- """The contents of the message."""
- name: str
- """An optional name for the participant.
- Provides the model information to differentiate between participants of the
- same role.
- """
- tool_call_id: Optional[str]
- """Tool call that this message is responding to."""
- tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
- """The tool calls generated by the model, such as function calls."""
- ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam,
- CustomChatCompletionMessageParam]
- # TODO: Make fields ReadOnly once mypy supports it
- class ConversationMessage(TypedDict, total=False):
- role: Required[str]
- """The role of the message's author."""
- content: Optional[str]
- """The contents of the message"""
- tool_call_id: Optional[str]
- """Tool call that this message is responding to."""
- name: Optional[str]
- """The name of the function to call"""
- tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
- """The tool calls generated by the model, such as function calls."""
- ModalityStr = Literal["image", "audio", "video"]
- _T = TypeVar("_T")
- class BaseMultiModalItemTracker(ABC, Generic[_T]):
- """
- Tracks multi-modal items in a given request and ensures that the number
- of multi-modal items in a given request does not exceed the configured
- maximum per prompt.
- """
- def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer):
- super().__init__()
- self._model_config = model_config
- self._tokenizer = tokenizer
- self._allowed_items = (model_config.multimodal_config.limit_per_prompt
- if model_config.multimodal_config else {})
- self._consumed_items = {k: 0 for k in self._allowed_items}
- self._items: List[_T] = []
- @staticmethod
- @lru_cache(maxsize=None)
- def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str:
- return tokenizer.decode(token_index)
- def _placeholder_str(self, modality: ModalityStr,
- current_count: int) -> Optional[str]:
- # TODO: Let user specify how to insert image tokens into prompt
- # (similar to chat template)
- hf_config = self._model_config.hf_config
- model_type = hf_config.model_type
- if modality == "image":
- if model_type == "phi3_v":
- # Workaround since this token is not defined in the tokenizer
- return f"<|image_{current_count}|>"
- if model_type == "minicpmv":
- return "(<image>./</image>)"
- if model_type in ("blip-2", "chatglm", "fuyu", "paligemma",
- "pixtral"):
- # These models do not use image tokens in the prompt
- return None
- if model_type == "qwen":
- return f"Picture {current_count}: <img></img>"
- if model_type.startswith("llava"):
- return self._cached_token_str(self._tokenizer,
- hf_config.image_token_index)
- if model_type in ("chameleon", "internvl_chat"):
- return "<image>"
- if model_type == "qwen2_vl":
- return "<|vision_start|><|image_pad|><|vision_end|>"
- if model_type == "molmo":
- return ""
- raise TypeError(f"Unknown model type: {model_type}")
- elif modality == "audio":
- if model_type == "ultravox":
- return "<|reserved_special_token_0|>"
- raise TypeError(f"Unknown model type: {model_type}")
- elif modality == "video":
- if model_type == "qwen2_vl":
- return "<|vision_start|><|video_pad|><|vision_end|>"
- raise TypeError(f"Unknown model type: {model_type}")
- else:
- raise TypeError(f"Unknown modality: {modality}")
- @staticmethod
- def _combine(items: List[MultiModalDataDict]) -> MultiModalDataDict:
- mm_lists: Mapping[str, List[object]] = defaultdict(list)
- # Merge all the multi-modal items
- for single_mm_data in items:
- for mm_key, mm_item in single_mm_data.items():
- if isinstance(mm_item, list):
- mm_lists[mm_key].extend(mm_item)
- else:
- mm_lists[mm_key].append(mm_item)
- # Unpack any single item lists for models that don't expect multiple.
- return {
- mm_key: mm_list[0] if len(mm_list) == 1 else mm_list
- for mm_key, mm_list in mm_lists.items()
- }
- def add(self, modality: ModalityStr, item: _T) -> Optional[str]:
- """
- Add a multi-modal item to the current prompt and returns the
- placeholder string to use, if any.
- """
- allowed_count = self._allowed_items.get(modality, 1)
- current_count = self._consumed_items.get(modality, 0) + 1
- if current_count > allowed_count:
- raise ValueError(
- f"At most {allowed_count} {modality}(s) may be provided in "
- "one request.")
- self._consumed_items[modality] = current_count
- self._items.append(item)
- return self._placeholder_str(modality, current_count)
- @abstractmethod
- def create_parser(self) -> "BaseMultiModalContentParser":
- raise NotImplementedError
- class MultiModalItemTracker(BaseMultiModalItemTracker[MultiModalDataDict]):
- def all_mm_data(self) -> Optional[MultiModalDataDict]:
- return self._combine(self._items) if self._items else None
- def create_parser(self) -> "BaseMultiModalContentParser":
- return MultiModalContentParser(self)
- class AsyncMultiModalItemTracker(
- BaseMultiModalItemTracker[Awaitable[MultiModalDataDict]]):
- async def all_mm_data(self) -> Optional[MultiModalDataDict]:
- if self._items:
- items = await asyncio.gather(*self._items)
- return self._combine(items)
- return None
- def create_parser(self) -> "BaseMultiModalContentParser":
- return AsyncMultiModalContentParser(self)
- class BaseMultiModalContentParser(ABC):
- def __init__(self) -> None:
- super().__init__()
- # multimodal placeholder_string : count
- self._placeholder_counts: Dict[str, int] = defaultdict(lambda: 0)
- def _add_placeholder(self, placeholder: Optional[str]):
- if placeholder:
- self._placeholder_counts[placeholder] += 1
- def mm_placeholder_counts(self) -> Dict[str, int]:
- return dict(self._placeholder_counts)
- @abstractmethod
- def parse_image(self, image_url: str) -> None:
- raise NotImplementedError
- @abstractmethod
- def parse_audio(self, audio_url: str) -> None:
- raise NotImplementedError
- class MultiModalContentParser(BaseMultiModalContentParser):
- def __init__(self, tracker: MultiModalItemTracker) -> None:
- super().__init__()
- self._tracker = tracker
- def parse_image(self, image_url: str) -> None:
- image = get_and_parse_image(image_url)
- placeholder = self._tracker.add("image", image)
- self._add_placeholder(placeholder)
- def parse_audio(self, audio_url: str) -> None:
- audio = get_and_parse_audio(audio_url)
- placeholder = self._tracker.add("audio", audio)
- self._add_placeholder(placeholder)
- class AsyncMultiModalContentParser(BaseMultiModalContentParser):
- def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
- super().__init__()
- self._tracker = tracker
- def parse_image(self, image_url: str) -> None:
- image_coro = async_get_and_parse_image(image_url)
- placeholder = self._tracker.add("image", image_coro)
- self._add_placeholder(placeholder)
- def parse_audio(self, audio_url: str) -> None:
- audio_coro = async_get_and_parse_audio(audio_url)
- placeholder = self._tracker.add("audio", audio_coro)
- self._add_placeholder(placeholder)
- def load_chat_template(
- chat_template: Optional[Union[Path, str]]) -> Optional[str]:
- if chat_template is None:
- return None
- try:
- with open(chat_template, "r") as f:
- resolved_chat_template = f.read()
- except OSError as e:
- if isinstance(chat_template, Path):
- raise
- JINJA_CHARS = "{}\n"
- if not any(c in chat_template for c in JINJA_CHARS):
- msg = (f"The supplied chat template ({chat_template}) "
- f"looks like a file path, but it failed to be "
- f"opened. Reason: {e}")
- raise ValueError(msg) from e
- # If opening a file fails, set chat template to be args to
- # ensure we decode so our escape are interpreted correctly
- resolved_chat_template = codecs.decode(chat_template, "unicode_escape")
- logger.info(f"Using supplied chat template:\n{resolved_chat_template}")
- return resolved_chat_template
- # TODO: Let user specify how to insert multimodal tokens into prompt
- # (similar to chat template)
- def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],
- text_prompt: str) -> str:
- """Combine multimodal prompts for a multimodal language model."""
- # Look through the text prompt to check for missing placeholders
- missing_placeholders: List[str] = []
- for placeholder in placeholder_counts:
- # For any existing placeholder in the text prompt, we leave it as is
- placeholder_counts[placeholder] -= text_prompt.count(placeholder)
- if placeholder_counts[placeholder] < 0:
- raise ValueError(
- f"Found more '{placeholder}' placeholders in input prompt than "
- "actual multimodal data items.")
- missing_placeholders.extend([placeholder] *
- placeholder_counts[placeholder])
- # NOTE: For now we always add missing placeholders at the front of
- # the prompt. This may change to be customizable in the future.
- return "\n".join(missing_placeholders + [text_prompt])
- # No need to validate using Pydantic again
- _TextParser = partial(cast, ChatCompletionContentPartTextParam)
- _ImageParser = partial(cast, ChatCompletionContentPartImageParam)
- _AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
- _RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
- def _parse_chat_message_content_parts(
- role: str,
- parts: Iterable[ChatCompletionContentPartParam],
- mm_tracker: BaseMultiModalItemTracker,
- ) -> List[ConversationMessage]:
- texts: List[str] = []
- mm_parser = mm_tracker.create_parser()
- for part in parts:
- part_type = part["type"]
- if part_type == "text":
- text = _TextParser(part)["text"]
- texts.append(text)
- elif part_type == "image_url":
- image_url = _ImageParser(part)["image_url"]
- if image_url.get("detail", "auto") != "auto":
- logger.warning(
- "'image_url.detail' is currently not supported and "
- "will be ignored.")
- mm_parser.parse_image(image_url["url"])
- elif part_type == "audio_url":
- audio_url = _AudioParser(part)["audio_url"]
- mm_parser.parse_audio(audio_url["url"])
- elif part_type == "refusal":
- text = _RefusalParser(part)["refusal"]
- texts.append(text)
- else:
- raise NotImplementedError(f"Unknown part type: {part_type}")
- text_prompt = "\n".join(texts)
- mm_placeholder_counts = mm_parser.mm_placeholder_counts()
- if mm_placeholder_counts:
- text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts,
- text_prompt)
- return [ConversationMessage(role=role, content=text_prompt)]
- # No need to validate using Pydantic again
- _AssistantParser = partial(cast, ChatCompletionAssistantMessageParam)
- _ToolParser = partial(cast, ChatCompletionToolMessageParam)
- def _parse_chat_message_content(
- message: ChatCompletionMessageParam,
- mm_tracker: BaseMultiModalItemTracker,
- ) -> List[ConversationMessage]:
- role = message["role"]
- content = message.get("content")
- if content is None:
- content = []
- elif isinstance(content, str):
- content = [
- ChatCompletionContentPartTextParam(type="text", text=content)
- ]
- result = _parse_chat_message_content_parts(
- role,
- content, # type: ignore
- mm_tracker,
- )
- for result_msg in result:
- if role == 'assistant':
- parsed_msg = _AssistantParser(message)
- if "tool_calls" in parsed_msg:
- result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
- elif role == "tool":
- parsed_msg = _ToolParser(message)
- if "tool_call_id" in parsed_msg:
- result_msg["tool_call_id"] = parsed_msg["tool_call_id"]
- if "name" in message and isinstance(message["name"], str):
- result_msg["name"] = message["name"]
- return result
- def _postprocess_messages(messages: List[ConversationMessage]) -> None:
- # per the Transformers docs & maintainers, tool call arguments in
- # assistant-role messages with tool_calls need to be dicts not JSON str -
- # this is how tool-use chat templates will expect them moving forwards
- # so, for messages that have tool_calls, parse the string (which we get
- # from openAI format) to dict
- for message in messages:
- if (message["role"] == "assistant" and "tool_calls" in message
- and isinstance(message["tool_calls"], list)):
- for item in message["tool_calls"]:
- item["function"]["arguments"] = json.loads(
- item["function"]["arguments"])
- def parse_chat_messages(
- messages: List[ChatCompletionMessageParam],
- model_config: ModelConfig,
- tokenizer: AnyTokenizer,
- ) -> Tuple[List[ConversationMessage], Optional[MultiModalDataDict]]:
- conversation: List[ConversationMessage] = []
- mm_tracker = MultiModalItemTracker(model_config, tokenizer)
- for msg in messages:
- sub_messages = _parse_chat_message_content(msg, mm_tracker)
- conversation.extend(sub_messages)
- _postprocess_messages(conversation)
- return conversation, mm_tracker.all_mm_data()
- def parse_chat_messages_futures(
- messages: List[ChatCompletionMessageParam],
- model_config: ModelConfig,
- tokenizer: AnyTokenizer,
- ) -> Tuple[List[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]:
- conversation: List[ConversationMessage] = []
- mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer)
- for msg in messages:
- sub_messages = _parse_chat_message_content(msg, mm_tracker)
- conversation.extend(sub_messages)
- _postprocess_messages(conversation)
- return conversation, mm_tracker.all_mm_data()
- def apply_hf_chat_template(
- tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
- conversation: List[ConversationMessage],
- chat_template: Optional[str],
- *,
- tokenize: bool = False, # Different from HF's default
- **kwargs: Any,
- ) -> str:
- if chat_template is None and tokenizer.chat_template is None:
- raise ValueError(
- "As of transformers v4.44, default chat template is no longer "
- "allowed, so you must provide a chat template if the tokenizer "
- "does not define one.")
- return tokenizer.apply_chat_template(
- conversation=conversation, # type: ignore[arg-type]
- chat_template=chat_template,
- tokenize=tokenize,
- **kwargs,
- )
- def apply_mistral_chat_template(
- tokenizer: MistralTokenizer,
- messages: List[ChatCompletionMessageParam],
- chat_template: Optional[str] = None,
- **kwargs: Any,
- ) -> List[int]:
- if chat_template is not None:
- logger.warning(
- "'chat_template' cannot be overridden for mistral tokenizer.")
- return tokenizer.apply_chat_template(
- messages=messages,
- **kwargs,
- )
|