1
0

chat_utils.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535
  1. import asyncio
  2. import codecs
  3. import json
  4. from abc import ABC, abstractmethod
  5. from collections import defaultdict
  6. from functools import lru_cache, partial
  7. from pathlib import Path
  8. from typing import (Any, Awaitable, Dict, Generic, Iterable, List, Literal,
  9. Mapping, Optional, Tuple, TypeVar, Union, cast)
  10. from loguru import logger
  11. # yapf conflicts with isort for this block
  12. # yapf: disable
  13. from openai.types.chat import (ChatCompletionAssistantMessageParam,
  14. ChatCompletionContentPartImageParam)
  15. from openai.types.chat import (
  16. ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam)
  17. from openai.types.chat import (ChatCompletionContentPartRefusalParam,
  18. ChatCompletionContentPartTextParam)
  19. from openai.types.chat import (
  20. ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
  21. from openai.types.chat import (ChatCompletionMessageToolCallParam,
  22. ChatCompletionToolMessageParam)
  23. # yapf: enable
  24. # pydantic needs the TypedDict from typing_extensions
  25. from pydantic import ConfigDict
  26. from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
  27. from typing_extensions import Required, TypeAlias, TypedDict
  28. from aphrodite.common.config import ModelConfig
  29. from aphrodite.multimodal import MultiModalDataDict
  30. from aphrodite.multimodal.utils import (async_get_and_parse_audio,
  31. async_get_and_parse_image,
  32. get_and_parse_audio,
  33. get_and_parse_image)
  34. from aphrodite.transformers_utils.tokenizer import (AnyTokenizer,
  35. MistralTokenizer)
  36. class AudioURL(TypedDict, total=False):
  37. url: Required[str]
  38. """
  39. Either a URL of the audio or a data URL with base64 encoded audio data.
  40. """
  41. class ChatCompletionContentPartAudioParam(TypedDict, total=False):
  42. audio_url: Required[AudioURL]
  43. type: Required[Literal["audio_url"]]
  44. """The type of the content part."""
  45. class CustomChatCompletionContentPartParam(TypedDict, total=False):
  46. __pydantic_config__ = ConfigDict(extra="allow") # type: ignore
  47. type: Required[str]
  48. """The type of the content part."""
  49. ChatCompletionContentPartParam: TypeAlias = Union[
  50. OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
  51. ChatCompletionContentPartRefusalParam,
  52. CustomChatCompletionContentPartParam]
  53. class CustomChatCompletionMessageParam(TypedDict, total=False):
  54. """Enables custom roles in the Chat Completion API."""
  55. role: Required[str]
  56. """The role of the message's author."""
  57. content: Union[str, List[ChatCompletionContentPartParam]]
  58. """The contents of the message."""
  59. name: str
  60. """An optional name for the participant.
  61. Provides the model information to differentiate between participants of the
  62. same role.
  63. """
  64. tool_call_id: Optional[str]
  65. """Tool call that this message is responding to."""
  66. tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
  67. """The tool calls generated by the model, such as function calls."""
  68. ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam,
  69. CustomChatCompletionMessageParam]
  70. # TODO: Make fields ReadOnly once mypy supports it
  71. class ConversationMessage(TypedDict, total=False):
  72. role: Required[str]
  73. """The role of the message's author."""
  74. content: Optional[str]
  75. """The contents of the message"""
  76. tool_call_id: Optional[str]
  77. """Tool call that this message is responding to."""
  78. name: Optional[str]
  79. """The name of the function to call"""
  80. tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
  81. """The tool calls generated by the model, such as function calls."""
  82. ModalityStr = Literal["image", "audio", "video"]
  83. _T = TypeVar("_T")
  84. class BaseMultiModalItemTracker(ABC, Generic[_T]):
  85. """
  86. Tracks multi-modal items in a given request and ensures that the number
  87. of multi-modal items in a given request does not exceed the configured
  88. maximum per prompt.
  89. """
  90. def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer):
  91. super().__init__()
  92. self._model_config = model_config
  93. self._tokenizer = tokenizer
  94. self._allowed_items = (model_config.multimodal_config.limit_per_prompt
  95. if model_config.multimodal_config else {})
  96. self._consumed_items = {k: 0 for k in self._allowed_items}
  97. self._items: List[_T] = []
  98. @staticmethod
  99. @lru_cache(maxsize=None)
  100. def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str:
  101. return tokenizer.decode(token_index)
  102. def _placeholder_str(self, modality: ModalityStr,
  103. current_count: int) -> Optional[str]:
  104. # TODO: Let user specify how to insert image tokens into prompt
  105. # (similar to chat template)
  106. hf_config = self._model_config.hf_config
  107. model_type = hf_config.model_type
  108. if modality == "image":
  109. if model_type == "phi3_v":
  110. # Workaround since this token is not defined in the tokenizer
  111. return f"<|image_{current_count}|>"
  112. if model_type == "minicpmv":
  113. return "(<image>./</image>)"
  114. if model_type in ("blip-2", "chatglm", "fuyu", "paligemma",
  115. "pixtral"):
  116. # These models do not use image tokens in the prompt
  117. return None
  118. if model_type == "qwen":
  119. return f"Picture {current_count}: <img></img>"
  120. if model_type.startswith("llava"):
  121. return self._cached_token_str(self._tokenizer,
  122. hf_config.image_token_index)
  123. if model_type in ("chameleon", "internvl_chat"):
  124. return "<image>"
  125. if model_type == "qwen2_vl":
  126. return "<|vision_start|><|image_pad|><|vision_end|>"
  127. if model_type == "molmo":
  128. return ""
  129. raise TypeError(f"Unknown model type: {model_type}")
  130. elif modality == "audio":
  131. if model_type == "ultravox":
  132. return "<|reserved_special_token_0|>"
  133. raise TypeError(f"Unknown model type: {model_type}")
  134. elif modality == "video":
  135. if model_type == "qwen2_vl":
  136. return "<|vision_start|><|video_pad|><|vision_end|>"
  137. raise TypeError(f"Unknown model type: {model_type}")
  138. else:
  139. raise TypeError(f"Unknown modality: {modality}")
  140. @staticmethod
  141. def _combine(items: List[MultiModalDataDict]) -> MultiModalDataDict:
  142. mm_lists: Mapping[str, List[object]] = defaultdict(list)
  143. # Merge all the multi-modal items
  144. for single_mm_data in items:
  145. for mm_key, mm_item in single_mm_data.items():
  146. if isinstance(mm_item, list):
  147. mm_lists[mm_key].extend(mm_item)
  148. else:
  149. mm_lists[mm_key].append(mm_item)
  150. # Unpack any single item lists for models that don't expect multiple.
  151. return {
  152. mm_key: mm_list[0] if len(mm_list) == 1 else mm_list
  153. for mm_key, mm_list in mm_lists.items()
  154. }
  155. def add(self, modality: ModalityStr, item: _T) -> Optional[str]:
  156. """
  157. Add a multi-modal item to the current prompt and returns the
  158. placeholder string to use, if any.
  159. """
  160. allowed_count = self._allowed_items.get(modality, 1)
  161. current_count = self._consumed_items.get(modality, 0) + 1
  162. if current_count > allowed_count:
  163. raise ValueError(
  164. f"At most {allowed_count} {modality}(s) may be provided in "
  165. "one request.")
  166. self._consumed_items[modality] = current_count
  167. self._items.append(item)
  168. return self._placeholder_str(modality, current_count)
  169. @abstractmethod
  170. def create_parser(self) -> "BaseMultiModalContentParser":
  171. raise NotImplementedError
  172. class MultiModalItemTracker(BaseMultiModalItemTracker[MultiModalDataDict]):
  173. def all_mm_data(self) -> Optional[MultiModalDataDict]:
  174. return self._combine(self._items) if self._items else None
  175. def create_parser(self) -> "BaseMultiModalContentParser":
  176. return MultiModalContentParser(self)
  177. class AsyncMultiModalItemTracker(
  178. BaseMultiModalItemTracker[Awaitable[MultiModalDataDict]]):
  179. async def all_mm_data(self) -> Optional[MultiModalDataDict]:
  180. if self._items:
  181. items = await asyncio.gather(*self._items)
  182. return self._combine(items)
  183. return None
  184. def create_parser(self) -> "BaseMultiModalContentParser":
  185. return AsyncMultiModalContentParser(self)
  186. class BaseMultiModalContentParser(ABC):
  187. def __init__(self) -> None:
  188. super().__init__()
  189. # multimodal placeholder_string : count
  190. self._placeholder_counts: Dict[str, int] = defaultdict(lambda: 0)
  191. def _add_placeholder(self, placeholder: Optional[str]):
  192. if placeholder:
  193. self._placeholder_counts[placeholder] += 1
  194. def mm_placeholder_counts(self) -> Dict[str, int]:
  195. return dict(self._placeholder_counts)
  196. @abstractmethod
  197. def parse_image(self, image_url: str) -> None:
  198. raise NotImplementedError
  199. @abstractmethod
  200. def parse_audio(self, audio_url: str) -> None:
  201. raise NotImplementedError
  202. class MultiModalContentParser(BaseMultiModalContentParser):
  203. def __init__(self, tracker: MultiModalItemTracker) -> None:
  204. super().__init__()
  205. self._tracker = tracker
  206. def parse_image(self, image_url: str) -> None:
  207. image = get_and_parse_image(image_url)
  208. placeholder = self._tracker.add("image", image)
  209. self._add_placeholder(placeholder)
  210. def parse_audio(self, audio_url: str) -> None:
  211. audio = get_and_parse_audio(audio_url)
  212. placeholder = self._tracker.add("audio", audio)
  213. self._add_placeholder(placeholder)
  214. class AsyncMultiModalContentParser(BaseMultiModalContentParser):
  215. def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
  216. super().__init__()
  217. self._tracker = tracker
  218. def parse_image(self, image_url: str) -> None:
  219. image_coro = async_get_and_parse_image(image_url)
  220. placeholder = self._tracker.add("image", image_coro)
  221. self._add_placeholder(placeholder)
  222. def parse_audio(self, audio_url: str) -> None:
  223. audio_coro = async_get_and_parse_audio(audio_url)
  224. placeholder = self._tracker.add("audio", audio_coro)
  225. self._add_placeholder(placeholder)
  226. def load_chat_template(
  227. chat_template: Optional[Union[Path, str]]) -> Optional[str]:
  228. if chat_template is None:
  229. return None
  230. try:
  231. with open(chat_template, "r") as f:
  232. resolved_chat_template = f.read()
  233. except OSError as e:
  234. if isinstance(chat_template, Path):
  235. raise
  236. JINJA_CHARS = "{}\n"
  237. if not any(c in chat_template for c in JINJA_CHARS):
  238. msg = (f"The supplied chat template ({chat_template}) "
  239. f"looks like a file path, but it failed to be "
  240. f"opened. Reason: {e}")
  241. raise ValueError(msg) from e
  242. # If opening a file fails, set chat template to be args to
  243. # ensure we decode so our escape are interpreted correctly
  244. resolved_chat_template = codecs.decode(chat_template, "unicode_escape")
  245. logger.info(f"Using supplied chat template:\n{resolved_chat_template}")
  246. return resolved_chat_template
  247. # TODO: Let user specify how to insert multimodal tokens into prompt
  248. # (similar to chat template)
  249. def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],
  250. text_prompt: str) -> str:
  251. """Combine multimodal prompts for a multimodal language model."""
  252. # Look through the text prompt to check for missing placeholders
  253. missing_placeholders: List[str] = []
  254. for placeholder in placeholder_counts:
  255. # For any existing placeholder in the text prompt, we leave it as is
  256. placeholder_counts[placeholder] -= text_prompt.count(placeholder)
  257. if placeholder_counts[placeholder] < 0:
  258. raise ValueError(
  259. f"Found more '{placeholder}' placeholders in input prompt than "
  260. "actual multimodal data items.")
  261. missing_placeholders.extend([placeholder] *
  262. placeholder_counts[placeholder])
  263. # NOTE: For now we always add missing placeholders at the front of
  264. # the prompt. This may change to be customizable in the future.
  265. return "\n".join(missing_placeholders + [text_prompt])
  266. # No need to validate using Pydantic again
  267. _TextParser = partial(cast, ChatCompletionContentPartTextParam)
  268. _ImageParser = partial(cast, ChatCompletionContentPartImageParam)
  269. _AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
  270. _RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
  271. def _parse_chat_message_content_parts(
  272. role: str,
  273. parts: Iterable[ChatCompletionContentPartParam],
  274. mm_tracker: BaseMultiModalItemTracker,
  275. ) -> List[ConversationMessage]:
  276. texts: List[str] = []
  277. mm_parser = mm_tracker.create_parser()
  278. for part in parts:
  279. part_type = part["type"]
  280. if part_type == "text":
  281. text = _TextParser(part)["text"]
  282. texts.append(text)
  283. elif part_type == "image_url":
  284. image_url = _ImageParser(part)["image_url"]
  285. if image_url.get("detail", "auto") != "auto":
  286. logger.warning(
  287. "'image_url.detail' is currently not supported and "
  288. "will be ignored.")
  289. mm_parser.parse_image(image_url["url"])
  290. elif part_type == "audio_url":
  291. audio_url = _AudioParser(part)["audio_url"]
  292. mm_parser.parse_audio(audio_url["url"])
  293. elif part_type == "refusal":
  294. text = _RefusalParser(part)["refusal"]
  295. texts.append(text)
  296. else:
  297. raise NotImplementedError(f"Unknown part type: {part_type}")
  298. text_prompt = "\n".join(texts)
  299. mm_placeholder_counts = mm_parser.mm_placeholder_counts()
  300. if mm_placeholder_counts:
  301. text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts,
  302. text_prompt)
  303. return [ConversationMessage(role=role, content=text_prompt)]
  304. # No need to validate using Pydantic again
  305. _AssistantParser = partial(cast, ChatCompletionAssistantMessageParam)
  306. _ToolParser = partial(cast, ChatCompletionToolMessageParam)
  307. def _parse_chat_message_content(
  308. message: ChatCompletionMessageParam,
  309. mm_tracker: BaseMultiModalItemTracker,
  310. ) -> List[ConversationMessage]:
  311. role = message["role"]
  312. content = message.get("content")
  313. if content is None:
  314. content = []
  315. elif isinstance(content, str):
  316. content = [
  317. ChatCompletionContentPartTextParam(type="text", text=content)
  318. ]
  319. result = _parse_chat_message_content_parts(
  320. role,
  321. content, # type: ignore
  322. mm_tracker,
  323. )
  324. for result_msg in result:
  325. if role == 'assistant':
  326. parsed_msg = _AssistantParser(message)
  327. if "tool_calls" in parsed_msg:
  328. result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
  329. elif role == "tool":
  330. parsed_msg = _ToolParser(message)
  331. if "tool_call_id" in parsed_msg:
  332. result_msg["tool_call_id"] = parsed_msg["tool_call_id"]
  333. if "name" in message and isinstance(message["name"], str):
  334. result_msg["name"] = message["name"]
  335. return result
  336. def _postprocess_messages(messages: List[ConversationMessage]) -> None:
  337. # per the Transformers docs & maintainers, tool call arguments in
  338. # assistant-role messages with tool_calls need to be dicts not JSON str -
  339. # this is how tool-use chat templates will expect them moving forwards
  340. # so, for messages that have tool_calls, parse the string (which we get
  341. # from openAI format) to dict
  342. for message in messages:
  343. if (message["role"] == "assistant" and "tool_calls" in message
  344. and isinstance(message["tool_calls"], list)):
  345. for item in message["tool_calls"]:
  346. item["function"]["arguments"] = json.loads(
  347. item["function"]["arguments"])
  348. def parse_chat_messages(
  349. messages: List[ChatCompletionMessageParam],
  350. model_config: ModelConfig,
  351. tokenizer: AnyTokenizer,
  352. ) -> Tuple[List[ConversationMessage], Optional[MultiModalDataDict]]:
  353. conversation: List[ConversationMessage] = []
  354. mm_tracker = MultiModalItemTracker(model_config, tokenizer)
  355. for msg in messages:
  356. sub_messages = _parse_chat_message_content(msg, mm_tracker)
  357. conversation.extend(sub_messages)
  358. _postprocess_messages(conversation)
  359. return conversation, mm_tracker.all_mm_data()
  360. def parse_chat_messages_futures(
  361. messages: List[ChatCompletionMessageParam],
  362. model_config: ModelConfig,
  363. tokenizer: AnyTokenizer,
  364. ) -> Tuple[List[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]:
  365. conversation: List[ConversationMessage] = []
  366. mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer)
  367. for msg in messages:
  368. sub_messages = _parse_chat_message_content(msg, mm_tracker)
  369. conversation.extend(sub_messages)
  370. _postprocess_messages(conversation)
  371. return conversation, mm_tracker.all_mm_data()
  372. def apply_hf_chat_template(
  373. tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
  374. conversation: List[ConversationMessage],
  375. chat_template: Optional[str],
  376. *,
  377. tokenize: bool = False, # Different from HF's default
  378. **kwargs: Any,
  379. ) -> str:
  380. if chat_template is None and tokenizer.chat_template is None:
  381. raise ValueError(
  382. "As of transformers v4.44, default chat template is no longer "
  383. "allowed, so you must provide a chat template if the tokenizer "
  384. "does not define one.")
  385. return tokenizer.apply_chat_template(
  386. conversation=conversation, # type: ignore[arg-type]
  387. chat_template=chat_template,
  388. tokenize=tokenize,
  389. **kwargs,
  390. )
  391. def apply_mistral_chat_template(
  392. tokenizer: MistralTokenizer,
  393. messages: List[ChatCompletionMessageParam],
  394. chat_template: Optional[str] = None,
  395. **kwargs: Any,
  396. ) -> List[int]:
  397. if chat_template is not None:
  398. logger.warning(
  399. "'chat_template' cannot be overridden for mistral tokenizer.")
  400. return tokenizer.apply_chat_template(
  401. messages=messages,
  402. **kwargs,
  403. )