1
0

chat_utils.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. import codecs
  2. import tempfile
  3. from dataclasses import dataclass
  4. from functools import lru_cache
  5. from pathlib import Path
  6. from typing import (Any, Awaitable, Iterable, List, Optional, Tuple, Union,
  7. cast, final)
  8. import requests
  9. from loguru import logger
  10. # yapf conflicts with isort for this block
  11. # yapf: disable
  12. from openai.types.chat import ChatCompletionContentPartImageParam
  13. from openai.types.chat import (
  14. ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam)
  15. from openai.types.chat import ChatCompletionContentPartTextParam
  16. from openai.types.chat import (
  17. ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
  18. # yapf: enable
  19. # pydantic needs the TypedDict from typing_extensions
  20. from pydantic import ConfigDict
  21. from transformers import PreTrainedTokenizer
  22. from typing_extensions import Required, TypedDict
  23. from aphrodite.common.config import ModelConfig
  24. from aphrodite.multimodal import MultiModalDataDict
  25. from aphrodite.multimodal.utils import async_get_and_parse_image
  26. from aphrodite.transformers_utils.tokenizer import AnyTokenizer
  27. class CustomChatCompletionContentPartParam(TypedDict, total=False):
  28. __pydantic_config__ = ConfigDict(extra="allow") # type: ignore
  29. type: Required[str]
  30. """The type of the content part."""
  31. ChatCompletionContentPartParam = Union[OpenAIChatCompletionContentPartParam,
  32. CustomChatCompletionContentPartParam]
  33. class CustomChatCompletionMessageParam(TypedDict, total=False):
  34. """Enables custom roles in the Chat Completion API."""
  35. role: Required[str]
  36. """The role of the message's author."""
  37. content: Union[str, List[ChatCompletionContentPartParam]]
  38. """The contents of the message."""
  39. name: str
  40. """An optional name for the participant.
  41. Provides the model information to differentiate between participants of the
  42. same role.
  43. """
  44. ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam,
  45. CustomChatCompletionMessageParam]
  46. @final # So that it should be compatible with Dict[str, str]
  47. class ConversationMessage(TypedDict):
  48. role: str
  49. content: str
  50. @dataclass(frozen=True)
  51. class ChatMessageParseResult:
  52. messages: List[ConversationMessage]
  53. mm_futures: List[Awaitable[MultiModalDataDict]]
  54. def load_chat_template(
  55. chat_template: Optional[Union[Path, str]]) -> Optional[str]:
  56. if chat_template is None:
  57. return None
  58. try:
  59. chat_template_str = str(chat_template)
  60. if chat_template_str.startswith(('http')):
  61. response = requests.get(chat_template_str)
  62. temp = tempfile.NamedTemporaryFile(delete=False)
  63. temp.write(response.content)
  64. temp.close()
  65. chat_template = temp.name
  66. with open(chat_template, "r") as f:
  67. resolved_chat_template = f.read()
  68. except OSError as e:
  69. if isinstance(chat_template, Path):
  70. raise
  71. JINJA_CHARS = "{}\n"
  72. if not any(c in chat_template for c in JINJA_CHARS):
  73. msg = (f"The supplied chat template ({chat_template}) "
  74. "looks like a file path, but it failed to be "
  75. f"opened. Reason: {e}")
  76. raise ValueError(msg) from e
  77. # If opening a file fails, set chat template to be args to
  78. # ensure we decode so our escape are interpreted correctly
  79. resolved_chat_template = codecs.decode(chat_template, "unicode_escape")
  80. logger.info(f"Using supplied chat template:\n{resolved_chat_template}")
  81. return resolved_chat_template
  82. @lru_cache(maxsize=None)
  83. def _image_token_str(model_config: ModelConfig,
  84. tokenizer: PreTrainedTokenizer) -> Optional[str]:
  85. # TODO: Let user specify how to insert image tokens into prompt
  86. # (similar to chat template)
  87. model_type = model_config.hf_config.model_type
  88. if model_type == "phi3_v":
  89. # Workaround since this token is not defined in the tokenizer
  90. return "<|image_1|>"
  91. if model_type == "minicpmv":
  92. return "(<image>./</image>)"
  93. if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"):
  94. # These models do not use image tokens in the prompt
  95. return None
  96. if model_type.startswith("llava"):
  97. return tokenizer.decode(model_config.hf_config.image_token_index)
  98. if model_type in ("chameleon", "internvl_chat"):
  99. return "<image>"
  100. raise TypeError(f"Unknown model type: {model_type}")
  101. # TODO: Let user specify how to insert image tokens into prompt
  102. # (similar to chat template)
  103. def _get_full_image_text_prompt(image_token_str: str, text_prompt: str) -> str:
  104. """Combine image and text prompts for vision language model"""
  105. # NOTE: For now we assume all model architectures use the same
  106. # image + text prompt format. This may change in the future.
  107. return f"{image_token_str}\n{text_prompt}"
  108. def _parse_chat_message_content_parts(
  109. role: str,
  110. parts: Iterable[ChatCompletionContentPartParam],
  111. model_config: ModelConfig,
  112. tokenizer: PreTrainedTokenizer,
  113. ) -> ChatMessageParseResult:
  114. texts: List[str] = []
  115. mm_futures: List[Awaitable[MultiModalDataDict]] = []
  116. for part in parts:
  117. part_type = part["type"]
  118. if part_type == "text":
  119. text = cast(ChatCompletionContentPartTextParam, part)["text"]
  120. texts.append(text)
  121. elif part_type == "image_url":
  122. if len(mm_futures) > 0:
  123. raise NotImplementedError(
  124. "Multiple 'image_url' input is currently not supported.")
  125. image_url = cast(ChatCompletionContentPartImageParam,
  126. part)["image_url"]
  127. if image_url.get("detail", "auto") != "auto":
  128. logger.warning(
  129. "'image_url.detail' is currently not supported and "
  130. "will be ignored.")
  131. image_future = async_get_and_parse_image(image_url["url"])
  132. mm_futures.append(image_future)
  133. else:
  134. raise NotImplementedError(f"Unknown part type: {part_type}")
  135. text_prompt = "\n".join(texts)
  136. if mm_futures:
  137. image_token_str = _image_token_str(model_config, tokenizer)
  138. if image_token_str is not None:
  139. if image_token_str in text_prompt:
  140. logger.warning(
  141. "Detected image token string in the text prompt. "
  142. "Skipping prompt formatting.")
  143. else:
  144. text_prompt = _get_full_image_text_prompt(
  145. image_token_str=image_token_str,
  146. text_prompt=text_prompt,
  147. )
  148. messages = [ConversationMessage(role=role, content=text_prompt)]
  149. return ChatMessageParseResult(messages=messages, mm_futures=mm_futures)
  150. def _parse_chat_message_content(
  151. message: ChatCompletionMessageParam,
  152. model_config: ModelConfig,
  153. tokenizer: PreTrainedTokenizer,
  154. ) -> ChatMessageParseResult:
  155. role = message["role"]
  156. content = message.get("content")
  157. if content is None:
  158. return ChatMessageParseResult(messages=[], mm_futures=[])
  159. if isinstance(content, str):
  160. messages = [ConversationMessage(role=role, content=content)]
  161. return ChatMessageParseResult(messages=messages, mm_futures=[])
  162. return _parse_chat_message_content_parts(role, content, model_config,
  163. tokenizer)
  164. def parse_chat_messages(
  165. messages: List[ChatCompletionMessageParam],
  166. model_config: ModelConfig,
  167. tokenizer: PreTrainedTokenizer,
  168. ) -> Tuple[List[ConversationMessage], List[Awaitable[MultiModalDataDict]]]:
  169. conversation: List[ConversationMessage] = []
  170. mm_futures: List[Awaitable[MultiModalDataDict]] = []
  171. for msg in messages:
  172. parse_result = _parse_chat_message_content(msg, model_config,
  173. tokenizer)
  174. conversation.extend(parse_result.messages)
  175. mm_futures.extend(parse_result.mm_futures)
  176. return conversation, mm_futures
  177. def apply_chat_template(
  178. tokenizer: AnyTokenizer,
  179. conversation: List[ConversationMessage],
  180. chat_template: Optional[str],
  181. *,
  182. tokenize: bool = False, # Different from HF's default
  183. **kwargs: Any,
  184. ) -> str:
  185. if chat_template is None and tokenizer.chat_template is None:
  186. raise ValueError(
  187. "As of transformers v4.44, default chat template is no longer "
  188. "allowed, so you must provide a chat template if the tokenizer "
  189. "does not define one.")
  190. prompt = tokenizer.apply_chat_template(
  191. conversation=conversation,
  192. chat_template=chat_template,
  193. tokenize=tokenize,
  194. **kwargs,
  195. )
  196. assert isinstance(prompt, str)
  197. return prompt