serving_engine.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394
  1. import json
  2. import pathlib
  3. from dataclasses import dataclass
  4. from http import HTTPStatus
  5. from typing import Iterable, Iterator, List, Optional, Tuple, TypedDict, Union
  6. from pydantic import Field
  7. from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
  8. from typing_extensions import Annotated
  9. from aphrodite.common.config import ModelConfig
  10. from aphrodite.common.pooling_params import PoolingParams
  11. from aphrodite.common.sampling_params import SamplingParams
  12. from aphrodite.common.sequence import Logprob
  13. from aphrodite.endpoints.logger import RequestLogger
  14. # yapf conflicts with isort here
  15. # yapf: disable
  16. from aphrodite.endpoints.openai.protocol import (ChatCompletionRequest,
  17. CompletionRequest,
  18. DetokenizeRequest,
  19. EmbeddingRequest,
  20. ErrorResponse, ModelCard,
  21. ModelList, ModelPermission,
  22. TokenizeChatRequest,
  23. TokenizeCompletionRequest,
  24. TokenizeRequest)
  25. # yapf: enable
  26. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  27. from aphrodite.inputs import parse_and_batch_prompt
  28. from aphrodite.lora.request import LoRARequest
  29. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  30. @dataclass
  31. class PromptAdapterPath:
  32. name: str
  33. local_path: str
  34. @dataclass
  35. class LoRAModulePath:
  36. name: str
  37. path: str
  38. AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest,
  39. EmbeddingRequest, TokenizeRequest]
  40. AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
  41. class TextTokensPrompt(TypedDict):
  42. prompt: str
  43. prompt_token_ids: List[int]
  44. class OpenAIServing:
  45. def __init__(
  46. self,
  47. engine: AsyncAphrodite,
  48. model_config: ModelConfig,
  49. served_model_names: List[str],
  50. *,
  51. lora_modules: Optional[List[LoRAModulePath]],
  52. prompt_adapters: Optional[List[PromptAdapterPath]],
  53. request_logger: Optional[RequestLogger],
  54. ):
  55. super().__init__()
  56. self.engine = engine
  57. self.model_config = model_config
  58. self.max_model_len = model_config.max_model_len
  59. self.served_model_names = served_model_names
  60. self.lora_requests = []
  61. if lora_modules is not None:
  62. self.lora_requests = [
  63. LoRARequest(
  64. lora_name=lora.name,
  65. lora_int_id=i,
  66. lora_path=lora.path,
  67. ) for i, lora in enumerate(lora_modules, start=1)
  68. ]
  69. self.prompt_adapter_requests = []
  70. if prompt_adapters is not None:
  71. for i, prompt_adapter in enumerate(prompt_adapters, start=1):
  72. with pathlib.Path(prompt_adapter.local_path,
  73. "adapter_config.json").open() as f:
  74. adapter_config = json.load(f)
  75. num_virtual_tokens = adapter_config["num_virtual_tokens"]
  76. self.prompt_adapter_requests.append(
  77. PromptAdapterRequest(
  78. prompt_adapter_name=prompt_adapter.name,
  79. prompt_adapter_id=i,
  80. prompt_adapter_local_path=prompt_adapter.local_path,
  81. prompt_adapter_num_virtual_tokens=num_virtual_tokens))
  82. self.request_logger = request_logger
  83. async def show_available_models(self) -> ModelList:
  84. """Show available models. Right now we only have one model."""
  85. model_cards = [
  86. ModelCard(id=served_model_name,
  87. max_model_len=self.max_model_len,
  88. root=self.served_model_names[0],
  89. permission=[ModelPermission()])
  90. for served_model_name in self.served_model_names
  91. ]
  92. lora_cards = [
  93. ModelCard(id=lora.lora_name,
  94. root=self.served_model_names[0],
  95. permission=[ModelPermission()])
  96. for lora in self.lora_requests
  97. ]
  98. prompt_adapter_cards = [
  99. ModelCard(id=prompt_adapter.prompt_adapter_name,
  100. root=self.served_model_names[0],
  101. permission=[ModelPermission()])
  102. for prompt_adapter in self.prompt_adapter_requests
  103. ]
  104. model_cards.extend(lora_cards)
  105. model_cards.extend(prompt_adapter_cards)
  106. return ModelList(data=model_cards)
  107. def create_error_response(
  108. self,
  109. message: str,
  110. err_type: str = "BadRequestError",
  111. status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
  112. return ErrorResponse(message=message,
  113. type=err_type,
  114. code=status_code.value)
  115. def create_streaming_error_response(
  116. self,
  117. message: str,
  118. err_type: str = "BadRequestError",
  119. status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
  120. json_str = json.dumps({
  121. "error":
  122. self.create_error_response(message=message,
  123. err_type=err_type,
  124. status_code=status_code).model_dump()
  125. })
  126. return json_str
  127. async def _check_model(
  128. self,
  129. request: AnyRequest,
  130. ) -> Optional[ErrorResponse]:
  131. # only check these if it's not a Tokenizer/Detokenize Request
  132. if not isinstance(request, (TokenizeRequest, DetokenizeRequest)):
  133. if request.model in self.served_model_names:
  134. return None
  135. if request.model in [
  136. lora.lora_name for lora in self.lora_requests
  137. ]:
  138. return None
  139. if request.model in [
  140. prompt_adapter.prompt_adapter_name
  141. for prompt_adapter in self.prompt_adapter_requests
  142. ]:
  143. return None
  144. return self.create_error_response(
  145. message=f"The model `{request.model}` does not exist.",
  146. err_type="NotFoundError",
  147. status_code=HTTPStatus.NOT_FOUND)
  148. def _maybe_get_adapters(
  149. self, request: AnyRequest
  150. ) -> Union[Tuple[None, None], Tuple[LoRARequest, None], Tuple[
  151. None, PromptAdapterRequest]]:
  152. if request.model in self.served_model_names:
  153. return None, None
  154. for lora in self.lora_requests:
  155. if request.model == lora.lora_name:
  156. return lora, None
  157. for prompt_adapter in self.prompt_adapter_requests:
  158. if request.model == prompt_adapter.prompt_adapter_name:
  159. return None, prompt_adapter
  160. # if _check_model has been called earlier, this will be unreachable
  161. raise ValueError(f"The model `{request.model}` does not exist.")
  162. def _normalize_prompt_text_to_input(
  163. self,
  164. request: AnyRequest,
  165. tokenizer: AnyTokenizer,
  166. prompt: str,
  167. truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
  168. add_special_tokens: bool,
  169. ) -> TextTokensPrompt:
  170. if truncate_prompt_tokens is None:
  171. encoded = tokenizer(prompt, add_special_tokens=add_special_tokens)
  172. else:
  173. encoded = tokenizer(prompt,
  174. add_special_tokens=add_special_tokens,
  175. truncation=True,
  176. max_length=truncate_prompt_tokens)
  177. input_ids = encoded.input_ids
  178. input_text = prompt
  179. return self._validate_input(request, input_ids, input_text)
  180. def _normalize_prompt_tokens_to_input(
  181. self,
  182. request: AnyRequest,
  183. tokenizer: AnyTokenizer,
  184. prompt_ids: List[int],
  185. truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
  186. ) -> TextTokensPrompt:
  187. if truncate_prompt_tokens is None:
  188. input_ids = prompt_ids
  189. else:
  190. input_ids = prompt_ids[-truncate_prompt_tokens:]
  191. input_text = tokenizer.decode(input_ids)
  192. return self._validate_input(request, input_ids, input_text)
  193. def _validate_input(
  194. self,
  195. request: AnyRequest,
  196. input_ids: List[int],
  197. input_text: str,
  198. ) -> TextTokensPrompt:
  199. token_num = len(input_ids)
  200. # Note: EmbeddingRequest doesn't have max_tokens
  201. if isinstance(request, EmbeddingRequest):
  202. if token_num > self.max_model_len:
  203. raise ValueError(
  204. f"This model's maximum context length is "
  205. f"{self.max_model_len} tokens. However, you requested "
  206. f"{token_num} tokens in the input for embedding "
  207. f"generation. Please reduce the length of the input.")
  208. return TextTokensPrompt(prompt=input_text,
  209. prompt_token_ids=input_ids)
  210. # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
  211. # and does not require model context length validation
  212. if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest,
  213. DetokenizeRequest)):
  214. return TextTokensPrompt(prompt=input_text,
  215. prompt_token_ids=input_ids)
  216. if request.max_tokens is None:
  217. if token_num >= self.max_model_len:
  218. raise ValueError(
  219. f"This model's maximum context length is "
  220. f"{self.max_model_len} tokens. However, you requested "
  221. f"{token_num} tokens in the messages, "
  222. f"Please reduce the length of the messages.")
  223. request.max_tokens = self.max_model_len - token_num
  224. if token_num + request.max_tokens > self.max_model_len:
  225. raise ValueError(
  226. f"This model's maximum context length is "
  227. f"{self.max_model_len} tokens. However, you requested "
  228. f"{request.max_tokens + token_num} tokens "
  229. f"({token_num} in the messages, "
  230. f"{request.max_tokens} in the completion). "
  231. f"Please reduce the length of the messages or completion.")
  232. return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
  233. def _tokenize_prompt_input(
  234. self,
  235. request: AnyRequest,
  236. tokenizer: AnyTokenizer,
  237. prompt_input: Union[str, List[int]],
  238. truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
  239. add_special_tokens: bool = True,
  240. ) -> TextTokensPrompt:
  241. """
  242. A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
  243. that assumes single input.
  244. """
  245. return next(
  246. self._tokenize_prompt_inputs(
  247. request,
  248. tokenizer,
  249. [prompt_input],
  250. truncate_prompt_tokens=truncate_prompt_tokens,
  251. add_special_tokens=add_special_tokens,
  252. ))
  253. def _tokenize_prompt_inputs(
  254. self,
  255. request: AnyRequest,
  256. tokenizer: AnyTokenizer,
  257. prompt_inputs: Iterable[Union[str, List[int]]],
  258. truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
  259. add_special_tokens: bool = True,
  260. ) -> Iterator[TextTokensPrompt]:
  261. """
  262. A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
  263. that assumes multiple inputs.
  264. """
  265. for text in prompt_inputs:
  266. if isinstance(text, str):
  267. yield self._normalize_prompt_text_to_input(
  268. request,
  269. tokenizer,
  270. prompt=text,
  271. truncate_prompt_tokens=truncate_prompt_tokens,
  272. add_special_tokens=add_special_tokens,
  273. )
  274. else:
  275. yield self._normalize_prompt_tokens_to_input(
  276. request,
  277. tokenizer,
  278. prompt_ids=text,
  279. truncate_prompt_tokens=truncate_prompt_tokens,
  280. )
  281. def _tokenize_prompt_input_or_inputs(
  282. self,
  283. request: AnyRequest,
  284. tokenizer: AnyTokenizer,
  285. input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
  286. truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
  287. add_special_tokens: bool = True,
  288. ) -> Iterator[TextTokensPrompt]:
  289. """
  290. Tokenize/detokenize depending on the input format.
  291. According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_
  292. , each input can be a string or array of tokens. Note that each request
  293. can pass one or more inputs.
  294. """
  295. for prompt_input in parse_and_batch_prompt(input_or_inputs):
  296. # Although our type checking is based on mypy,
  297. # VSCode Pyright extension should still work properly
  298. # "is True" is required for Pyright to perform type narrowing
  299. # See: https://github.com/microsoft/pyright/issues/7672
  300. if prompt_input["is_tokens"] is False:
  301. yield self._normalize_prompt_text_to_input(
  302. request,
  303. tokenizer,
  304. prompt=prompt_input["content"],
  305. truncate_prompt_tokens=truncate_prompt_tokens,
  306. add_special_tokens=add_special_tokens,
  307. )
  308. else:
  309. yield self._normalize_prompt_tokens_to_input(
  310. request,
  311. tokenizer,
  312. prompt_ids=prompt_input["content"],
  313. truncate_prompt_tokens=truncate_prompt_tokens,
  314. )
  315. def _log_inputs(
  316. self,
  317. request_id: str,
  318. inputs: Union[str, List[int], TextTokensPrompt],
  319. params: Optional[Union[SamplingParams, PoolingParams]],
  320. lora_request: Optional[LoRARequest],
  321. prompt_adapter_request: Optional[PromptAdapterRequest],
  322. ) -> None:
  323. if self.request_logger is None:
  324. return
  325. if isinstance(inputs, str):
  326. prompt = inputs
  327. prompt_token_ids = None
  328. elif isinstance(inputs, list):
  329. prompt = None
  330. prompt_token_ids = inputs
  331. else:
  332. prompt = inputs["prompt"]
  333. prompt_token_ids = inputs["prompt_token_ids"]
  334. self.request_logger.log_inputs(
  335. request_id,
  336. prompt,
  337. prompt_token_ids,
  338. params=params,
  339. lora_request=lora_request,
  340. prompt_adapter_request=prompt_adapter_request,
  341. )
  342. @staticmethod
  343. def _get_decoded_token(
  344. logprob: Logprob,
  345. token_id: int,
  346. tokenizer: AnyTokenizer,
  347. ) -> str:
  348. if logprob.decoded_token is not None:
  349. return logprob.decoded_token
  350. return tokenizer.decode(token_id)