chat_utils.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. import codecs
  2. import tempfile
  3. from dataclasses import dataclass
  4. from functools import lru_cache
  5. from pathlib import Path
  6. from typing import Any, Awaitable, Iterable, List, Optional, Tuple, Union, cast
  7. import requests
  8. from loguru import logger
  9. # yapf conflicts with isort for this block
  10. # yapf: disable
  11. from openai.types.chat import ChatCompletionContentPartImageParam
  12. from openai.types.chat import (
  13. ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam)
  14. from openai.types.chat import ChatCompletionContentPartTextParam
  15. from openai.types.chat import (
  16. ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
  17. # yapf: enable
  18. # pydantic needs the TypedDict from typing_extensions
  19. from pydantic import ConfigDict
  20. from transformers import PreTrainedTokenizer
  21. from typing_extensions import Required, TypedDict
  22. from aphrodite.common.config import ModelConfig
  23. from aphrodite.multimodal import MultiModalDataDict
  24. from aphrodite.multimodal.utils import async_get_and_parse_image
  25. from aphrodite.transformers_utils.tokenizer import AnyTokenizer
  26. class CustomChatCompletionContentPartParam(TypedDict, total=False):
  27. __pydantic_config__ = ConfigDict(extra="allow") # type: ignore
  28. type: Required[str]
  29. """The type of the content part."""
  30. ChatCompletionContentPartParam = Union[OpenAIChatCompletionContentPartParam,
  31. CustomChatCompletionContentPartParam]
  32. class CustomChatCompletionMessageParam(TypedDict, total=False):
  33. """Enables custom roles in the Chat Completion API."""
  34. role: Required[str]
  35. """The role of the message's author."""
  36. content: Union[str, List[ChatCompletionContentPartParam]]
  37. """The contents of the message."""
  38. name: str
  39. """An optional name for the participant.
  40. Provides the model information to differentiate between participants of the
  41. same role.
  42. """
  43. ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam,
  44. CustomChatCompletionMessageParam]
  45. # TODO: Make fields ReadOnly once mypy supports it
  46. class ConversationMessage(TypedDict):
  47. role: str
  48. content: str
  49. @dataclass(frozen=True)
  50. class ChatMessageParseResult:
  51. messages: List[ConversationMessage]
  52. mm_futures: List[Awaitable[MultiModalDataDict]]
  53. def load_chat_template(
  54. chat_template: Optional[Union[Path, str]]) -> Optional[str]:
  55. if chat_template is None:
  56. return None
  57. try:
  58. chat_template_str = str(chat_template)
  59. if chat_template_str.startswith(('http')):
  60. response = requests.get(chat_template_str)
  61. temp = tempfile.NamedTemporaryFile(delete=False)
  62. temp.write(response.content)
  63. temp.close()
  64. chat_template = temp.name
  65. with open(chat_template, "r") as f:
  66. resolved_chat_template = f.read()
  67. except OSError as e:
  68. if isinstance(chat_template, Path):
  69. raise
  70. JINJA_CHARS = "{}\n"
  71. if not any(c in chat_template for c in JINJA_CHARS):
  72. msg = (f"The supplied chat template ({chat_template}) "
  73. "looks like a file path, but it failed to be "
  74. f"opened. Reason: {e}")
  75. raise ValueError(msg) from e
  76. # If opening a file fails, set chat template to be args to
  77. # ensure we decode so our escape are interpreted correctly
  78. resolved_chat_template = codecs.decode(chat_template, "unicode_escape")
  79. logger.info(f"Using supplied chat template:\n{resolved_chat_template}")
  80. return resolved_chat_template
  81. @lru_cache(maxsize=None)
  82. def _image_token_str(model_config: ModelConfig,
  83. tokenizer: PreTrainedTokenizer) -> Optional[str]:
  84. # TODO: Let user specify how to insert image tokens into prompt
  85. # (similar to chat template)
  86. model_type = model_config.hf_config.model_type
  87. if model_type == "phi3_v":
  88. # Workaround since this token is not defined in the tokenizer
  89. return "<|image_1|>"
  90. if model_type == "minicpmv":
  91. return "(<image>./</image>)"
  92. if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"):
  93. # These models do not use image tokens in the prompt
  94. return None
  95. if model_type.startswith("llava"):
  96. return tokenizer.decode(model_config.hf_config.image_token_index)
  97. if model_type in ("chameleon", "internvl_chat"):
  98. return "<image>"
  99. raise TypeError(f"Unknown model type: {model_type}")
  100. # TODO: Let user specify how to insert image tokens into prompt
  101. # (similar to chat template)
  102. def _get_full_image_text_prompt(image_token_str: str, text_prompt: str) -> str:
  103. """Combine image and text prompts for vision language model"""
  104. # NOTE: For now we assume all model architectures use the same
  105. # image + text prompt format. This may change in the future.
  106. return f"{image_token_str}\n{text_prompt}"
  107. def _parse_chat_message_content_parts(
  108. role: str,
  109. parts: Iterable[ChatCompletionContentPartParam],
  110. model_config: ModelConfig,
  111. tokenizer: PreTrainedTokenizer,
  112. ) -> ChatMessageParseResult:
  113. texts: List[str] = []
  114. mm_futures: List[Awaitable[MultiModalDataDict]] = []
  115. for part in parts:
  116. part_type = part["type"]
  117. if part_type == "text":
  118. text = cast(ChatCompletionContentPartTextParam, part)["text"]
  119. texts.append(text)
  120. elif part_type == "image_url":
  121. if len(mm_futures) > 0:
  122. raise NotImplementedError(
  123. "Multiple 'image_url' input is currently not supported.")
  124. image_url = cast(ChatCompletionContentPartImageParam,
  125. part)["image_url"]
  126. if image_url.get("detail", "auto") != "auto":
  127. logger.warning(
  128. "'image_url.detail' is currently not supported and "
  129. "will be ignored.")
  130. image_future = async_get_and_parse_image(image_url["url"])
  131. mm_futures.append(image_future)
  132. else:
  133. raise NotImplementedError(f"Unknown part type: {part_type}")
  134. text_prompt = "\n".join(texts)
  135. if mm_futures:
  136. image_token_str = _image_token_str(model_config, tokenizer)
  137. if image_token_str is not None:
  138. if image_token_str in text_prompt:
  139. logger.warning(
  140. "Detected image token string in the text prompt. "
  141. "Skipping prompt formatting.")
  142. else:
  143. text_prompt = _get_full_image_text_prompt(
  144. image_token_str=image_token_str,
  145. text_prompt=text_prompt,
  146. )
  147. messages = [ConversationMessage(role=role, content=text_prompt)]
  148. return ChatMessageParseResult(messages=messages, mm_futures=mm_futures)
  149. def _parse_chat_message_content(
  150. message: ChatCompletionMessageParam,
  151. model_config: ModelConfig,
  152. tokenizer: PreTrainedTokenizer,
  153. ) -> ChatMessageParseResult:
  154. role = message["role"]
  155. content = message.get("content")
  156. if content is None:
  157. return ChatMessageParseResult(messages=[], mm_futures=[])
  158. if isinstance(content, str):
  159. messages = [ConversationMessage(role=role, content=content)]
  160. return ChatMessageParseResult(messages=messages, mm_futures=[])
  161. return _parse_chat_message_content_parts(role, content, model_config,
  162. tokenizer)
  163. def parse_chat_messages(
  164. messages: List[ChatCompletionMessageParam],
  165. model_config: ModelConfig,
  166. tokenizer: PreTrainedTokenizer,
  167. ) -> Tuple[List[ConversationMessage], List[Awaitable[MultiModalDataDict]]]:
  168. conversation: List[ConversationMessage] = []
  169. mm_futures: List[Awaitable[MultiModalDataDict]] = []
  170. for msg in messages:
  171. parse_result = _parse_chat_message_content(msg, model_config,
  172. tokenizer)
  173. conversation.extend(parse_result.messages)
  174. mm_futures.extend(parse_result.mm_futures)
  175. return conversation, mm_futures
  176. def apply_chat_template(
  177. tokenizer: AnyTokenizer,
  178. conversation: List[ConversationMessage],
  179. chat_template: Optional[str],
  180. *,
  181. tokenize: bool = False, # Different from HF's default
  182. **kwargs: Any,
  183. ) -> str:
  184. if chat_template is None and tokenizer.chat_template is None:
  185. raise ValueError(
  186. "As of transformers v4.44, default chat template is no longer "
  187. "allowed, so you must provide a chat template if the tokenizer "
  188. "does not define one.")
  189. prompt = tokenizer.apply_chat_template(
  190. conversation=conversation,
  191. chat_template=chat_template,
  192. tokenize=tokenize,
  193. **kwargs,
  194. )
  195. assert isinstance(prompt, str)
  196. return prompt