chat_utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. import asyncio
  2. import codecs
  3. from collections import defaultdict
  4. from functools import lru_cache
  5. from pathlib import Path
  6. from typing import (Any, Awaitable, Dict, Iterable, List, Literal, Mapping,
  7. Optional, Tuple, Union)
  8. from loguru import logger
  9. # yapf conflicts with isort for this block
  10. # yapf: disable
  11. from openai.types.chat import ChatCompletionContentPartImageParam
  12. from openai.types.chat import (
  13. ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam)
  14. from openai.types.chat import ChatCompletionContentPartTextParam
  15. from openai.types.chat import (
  16. ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
  17. # yapf: enable
  18. # pydantic needs the TypedDict from typing_extensions
  19. from pydantic import ConfigDict, TypeAdapter
  20. from typing_extensions import Required, TypeAlias, TypedDict
  21. from aphrodite.common.config import ModelConfig
  22. from aphrodite.multimodal import MultiModalDataDict
  23. from aphrodite.multimodal.utils import (async_get_and_parse_audio,
  24. async_get_and_parse_image)
  25. from aphrodite.transformers_utils.tokenizer import AnyTokenizer
  26. class AudioURL(TypedDict, total=False):
  27. url: Required[str]
  28. """
  29. Either a URL of the audio or a data URL with base64 encoded audio data.
  30. """
  31. class ChatCompletionContentPartAudioParam(TypedDict, total=False):
  32. audio_url: Required[AudioURL]
  33. type: Required[Literal["audio_url"]]
  34. """The type of the content part."""
  35. class CustomChatCompletionContentPartParam(TypedDict, total=False):
  36. __pydantic_config__ = ConfigDict(extra="allow") # type: ignore
  37. type: Required[str]
  38. """The type of the content part."""
  39. ChatCompletionContentPartParam: TypeAlias = Union[
  40. OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
  41. CustomChatCompletionContentPartParam, ]
  42. class CustomChatCompletionMessageParam(TypedDict, total=False):
  43. """Enables custom roles in the Chat Completion API."""
  44. role: Required[str]
  45. """The role of the message's author."""
  46. content: Union[str, List[ChatCompletionContentPartParam]]
  47. """The contents of the message."""
  48. name: str
  49. """An optional name for the participant.
  50. Provides the model information to differentiate between participants of the
  51. same role.
  52. """
  53. ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam,
  54. CustomChatCompletionMessageParam]
  55. # TODO: Make fields ReadOnly once mypy supports it
  56. class ConversationMessage(TypedDict):
  57. role: str
  58. content: str
  59. class MultiModalItemTracker:
  60. """
  61. Tracks multi-modal items in a given request and ensures that the number
  62. of multi-modal items in a given request does not exceed the configured
  63. maximum per prompt.
  64. """
  65. def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer):
  66. self._model_config = model_config
  67. self._tokenizer = tokenizer
  68. self._allowed_items = (model_config.multimodal_config.limit_per_prompt
  69. if model_config.multimodal_config else {})
  70. self._consumed_items = {k: 0 for k in self._allowed_items}
  71. self._futures: List[Awaitable[MultiModalDataDict]] = []
  72. @staticmethod
  73. @lru_cache(maxsize=None)
  74. def _cached_token_str(tokenizer: AnyTokenizer, token_index: int):
  75. return tokenizer.decode(token_index)
  76. def add(self, modality: Literal["image", "audio"],
  77. mm_future: Awaitable[MultiModalDataDict]) -> Optional[str]:
  78. """
  79. Adds the multi-modal item to the current prompt and returns the
  80. placeholder string to use, if any.
  81. """
  82. allowed_count = self._allowed_items.get(modality, 1)
  83. current_count = self._consumed_items.get(modality, 0) + 1
  84. if current_count > allowed_count:
  85. raise ValueError(
  86. f"At most {allowed_count} {modality}(s) may be provided in "
  87. "one request.")
  88. self._consumed_items[modality] = current_count
  89. self._futures.append(mm_future)
  90. # TODO: Let user specify how to insert image tokens into prompt
  91. # (similar to chat template)
  92. model_type = self._model_config.hf_config.model_type
  93. if modality == "image":
  94. if model_type == "phi3_v":
  95. # Workaround since this token is not defined in the tokenizer
  96. return f"<|image_{current_count}|>"
  97. if model_type == "minicpmv":
  98. return "(<image>./</image>)"
  99. if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"):
  100. # These models do not use image tokens in the prompt
  101. return None
  102. if model_type.startswith("llava"):
  103. return MultiModalItemTracker._cached_token_str(
  104. self._tokenizer,
  105. self._model_config.hf_config.image_token_index)
  106. if model_type in ("chameleon", "internvl_chat"):
  107. return "<image>"
  108. raise TypeError(f"Unknown model type: {model_type}")
  109. elif modality == "audio":
  110. if model_type == "ultravox":
  111. return "<|reserved_special_token_0|>"
  112. raise TypeError(f"Unknown model type: {model_type}")
  113. else:
  114. raise TypeError(f"Unknown modality: {modality}")
  115. @staticmethod
  116. async def _combine(futures: List[Awaitable[MultiModalDataDict]]):
  117. mm_lists: Mapping[str, List[object]] = defaultdict(list)
  118. # Merge all the multi-modal items
  119. for single_mm_data in (await asyncio.gather(*futures)):
  120. for mm_key, mm_item in single_mm_data.items():
  121. if isinstance(mm_item, list):
  122. mm_lists[mm_key].extend(mm_item)
  123. else:
  124. mm_lists[mm_key].append(mm_item)
  125. # Unpack any single item lists for models that don't expect multiple.
  126. return {
  127. mm_key: mm_list[0] if len(mm_list) == 1 else mm_list
  128. for mm_key, mm_list in mm_lists.items()
  129. }
  130. def all_mm_data(self) -> Optional[Awaitable[MultiModalDataDict]]:
  131. return MultiModalItemTracker._combine(
  132. self._futures) if self._futures else None
  133. def load_chat_template(
  134. chat_template: Optional[Union[Path, str]]) -> Optional[str]:
  135. if chat_template is None:
  136. return None
  137. try:
  138. with open(chat_template, "r") as f:
  139. resolved_chat_template = f.read()
  140. except OSError as e:
  141. if isinstance(chat_template, Path):
  142. raise
  143. JINJA_CHARS = "{}\n"
  144. if not any(c in chat_template for c in JINJA_CHARS):
  145. msg = (f"The supplied chat template ({chat_template}) "
  146. f"looks like a file path, but it failed to be "
  147. f"opened. Reason: {e}")
  148. raise ValueError(msg) from e
  149. # If opening a file fails, set chat template to be args to
  150. # ensure we decode so our escape are interpreted correctly
  151. resolved_chat_template = codecs.decode(chat_template, "unicode_escape")
  152. logger.info("Using supplied chat template:\n%s", resolved_chat_template)
  153. return resolved_chat_template
  154. # TODO: Let user specify how to insert multimodal tokens into prompt
  155. # (similar to chat template)
  156. def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],
  157. text_prompt: str) -> str:
  158. """Combine multimodal prompts for a multimodal language model"""
  159. # Look through the text prompt to check for missing placeholders
  160. missing_placeholders = []
  161. for placeholder in placeholder_counts:
  162. # For any existing placeholder in the text prompt, we leave it as is
  163. placeholder_counts[placeholder] -= text_prompt.count(placeholder)
  164. if placeholder_counts[placeholder] < 0:
  165. raise ValueError(
  166. f"Found more '{placeholder}' placeholders in input prompt than "
  167. "actual multimodal data items.")
  168. missing_placeholders.extend([placeholder] *
  169. placeholder_counts[placeholder])
  170. # NOTE: For now we always add missing placeholders at the front of
  171. # the prompt. This may change to be customizable in the future.
  172. return "\n".join(missing_placeholders + [text_prompt])
  173. _TextParser = TypeAdapter(ChatCompletionContentPartTextParam)
  174. _ImageParser = TypeAdapter(ChatCompletionContentPartImageParam)
  175. _AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam)
  176. def _parse_chat_message_content_parts(
  177. role: str,
  178. parts: Iterable[ChatCompletionContentPartParam],
  179. mm_tracker: MultiModalItemTracker,
  180. ) -> List[ConversationMessage]:
  181. texts: List[str] = []
  182. # multimodal placeholder_string : count
  183. mm_placeholder_counts: Dict[str, int] = {}
  184. for part in parts:
  185. part_type = part["type"]
  186. if part_type == "text":
  187. text = _TextParser.validate_python(part)["text"]
  188. texts.append(text)
  189. elif part_type == "image_url":
  190. image_url = _ImageParser.validate_python(part)["image_url"]
  191. if image_url.get("detail", "auto") != "auto":
  192. logger.warning(
  193. "'image_url.detail' is currently not supported and "
  194. "will be ignored.")
  195. image_coro = async_get_and_parse_image(image_url["url"])
  196. placeholder = mm_tracker.add("image", image_coro)
  197. if placeholder:
  198. mm_placeholder_counts[placeholder] = mm_placeholder_counts.get(
  199. placeholder, 0) + 1
  200. elif part_type == "audio_url":
  201. audio_url = _AudioParser.validate_python(part)["audio_url"]
  202. audio_coro = async_get_and_parse_audio(audio_url["url"])
  203. placeholder = mm_tracker.add("audio", audio_coro)
  204. if placeholder:
  205. mm_placeholder_counts[placeholder] = mm_placeholder_counts.get(
  206. placeholder, 0) + 1
  207. else:
  208. raise NotImplementedError(f"Unknown part type: {part_type}")
  209. text_prompt = "\n".join(texts)
  210. if mm_placeholder_counts:
  211. text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts,
  212. text_prompt)
  213. return [ConversationMessage(role=role, content=text_prompt)]
  214. def _parse_chat_message_content(
  215. message: ChatCompletionMessageParam,
  216. mm_tracker: MultiModalItemTracker) -> List[ConversationMessage]:
  217. role = message["role"]
  218. content = message.get("content")
  219. if content is None:
  220. return []
  221. if isinstance(content, str):
  222. return [ConversationMessage(role=role, content=content)]
  223. return _parse_chat_message_content_parts(
  224. role,
  225. content, # type: ignore
  226. mm_tracker,
  227. )
  228. def parse_chat_messages(
  229. messages: List[ChatCompletionMessageParam],
  230. model_config: ModelConfig,
  231. tokenizer: AnyTokenizer,
  232. ) -> Tuple[List[ConversationMessage], Optional[Awaitable[MultiModalDataDict]]]:
  233. conversation: List[ConversationMessage] = []
  234. mm_tracker = MultiModalItemTracker(model_config, tokenizer)
  235. for msg in messages:
  236. sub_messages = _parse_chat_message_content(msg, mm_tracker)
  237. conversation.extend(sub_messages)
  238. return conversation, mm_tracker.all_mm_data()
  239. def apply_chat_template(
  240. tokenizer: AnyTokenizer,
  241. conversation: List[ConversationMessage],
  242. chat_template: Optional[str],
  243. *,
  244. tokenize: bool = False, # Different from HF's default
  245. **kwargs: Any,
  246. ) -> Union[str, List[int]]:
  247. if chat_template is None and tokenizer.chat_template is None:
  248. raise ValueError(
  249. "As of transformers v4.44, default chat template is no longer "
  250. "allowed, so you must provide a chat template if the tokenizer "
  251. "does not define one.")
  252. prompt = tokenizer.apply_chat_template(
  253. conversation=conversation,
  254. chat_template=chat_template,
  255. tokenize=tokenize,
  256. **kwargs,
  257. )
  258. return prompt