serving_completions.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485
  1. import asyncio
  2. import time
  3. from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
  4. Optional)
  5. from typing import Sequence as GenericSequence
  6. from typing import Tuple, Union, cast
  7. from fastapi import Request
  8. from aphrodite.common.config import ModelConfig
  9. from aphrodite.common.outputs import RequestOutput
  10. from aphrodite.common.sequence import Logprob
  11. from aphrodite.common.utils import merge_async_iterators, random_uuid
  12. from aphrodite.endpoints.logger import RequestLogger
  13. from aphrodite.endpoints.openai.protocol import (
  14. CompletionLogProbs, CompletionRequest, CompletionResponse,
  15. CompletionResponseChoice, CompletionResponseStreamChoice,
  16. CompletionStreamResponse, ErrorResponse, UsageInfo)
  17. from aphrodite.endpoints.openai.serving_engine import (LoRAModulePath,
  18. OpenAIServing,
  19. PromptAdapterPath)
  20. from aphrodite.engine.protocol import AsyncEngineClient
  21. from aphrodite.transformers_utils.tokenizer import AnyTokenizer
  22. TypeTokenIDs = List[int]
  23. TypeTopLogProbs = List[Optional[Dict[int, float]]]
  24. TypeCreateLogProbsFn = Callable[
  25. [TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs]
  26. class OpenAIServingCompletion(OpenAIServing):
  27. def __init__(
  28. self,
  29. async_engine_client: AsyncEngineClient,
  30. model_config: ModelConfig,
  31. served_model_names: List[str],
  32. *,
  33. lora_modules: Optional[List[LoRAModulePath]],
  34. prompt_adapters: Optional[List[PromptAdapterPath]],
  35. request_logger: Optional[RequestLogger],
  36. return_tokens_as_token_ids: bool = False,
  37. ):
  38. super().__init__(async_engine_client=async_engine_client,
  39. model_config=model_config,
  40. served_model_names=served_model_names,
  41. lora_modules=lora_modules,
  42. prompt_adapters=prompt_adapters,
  43. request_logger=request_logger,
  44. return_tokens_as_token_ids=return_tokens_as_token_ids)
  45. async def create_completion(
  46. self,
  47. request: CompletionRequest,
  48. raw_request: Request,
  49. ) -> Union[AsyncGenerator[str, None], CompletionResponse, ErrorResponse]:
  50. """Completion API similar to OpenAI's API.
  51. See https://platform.openai.com/docs/api-reference/completions/create
  52. for the API specification. This API mimics the OpenAI Completion API.
  53. NOTE: Currently we do not support the following feature:
  54. - suffix (the language models we currently support do not support
  55. suffix)
  56. """
  57. error_check_ret = await self._check_model(request)
  58. if error_check_ret is not None:
  59. return error_check_ret
  60. # Return error for unsupported features.
  61. if request.suffix is not None:
  62. return self.create_error_response(
  63. "suffix is not currently supported")
  64. model_name = self.served_model_names[0]
  65. request_id = f"cmpl-{random_uuid()}"
  66. created_time = int(time.time())
  67. # Schedule the request and get the result generator.
  68. generators: List[AsyncGenerator[RequestOutput, None]] = []
  69. try:
  70. (
  71. lora_request,
  72. prompt_adapter_request,
  73. ) = self._maybe_get_adapters(request)
  74. tokenizer = await self.async_engine_client.get_tokenizer(
  75. lora_request)
  76. guided_decode_logits_processor = (
  77. await self._guided_decode_logits_processor(request, tokenizer))
  78. prompts = list(
  79. self._tokenize_prompt_input_or_inputs(
  80. request,
  81. tokenizer,
  82. request.prompt,
  83. truncate_prompt_tokens=request.truncate_prompt_tokens,
  84. add_special_tokens=request.add_special_tokens,
  85. ))
  86. for i, prompt_inputs in enumerate(prompts):
  87. sampling_params = request.to_sampling_params(
  88. tokenizer,
  89. guided_decode_logits_processor,
  90. default_max_tokens=self.max_model_len -
  91. len(prompt_inputs["prompt_token_ids"]))
  92. request_id_item = f"{request_id}-{i}"
  93. self._log_inputs(request_id_item,
  94. prompt_inputs,
  95. params=sampling_params,
  96. lora_request=lora_request,
  97. prompt_adapter_request=prompt_adapter_request)
  98. generator = self.async_engine_client.generate(
  99. {"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
  100. sampling_params,
  101. request_id_item,
  102. lora_request=lora_request,
  103. prompt_adapter_request=prompt_adapter_request,
  104. )
  105. generators.append(generator)
  106. except ValueError as e:
  107. # TODO: Use an aphrodite-specific Validation Error
  108. return self.create_error_response(str(e))
  109. result_generator = merge_async_iterators(
  110. *generators, is_cancelled=raw_request.is_disconnected)
  111. # Similar to the OpenAI API, when n != best_of, we do not stream the
  112. # results. In addition, we do not stream the results when use
  113. # beam search.
  114. stream = (request.stream
  115. and (request.best_of is None or request.n == request.best_of)
  116. and not request.use_beam_search)
  117. # Streaming response
  118. if stream:
  119. return self.completion_stream_generator(request,
  120. result_generator,
  121. request_id,
  122. created_time,
  123. model_name,
  124. num_prompts=len(prompts),
  125. tokenizer=tokenizer)
  126. # Non-streaming response
  127. final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
  128. try:
  129. async for i, res in result_generator:
  130. final_res_batch[i] = res
  131. for i, final_res in enumerate(final_res_batch):
  132. assert final_res is not None
  133. # The output should contain the input text
  134. # We did not pass it into Aphrodite engine to avoid being
  135. # redundant with the inputs token IDs
  136. if final_res.prompt is None:
  137. final_res.prompt = prompts[i]["prompt"]
  138. final_res_batch_checked = cast(List[RequestOutput],
  139. final_res_batch)
  140. response = self.request_output_to_completion_response(
  141. final_res_batch_checked,
  142. request,
  143. request_id,
  144. created_time,
  145. model_name,
  146. tokenizer,
  147. )
  148. except asyncio.CancelledError:
  149. return self.create_error_response("Client disconnected")
  150. except ValueError as e:
  151. # TODO: Use an aphrodite-specific Validation Error
  152. return self.create_error_response(str(e))
  153. # When user requests streaming but we don't stream, we still need to
  154. # return a streaming response with a single event.
  155. if request.stream:
  156. response_json = response.model_dump_json()
  157. async def fake_stream_generator() -> AsyncGenerator[str, None]:
  158. yield f"data: {response_json}\n\n"
  159. yield "data: [DONE]\n\n"
  160. return fake_stream_generator()
  161. return response
  162. async def completion_stream_generator(
  163. self,
  164. request: CompletionRequest,
  165. result_generator: AsyncIterator[Tuple[int, RequestOutput]],
  166. request_id: str,
  167. created_time: int,
  168. model_name: str,
  169. num_prompts: int,
  170. tokenizer: AnyTokenizer,
  171. ) -> AsyncGenerator[str, None]:
  172. num_choices = 1 if request.n is None else request.n
  173. previous_text_lens = [0] * num_choices * num_prompts
  174. previous_num_tokens = [0] * num_choices * num_prompts
  175. has_echoed = [False] * num_choices * num_prompts
  176. num_prompt_tokens = [0] * num_prompts
  177. try:
  178. async for prompt_idx, res in result_generator:
  179. prompt_token_ids = res.prompt_token_ids
  180. prompt_logprobs = res.prompt_logprobs
  181. prompt_text = res.prompt
  182. # Prompt details are excluded from later streamed outputs
  183. if res.prompt_token_ids is not None:
  184. num_prompt_tokens[prompt_idx] = len(res.prompt_token_ids)
  185. delta_token_ids: GenericSequence[int]
  186. out_logprobs: Optional[GenericSequence[Optional[Dict[
  187. int, Logprob]]]]
  188. for output in res.outputs:
  189. i = output.index + prompt_idx * num_choices
  190. # TODO: optimize the performance by avoiding full
  191. # text O(n^2) sending.
  192. assert request.max_tokens is not None
  193. if request.echo and request.max_tokens == 0:
  194. assert prompt_token_ids is not None
  195. assert prompt_text is not None
  196. # only return the prompt
  197. delta_text = prompt_text
  198. delta_token_ids = prompt_token_ids
  199. out_logprobs = prompt_logprobs
  200. has_echoed[i] = True
  201. elif (request.echo and request.max_tokens > 0
  202. and not has_echoed[i]):
  203. assert prompt_token_ids is not None
  204. assert prompt_text is not None
  205. assert prompt_logprobs is not None
  206. # echo the prompt and first token
  207. delta_text = prompt_text + output.text
  208. delta_token_ids = [
  209. *prompt_token_ids, *output.token_ids
  210. ]
  211. out_logprobs = [
  212. *prompt_logprobs,
  213. *(output.logprobs or []),
  214. ]
  215. has_echoed[i] = True
  216. else:
  217. # return just the delta
  218. delta_text = output.text
  219. delta_token_ids = output.token_ids
  220. out_logprobs = output.logprobs
  221. if request.logprobs is not None:
  222. assert out_logprobs is not None, (
  223. "Did not output logprobs")
  224. logprobs = self._create_completion_logprobs(
  225. token_ids=delta_token_ids,
  226. top_logprobs=out_logprobs,
  227. num_output_top_logprobs=request.logprobs,
  228. tokenizer=tokenizer,
  229. initial_text_offset=previous_text_lens[i],
  230. )
  231. else:
  232. logprobs = None
  233. previous_text_lens[i] += len(output.text)
  234. previous_num_tokens[i] += len(output.token_ids)
  235. finish_reason = output.finish_reason
  236. stop_reason = output.stop_reason
  237. chunk = CompletionStreamResponse(
  238. id=request_id,
  239. created=created_time,
  240. model=model_name,
  241. choices=[
  242. CompletionResponseStreamChoice(
  243. index=i,
  244. text=delta_text,
  245. logprobs=logprobs,
  246. finish_reason=finish_reason,
  247. stop_reason=stop_reason,
  248. )
  249. ])
  250. if (request.stream_options
  251. and request.stream_options.include_usage):
  252. if (request.stream_options.continuous_usage_stats
  253. or output.finish_reason is not None):
  254. prompt_tokens = num_prompt_tokens[prompt_idx]
  255. completion_tokens = previous_num_tokens[i]
  256. usage = UsageInfo(
  257. prompt_tokens=prompt_tokens,
  258. completion_tokens=completion_tokens,
  259. total_tokens=prompt_tokens + completion_tokens,
  260. )
  261. if request.stream_options.continuous_usage_stats:
  262. chunk.usage = usage
  263. else:
  264. chunk.usage = None
  265. response_json = chunk.model_dump_json(exclude_unset=False)
  266. yield f"data: {response_json}\n\n"
  267. if (request.stream_options
  268. and request.stream_options.include_usage):
  269. final_usage_chunk = CompletionStreamResponse(
  270. id=request_id,
  271. created=created_time,
  272. model=model_name,
  273. choices=[],
  274. usage=usage,
  275. )
  276. final_usage_data = (final_usage_chunk.model_dump_json(
  277. exclude_unset=False, exclude_none=True))
  278. yield f"data: {final_usage_data}\n\n"
  279. except ValueError as e:
  280. # TODO: Use an aphrodite-specific Validation Error
  281. data = self.create_streaming_error_response(str(e))
  282. yield f"data: {data}\n\n"
  283. yield "data: [DONE]\n\n"
  284. def request_output_to_completion_response(
  285. self,
  286. final_res_batch: List[RequestOutput],
  287. request: CompletionRequest,
  288. request_id: str,
  289. created_time: int,
  290. model_name: str,
  291. tokenizer: AnyTokenizer,
  292. ) -> CompletionResponse:
  293. choices: List[CompletionResponseChoice] = []
  294. num_prompt_tokens = 0
  295. num_generated_tokens = 0
  296. for final_res in final_res_batch:
  297. prompt_token_ids = final_res.prompt_token_ids
  298. assert prompt_token_ids is not None
  299. prompt_logprobs = final_res.prompt_logprobs
  300. prompt_text = final_res.prompt
  301. token_ids: GenericSequence[int]
  302. out_logprobs: Optional[GenericSequence[Optional[Dict[int,
  303. Logprob]]]]
  304. for output in final_res.outputs:
  305. assert request.max_tokens is not None
  306. if request.echo and request.max_tokens == 0:
  307. assert prompt_text is not None
  308. token_ids = prompt_token_ids
  309. out_logprobs = prompt_logprobs
  310. output_text = prompt_text
  311. elif request.echo and request.max_tokens > 0:
  312. assert prompt_text is not None
  313. token_ids = [*prompt_token_ids, *output.token_ids]
  314. if request.logprobs is None:
  315. out_logprobs = None
  316. else:
  317. assert prompt_logprobs is not None
  318. assert output.logprobs is not None
  319. out_logprobs = [
  320. *prompt_logprobs,
  321. *output.logprobs,
  322. ]
  323. output_text = prompt_text + output.text
  324. else:
  325. token_ids = output.token_ids
  326. out_logprobs = output.logprobs
  327. output_text = output.text
  328. if request.logprobs is not None:
  329. assert out_logprobs is not None, "Did not output logprobs"
  330. logprobs = self._create_completion_logprobs(
  331. token_ids=token_ids,
  332. top_logprobs=out_logprobs,
  333. tokenizer=tokenizer,
  334. num_output_top_logprobs=request.logprobs,
  335. )
  336. else:
  337. logprobs = None
  338. choice_data = CompletionResponseChoice(
  339. index=len(choices),
  340. text=output_text,
  341. logprobs=logprobs,
  342. finish_reason=output.finish_reason,
  343. stop_reason=output.stop_reason,
  344. prompt_logprobs=final_res.prompt_logprobs,
  345. )
  346. choices.append(choice_data)
  347. num_generated_tokens += len(output.token_ids)
  348. num_prompt_tokens += len(prompt_token_ids)
  349. usage = UsageInfo(
  350. prompt_tokens=num_prompt_tokens,
  351. completion_tokens=num_generated_tokens,
  352. total_tokens=num_prompt_tokens + num_generated_tokens,
  353. )
  354. return CompletionResponse(
  355. id=request_id,
  356. created=created_time,
  357. model=model_name,
  358. choices=choices,
  359. usage=usage,
  360. )
  361. def _create_completion_logprobs(
  362. self,
  363. token_ids: GenericSequence[int],
  364. top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
  365. num_output_top_logprobs: int,
  366. tokenizer: AnyTokenizer,
  367. initial_text_offset: int = 0,
  368. ) -> CompletionLogProbs:
  369. """Create logprobs for OpenAI Completion API."""
  370. out_text_offset: List[int] = []
  371. out_token_logprobs: List[Optional[float]] = []
  372. out_tokens: List[str] = []
  373. out_top_logprobs: List[Optional[Dict[str, float]]] = []
  374. last_token_len = 0
  375. for i, token_id in enumerate(token_ids):
  376. step_top_logprobs = top_logprobs[i]
  377. if step_top_logprobs is None:
  378. token = tokenizer.decode(token_id)
  379. if self.return_tokens_as_token_ids:
  380. token = f"token_id:{token_id}"
  381. out_tokens.append(token)
  382. out_token_logprobs.append(None)
  383. out_top_logprobs.append(None)
  384. else:
  385. step_token = step_top_logprobs[token_id]
  386. token = self._get_decoded_token(
  387. step_token,
  388. token_id,
  389. tokenizer,
  390. return_as_token_id=self.return_tokens_as_token_ids,
  391. )
  392. token_logprob = max(step_token.logprob, -9999.0)
  393. out_tokens.append(token)
  394. out_token_logprobs.append(token_logprob)
  395. # makes sure to add the top num_output_top_logprobs + 1
  396. # logprobs, as defined in the openai API
  397. # (cf. https://github.com/openai/openai-openapi/blob/
  398. # 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
  399. out_top_logprobs.append({
  400. # Convert float("-inf") to the
  401. # JSON-serializable float that OpenAI uses
  402. self._get_decoded_token(
  403. top_lp[1],
  404. top_lp[0],
  405. tokenizer,
  406. return_as_token_id=self.return_tokens_as_token_ids):
  407. max(top_lp[1].logprob, -9999.0)
  408. for i, top_lp in enumerate(step_top_logprobs.items())
  409. if num_output_top_logprobs >= i
  410. })
  411. if len(out_text_offset) == 0:
  412. out_text_offset.append(initial_text_offset)
  413. else:
  414. out_text_offset.append(out_text_offset[-1] + last_token_len)
  415. last_token_len = len(token)
  416. return CompletionLogProbs(
  417. text_offset=out_text_offset,
  418. token_logprobs=out_token_logprobs,
  419. tokens=out_tokens,
  420. top_logprobs=out_top_logprobs,
  421. )