serving_completions.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. import asyncio
  2. import time
  3. from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
  4. Optional)
  5. from typing import Sequence as GenericSequence
  6. from typing import Tuple, cast
  7. from fastapi import Request
  8. from transformers import PreTrainedTokenizer
  9. from aphrodite.common.config import ModelConfig
  10. from aphrodite.common.outputs import RequestOutput
  11. from aphrodite.common.sequence import Logprob
  12. from aphrodite.common.utils import merge_async_iterators, random_uuid
  13. from aphrodite.endpoints.logger import RequestLogger
  14. # yapf conflicts with isort for this block
  15. # yapf: disable
  16. from aphrodite.endpoints.openai.protocol import (
  17. CompletionLogProbs, CompletionRequest, CompletionResponse,
  18. CompletionResponseChoice, CompletionResponseStreamChoice,
  19. CompletionStreamResponse, UsageInfo)
  20. # yapf: enable
  21. from aphrodite.endpoints.openai.serving_engine import (LoRAModulePath,
  22. OpenAIServing,
  23. PromptAdapterPath)
  24. from aphrodite.engine.protocol import AsyncEngineClient
  25. TypeTokenIDs = List[int]
  26. TypeTopLogProbs = List[Optional[Dict[int, float]]]
  27. TypeCreateLogProbsFn = Callable[
  28. [TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs]
  29. class OpenAIServingCompletion(OpenAIServing):
  30. def __init__(
  31. self,
  32. async_engine_client: AsyncEngineClient,
  33. model_config: ModelConfig,
  34. served_model_names: List[str],
  35. *,
  36. lora_modules: Optional[List[LoRAModulePath]],
  37. prompt_adapters: Optional[List[PromptAdapterPath]],
  38. request_logger: Optional[RequestLogger],
  39. return_tokens_as_token_ids: bool = False,
  40. ):
  41. super().__init__(async_engine_client=async_engine_client,
  42. model_config=model_config,
  43. served_model_names=served_model_names,
  44. lora_modules=lora_modules,
  45. prompt_adapters=prompt_adapters,
  46. request_logger=request_logger,
  47. return_tokens_as_token_ids=return_tokens_as_token_ids)
  48. async def create_completion(self, request: CompletionRequest,
  49. raw_request: Request):
  50. """Completion API similar to OpenAI's API.
  51. See https://platform.openai.com/docs/api-reference/completions/create
  52. for the API specification. This API mimics the OpenAI Completion API.
  53. NOTE: Currently we do not support the following feature:
  54. - suffix (the language models we currently support do not support
  55. suffix)
  56. """
  57. error_check_ret = await self._check_model(request)
  58. if error_check_ret is not None:
  59. return error_check_ret
  60. # Return error for unsupported features.
  61. if request.suffix is not None:
  62. return self.create_error_response(
  63. "suffix is not currently supported")
  64. model_name = self.served_model_names[0]
  65. request_id = f"cmpl-{random_uuid()}"
  66. created_time = int(time.time())
  67. if request.prompt_logprobs is not None:
  68. if request.stream and request.prompt_logprobs > 0:
  69. return self.create_error_response(
  70. "Prompt_logprobs are not available when stream is enabled")
  71. elif request.prompt_logprobs < 0:
  72. return self.create_error_response(
  73. f"Prompt_logprobs set to invalid negative "
  74. f"value: {request.prompt_logprobs}")
  75. # Schedule the request and get the result generator.
  76. generators: List[AsyncGenerator[RequestOutput, None]] = []
  77. try:
  78. (
  79. lora_request,
  80. prompt_adapter_request,
  81. ) = self._maybe_get_adapters(request)
  82. tokenizer = await self.async_engine_client.get_tokenizer(
  83. lora_request)
  84. guided_decode_logits_processor = (
  85. await self._guided_decode_logits_processor(request, tokenizer))
  86. prompts = list(
  87. self._tokenize_prompt_input_or_inputs(
  88. request,
  89. tokenizer,
  90. request.prompt,
  91. truncate_prompt_tokens=request.truncate_prompt_tokens,
  92. add_special_tokens=request.add_special_tokens,
  93. ))
  94. for i, prompt_inputs in enumerate(prompts):
  95. sampling_params = request.to_sampling_params(
  96. tokenizer,
  97. guided_decode_logits_processor,
  98. default_max_tokens=self.max_model_len -
  99. len(prompt_inputs["prompt_token_ids"]))
  100. request_id_item = f"{request_id}-{i}"
  101. self._log_inputs(request_id_item,
  102. prompt_inputs,
  103. params=sampling_params,
  104. lora_request=lora_request,
  105. prompt_adapter_request=prompt_adapter_request)
  106. generator = self.async_engine_client.generate(
  107. {"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
  108. sampling_params,
  109. request_id_item,
  110. lora_request=lora_request,
  111. prompt_adapter_request=prompt_adapter_request,
  112. )
  113. generators.append(generator)
  114. except ValueError as e:
  115. # TODO: Use an aphrodite-specific Validation Error
  116. return self.create_error_response(str(e))
  117. result_generator: AsyncIterator[Tuple[
  118. int, RequestOutput]] = merge_async_iterators(
  119. *generators, is_cancelled=raw_request.is_disconnected)
  120. # Similar to the OpenAI API, when n != best_of, we do not stream the
  121. # results. In addition, we do not stream the results when use
  122. # beam search.
  123. stream = (request.stream
  124. and (request.best_of is None or request.n == request.best_of)
  125. and not request.use_beam_search)
  126. # Streaming response
  127. if stream:
  128. return self.completion_stream_generator(request,
  129. result_generator,
  130. request_id,
  131. created_time,
  132. model_name,
  133. num_prompts=len(prompts),
  134. tokenizer=tokenizer)
  135. # Non-streaming response
  136. final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
  137. try:
  138. async for i, res in result_generator:
  139. final_res_batch[i] = res
  140. for i, final_res in enumerate(final_res_batch):
  141. assert final_res is not None
  142. # The output should contain the input text
  143. # We did not pass it into Aphrodite engine to avoid being
  144. # redundant with the inputs token IDs
  145. if final_res.prompt is None:
  146. final_res.prompt = prompts[i]["prompt"]
  147. final_res_batch_checked = cast(List[RequestOutput],
  148. final_res_batch)
  149. response = self.request_output_to_completion_response(
  150. final_res_batch_checked,
  151. request,
  152. request_id,
  153. created_time,
  154. model_name,
  155. tokenizer,
  156. )
  157. except asyncio.CancelledError:
  158. return self.create_error_response("Client disconnected")
  159. except ValueError as e:
  160. # TODO: Use an aphrodite-specific Validation Error
  161. return self.create_error_response(str(e))
  162. # When user requests streaming but we don't stream, we still need to
  163. # return a streaming response with a single event.
  164. if request.stream:
  165. response_json = response.model_dump_json()
  166. async def fake_stream_generator() -> AsyncGenerator[str, None]:
  167. yield f"data: {response_json}\n\n"
  168. yield "data: [DONE]\n\n"
  169. return fake_stream_generator()
  170. return response
  171. async def completion_stream_generator(
  172. self,
  173. request: CompletionRequest,
  174. result_generator: AsyncIterator[Tuple[int, RequestOutput]],
  175. request_id: str,
  176. created_time: int,
  177. model_name: str,
  178. num_prompts: int,
  179. tokenizer: PreTrainedTokenizer,
  180. ) -> AsyncGenerator[str, None]:
  181. num_choices = 1 if request.n is None else request.n
  182. previous_texts = [""] * num_choices * num_prompts
  183. previous_num_tokens = [0] * num_choices * num_prompts
  184. has_echoed = [False] * num_choices * num_prompts
  185. try:
  186. async for prompt_idx, res in result_generator:
  187. for output in res.outputs:
  188. i = output.index + prompt_idx * num_choices
  189. # TODO: optimize the performance by avoiding full
  190. # text O(n^2) sending.
  191. assert request.max_tokens is not None
  192. if request.echo and request.max_tokens == 0:
  193. # only return the prompt
  194. delta_text = res.prompt
  195. delta_token_ids = res.prompt_token_ids
  196. out_logprobs = res.prompt_logprobs
  197. has_echoed[i] = True
  198. elif (request.echo and request.max_tokens > 0
  199. and not has_echoed[i]):
  200. # echo the prompt and first token
  201. delta_text = res.prompt + output.text
  202. delta_token_ids = (res.prompt_token_ids +
  203. output.token_ids)
  204. out_logprobs = res.prompt_logprobs + (output.logprobs
  205. or [])
  206. has_echoed[i] = True
  207. else:
  208. # return just the delta
  209. delta_text = output.text[len(previous_texts[i]):]
  210. delta_token_ids = output.token_ids[
  211. previous_num_tokens[i]:]
  212. out_logprobs = output.logprobs[previous_num_tokens[
  213. i]:] if output.logprobs else None
  214. if request.logprobs is not None:
  215. assert out_logprobs is not None, (
  216. "Did not output logprobs")
  217. logprobs = self._create_completion_logprobs(
  218. token_ids=delta_token_ids,
  219. top_logprobs=out_logprobs,
  220. num_output_top_logprobs=request.logprobs,
  221. tokenizer=tokenizer,
  222. initial_text_offset=len(previous_texts[i]),
  223. )
  224. else:
  225. logprobs = None
  226. previous_texts[i] = output.text
  227. previous_num_tokens[i] = len(output.token_ids)
  228. finish_reason = output.finish_reason
  229. stop_reason = output.stop_reason
  230. chunk = CompletionStreamResponse(
  231. id=request_id,
  232. created=created_time,
  233. model=model_name,
  234. choices=[
  235. CompletionResponseStreamChoice(
  236. index=i,
  237. text=delta_text,
  238. logprobs=logprobs,
  239. finish_reason=finish_reason,
  240. stop_reason=stop_reason,
  241. )
  242. ])
  243. if (request.stream_options
  244. and request.stream_options.include_usage):
  245. if (request.stream_options.continuous_usage_stats
  246. or output.finish_reason is not None):
  247. prompt_tokens = len(res.prompt_token_ids)
  248. completion_tokens = len(output.token_ids)
  249. usage = UsageInfo(
  250. prompt_tokens=prompt_tokens,
  251. completion_tokens=completion_tokens,
  252. total_tokens=prompt_tokens + completion_tokens,
  253. )
  254. if request.stream_options.continuous_usage_stats:
  255. chunk.usage = usage
  256. else:
  257. chunk.usage = None
  258. response_json = chunk.model_dump_json(exclude_unset=False)
  259. yield f"data: {response_json}\n\n"
  260. if (request.stream_options
  261. and request.stream_options.include_usage):
  262. final_usage_chunk = CompletionStreamResponse(
  263. id=request_id,
  264. created=created_time,
  265. model=model_name,
  266. choices=[],
  267. usage=usage,
  268. )
  269. final_usage_data = (final_usage_chunk.model_dump_json(
  270. exclude_unset=False, exclude_none=True))
  271. yield f"data: {final_usage_data}\n\n"
  272. except ValueError as e:
  273. # TODO: Use an aphrodite-specific Validation Error
  274. data = self.create_streaming_error_response(str(e))
  275. yield f"data: {data}\n\n"
  276. yield "data: [DONE]\n\n"
  277. def request_output_to_completion_response(
  278. self,
  279. final_res_batch: List[RequestOutput],
  280. request: CompletionRequest,
  281. request_id: str,
  282. created_time: int,
  283. model_name: str,
  284. tokenizer: PreTrainedTokenizer,
  285. ) -> CompletionResponse:
  286. choices: List[CompletionResponseChoice] = []
  287. num_prompt_tokens = 0
  288. num_generated_tokens = 0
  289. for final_res in final_res_batch:
  290. prompt_token_ids = final_res.prompt_token_ids
  291. prompt_logprobs = final_res.prompt_logprobs
  292. prompt_text = final_res.prompt
  293. for output in final_res.outputs:
  294. assert request.max_tokens is not None
  295. if request.echo and request.max_tokens == 0:
  296. token_ids = prompt_token_ids
  297. out_logprobs = prompt_logprobs
  298. output_text = prompt_text
  299. elif request.echo and request.max_tokens > 0:
  300. token_ids = prompt_token_ids + list(output.token_ids)
  301. out_logprobs = (prompt_logprobs + output.logprobs
  302. if request.logprobs is not None else None)
  303. output_text = prompt_text + output.text
  304. else:
  305. token_ids = output.token_ids
  306. out_logprobs = output.logprobs
  307. output_text = output.text
  308. if request.logprobs is not None:
  309. assert out_logprobs is not None, "Did not output logprobs"
  310. logprobs = self._create_completion_logprobs(
  311. token_ids=token_ids,
  312. top_logprobs=out_logprobs,
  313. tokenizer=tokenizer,
  314. num_output_top_logprobs=request.logprobs,
  315. )
  316. else:
  317. logprobs = None
  318. choice_data = CompletionResponseChoice(
  319. index=len(choices),
  320. text=output_text,
  321. logprobs=logprobs,
  322. finish_reason=output.finish_reason,
  323. stop_reason=output.stop_reason,
  324. prompt_logprobs=final_res.prompt_logprobs,
  325. )
  326. choices.append(choice_data)
  327. num_prompt_tokens += len(prompt_token_ids)
  328. num_generated_tokens += sum(
  329. len(output.token_ids) for output in final_res.outputs)
  330. usage = UsageInfo(
  331. prompt_tokens=num_prompt_tokens,
  332. completion_tokens=num_generated_tokens,
  333. total_tokens=num_prompt_tokens + num_generated_tokens,
  334. )
  335. return CompletionResponse(
  336. id=request_id,
  337. created=created_time,
  338. model=model_name,
  339. choices=choices,
  340. usage=usage,
  341. )
  342. def _create_completion_logprobs(
  343. self,
  344. token_ids: GenericSequence[int],
  345. top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
  346. num_output_top_logprobs: int,
  347. tokenizer: PreTrainedTokenizer,
  348. initial_text_offset: int = 0,
  349. ) -> CompletionLogProbs:
  350. """Create logprobs for OpenAI Completion API."""
  351. out_text_offset: List[int] = []
  352. out_token_logprobs: List[Optional[float]] = []
  353. out_tokens: List[str] = []
  354. out_top_logprobs: List[Optional[Dict[str, float]]] = []
  355. last_token_len = 0
  356. for i, token_id in enumerate(token_ids):
  357. step_top_logprobs = top_logprobs[i]
  358. if step_top_logprobs is None:
  359. token = tokenizer.decode(token_id)
  360. if self.return_tokens_as_token_ids:
  361. token = f"token_id:{token_id}"
  362. out_tokens.append(token)
  363. out_token_logprobs.append(None)
  364. out_top_logprobs.append(None)
  365. else:
  366. token = self._get_decoded_token(
  367. step_top_logprobs[token_id],
  368. token_id,
  369. tokenizer,
  370. return_as_token_id=self.return_tokens_as_token_ids)
  371. token_logprob = max(step_top_logprobs[token_id].logprob,
  372. -9999.0)
  373. out_tokens.append(token)
  374. out_token_logprobs.append(token_logprob)
  375. # makes sure to add the top num_output_top_logprobs + 1
  376. # logprobs, as defined in the openai API
  377. # (cf. https://github.com/openai/openai-openapi/blob/
  378. # 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
  379. out_top_logprobs.append({
  380. # Convert float("-inf") to the
  381. # JSON-serializable float that OpenAI uses
  382. self._get_decoded_token(
  383. top_lp[1],
  384. top_lp[0],
  385. tokenizer,
  386. return_as_token_id=self.return_tokens_as_token_ids):
  387. max(top_lp[1].logprob, -9999.0)
  388. for i, top_lp in enumerate(step_top_logprobs.items())
  389. if num_output_top_logprobs >= i
  390. })
  391. if len(out_text_offset) == 0:
  392. out_text_offset.append(initial_text_offset)
  393. else:
  394. out_text_offset.append(out_text_offset[-1] + last_token_len)
  395. last_token_len = len(token)
  396. return CompletionLogProbs(
  397. text_offset=out_text_offset,
  398. token_logprobs=out_token_logprobs,
  399. tokens=out_tokens,
  400. top_logprobs=out_top_logprobs,
  401. )