chat_utils.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. import codecs
  2. import tempfile
  3. from dataclasses import dataclass, field
  4. from functools import lru_cache
  5. from typing import Awaitable, Iterable, List, Optional, Union, cast, final
  6. import requests
  7. from loguru import logger
  8. # yapf conflicts with isort for this block
  9. # yapf: disable
  10. from openai.types.chat import ChatCompletionContentPartImageParam
  11. from openai.types.chat import \
  12. ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam
  13. from openai.types.chat import ChatCompletionContentPartTextParam
  14. from openai.types.chat import \
  15. ChatCompletionMessageParam as OpenAIChatCompletionMessageParam
  16. # yapf: enable
  17. # pydantic needs the TypedDict from typing_extensions
  18. from pydantic import ConfigDict
  19. from transformers import PreTrainedTokenizer
  20. from typing_extensions import Required, TypedDict
  21. from aphrodite.common.config import ModelConfig
  22. from aphrodite.multimodal import MultiModalDataDict
  23. from aphrodite.multimodal.utils import async_get_and_parse_image
  24. class CustomChatCompletionContentPartParam(TypedDict, total=False):
  25. __pydantic_config__ = ConfigDict(extra="allow") # type: ignore
  26. type: Required[str]
  27. """The type of the content part."""
  28. ChatCompletionContentPartParam = Union[OpenAIChatCompletionContentPartParam,
  29. CustomChatCompletionContentPartParam]
  30. class CustomChatCompletionMessageParam(TypedDict, total=False):
  31. """Enables custom roles in the Chat Completion API."""
  32. role: Required[str]
  33. """The role of the message's author."""
  34. content: Union[str, List[ChatCompletionContentPartParam]]
  35. """The contents of the message."""
  36. name: str
  37. """An optional name for the participant.
  38. Provides the model information to differentiate between participants of the
  39. same role.
  40. """
  41. ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam,
  42. CustomChatCompletionMessageParam]
  43. @final # So that it should be compatible with Dict[str, str]
  44. class ConversationMessage(TypedDict):
  45. role: str
  46. content: str
  47. @dataclass(frozen=True)
  48. class ChatMessageParseResult:
  49. messages: List[ConversationMessage]
  50. mm_futures: List[Awaitable[MultiModalDataDict]] = field(
  51. default_factory=list)
  52. def load_chat_template(chat_template: Optional[str]) -> Optional[str]:
  53. if chat_template is None:
  54. return None
  55. try:
  56. if chat_template.startswith(('http')):
  57. response = requests.get(chat_template)
  58. temp = tempfile.NamedTemporaryFile(delete=False)
  59. temp.write(response.content)
  60. temp.close()
  61. chat_template = temp.name
  62. with open(chat_template, "r") as f:
  63. resolved_chat_template = f.read()
  64. except OSError as e:
  65. JINJA_CHARS = "{}\n"
  66. if not any(c in chat_template for c in JINJA_CHARS):
  67. msg = (f"The supplied chat template ({chat_template}) "
  68. "looks like a file path, but it failed to be "
  69. f"opened. Reason: {e}")
  70. raise ValueError(msg) from e
  71. # If opening a file fails, set chat template to be args to
  72. # ensure we decode so our escape are interpreted correctly
  73. resolved_chat_template = codecs.decode(chat_template, "unicode_escape")
  74. logger.info(f"Using supplied chat template:\n{resolved_chat_template}")
  75. return resolved_chat_template
  76. @lru_cache(maxsize=None)
  77. def _image_token_str(model_config: ModelConfig,
  78. tokenizer: PreTrainedTokenizer) -> Optional[str]:
  79. # TODO: Let user specify how to insert image tokens into prompt
  80. # (similar to chat template)
  81. model_type = model_config.hf_config.model_type
  82. if model_type == "phi3_v":
  83. # Workaround since this token is not defined in the tokenizer
  84. return "<|image_1|>"
  85. if model_type in ("blip-2", "chatglm", "fuyu", "minicpmv", "paligemma"):
  86. # These models do not use image tokens in the prompt
  87. return None
  88. if model_type.startswith("llava"):
  89. return tokenizer.decode(model_config.hf_config.image_token_index)
  90. raise TypeError(f"Unknown model type: {model_type}")
  91. # TODO: Let user specify how to insert image tokens into prompt
  92. # (similar to chat template)
  93. def _get_full_image_text_prompt(image_token_str: str, text_prompt: str) -> str:
  94. """Combine image and text prompts for vision language model"""
  95. # NOTE: For now we assume all model architectures use the same
  96. # image + text prompt format. This may change in the future.
  97. return f"{image_token_str}\n{text_prompt}"
  98. def _parse_chat_message_content_parts(
  99. role: str,
  100. parts: Iterable[ChatCompletionContentPartParam],
  101. model_config: ModelConfig,
  102. tokenizer: PreTrainedTokenizer,
  103. ) -> ChatMessageParseResult:
  104. texts: List[str] = []
  105. mm_futures: List[Awaitable[MultiModalDataDict]] = []
  106. for part in parts:
  107. part_type = part["type"]
  108. if part_type == "text":
  109. text = cast(ChatCompletionContentPartTextParam, part)["text"]
  110. texts.append(text)
  111. elif part_type == "image_url":
  112. if len(mm_futures) > 0:
  113. raise NotImplementedError(
  114. "Multiple 'image_url' input is currently not supported.")
  115. image_url = cast(ChatCompletionContentPartImageParam,
  116. part)["image_url"]
  117. if image_url.get("detail", "auto") != "auto":
  118. logger.warning(
  119. "'image_url.detail' is currently not supported and "
  120. "will be ignored.")
  121. image_future = async_get_and_parse_image(image_url["url"])
  122. mm_futures.append(image_future)
  123. else:
  124. raise NotImplementedError(f"Unknown part type: {part_type}")
  125. text_prompt = "\n".join(texts)
  126. if mm_futures:
  127. image_token_str = _image_token_str(model_config, tokenizer)
  128. if image_token_str is not None:
  129. if image_token_str in text_prompt:
  130. logger.warning(
  131. "Detected image token string in the text prompt. "
  132. "Skipping prompt formatting.")
  133. else:
  134. text_prompt = _get_full_image_text_prompt(
  135. image_token_str=image_token_str,
  136. text_prompt=text_prompt,
  137. )
  138. messages = [ConversationMessage(role=role, content=text_prompt)]
  139. return ChatMessageParseResult(messages=messages, mm_futures=mm_futures)
  140. def parse_chat_message_content(
  141. message: ChatCompletionMessageParam,
  142. model_config: ModelConfig,
  143. tokenizer: PreTrainedTokenizer,
  144. ) -> ChatMessageParseResult:
  145. role = message["role"]
  146. content = message.get("content")
  147. if content is None:
  148. return ChatMessageParseResult(messages=[], mm_futures=[])
  149. if isinstance(content, str):
  150. messages = [ConversationMessage(role=role, content=content)]
  151. return ChatMessageParseResult(messages=messages, mm_futures=[])
  152. return _parse_chat_message_content_parts(role, content, model_config,
  153. tokenizer)