serving_engine.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. import asyncio
  2. import json
  3. from dataclasses import dataclass
  4. from http import HTTPStatus
  5. from typing import Dict, List, Optional, Tuple, Union
  6. from loguru import logger
  7. from pydantic import conint
  8. from aphrodite.common.sequence import Logprob
  9. from aphrodite.endpoints.openai.protocol import (
  10. ChatCompletionRequest, CompletionRequest, EmbeddingRequest, ErrorResponse,
  11. LogProbs, ModelCard, ModelList, ModelPermission, Prompt)
  12. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  13. from aphrodite.lora.request import LoRARequest
  14. from aphrodite.transformers_utils.tokenizer import get_tokenizer
  15. @dataclass
  16. class LoRA:
  17. name: str
  18. local_path: str
  19. class OpenAIServing:
  20. def __init__(self,
  21. engine: AsyncAphrodite,
  22. served_model_names: List[str],
  23. lora_modules=Optional[List[LoRA]]):
  24. self.engine = engine
  25. self.served_model_names = served_model_names
  26. if lora_modules is None:
  27. self.lora_requests = []
  28. else:
  29. self.lora_requests = [
  30. LoRARequest(
  31. lora_name=lora.name,
  32. lora_int_id=i,
  33. lora_local_path=lora.local_path,
  34. ) for i, lora in enumerate(lora_modules, start=1)
  35. ]
  36. self.max_model_len = 0
  37. self.tokenizer = None
  38. try:
  39. event_loop = asyncio.get_running_loop()
  40. except RuntimeError:
  41. event_loop = None
  42. if event_loop is not None and event_loop.is_running():
  43. # If the current is instanced by Ray Serve,
  44. # there is already a running event loop
  45. event_loop.create_task(self._post_init())
  46. else:
  47. # When using single Aphrodite without engine_use_ray
  48. asyncio.run(self._post_init())
  49. async def _post_init(self):
  50. engine_model_config = await self.engine.get_model_config()
  51. self.max_model_len = engine_model_config.max_model_len
  52. # A separate tokenizer to map token IDs to strings.
  53. self.tokenizer = get_tokenizer(
  54. engine_model_config.tokenizer,
  55. tokenizer_mode=engine_model_config.tokenizer_mode,
  56. trust_remote_code=engine_model_config.trust_remote_code,
  57. revision=engine_model_config.revision,
  58. truncation_side="left")
  59. async def show_available_models(self) -> ModelList:
  60. """Show available models. Right now we only have one model."""
  61. model_cards = [
  62. ModelCard(id=served_model_name,
  63. max_model_len=self.max_model_len,
  64. root=self.served_model_names[0],
  65. permission=[ModelPermission()])
  66. for served_model_name in self.served_model_names
  67. ]
  68. lora_cards = [
  69. ModelCard(id=lora.lora_name,
  70. root=self.served_model_names[0],
  71. permission=[ModelPermission()])
  72. for lora in self.lora_requests
  73. ]
  74. model_cards.extend(lora_cards)
  75. return ModelList(data=model_cards)
  76. async def tokenize(self, prompt: Prompt):
  77. """Tokenize a given prompt."""
  78. tokenized_prompt = self.tokenizer.tokenize(prompt.prompt)
  79. token_ids = self.tokenizer.convert_tokens_to_ids(tokenized_prompt)
  80. return {"value": len(tokenized_prompt), "ids": token_ids}
  81. async def detokenize(self, token_ids: List[int]):
  82. """Detokenize a given list of token IDs."""
  83. tokens = self.tokenizer.convert_ids_to_tokens(token_ids)
  84. detokenized_text = self.tokenizer.convert_tokens_to_string(tokens)
  85. return {"value": detokenized_text}
  86. def _create_logprobs(
  87. self,
  88. token_ids: List[int],
  89. top_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None,
  90. num_output_top_logprobs: Optional[int] = None,
  91. initial_text_offset: int = 0,
  92. ) -> LogProbs:
  93. """Create OpenAI-style logprobs."""
  94. logprobs = LogProbs()
  95. last_token_len = 0
  96. if num_output_top_logprobs:
  97. logprobs.top_logprobs = []
  98. for i, token_id in enumerate(token_ids):
  99. step_top_logprobs = top_logprobs[i]
  100. if step_top_logprobs is None:
  101. token = self.tokenizer.decode(token_id)
  102. logprobs.tokens.append(token)
  103. logprobs.token_logprobs.append(None)
  104. logprobs.top_logprobs.append(None)
  105. else:
  106. token_logprob = step_top_logprobs[token_id].logprob
  107. token = step_top_logprobs[token_id].decoded_token
  108. logprobs.tokens.append(token)
  109. token_logprob = max(token_logprob, -9999.0)
  110. logprobs.token_logprobs.append(token_logprob)
  111. if num_output_top_logprobs:
  112. logprobs.top_logprobs.append({
  113. # Convert float("-inf") to the
  114. # JSON-serializable float that OpenAI uses
  115. p.decoded_token: max(p.logprob, -9999.0)
  116. for i, p in step_top_logprobs.items()
  117. } if step_top_logprobs else None)
  118. # TODO: Check if this is still needed
  119. if logprobs.top_logprobs:
  120. logprobs.top_logprobs = [{
  121. k: v if v > -1000 else -1000
  122. for k, v in top_logprob.items()
  123. } for top_logprob in logprobs.top_logprobs
  124. if top_logprob is not None
  125. ] # noqa: E501
  126. if len(logprobs.text_offset) == 0:
  127. logprobs.text_offset.append(initial_text_offset)
  128. else:
  129. logprobs.text_offset.append(logprobs.text_offset[-1] +
  130. last_token_len)
  131. last_token_len = len(token)
  132. return logprobs
  133. def create_error_response(
  134. self,
  135. message: str,
  136. err_type: str = "BadRequestError",
  137. status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
  138. return ErrorResponse(message=message,
  139. type=err_type,
  140. code=status_code.value)
  141. def create_streaming_error_response(
  142. self,
  143. message: str,
  144. err_type: str = "BadRequestError",
  145. status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
  146. json_str = json.dumps({
  147. "error":
  148. self.create_error_response(message=message,
  149. err_type=err_type,
  150. status_code=status_code).model_dump()
  151. })
  152. return json_str
  153. async def _check_model(
  154. self, request: Union[CompletionRequest, ChatCompletionRequest,
  155. EmbeddingRequest]
  156. ) -> Optional[ErrorResponse]:
  157. if request.model in self.served_model_names:
  158. return
  159. if request.model in [lora.lora_name for lora in self.lora_requests]:
  160. return
  161. return self.create_error_response(
  162. message=f"The model `{request.model}` does not exist.",
  163. err_type="NotFoundError",
  164. status_code=HTTPStatus.NOT_FOUND)
  165. def add_lora(self, lora: LoRA):
  166. if lora.name in [
  167. existing_lora.lora_name for existing_lora in self.lora_requests
  168. ]:
  169. logger.error(f"LoRA with name {lora.name} already exists.")
  170. return
  171. self.lora_requests.append(
  172. LoRARequest(
  173. lora_name=lora.name,
  174. lora_int_id=len(self.lora_requests) + 1,
  175. lora_local_path=lora.local_path,
  176. ))
  177. def remove_lora(self, lora_name: str):
  178. self.lora_requests = [
  179. lora for lora in self.lora_requests if lora.lora_name != lora_name
  180. ]
  181. def _maybe_get_lora(
  182. self, request: Union[CompletionRequest, ChatCompletionRequest,
  183. EmbeddingRequest]
  184. ) -> Optional[LoRARequest]:
  185. if request.model in self.served_model_names:
  186. return
  187. for lora in self.lora_requests:
  188. if request.model == lora.lora_name:
  189. return lora
  190. # if _check_model has been called earlier, this will be unreachable
  191. raise ValueError("The model `{request.model}` does not exist.")
  192. def _validate_prompt_and_tokenize(
  193. self,
  194. request: Union[ChatCompletionRequest, CompletionRequest,
  195. EmbeddingRequest],
  196. prompt: Optional[str] = None,
  197. prompt_ids: Optional[List[int]] = None,
  198. truncate_prompt_tokens: Optional[conint(ge=1)] = None
  199. ) -> Tuple[List[int], str]:
  200. if not (prompt or prompt_ids):
  201. raise ValueError("Either prompt or prompt_ids should be provided.")
  202. if (prompt and prompt_ids):
  203. raise ValueError(
  204. "Only one of prompt or prompt_ids should be provided.")
  205. if prompt_ids is None:
  206. tokenizer_kwargs = {} if truncate_prompt_tokens is None else {
  207. "truncation": True,
  208. "max_length": truncate_prompt_tokens,
  209. }
  210. input_ids = self.tokenizer(prompt, **tokenizer_kwargs).input_ids
  211. elif truncate_prompt_tokens is not None:
  212. input_ids = prompt_ids[-truncate_prompt_tokens:]
  213. else:
  214. input_ids = prompt_ids
  215. input_text = prompt if prompt is not None else self.tokenizer.decode(
  216. prompt_ids)
  217. token_num = len(input_ids)
  218. # Note: EmbeddingRequest doesn't have max_tokens
  219. if isinstance(request, EmbeddingRequest):
  220. if token_num > self.max_model_len:
  221. raise ValueError(
  222. f"This model's maximum context length is "
  223. f"{self.max_model_len} tokens. However, you requested "
  224. f"{token_num} tokens in the input for embedding "
  225. f"generation. Please reduce the length of the input.", )
  226. return input_ids, input_text
  227. if request.max_tokens is None:
  228. request.max_tokens = self.max_model_len - token_num
  229. if token_num + request.max_tokens > self.max_model_len:
  230. raise ValueError(
  231. f"This model's maximum context length is "
  232. f"{self.max_model_len} tokens. However, you requested "
  233. f"{request.max_tokens + token_num} tokens "
  234. f"({token_num} in the messages, "
  235. f"{request.max_tokens} in the completion). "
  236. f"Please reduce the length of the messages or completion.", )
  237. else:
  238. return input_ids, input_text