serving_engine.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. import asyncio
  2. import json
  3. from dataclasses import dataclass
  4. from http import HTTPStatus
  5. from typing import Dict, List, Optional, Union
  6. from aphrodite.transformers_utils.tokenizer import get_tokenizer
  7. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  8. from aphrodite.endpoints.openai.protocol import (CompletionRequest,
  9. ChatCompletionRequest,
  10. ErrorResponse, LogProbs,
  11. ModelCard, ModelList,
  12. ModelPermission, Prompt)
  13. from aphrodite.lora.request import LoRARequest
  14. from aphrodite.common.sequence import Logprob
  15. @dataclass
  16. class LoRA:
  17. name: str
  18. local_path: str
  19. class OpenAIServing:
  20. def __init__(self,
  21. engine: AsyncAphrodite,
  22. served_model: str,
  23. lora_modules=Optional[List[LoRA]]):
  24. self.engine = engine
  25. self.served_model = served_model
  26. if lora_modules is None:
  27. self.lora_requests = []
  28. else:
  29. self.lora_requests = [
  30. LoRARequest(
  31. lora_name=lora.name,
  32. lora_int_id=i,
  33. lora_local_path=lora.local_path,
  34. ) for i, lora in enumerate(lora_modules, start=1)
  35. ]
  36. self.max_model_len = 0
  37. self.tokenizer = None
  38. try:
  39. event_loop = asyncio.get_running_loop()
  40. except RuntimeError:
  41. event_loop = None
  42. if event_loop is not None and event_loop.is_running(
  43. ): # If the current is instanced by Ray Serve, there is already a running event loop
  44. event_loop.create_task(self._post_init())
  45. else: # When using single Aphrodite without engine_use_ray
  46. asyncio.run(self._post_init())
  47. async def _post_init(self):
  48. engine_model_config = await self.engine.get_model_config()
  49. self.max_model_len = engine_model_config.max_model_len
  50. # A separate tokenizer to map token IDs to strings.
  51. self.tokenizer = get_tokenizer(
  52. engine_model_config.tokenizer,
  53. tokenizer_mode=engine_model_config.tokenizer_mode,
  54. trust_remote_code=engine_model_config.trust_remote_code)
  55. async def show_available_models(self) -> ModelList:
  56. """Show available models. Right now we only have one model."""
  57. model_cards = [
  58. ModelCard(id=self.served_model,
  59. root=self.served_model,
  60. permission=[ModelPermission()])
  61. ]
  62. lora_cards = [
  63. ModelCard(id=lora.lora_name,
  64. root=self.served_model,
  65. permission=[ModelPermission()])
  66. for lora in self.lora_requests
  67. ]
  68. model_cards.extend(lora_cards)
  69. return ModelList(data=model_cards)
  70. async def tokenize(self, prompt: Prompt):
  71. """Tokenize a given prompt."""
  72. tokenized_prompt = self.tokenizer.tokenize(prompt.prompt)
  73. token_ids = self.tokenizer.convert_tokens_to_ids(tokenized_prompt)
  74. return {"value": len(tokenized_prompt), "ids": token_ids}
  75. async def detokenize(self, token_ids: List[int]):
  76. """Detokenize a given list of token IDs."""
  77. tokens = self.tokenizer.convert_ids_to_tokens(token_ids)
  78. detokenized_text = self.tokenizer.convert_tokens_to_string(tokens)
  79. return {"value": detokenized_text}
  80. def _create_logprobs(
  81. self,
  82. token_ids: List[int],
  83. top_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None,
  84. num_output_top_logprobs: Optional[int] = None,
  85. initial_text_offset: int = 0,
  86. ) -> LogProbs:
  87. """Create OpenAI-style logprobs."""
  88. logprobs = LogProbs()
  89. last_token_len = 0
  90. if num_output_top_logprobs:
  91. logprobs.top_logprobs = []
  92. for i, token_id in enumerate(token_ids):
  93. step_top_logprobs = top_logprobs[i]
  94. if step_top_logprobs is not None:
  95. token_logprob = step_top_logprobs[token_id].logprob
  96. else:
  97. token_logprob = None
  98. token = step_top_logprobs[token_id].decoded_token
  99. logprobs.tokens.append(token)
  100. logprobs.token_logprobs.append(token_logprob)
  101. if len(logprobs.text_offset) == 0:
  102. logprobs.text_offset.append(initial_text_offset)
  103. else:
  104. logprobs.text_offset.append(logprobs.text_offset[-1] +
  105. last_token_len)
  106. last_token_len = len(token)
  107. if num_output_top_logprobs:
  108. logprobs.top_logprobs.append({
  109. p.decoded_token: p.logprob
  110. for i, p in step_top_logprobs.items()
  111. } if step_top_logprobs else None)
  112. return logprobs
  113. def create_error_response(
  114. self,
  115. message: str,
  116. err_type: str = "BadRequestError",
  117. status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
  118. return ErrorResponse(message=message,
  119. type=err_type,
  120. code=status_code.value)
  121. def create_streaming_error_response(
  122. self,
  123. message: str,
  124. err_type: str = "BadRequestError",
  125. status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
  126. json_str = json.dumps({
  127. "error":
  128. self.create_error_response(message=message,
  129. err_type=err_type,
  130. status_code=status_code).model_dump()
  131. })
  132. return json_str
  133. async def _check_model(self, request) -> Optional[ErrorResponse]:
  134. if request.model == self.served_model:
  135. return
  136. if request.model in [lora.lora_name for lora in self.lora_requests]:
  137. return
  138. return self.create_error_response(
  139. message=f"The model `{request.model}` does not exist.",
  140. err_type="NotFoundError",
  141. status_code=HTTPStatus.NOT_FOUND)
  142. def _maybe_get_lora(self, request) -> Optional[LoRARequest]:
  143. if request.model == self.served_model:
  144. return
  145. for lora in self.lora_requests:
  146. if request.model == lora.lora_name:
  147. return lora
  148. # if _check_model has been called earlier, this will be unreachable
  149. raise ValueError("The model `{request.model}` does not exist.")
  150. def _validate_prompt_and_tokenize(
  151. self,
  152. request: Union[ChatCompletionRequest, CompletionRequest],
  153. prompt: Optional[str] = None,
  154. prompt_ids: Optional[List[int]] = None) -> List[int]:
  155. if not (prompt or prompt_ids):
  156. raise ValueError("Either prompt or prompt_ids should be provided.")
  157. if (prompt and prompt_ids):
  158. raise ValueError(
  159. "Only one of prompt or prompt_ids should be provided.")
  160. input_ids = prompt_ids if prompt_ids is not None else self.tokenizer(
  161. prompt).input_ids
  162. token_num = len(input_ids)
  163. if request.max_tokens is None:
  164. request.max_tokens = self.max_model_len - token_num
  165. if token_num + request.max_tokens > self.max_model_len:
  166. raise ValueError(
  167. f"This model's maximum context length is {self.max_model_len} tokens. "
  168. f"However, you requested {request.max_tokens + token_num} tokens "
  169. f"({token_num} in the messages, "
  170. f"{request.max_tokens} in the completion). "
  171. f"Please reduce the length of the messages or completion.", )
  172. else:
  173. return input_ids