serving_engine.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. import asyncio
  2. import json
  3. from dataclasses import dataclass
  4. from http import HTTPStatus
  5. from typing import Dict, List, Optional, Union
  6. from loguru import logger
  7. from aphrodite.transformers_utils.tokenizer import get_tokenizer
  8. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  9. from aphrodite.endpoints.openai.protocol import (CompletionRequest,
  10. ChatCompletionRequest,
  11. ErrorResponse, LogProbs,
  12. ModelCard, ModelList,
  13. ModelPermission, Prompt)
  14. from aphrodite.lora.request import LoRARequest
  15. from aphrodite.common.sequence import Logprob
  16. @dataclass
  17. class LoRA:
  18. name: str
  19. local_path: str
  20. class OpenAIServing:
  21. def __init__(self,
  22. engine: AsyncAphrodite,
  23. served_model: str,
  24. lora_modules=Optional[List[LoRA]]):
  25. self.engine = engine
  26. self.served_model = served_model
  27. if lora_modules is None:
  28. self.lora_requests = []
  29. else:
  30. self.lora_requests = [
  31. LoRARequest(
  32. lora_name=lora.name,
  33. lora_int_id=i,
  34. lora_local_path=lora.local_path,
  35. ) for i, lora in enumerate(lora_modules, start=1)
  36. ]
  37. self.max_model_len = 0
  38. self.tokenizer = None
  39. try:
  40. event_loop = asyncio.get_running_loop()
  41. except RuntimeError:
  42. event_loop = None
  43. if event_loop is not None and event_loop.is_running(
  44. ): # If the current is instanced by Ray Serve, there is already
  45. # a running event loop
  46. event_loop.create_task(self._post_init())
  47. else: # When using single Aphrodite without engine_use_ray
  48. asyncio.run(self._post_init())
  49. async def _post_init(self):
  50. engine_model_config = await self.engine.get_model_config()
  51. self.max_model_len = engine_model_config.max_model_len
  52. # A separate tokenizer to map token IDs to strings.
  53. self.tokenizer = get_tokenizer(
  54. engine_model_config.tokenizer,
  55. tokenizer_mode=engine_model_config.tokenizer_mode,
  56. trust_remote_code=engine_model_config.trust_remote_code)
  57. async def show_available_models(self) -> ModelList:
  58. """Show available models. Right now we only have one model."""
  59. model_cards = [
  60. ModelCard(id=self.served_model,
  61. root=self.served_model,
  62. permission=[ModelPermission()])
  63. ]
  64. lora_cards = [
  65. ModelCard(id=lora.lora_name,
  66. root=self.served_model,
  67. permission=[ModelPermission()])
  68. for lora in self.lora_requests
  69. ]
  70. model_cards.extend(lora_cards)
  71. return ModelList(data=model_cards)
  72. async def tokenize(self, prompt: Prompt):
  73. """Tokenize a given prompt."""
  74. tokenized_prompt = self.tokenizer.tokenize(prompt.prompt)
  75. token_ids = self.tokenizer.convert_tokens_to_ids(tokenized_prompt)
  76. return {"value": len(tokenized_prompt), "ids": token_ids}
  77. async def detokenize(self, token_ids: List[int]):
  78. """Detokenize a given list of token IDs."""
  79. tokens = self.tokenizer.convert_ids_to_tokens(token_ids)
  80. detokenized_text = self.tokenizer.convert_tokens_to_string(tokens)
  81. return {"value": detokenized_text}
  82. def _create_logprobs(
  83. self,
  84. token_ids: List[int],
  85. top_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None,
  86. num_output_top_logprobs: Optional[int] = None,
  87. initial_text_offset: int = 0,
  88. ) -> LogProbs:
  89. """Create OpenAI-style logprobs."""
  90. logprobs = LogProbs()
  91. last_token_len = 0
  92. if num_output_top_logprobs:
  93. logprobs.top_logprobs = []
  94. for i, token_id in enumerate(token_ids):
  95. step_top_logprobs = top_logprobs[i]
  96. if step_top_logprobs is not None:
  97. token_logprob = step_top_logprobs[token_id].logprob
  98. else:
  99. token_logprob = None
  100. token = step_top_logprobs[token_id].decoded_token
  101. logprobs.tokens.append(token)
  102. logprobs.token_logprobs.append(token_logprob)
  103. if len(logprobs.text_offset) == 0:
  104. logprobs.text_offset.append(initial_text_offset)
  105. else:
  106. logprobs.text_offset.append(logprobs.text_offset[-1] +
  107. last_token_len)
  108. last_token_len = len(token)
  109. if num_output_top_logprobs:
  110. logprobs.top_logprobs.append({
  111. p.decoded_token: p.logprob
  112. for i, p in step_top_logprobs.items()
  113. } if step_top_logprobs else None)
  114. logprobs.top_logprobs = [{
  115. k: v if v > -1000 else -1000
  116. for k, v in top_logprob.items()
  117. } for top_logprob in logprobs.top_logprobs if top_logprob is not None]
  118. return logprobs
  119. def create_error_response(
  120. self,
  121. message: str,
  122. err_type: str = "BadRequestError",
  123. status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
  124. return ErrorResponse(message=message,
  125. type=err_type,
  126. code=status_code.value)
  127. def create_streaming_error_response(
  128. self,
  129. message: str,
  130. err_type: str = "BadRequestError",
  131. status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
  132. json_str = json.dumps({
  133. "error":
  134. self.create_error_response(message=message,
  135. err_type=err_type,
  136. status_code=status_code).model_dump()
  137. })
  138. return json_str
  139. async def _check_model(self, request) -> Optional[ErrorResponse]:
  140. if request.model == self.served_model:
  141. return
  142. if request.model in [lora.lora_name for lora in self.lora_requests]:
  143. return
  144. return self.create_error_response(
  145. message=f"The model `{request.model}` does not exist.",
  146. err_type="NotFoundError",
  147. status_code=HTTPStatus.NOT_FOUND)
  148. def add_lora(self, lora: LoRA):
  149. if lora.name in [
  150. existing_lora.lora_name for existing_lora in self.lora_requests
  151. ]:
  152. logger.error(f"LoRA with name {lora.name} already exists.")
  153. return
  154. self.lora_requests.append(
  155. LoRARequest(
  156. lora_name=lora.name,
  157. lora_int_id=len(self.lora_requests) + 1,
  158. lora_local_path=lora.local_path,
  159. ))
  160. def remove_lora(self, lora_name: str):
  161. self.lora_requests = [
  162. lora for lora in self.lora_requests if lora.lora_name != lora_name
  163. ]
  164. def _maybe_get_lora(self, request) -> Optional[LoRARequest]:
  165. if request.model == self.served_model:
  166. return
  167. for lora in self.lora_requests:
  168. if request.model == lora.lora_name:
  169. return lora
  170. # if _check_model has been called earlier, this will be unreachable
  171. raise ValueError("The model `{request.model}` does not exist.")
  172. def _validate_prompt_and_tokenize(
  173. self,
  174. request: Union[ChatCompletionRequest, CompletionRequest],
  175. prompt: Optional[str] = None,
  176. prompt_ids: Optional[List[int]] = None) -> List[int]:
  177. if not (prompt or prompt_ids):
  178. raise ValueError("Either prompt or prompt_ids should be provided.")
  179. if (prompt and prompt_ids):
  180. raise ValueError(
  181. "Only one of prompt or prompt_ids should be provided.")
  182. input_ids = prompt_ids if prompt_ids is not None else self.tokenizer(
  183. prompt).input_ids
  184. token_num = len(input_ids)
  185. if request.max_tokens is None:
  186. request.max_tokens = self.max_model_len - token_num
  187. if token_num + request.max_tokens > self.max_model_len:
  188. raise ValueError(
  189. f"This model's maximum context length is {self.max_model_len} "
  190. "tokens. "
  191. f"However, you requested {request.max_tokens + token_num} "
  192. f"tokens ({token_num} in the messages, "
  193. f"{request.max_tokens} in the completion). "
  194. f"Please reduce the length of the messages or completion.", )
  195. else:
  196. return input_ids