chat_utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. import codecs
  2. import json
  3. import tempfile
  4. from dataclasses import dataclass
  5. from functools import lru_cache, partial
  6. from pathlib import Path
  7. from typing import (Any, Awaitable, Dict, Iterable, List, Literal, Optional,
  8. Tuple, Union, cast)
  9. import requests
  10. from loguru import logger
  11. # yapf conflicts with isort for this block
  12. # yapf: disable
  13. from openai.types.chat import (ChatCompletionAssistantMessageParam,
  14. ChatCompletionContentPartImageParam)
  15. from openai.types.chat import (
  16. ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam)
  17. from openai.types.chat import (ChatCompletionContentPartRefusalParam,
  18. ChatCompletionContentPartTextParam)
  19. from openai.types.chat import (
  20. ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
  21. from openai.types.chat import (ChatCompletionMessageToolCallParam,
  22. ChatCompletionToolMessageParam)
  23. # yapf: enable
  24. # pydantic needs the TypedDict from typing_extensions
  25. from pydantic import ConfigDict
  26. from typing_extensions import Required, TypeAlias, TypedDict
  27. from aphrodite.common.config import ModelConfig
  28. from aphrodite.multimodal import MultiModalDataDict
  29. from aphrodite.multimodal.utils import (async_get_and_parse_audio,
  30. async_get_and_parse_image)
  31. from aphrodite.transformers_utils.tokenizer import AnyTokenizer
  32. class AudioURL(TypedDict, total=False):
  33. url: Required[str]
  34. """
  35. Either a URL of the audio or a data URL with base64 encoded audio data.
  36. """
  37. class ChatCompletionContentPartAudioParam(TypedDict, total=False):
  38. audio_url: Required[AudioURL]
  39. type: Required[Literal["audio_url"]]
  40. """The type of the content part."""
  41. class CustomChatCompletionContentPartParam(TypedDict, total=False):
  42. __pydantic_config__ = ConfigDict(extra="allow") # type: ignore
  43. type: Required[str]
  44. """The type of the content part."""
  45. ChatCompletionContentPartParam: TypeAlias = Union[
  46. OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
  47. ChatCompletionContentPartRefusalParam,
  48. CustomChatCompletionContentPartParam]
  49. class CustomChatCompletionMessageParam(TypedDict, total=False):
  50. """Enables custom roles in the Chat Completion API."""
  51. role: Required[str]
  52. """The role of the message's author."""
  53. content: Union[str, List[ChatCompletionContentPartParam]]
  54. """The contents of the message."""
  55. name: str
  56. """An optional name for the participant.
  57. Provides the model information to differentiate between participants of the
  58. same role.
  59. """
  60. tool_call_id: Optional[str]
  61. """Tool call that this message is responding to."""
  62. tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
  63. """The tool calls generated by the model, such as function calls."""
  64. ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam,
  65. CustomChatCompletionMessageParam]
  66. # TODO: Make fields ReadOnly once mypy supports it
  67. class ConversationMessage(TypedDict, total=False):
  68. role: Required[str]
  69. """The role of the message's author."""
  70. content: Optional[str]
  71. """The contents of the message"""
  72. tool_call_id: Optional[str]
  73. """Tool call that this message is responding to."""
  74. name: Optional[str]
  75. """The name of the function to call"""
  76. tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
  77. """The tool calls generated by the model, such as function calls."""
  78. @dataclass(frozen=True)
  79. class ChatMessageParseResult:
  80. messages: List[ConversationMessage]
  81. mm_futures: List[Awaitable[MultiModalDataDict]]
  82. def load_chat_template(
  83. chat_template: Optional[Union[Path, str]]) -> Optional[str]:
  84. if chat_template is None:
  85. return None
  86. try:
  87. chat_template_str = str(chat_template)
  88. if chat_template_str.startswith(('http')):
  89. response = requests.get(chat_template_str)
  90. temp = tempfile.NamedTemporaryFile(delete=False)
  91. temp.write(response.content)
  92. temp.close()
  93. chat_template = temp.name
  94. with open(chat_template, "r") as f:
  95. resolved_chat_template = f.read()
  96. except OSError as e:
  97. if isinstance(chat_template, Path):
  98. raise
  99. JINJA_CHARS = "{}\n"
  100. if not any(c in chat_template for c in JINJA_CHARS):
  101. msg = (f"The supplied chat template ({chat_template}) "
  102. "looks like a file path, but it failed to be "
  103. f"opened. Reason: {e}")
  104. raise ValueError(msg) from e
  105. # If opening a file fails, set chat template to be args to
  106. # ensure we decode so our escape are interpreted correctly
  107. resolved_chat_template = codecs.decode(chat_template, "unicode_escape")
  108. logger.info(f"Using supplied chat template:\n{resolved_chat_template}")
  109. return resolved_chat_template
  110. @lru_cache(maxsize=None)
  111. def _mm_token_str(model_config: ModelConfig, tokenizer: AnyTokenizer,
  112. modality: Literal["image", "audio"]) -> Optional[str]:
  113. # TODO: Let user specify how to insert image tokens into prompt
  114. # (similar to chat template)
  115. if modality == "image":
  116. model_type = model_config.hf_config.model_type
  117. if model_type == "phi3_v":
  118. # Workaround since this token is not defined in the tokenizer
  119. return "<|image_1|>"
  120. if model_type == "minicpmv":
  121. return "(<image>./</image>)"
  122. if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"):
  123. # These models do not use image tokens in the prompt
  124. return None
  125. if model_type.startswith("llava"):
  126. return tokenizer.decode(model_config.hf_config.image_token_index)
  127. if model_type in ("chameleon", "internvl_chat"):
  128. return "<image>"
  129. raise TypeError(f"Unknown model type: {model_type}")
  130. elif modality == "audio":
  131. raise TypeError("No audio models are supported yet.")
  132. else:
  133. raise TypeError(f"Unknown modality: {modality}")
  134. # TODO: Let user specify how to insert multimodal tokens into prompt
  135. # (similar to chat template)
  136. def _get_full_multimodal_text_prompt(placeholder_token_str: str,
  137. text_prompt: str) -> str:
  138. """Combine multimodal prompts for a multimodal language model"""
  139. # NOTE: For now we assume all model architectures use the same
  140. # placeholder + text prompt format. This may change in the future.
  141. return f"{placeholder_token_str}\n{text_prompt}"
  142. def _parse_chat_message_content_parts(
  143. role: str,
  144. parts: Iterable[ChatCompletionContentPartParam],
  145. model_config: ModelConfig,
  146. tokenizer: AnyTokenizer,
  147. ) -> ChatMessageParseResult:
  148. texts: List[str] = []
  149. mm_futures: List[Awaitable[MultiModalDataDict]] = []
  150. modality: Literal["image", "audio"] = "image"
  151. for part in parts:
  152. part_type = part["type"]
  153. if part_type == "text":
  154. text = partial(cast(ChatCompletionContentPartTextParam, part)
  155. )["text"]
  156. texts.append(text)
  157. elif part_type == "image_url":
  158. modality = "image"
  159. if len(mm_futures) > 0:
  160. raise NotImplementedError(
  161. "Multiple multimodal inputs is currently not supported.")
  162. image_url = cast(ChatCompletionContentPartImageParam,
  163. part)["image_url"]
  164. if image_url.get("detail", "auto") != "auto":
  165. logger.warning(
  166. "'image_url.detail' is currently not supported and "
  167. "will be ignored.")
  168. image_future = async_get_and_parse_image(image_url["url"])
  169. mm_futures.append(image_future)
  170. elif part_type == "audio_url":
  171. modality = "audio"
  172. if len(mm_futures) > 0:
  173. raise NotImplementedError(
  174. "Multiple multimodal inputs is currently not supported.")
  175. audio_url = cast(ChatCompletionContentPartAudioParam,
  176. part)["audio_url"]
  177. audio_future = async_get_and_parse_audio(audio_url["url"])
  178. mm_futures.append(audio_future)
  179. else:
  180. raise NotImplementedError(f"Unknown part type: {part_type}")
  181. text_prompt = "\n".join(texts)
  182. if mm_futures:
  183. placeholder_token_str = _mm_token_str(model_config, tokenizer,
  184. modality)
  185. if placeholder_token_str is not None:
  186. if placeholder_token_str in text_prompt:
  187. logger.warning(
  188. "Detected multi-modal token string in the text prompt. "
  189. "Skipping prompt formatting.")
  190. else:
  191. text_prompt = _get_full_multimodal_text_prompt(
  192. placeholder_token_str=placeholder_token_str,
  193. text_prompt=text_prompt,
  194. )
  195. messages = [ConversationMessage(role=role, content=text_prompt)]
  196. return ChatMessageParseResult(messages=messages, mm_futures=mm_futures)
  197. # No need to validate using Pydantic again
  198. _AssistantParser = partial(cast, ChatCompletionAssistantMessageParam)
  199. _ToolParser = partial(cast, ChatCompletionToolMessageParam)
  200. def _parse_chat_message_content(
  201. message: ChatCompletionMessageParam,
  202. model_config: ModelConfig,
  203. tokenizer: AnyTokenizer,
  204. ) -> ChatMessageParseResult:
  205. role = message["role"]
  206. content = message.get("content")
  207. if content is None:
  208. content = []
  209. elif isinstance(content, str):
  210. content = [
  211. ChatCompletionContentPartTextParam(type="text", text=content)
  212. ]
  213. result = _parse_chat_message_content_parts(role, content,
  214. model_config, tokenizer)
  215. for result_msg in result:
  216. if role == 'assistant':
  217. parsed_msg = _AssistantParser(message)
  218. if "tool_calls" in parsed_msg:
  219. result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
  220. elif role == "tool":
  221. parsed_msg = _ToolParser(message)
  222. if "tool_call_id" in parsed_msg:
  223. result_msg["tool_call_id"] = parsed_msg["tool_call_id"]
  224. if "name" in message and isinstance(message["name"], str):
  225. result_msg["name"] = message["name"]
  226. return result
  227. def parse_chat_messages(
  228. messages: List[ChatCompletionMessageParam],
  229. model_config: ModelConfig,
  230. tokenizer: AnyTokenizer,
  231. ) -> Tuple[List[ConversationMessage], List[Awaitable[MultiModalDataDict]]]:
  232. conversation: List[ConversationMessage] = []
  233. mm_futures: List[Awaitable[MultiModalDataDict]] = []
  234. for msg in messages:
  235. parse_result = _parse_chat_message_content(msg, model_config,
  236. tokenizer)
  237. conversation.extend(parse_result.messages)
  238. mm_futures.extend(parse_result.mm_futures)
  239. return conversation, mm_futures
  240. def apply_chat_template(
  241. tokenizer: AnyTokenizer,
  242. conversation: List[ConversationMessage],
  243. chat_template: Optional[str],
  244. *,
  245. tokenize: bool = False, # Different from HF's default
  246. **kwargs: Any,
  247. ) -> Union[str, List[int]]:
  248. if chat_template is None and tokenizer.chat_template is None:
  249. raise ValueError(
  250. "As of transformers v4.44, default chat template is no longer "
  251. "allowed, so you must provide a chat template if the tokenizer "
  252. "does not define one.")
  253. # per the Transformers docs & maintainers, tool call arguments in
  254. # assistant-role messages with tool_calls need to be dicts not JSON str -
  255. # this is how tool-use chat templates will expect them moving forwards
  256. # so, for messages that have tool_calls, parse the string (which we get
  257. # from openAI format) to dict
  258. for message in conversation:
  259. if (message["role"] == "assistant" and "tool_calls" in message
  260. and isinstance(message["tool_calls"], list)):
  261. for i in range(len(message["tool_calls"])):
  262. args: str = message["tool_calls"][i]["function"]["arguments"]
  263. parsed_args: Dict = json.loads(args)
  264. message["tool_calls"][i]["function"]["arguments"] = parsed_args
  265. prompt = tokenizer.apply_chat_template(
  266. conversation=conversation,
  267. chat_template=chat_template,
  268. tokenize=tokenize,
  269. **kwargs,
  270. )
  271. return prompt