serving_engine.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. import json
  2. import pathlib
  3. from dataclasses import dataclass
  4. from http import HTTPStatus
  5. from typing import Iterable, Iterator, List, Optional, Tuple, TypedDict, Union
  6. from pydantic import Field
  7. from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
  8. from typing_extensions import Annotated
  9. from aphrodite.common.config import ModelConfig
  10. from aphrodite.common.pooling_params import PoolingParams
  11. from aphrodite.common.sampling_params import (LogitsProcessorFunc,
  12. SamplingParams)
  13. from aphrodite.common.sequence import Logprob
  14. from aphrodite.endpoints.logger import RequestLogger
  15. # yapf conflicts with isort here
  16. # yapf: disable
  17. from aphrodite.endpoints.openai.protocol import (ChatCompletionRequest,
  18. CompletionRequest,
  19. DetokenizeRequest,
  20. EmbeddingRequest,
  21. ErrorResponse, ModelCard,
  22. ModelList, ModelPermission,
  23. TokenizeChatRequest,
  24. TokenizeCompletionRequest,
  25. TokenizeRequest)
  26. # yapf: enable
  27. from aphrodite.engine.protocol import AsyncEngineClient
  28. from aphrodite.inputs import parse_and_batch_prompt
  29. from aphrodite.lora.request import LoRARequest
  30. from aphrodite.modeling.guided_decoding import (
  31. get_guided_decoding_logits_processor)
  32. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  33. @dataclass
  34. class PromptAdapterPath:
  35. name: str
  36. local_path: str
  37. @dataclass
  38. class LoRAModulePath:
  39. name: str
  40. path: str
  41. AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest,
  42. EmbeddingRequest, TokenizeRequest]
  43. AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
  44. class TextTokensPrompt(TypedDict):
  45. prompt: str
  46. prompt_token_ids: List[int]
  47. class OpenAIServing:
  48. def __init__(
  49. self,
  50. async_engine_client: AsyncEngineClient,
  51. model_config: ModelConfig,
  52. served_model_names: List[str],
  53. *,
  54. lora_modules: Optional[List[LoRAModulePath]],
  55. prompt_adapters: Optional[List[PromptAdapterPath]],
  56. request_logger: Optional[RequestLogger],
  57. return_tokens_as_token_ids: bool = False,
  58. ):
  59. super().__init__()
  60. self.async_engine_client = async_engine_client
  61. self.model_config = model_config
  62. self.max_model_len = model_config.max_model_len
  63. self.served_model_names = served_model_names
  64. self.lora_requests = []
  65. if lora_modules is not None:
  66. self.lora_requests = [
  67. LoRARequest(
  68. lora_name=lora.name,
  69. lora_int_id=i,
  70. lora_path=lora.path,
  71. ) for i, lora in enumerate(lora_modules, start=1)
  72. ]
  73. self.prompt_adapter_requests = []
  74. if prompt_adapters is not None:
  75. for i, prompt_adapter in enumerate(prompt_adapters, start=1):
  76. with pathlib.Path(prompt_adapter.local_path,
  77. "adapter_config.json").open() as f:
  78. adapter_config = json.load(f)
  79. num_virtual_tokens = adapter_config["num_virtual_tokens"]
  80. self.prompt_adapter_requests.append(
  81. PromptAdapterRequest(
  82. prompt_adapter_name=prompt_adapter.name,
  83. prompt_adapter_id=i,
  84. prompt_adapter_local_path=prompt_adapter.local_path,
  85. prompt_adapter_num_virtual_tokens=num_virtual_tokens))
  86. self.request_logger = request_logger
  87. self.return_tokens_as_token_ids = return_tokens_as_token_ids
  88. async def show_available_models(self) -> ModelList:
  89. """Show available models. Right now we only have one model."""
  90. model_cards = [
  91. ModelCard(id=served_model_name,
  92. max_model_len=self.max_model_len,
  93. root=self.served_model_names[0],
  94. permission=[ModelPermission()])
  95. for served_model_name in self.served_model_names
  96. ]
  97. lora_cards = [
  98. ModelCard(id=lora.lora_name,
  99. root=self.served_model_names[0],
  100. permission=[ModelPermission()])
  101. for lora in self.lora_requests
  102. ]
  103. prompt_adapter_cards = [
  104. ModelCard(id=prompt_adapter.prompt_adapter_name,
  105. root=self.served_model_names[0],
  106. permission=[ModelPermission()])
  107. for prompt_adapter in self.prompt_adapter_requests
  108. ]
  109. model_cards.extend(lora_cards)
  110. model_cards.extend(prompt_adapter_cards)
  111. return ModelList(data=model_cards)
  112. def create_error_response(
  113. self,
  114. message: str,
  115. err_type: str = "BadRequestError",
  116. status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
  117. return ErrorResponse(message=message,
  118. type=err_type,
  119. code=status_code.value)
  120. def create_streaming_error_response(
  121. self,
  122. message: str,
  123. err_type: str = "BadRequestError",
  124. status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
  125. json_str = json.dumps({
  126. "error":
  127. self.create_error_response(message=message,
  128. err_type=err_type,
  129. status_code=status_code).model_dump()
  130. })
  131. return json_str
  132. async def _guided_decode_logits_processor(
  133. self, request: Union[ChatCompletionRequest, CompletionRequest],
  134. tokenizer: AnyTokenizer) -> Optional[LogitsProcessorFunc]:
  135. decoding_config = await self.async_engine_client.get_decoding_config()
  136. guided_decoding_backend = request.guided_decoding_backend \
  137. or decoding_config.guided_decoding_backend
  138. return await get_guided_decoding_logits_processor(
  139. guided_decoding_backend, request, tokenizer)
  140. async def _check_model(
  141. self,
  142. request: AnyRequest,
  143. ) -> Optional[ErrorResponse]:
  144. # only check these if it's not a Tokenizer/Detokenize Request
  145. if not isinstance(request, (TokenizeRequest, DetokenizeRequest)):
  146. if request.model in self.served_model_names:
  147. return None
  148. if request.model in [
  149. lora.lora_name for lora in self.lora_requests
  150. ]:
  151. return None
  152. if request.model in [
  153. prompt_adapter.prompt_adapter_name
  154. for prompt_adapter in self.prompt_adapter_requests
  155. ]:
  156. return None
  157. return self.create_error_response(
  158. message=f"The model `{request.model}` does not exist.",
  159. err_type="NotFoundError",
  160. status_code=HTTPStatus.NOT_FOUND)
  161. def _maybe_get_adapters(
  162. self, request: AnyRequest
  163. ) -> Union[Tuple[None, None], Tuple[LoRARequest, None], Tuple[
  164. None, PromptAdapterRequest]]:
  165. if request.model in self.served_model_names:
  166. return None, None
  167. for lora in self.lora_requests:
  168. if request.model == lora.lora_name:
  169. return lora, None
  170. for prompt_adapter in self.prompt_adapter_requests:
  171. if request.model == prompt_adapter.prompt_adapter_name:
  172. return None, prompt_adapter
  173. # if _check_model has been called earlier, this will be unreachable
  174. raise ValueError(f"The model `{request.model}` does not exist.")
  175. def _normalize_prompt_text_to_input(
  176. self,
  177. request: AnyRequest,
  178. tokenizer: AnyTokenizer,
  179. prompt: str,
  180. truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
  181. add_special_tokens: bool,
  182. ) -> TextTokensPrompt:
  183. if truncate_prompt_tokens is None:
  184. encoded = tokenizer(prompt, add_special_tokens=add_special_tokens)
  185. else:
  186. encoded = tokenizer(prompt,
  187. add_special_tokens=add_special_tokens,
  188. truncation=True,
  189. max_length=truncate_prompt_tokens)
  190. input_ids = encoded.input_ids
  191. input_text = prompt
  192. return self._validate_input(request, input_ids, input_text)
  193. def _normalize_prompt_tokens_to_input(
  194. self,
  195. request: AnyRequest,
  196. tokenizer: AnyTokenizer,
  197. prompt_ids: List[int],
  198. truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
  199. ) -> TextTokensPrompt:
  200. if truncate_prompt_tokens is None:
  201. input_ids = prompt_ids
  202. else:
  203. input_ids = prompt_ids[-truncate_prompt_tokens:]
  204. input_text = tokenizer.decode(input_ids)
  205. return self._validate_input(request, input_ids, input_text)
  206. def _validate_input(
  207. self,
  208. request: AnyRequest,
  209. input_ids: List[int],
  210. input_text: str,
  211. ) -> TextTokensPrompt:
  212. token_num = len(input_ids)
  213. # Note: EmbeddingRequest doesn't have max_tokens
  214. if isinstance(request, EmbeddingRequest):
  215. if token_num > self.max_model_len:
  216. raise ValueError(
  217. f"This model's maximum context length is "
  218. f"{self.max_model_len} tokens. However, you requested "
  219. f"{token_num} tokens in the input for embedding "
  220. f"generation. Please reduce the length of the input.")
  221. return TextTokensPrompt(prompt=input_text,
  222. prompt_token_ids=input_ids)
  223. # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
  224. # and does not require model context length validation
  225. if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest,
  226. DetokenizeRequest)):
  227. return TextTokensPrompt(prompt=input_text,
  228. prompt_token_ids=input_ids)
  229. if request.max_tokens is None:
  230. if token_num >= self.max_model_len:
  231. raise ValueError(
  232. f"This model's maximum context length is "
  233. f"{self.max_model_len} tokens. However, you requested "
  234. f"{token_num} tokens in the messages, "
  235. f"Please reduce the length of the messages.")
  236. elif token_num + request.max_tokens > self.max_model_len:
  237. raise ValueError(
  238. f"This model's maximum context length is "
  239. f"{self.max_model_len} tokens. However, you requested "
  240. f"{request.max_tokens + token_num} tokens "
  241. f"({token_num} in the messages, "
  242. f"{request.max_tokens} in the completion). "
  243. f"Please reduce the length of the messages or completion.")
  244. return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
  245. def _tokenize_prompt_input(
  246. self,
  247. request: AnyRequest,
  248. tokenizer: AnyTokenizer,
  249. prompt_input: Union[str, List[int]],
  250. truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
  251. add_special_tokens: bool = True,
  252. ) -> TextTokensPrompt:
  253. """
  254. A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
  255. that assumes single input.
  256. """
  257. return next(
  258. self._tokenize_prompt_inputs(
  259. request,
  260. tokenizer,
  261. [prompt_input],
  262. truncate_prompt_tokens=truncate_prompt_tokens,
  263. add_special_tokens=add_special_tokens,
  264. ))
  265. def _tokenize_prompt_inputs(
  266. self,
  267. request: AnyRequest,
  268. tokenizer: AnyTokenizer,
  269. prompt_inputs: Iterable[Union[str, List[int]]],
  270. truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
  271. add_special_tokens: bool = True,
  272. ) -> Iterator[TextTokensPrompt]:
  273. """
  274. A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
  275. that assumes multiple inputs.
  276. """
  277. for text in prompt_inputs:
  278. if isinstance(text, str):
  279. yield self._normalize_prompt_text_to_input(
  280. request,
  281. tokenizer,
  282. prompt=text,
  283. truncate_prompt_tokens=truncate_prompt_tokens,
  284. add_special_tokens=add_special_tokens,
  285. )
  286. else:
  287. yield self._normalize_prompt_tokens_to_input(
  288. request,
  289. tokenizer,
  290. prompt_ids=text,
  291. truncate_prompt_tokens=truncate_prompt_tokens,
  292. )
  293. def _tokenize_prompt_input_or_inputs(
  294. self,
  295. request: AnyRequest,
  296. tokenizer: AnyTokenizer,
  297. input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
  298. truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
  299. add_special_tokens: bool = True,
  300. ) -> Iterator[TextTokensPrompt]:
  301. """
  302. Tokenize/detokenize depending on the input format.
  303. According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_
  304. , each input can be a string or array of tokens. Note that each request
  305. can pass one or more inputs.
  306. """
  307. for prompt_input in parse_and_batch_prompt(input_or_inputs):
  308. # Although our type checking is based on mypy,
  309. # VSCode Pyright extension should still work properly
  310. # "is True" is required for Pyright to perform type narrowing
  311. # See: https://github.com/microsoft/pyright/issues/7672
  312. if prompt_input["is_tokens"] is False:
  313. yield self._normalize_prompt_text_to_input(
  314. request,
  315. tokenizer,
  316. prompt=prompt_input["content"],
  317. truncate_prompt_tokens=truncate_prompt_tokens,
  318. add_special_tokens=add_special_tokens,
  319. )
  320. else:
  321. yield self._normalize_prompt_tokens_to_input(
  322. request,
  323. tokenizer,
  324. prompt_ids=prompt_input["content"],
  325. truncate_prompt_tokens=truncate_prompt_tokens,
  326. )
  327. def _log_inputs(
  328. self,
  329. request_id: str,
  330. inputs: Union[str, List[int], TextTokensPrompt],
  331. params: Optional[Union[SamplingParams, PoolingParams]],
  332. lora_request: Optional[LoRARequest],
  333. prompt_adapter_request: Optional[PromptAdapterRequest],
  334. ) -> None:
  335. if self.request_logger is None:
  336. return
  337. if isinstance(inputs, str):
  338. prompt = inputs
  339. prompt_token_ids = None
  340. elif isinstance(inputs, list):
  341. prompt = None
  342. prompt_token_ids = inputs
  343. else:
  344. prompt = inputs["prompt"]
  345. prompt_token_ids = inputs["prompt_token_ids"]
  346. self.request_logger.log_inputs(
  347. request_id,
  348. prompt,
  349. prompt_token_ids,
  350. params=params,
  351. lora_request=lora_request,
  352. prompt_adapter_request=prompt_adapter_request,
  353. )
  354. @staticmethod
  355. def _get_decoded_token(logprob: Logprob,
  356. token_id: int,
  357. tokenizer: AnyTokenizer,
  358. return_as_token_id: bool = False) -> str:
  359. if return_as_token_id:
  360. return f"token_id:{token_id}"
  361. if logprob.decoded_token is not None:
  362. return logprob.decoded_token
  363. return tokenizer.decode(token_id)