chat_utils.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. import codecs
  2. import tempfile
  3. from dataclasses import dataclass
  4. from functools import lru_cache
  5. from pathlib import Path
  6. from typing import (Any, Awaitable, Iterable, List, Literal, Optional, Tuple,
  7. Union, cast)
  8. import requests
  9. from loguru import logger
  10. # yapf conflicts with isort for this block
  11. # yapf: disable
  12. from openai.types.chat import ChatCompletionContentPartImageParam
  13. from openai.types.chat import (
  14. ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam)
  15. from openai.types.chat import ChatCompletionContentPartTextParam
  16. from openai.types.chat import (
  17. ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
  18. # yapf: enable
  19. # pydantic needs the TypedDict from typing_extensions
  20. from pydantic import ConfigDict
  21. from transformers import PreTrainedTokenizer
  22. from typing_extensions import Required, TypedDict
  23. from aphrodite.common.config import ModelConfig
  24. from aphrodite.multimodal import MultiModalDataDict
  25. from aphrodite.multimodal.utils import (async_get_and_parse_audio,
  26. async_get_and_parse_image)
  27. from aphrodite.transformers_utils.tokenizer import AnyTokenizer
  28. class AudioURL(TypedDict, total=False):
  29. url: Required[str]
  30. """
  31. Either a URL of the audio or a data URL with base64 encoded audio data.
  32. """
  33. class ChatCompletionContentPartAudioParam(TypedDict, total=False):
  34. audio_url: Required[AudioURL]
  35. type: Required[Literal["audio_url"]]
  36. """The type of the content part."""
  37. class CustomChatCompletionContentPartParam(TypedDict, total=False):
  38. __pydantic_config__ = ConfigDict(extra="allow") # type: ignore
  39. type: Required[str]
  40. """The type of the content part."""
  41. ChatCompletionContentPartParam = Union[OpenAIChatCompletionContentPartParam,
  42. ChatCompletionContentPartAudioParam,
  43. CustomChatCompletionContentPartParam]
  44. class CustomChatCompletionMessageParam(TypedDict, total=False):
  45. """Enables custom roles in the Chat Completion API."""
  46. role: Required[str]
  47. """The role of the message's author."""
  48. content: Union[str, List[ChatCompletionContentPartParam]]
  49. """The contents of the message."""
  50. name: str
  51. """An optional name for the participant.
  52. Provides the model information to differentiate between participants of the
  53. same role.
  54. """
  55. ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam,
  56. CustomChatCompletionMessageParam]
  57. # TODO: Make fields ReadOnly once mypy supports it
  58. class ConversationMessage(TypedDict):
  59. role: str
  60. content: str
  61. @dataclass(frozen=True)
  62. class ChatMessageParseResult:
  63. messages: List[ConversationMessage]
  64. mm_futures: List[Awaitable[MultiModalDataDict]]
  65. def load_chat_template(
  66. chat_template: Optional[Union[Path, str]]) -> Optional[str]:
  67. if chat_template is None:
  68. return None
  69. try:
  70. chat_template_str = str(chat_template)
  71. if chat_template_str.startswith(('http')):
  72. response = requests.get(chat_template_str)
  73. temp = tempfile.NamedTemporaryFile(delete=False)
  74. temp.write(response.content)
  75. temp.close()
  76. chat_template = temp.name
  77. with open(chat_template, "r") as f:
  78. resolved_chat_template = f.read()
  79. except OSError as e:
  80. if isinstance(chat_template, Path):
  81. raise
  82. JINJA_CHARS = "{}\n"
  83. if not any(c in chat_template for c in JINJA_CHARS):
  84. msg = (f"The supplied chat template ({chat_template}) "
  85. "looks like a file path, but it failed to be "
  86. f"opened. Reason: {e}")
  87. raise ValueError(msg) from e
  88. # If opening a file fails, set chat template to be args to
  89. # ensure we decode so our escape are interpreted correctly
  90. resolved_chat_template = codecs.decode(chat_template, "unicode_escape")
  91. logger.info(f"Using supplied chat template:\n{resolved_chat_template}")
  92. return resolved_chat_template
  93. @lru_cache(maxsize=None)
  94. def _mm_token_str(model_config: ModelConfig, tokenizer: PreTrainedTokenizer,
  95. modality: Literal["image", "audio"]) -> Optional[str]:
  96. # TODO: Let user specify how to insert image tokens into prompt
  97. # (similar to chat template)
  98. model_type = model_config.hf_config.model_type
  99. if modality == "image":
  100. if model_type == "phi3_v":
  101. # Workaround since this token is not defined in the tokenizer
  102. return "<|image_1|>"
  103. if model_type == "minicpmv":
  104. return "(<image>./</image>)"
  105. if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"):
  106. # These models do not use image tokens in the prompt
  107. return None
  108. if model_type.startswith("llava"):
  109. return tokenizer.decode(model_config.hf_config.image_token_index)
  110. if model_type in ("chameleon", "internvl_chat"):
  111. return "<image>"
  112. raise TypeError(f"Unknown model type: {model_type}")
  113. elif modality == "audio":
  114. if model_type == "ultravox":
  115. return "<|reserved_special_token_0|>"
  116. raise TypeError(f"Unknown model type: {model_type}")
  117. elif modality == "video":
  118. if model_type == "qwen2_vl":
  119. return "<|vision_start|><|video_pad|><|vision_end|>"
  120. raise TypeError(f"Unknown model type: {model_type}")
  121. else:
  122. raise TypeError(f"Unknown modality: {modality}")
  123. # TODO: Let user specify how to insert multimodal tokens into prompt
  124. # (similar to chat template)
  125. def _get_full_multimodal_text_prompt(placeholder_token_str: str,
  126. text_prompt: str) -> str:
  127. """Combine multimodal prompts for a multimodal language model"""
  128. # NOTE: For now we assume all model architectures use the same
  129. # placeholder + text prompt format. This may change in the future.
  130. return f"{placeholder_token_str}\n{text_prompt}"
  131. def _parse_chat_message_content_parts(
  132. role: str,
  133. parts: Iterable[ChatCompletionContentPartParam],
  134. model_config: ModelConfig,
  135. tokenizer: PreTrainedTokenizer,
  136. ) -> ChatMessageParseResult:
  137. texts: List[str] = []
  138. mm_futures: List[Awaitable[MultiModalDataDict]] = []
  139. modality: Literal["image", "audio", "video"] = "image"
  140. for part in parts:
  141. part_type = part["type"]
  142. if part_type == "text":
  143. text = cast(ChatCompletionContentPartTextParam, part)["text"]
  144. texts.append(text)
  145. elif part_type == "image_url":
  146. modality = "image"
  147. if len(mm_futures) > 0:
  148. raise NotImplementedError(
  149. "Multiple multimodal inputs is currently not supported.")
  150. image_url = cast(ChatCompletionContentPartImageParam,
  151. part)["image_url"]
  152. if image_url.get("detail", "auto") != "auto":
  153. logger.warning(
  154. "'image_url.detail' is currently not supported and "
  155. "will be ignored.")
  156. image_future = async_get_and_parse_image(image_url["url"])
  157. mm_futures.append(image_future)
  158. elif part_type == "audio_url":
  159. modality = "audio"
  160. if len(mm_futures) > 0:
  161. raise NotImplementedError(
  162. "Multiple multimodal inputs is currently not supported.")
  163. audio_url = cast(ChatCompletionContentPartAudioParam,
  164. part)["audio_url"]
  165. audio_future = async_get_and_parse_audio(audio_url["url"])
  166. mm_futures.append(audio_future)
  167. else:
  168. raise NotImplementedError(f"Unknown part type: {part_type}")
  169. text_prompt = "\n".join(texts)
  170. if mm_futures:
  171. placeholder_token_str = _mm_token_str(model_config, tokenizer,
  172. modality)
  173. if placeholder_token_str is not None:
  174. if placeholder_token_str in text_prompt:
  175. logger.warning(
  176. "Detected multi-modal token string in the text prompt. "
  177. "Skipping prompt formatting.")
  178. else:
  179. text_prompt = _get_full_multimodal_text_prompt(
  180. placeholder_token_str=placeholder_token_str,
  181. text_prompt=text_prompt,
  182. )
  183. messages = [ConversationMessage(role=role, content=text_prompt)]
  184. return ChatMessageParseResult(messages=messages, mm_futures=mm_futures)
  185. def _parse_chat_message_content(
  186. message: ChatCompletionMessageParam,
  187. model_config: ModelConfig,
  188. tokenizer: PreTrainedTokenizer,
  189. ) -> ChatMessageParseResult:
  190. role = message["role"]
  191. content = message.get("content")
  192. if content is None:
  193. return ChatMessageParseResult(messages=[], mm_futures=[])
  194. if isinstance(content, str):
  195. messages = [ConversationMessage(role=role, content=content)]
  196. return ChatMessageParseResult(messages=messages, mm_futures=[])
  197. return _parse_chat_message_content_parts(role, content, model_config,
  198. tokenizer)
  199. def parse_chat_messages(
  200. messages: List[ChatCompletionMessageParam],
  201. model_config: ModelConfig,
  202. tokenizer: PreTrainedTokenizer,
  203. ) -> Tuple[List[ConversationMessage], List[Awaitable[MultiModalDataDict]]]:
  204. conversation: List[ConversationMessage] = []
  205. mm_futures: List[Awaitable[MultiModalDataDict]] = []
  206. for msg in messages:
  207. parse_result = _parse_chat_message_content(msg, model_config,
  208. tokenizer)
  209. conversation.extend(parse_result.messages)
  210. mm_futures.extend(parse_result.mm_futures)
  211. return conversation, mm_futures
  212. def apply_chat_template(
  213. tokenizer: AnyTokenizer,
  214. conversation: List[ConversationMessage],
  215. chat_template: Optional[str],
  216. *,
  217. tokenize: bool = False, # Different from HF's default
  218. **kwargs: Any,
  219. ) -> Union[str, List[int]]:
  220. if chat_template is None and tokenizer.chat_template is None:
  221. raise ValueError(
  222. "As of transformers v4.44, default chat template is no longer "
  223. "allowed, so you must provide a chat template if the tokenizer "
  224. "does not define one.")
  225. prompt = tokenizer.apply_chat_template(
  226. conversation=conversation,
  227. chat_template=chat_template,
  228. tokenize=tokenize,
  229. **kwargs,
  230. )
  231. return prompt