serving_engine.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. import json
  2. import pathlib
  3. from dataclasses import dataclass
  4. from http import HTTPStatus
  5. from typing import Iterable, Iterator, List, Optional, Tuple, TypedDict, Union
  6. from loguru import logger
  7. from pydantic import Field
  8. from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
  9. from typing_extensions import Annotated
  10. from aphrodite.common.config import ModelConfig
  11. from aphrodite.common.pooling_params import PoolingParams
  12. from aphrodite.common.sampling_params import (LogitsProcessorFunc,
  13. SamplingParams)
  14. from aphrodite.common.sequence import Logprob
  15. from aphrodite.endpoints.logger import RequestLogger
  16. # yapf conflicts with isort here
  17. # yapf: disable
  18. from aphrodite.endpoints.openai.protocol import (ChatCompletionRequest,
  19. CompletionRequest,
  20. DetokenizeRequest,
  21. EmbeddingRequest,
  22. ErrorResponse, ModelCard,
  23. ModelList, ModelPermission,
  24. TokenizeChatRequest,
  25. TokenizeCompletionRequest,
  26. TokenizeRequest)
  27. # yapf: enable
  28. from aphrodite.engine.protocol import AsyncEngineClient
  29. from aphrodite.inputs.parse import parse_and_batch_prompt
  30. from aphrodite.lora.request import LoRARequest
  31. from aphrodite.modeling.guided_decoding import (
  32. get_guided_decoding_logits_processor)
  33. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  34. @dataclass
  35. class PromptAdapterPath:
  36. name: str
  37. local_path: str
  38. @dataclass
  39. class LoRAModulePath:
  40. name: str
  41. path: str
  42. AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest,
  43. EmbeddingRequest, TokenizeRequest]
  44. AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
  45. class TextTokensPrompt(TypedDict):
  46. prompt: str
  47. prompt_token_ids: List[int]
  48. class OpenAIServing:
  49. def __init__(
  50. self,
  51. async_engine_client: AsyncEngineClient,
  52. model_config: ModelConfig,
  53. served_model_names: List[str],
  54. *,
  55. lora_modules: Optional[List[LoRAModulePath]],
  56. prompt_adapters: Optional[List[PromptAdapterPath]],
  57. request_logger: Optional[RequestLogger],
  58. return_tokens_as_token_ids: bool = False,
  59. ):
  60. super().__init__()
  61. self.async_engine_client = async_engine_client
  62. self.model_config = model_config
  63. self.max_model_len = model_config.max_model_len
  64. self.served_model_names = served_model_names
  65. self.lora_requests = []
  66. if lora_modules is not None:
  67. self.lora_requests = [
  68. LoRARequest(
  69. lora_name=lora.name,
  70. lora_int_id=i,
  71. lora_path=lora.path,
  72. ) for i, lora in enumerate(lora_modules, start=1)
  73. ]
  74. self.prompt_adapter_requests = []
  75. if prompt_adapters is not None:
  76. for i, prompt_adapter in enumerate(prompt_adapters, start=1):
  77. with pathlib.Path(prompt_adapter.local_path,
  78. "adapter_config.json").open() as f:
  79. adapter_config = json.load(f)
  80. num_virtual_tokens = adapter_config["num_virtual_tokens"]
  81. self.prompt_adapter_requests.append(
  82. PromptAdapterRequest(
  83. prompt_adapter_name=prompt_adapter.name,
  84. prompt_adapter_id=i,
  85. prompt_adapter_local_path=prompt_adapter.local_path,
  86. prompt_adapter_num_virtual_tokens=num_virtual_tokens))
  87. self.request_logger = request_logger
  88. self.return_tokens_as_token_ids = return_tokens_as_token_ids
  89. async def show_available_models(self) -> ModelList:
  90. """Show available models. Right now we only have one model."""
  91. model_cards = [
  92. ModelCard(id=served_model_name,
  93. max_model_len=self.max_model_len,
  94. root=self.served_model_names[0],
  95. permission=[ModelPermission()])
  96. for served_model_name in self.served_model_names
  97. ]
  98. lora_cards = [
  99. ModelCard(id=lora.lora_name,
  100. root=self.served_model_names[0],
  101. permission=[ModelPermission()])
  102. for lora in self.lora_requests
  103. ]
  104. prompt_adapter_cards = [
  105. ModelCard(id=prompt_adapter.prompt_adapter_name,
  106. root=self.served_model_names[0],
  107. permission=[ModelPermission()])
  108. for prompt_adapter in self.prompt_adapter_requests
  109. ]
  110. model_cards.extend(lora_cards)
  111. model_cards.extend(prompt_adapter_cards)
  112. return ModelList(data=model_cards)
  113. def create_error_response(
  114. self,
  115. message: str,
  116. err_type: str = "BadRequestError",
  117. status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
  118. return ErrorResponse(message=message,
  119. type=err_type,
  120. code=status_code.value)
  121. def create_streaming_error_response(
  122. self,
  123. message: str,
  124. err_type: str = "BadRequestError",
  125. status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
  126. json_str = json.dumps({
  127. "error":
  128. self.create_error_response(message=message,
  129. err_type=err_type,
  130. status_code=status_code).model_dump()
  131. })
  132. return json_str
  133. async def _guided_decode_logits_processor(
  134. self, request: Union[ChatCompletionRequest, CompletionRequest],
  135. tokenizer: AnyTokenizer) -> Optional[LogitsProcessorFunc]:
  136. decoding_config = await self.async_engine_client.get_decoding_config()
  137. guided_decoding_backend = request.guided_decoding_backend \
  138. or decoding_config.guided_decoding_backend
  139. return await get_guided_decoding_logits_processor(
  140. guided_decoding_backend, request, tokenizer)
  141. async def _check_model(
  142. self,
  143. request: AnyRequest,
  144. ) -> Optional[ErrorResponse]:
  145. # only check these if it's not a Tokenizer/Detokenize Request
  146. if not isinstance(request, (TokenizeRequest, DetokenizeRequest)):
  147. if request.model in self.served_model_names:
  148. return None
  149. if request.model in [
  150. lora.lora_name for lora in self.lora_requests
  151. ]:
  152. return None
  153. if request.model in [
  154. prompt_adapter.prompt_adapter_name
  155. for prompt_adapter in self.prompt_adapter_requests
  156. ]:
  157. return None
  158. return self.create_error_response(
  159. message=f"The model `{request.model}` does not exist.",
  160. err_type="NotFoundError",
  161. status_code=HTTPStatus.NOT_FOUND)
  162. def _maybe_get_adapters(
  163. self, request: AnyRequest
  164. ) -> Union[Tuple[None, None], Tuple[LoRARequest, None], Tuple[
  165. None, PromptAdapterRequest]]:
  166. if request.model in self.served_model_names:
  167. return None, None
  168. for lora in self.lora_requests:
  169. if request.model == lora.lora_name:
  170. return lora, None
  171. for prompt_adapter in self.prompt_adapter_requests:
  172. if request.model == prompt_adapter.prompt_adapter_name:
  173. return None, prompt_adapter
  174. # if _check_model has been called earlier, this will be unreachable
  175. raise ValueError(f"The model `{request.model}` does not exist.")
  176. def add_lora(self, lora: LoRAModulePath):
  177. if lora.name in [
  178. lora.lora_name for lora in self.lora_requests
  179. ]:
  180. logger.error(f"LoRA module {lora.name} already exists.")
  181. return
  182. self.lora_requests.append(
  183. LoRARequest(
  184. lora_name=lora.name,
  185. lora_int_id=len(self.lora_requests) + 1,
  186. lora_path=lora.path,
  187. ))
  188. def remove_lora(self, lora_name: str):
  189. self.lora_requests = [
  190. lora for lora in self.lora_requests if lora.lora_name != lora_name
  191. ]
  192. def add_prompt_adapter(self, prompt_adapter: PromptAdapterPath):
  193. if prompt_adapter.name in [
  194. prompt_adapter.prompt_adapter_name
  195. for prompt_adapter in self.prompt_adapter_requests
  196. ]:
  197. logger.error(
  198. f"Prompt adapter {prompt_adapter.name} already exists.")
  199. return
  200. with pathlib.Path(prompt_adapter.local_path,
  201. "adapter_config.json").open() as f:
  202. adapter_config = json.load(f)
  203. num_virtual_tokens = adapter_config["num_virtual_tokens"]
  204. self.prompt_adapter_requests.append(
  205. PromptAdapterRequest(
  206. prompt_adapter_name=prompt_adapter.name,
  207. prompt_adapter_id=len(self.prompt_adapter_requests) + 1,
  208. prompt_adapter_local_path=prompt_adapter.local_path,
  209. prompt_adapter_num_virtual_tokens=num_virtual_tokens))
  210. def remove_prompt_adapter(self, prompt_adapter_name: str):
  211. self.prompt_adapter_requests = [
  212. prompt_adapter for prompt_adapter in self.prompt_adapter_requests
  213. if prompt_adapter.prompt_adapter_name != prompt_adapter_name
  214. ]
  215. def _normalize_prompt_text_to_input(
  216. self,
  217. request: AnyRequest,
  218. tokenizer: AnyTokenizer,
  219. prompt: str,
  220. truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
  221. add_special_tokens: bool,
  222. ) -> TextTokensPrompt:
  223. if truncate_prompt_tokens is None:
  224. encoded = tokenizer(prompt, add_special_tokens=add_special_tokens)
  225. else:
  226. encoded = tokenizer(prompt,
  227. add_special_tokens=add_special_tokens,
  228. truncation=True,
  229. max_length=truncate_prompt_tokens)
  230. input_ids = encoded.input_ids
  231. input_text = prompt
  232. return self._validate_input(request, input_ids, input_text)
  233. def _normalize_prompt_tokens_to_input(
  234. self,
  235. request: AnyRequest,
  236. tokenizer: AnyTokenizer,
  237. prompt_ids: List[int],
  238. truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
  239. ) -> TextTokensPrompt:
  240. if truncate_prompt_tokens is None:
  241. input_ids = prompt_ids
  242. else:
  243. input_ids = prompt_ids[-truncate_prompt_tokens:]
  244. input_text = tokenizer.decode(input_ids)
  245. return self._validate_input(request, input_ids, input_text)
  246. def _validate_input(
  247. self,
  248. request: AnyRequest,
  249. input_ids: List[int],
  250. input_text: str,
  251. ) -> TextTokensPrompt:
  252. token_num = len(input_ids)
  253. # Note: EmbeddingRequest doesn't have max_tokens
  254. if isinstance(request, EmbeddingRequest):
  255. if token_num > self.max_model_len:
  256. raise ValueError(
  257. f"This model's maximum context length is "
  258. f"{self.max_model_len} tokens. However, you requested "
  259. f"{token_num} tokens in the input for embedding "
  260. f"generation. Please reduce the length of the input.")
  261. return TextTokensPrompt(prompt=input_text,
  262. prompt_token_ids=input_ids)
  263. # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
  264. # and does not require model context length validation
  265. if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest,
  266. DetokenizeRequest)):
  267. return TextTokensPrompt(prompt=input_text,
  268. prompt_token_ids=input_ids)
  269. if request.max_tokens is None:
  270. if token_num >= self.max_model_len:
  271. raise ValueError(
  272. f"This model's maximum context length is "
  273. f"{self.max_model_len} tokens. However, you requested "
  274. f"{token_num} tokens in the messages, "
  275. f"Please reduce the length of the messages.")
  276. elif token_num + request.max_tokens > self.max_model_len:
  277. raise ValueError(
  278. f"This model's maximum context length is "
  279. f"{self.max_model_len} tokens. However, you requested "
  280. f"{request.max_tokens + token_num} tokens "
  281. f"({token_num} in the messages, "
  282. f"{request.max_tokens} in the completion). "
  283. f"Please reduce the length of the messages or completion.")
  284. return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
  285. def _tokenize_prompt_input(
  286. self,
  287. request: AnyRequest,
  288. tokenizer: AnyTokenizer,
  289. prompt_input: Union[str, List[int]],
  290. truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
  291. add_special_tokens: bool = True,
  292. ) -> TextTokensPrompt:
  293. """
  294. A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
  295. that assumes single input.
  296. """
  297. return next(
  298. self._tokenize_prompt_inputs(
  299. request,
  300. tokenizer,
  301. [prompt_input],
  302. truncate_prompt_tokens=truncate_prompt_tokens,
  303. add_special_tokens=add_special_tokens,
  304. ))
  305. def _tokenize_prompt_inputs(
  306. self,
  307. request: AnyRequest,
  308. tokenizer: AnyTokenizer,
  309. prompt_inputs: Iterable[Union[str, List[int]]],
  310. truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
  311. add_special_tokens: bool = True,
  312. ) -> Iterator[TextTokensPrompt]:
  313. """
  314. A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
  315. that assumes multiple inputs.
  316. """
  317. for text in prompt_inputs:
  318. if isinstance(text, str):
  319. yield self._normalize_prompt_text_to_input(
  320. request,
  321. tokenizer,
  322. prompt=text,
  323. truncate_prompt_tokens=truncate_prompt_tokens,
  324. add_special_tokens=add_special_tokens,
  325. )
  326. else:
  327. yield self._normalize_prompt_tokens_to_input(
  328. request,
  329. tokenizer,
  330. prompt_ids=text,
  331. truncate_prompt_tokens=truncate_prompt_tokens,
  332. )
  333. def _tokenize_prompt_input_or_inputs(
  334. self,
  335. request: AnyRequest,
  336. tokenizer: AnyTokenizer,
  337. input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
  338. truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
  339. add_special_tokens: bool = True,
  340. ) -> Iterator[TextTokensPrompt]:
  341. """
  342. Tokenize/detokenize depending on the input format.
  343. According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_
  344. , each input can be a string or array of tokens. Note that each request
  345. can pass one or more inputs.
  346. """
  347. for prompt_input in parse_and_batch_prompt(input_or_inputs):
  348. # Although our type checking is based on mypy,
  349. # VSCode Pyright extension should still work properly
  350. # "is True" is required for Pyright to perform type narrowing
  351. # See: https://github.com/microsoft/pyright/issues/7672
  352. if prompt_input["is_tokens"] is False:
  353. yield self._normalize_prompt_text_to_input(
  354. request,
  355. tokenizer,
  356. prompt=prompt_input["content"],
  357. truncate_prompt_tokens=truncate_prompt_tokens,
  358. add_special_tokens=add_special_tokens,
  359. )
  360. else:
  361. yield self._normalize_prompt_tokens_to_input(
  362. request,
  363. tokenizer,
  364. prompt_ids=prompt_input["content"],
  365. truncate_prompt_tokens=truncate_prompt_tokens,
  366. )
  367. def _log_inputs(
  368. self,
  369. request_id: str,
  370. inputs: Union[str, List[int], TextTokensPrompt],
  371. params: Optional[Union[SamplingParams, PoolingParams]],
  372. lora_request: Optional[LoRARequest],
  373. prompt_adapter_request: Optional[PromptAdapterRequest],
  374. ) -> None:
  375. if self.request_logger is None:
  376. return
  377. if isinstance(inputs, str):
  378. prompt = inputs
  379. prompt_token_ids = None
  380. elif isinstance(inputs, list):
  381. prompt = None
  382. prompt_token_ids = inputs
  383. else:
  384. prompt = inputs["prompt"]
  385. prompt_token_ids = inputs["prompt_token_ids"]
  386. self.request_logger.log_inputs(
  387. request_id,
  388. prompt,
  389. prompt_token_ids,
  390. params=params,
  391. lora_request=lora_request,
  392. prompt_adapter_request=prompt_adapter_request,
  393. )
  394. @staticmethod
  395. def _get_decoded_token(logprob: Logprob,
  396. token_id: int,
  397. tokenizer: AnyTokenizer,
  398. return_as_token_id: bool = False) -> str:
  399. if return_as_token_id:
  400. return f"token_id:{token_id}"
  401. if logprob.decoded_token is not None:
  402. return logprob.decoded_token
  403. return tokenizer.decode(token_id)