serving_completions.py 14 KB


  1. import time
  2. from typing import (
  3. AsyncGenerator,
  4. AsyncIterator,
  5. Callable,
  6. Dict,
  7. List,
  8. Optional,
  9. Tuple,
  10. )
  11. from fastapi import Request
  12. from aphrodite.common.outputs import RequestOutput
  13. from aphrodite.common.utils import merge_async_iterators, random_uuid
  14. from aphrodite.endpoints.openai.protocol import (
  15. CompletionRequest,
  16. CompletionResponse,
  17. CompletionResponseChoice,
  18. CompletionResponseStreamChoice,
  19. CompletionStreamResponse,
  20. LogProbs,
  21. UsageInfo,
  22. )
  23. from aphrodite.endpoints.openai.serving_engine import LoRA, OpenAIServing
  24. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  25. from aphrodite.modeling.guided_decoding import (
  26. get_guided_decoding_logits_processor)
  27. TypeTokenIDs = List[int]
  28. TypeTopLogProbs = List[Optional[Dict[int, float]]]
  29. TypeCreateLogProbsFn = Callable[
  30. [TypeTokenIDs, TypeTopLogProbs, Optional[int], int], LogProbs]
  31. def parse_prompt_format(prompt) -> Tuple[bool, list]:
  32. # get the prompt, openai supports the following
  33. # "a string, array of strings, array of tokens, or array of token arrays."
  34. prompt_is_tokens = False
  35. prompts = [prompt] # case 1: a string
  36. if isinstance(prompt, list):
  37. if len(prompt) == 0:
  38. raise ValueError("please provide at least one prompt")
  39. elif isinstance(prompt[0], str):
  40. prompt_is_tokens = False
  41. prompts = prompt # case 2: array of strings
  42. elif isinstance(prompt[0], int):
  43. prompt_is_tokens = True
  44. prompts = [prompt] # case 3: array of tokens
  45. elif isinstance(prompt[0], list) and isinstance(prompt[0][0], int):
  46. prompt_is_tokens = True
  47. prompts = prompt # case 4: array of token arrays
  48. else:
  49. raise ValueError("prompt must be a string, array of strings, "
  50. "array of tokens, or array of token arrays")
  51. return prompt_is_tokens, prompts
  52. class OpenAIServingCompletion(OpenAIServing):
  53. def __init__(self,
  54. engine: AsyncAphrodite,
  55. served_model: str,
  56. lora_modules: Optional[List[LoRA]] = None):
  57. super().__init__(engine=engine,
  58. served_model=served_model,
  59. lora_modules=lora_modules)
  60. async def create_completion(self, request: CompletionRequest,
  61. raw_request: Request):
  62. """Completion API similar to OpenAI's API.
  63. See https://platform.openai.com/docs/api-reference/completions/create
  64. for the API specification. This API mimics the OpenAI Completion API.
  65. NOTE: Currently we do not support the following feature:
  66. - suffix (the language models we currently support do not support
  67. suffix)
  68. """
  69. error_check_ret = await self._check_model(request)
  70. if error_check_ret is not None:
  71. return error_check_ret
  72. # Return error for unsupported features.
  73. if request.suffix is not None:
  74. return self.create_error_response(
  75. "suffix is not currently supported")
  76. model_name = request.model
  77. request_id = f"cmpl-{random_uuid()}"
  78. created_time = int(time.time())
  79. # Schedule the request and get the result generator.
  80. generators = []
  81. try:
  82. sampling_params = request.to_sampling_params(
  83. self.tokenizer.vocab_size)
  84. lora_request = self._maybe_get_lora(request)
  85. decoding_config = self.engine.engine.decoding_config
  86. guided_decoding_backend = request.guided_decoding_backend \
  87. or decoding_config.guided_decoding_backend
  88. guided_decode_logit_processor = (
  89. await get_guided_decoding_logits_processor(
  90. guided_decoding_backend, request, await
  91. self.engine.get_tokenizer()))
  92. if guided_decode_logit_processor is not None:
  93. sampling_params.logits_processors.append(
  94. guided_decode_logit_processor)
  95. prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
  96. for i, prompt in enumerate(prompts):
  97. if prompt_is_tokens:
  98. prompt_formats = self._validate_prompt_and_tokenize(
  99. request,
  100. prompt_ids=prompt,
  101. truncate_prompt_tokens=sampling_params.
  102. truncate_prompt_tokens)
  103. else:
  104. prompt_formats = self._validate_prompt_and_tokenize(
  105. request,
  106. prompt=prompt,
  107. truncate_prompt_tokens=sampling_params.
  108. truncate_prompt_tokens)
  109. prompt_ids, prompt_text = prompt_formats
  110. generators.append(
  111. self.engine.generate(prompt_text,
  112. sampling_params,
  113. f"{request_id}-{i}",
  114. prompt_token_ids=prompt_ids,
  115. lora_request=lora_request))
  116. except ValueError as e:
  117. # TODO: Use a specific-specific Validation Error
  118. return self.create_error_response(str(e))
  119. result_generator: AsyncIterator[Tuple[
  120. int, RequestOutput]] = merge_async_iterators(*generators)
  121. # Similar to the OpenAI API, when n != best_of, we do not stream the
  122. # results. In addition, we do not stream the results when use
  123. # beam search.
  124. stream = (request.stream
  125. and (request.best_of is None or request.n == request.best_of)
  126. and not request.use_beam_search)
  127. # Streaming response
  128. if stream:
  129. return self.completion_stream_generator(request,
  130. raw_request,
  131. result_generator,
  132. request_id,
  133. created_time,
  134. model_name,
  135. num_prompts=len(prompts))
  136. # Non-streaming response
  137. final_res_batch: RequestOutput = [None] * len(prompts)
  138. try:
  139. async for i, res in result_generator:
  140. if await raw_request.is_disconnected():
  141. # Abort the request if the client disconnects.
  142. await self.engine.abort(f"{request_id}-{i}")
  143. return self.create_error_response("Client disconnected")
  144. final_res_batch[i] = res
  145. response = self.request_output_to_completion_response(
  146. final_res_batch, request, request_id, created_time, model_name)
  147. except ValueError as e:
  148. # TODO: Use a aphrodite-specific Validation Error
  149. return self.create_error_response(str(e))
  150. # When user requests streaming but we don't stream, we still need to
  151. # return a streaming response with a single event.
  152. if request.stream:
  153. response_json = response.model_dump_json()
  154. async def fake_stream_generator() -> AsyncGenerator[str, None]:
  155. yield f"data: {response_json}\n\n"
  156. yield "data: [DONE]\n\n"
  157. return fake_stream_generator()
  158. return response
  159. async def completion_stream_generator(
  160. self,
  161. request: CompletionRequest,
  162. raw_request: Request,
  163. result_generator: AsyncIterator[Tuple[int, RequestOutput]],
  164. request_id: str,
  165. created_time: int,
  166. model_name: str,
  167. num_prompts: int,
  168. ) -> AsyncGenerator[str, None]:
  169. previous_texts = [""] * request.n * num_prompts
  170. previous_num_tokens = [0] * request.n * num_prompts
  171. has_echoed = [False] * request.n * num_prompts
  172. try:
  173. async for prompt_idx, res in result_generator:
  174. # Abort the request if the client disconnects.
  175. if await raw_request.is_disconnected():
  176. await self.engine.abort(f"{request_id}-{prompt_idx}")
  177. raise StopAsyncIteration()
  178. for output in res.outputs:
  179. i = output.index + prompt_idx * request.n
  180. # TODO: optimize the performance by avoiding full
  181. # text O(n^2) sending.
  182. if request.echo and request.max_tokens == 0:
  183. # only return the prompt
  184. delta_text = res.prompt
  185. delta_token_ids = res.prompt_token_ids
  186. top_logprobs = res.prompt_logprobs
  187. has_echoed[i] = True
  188. elif (request.echo and request.max_tokens > 0
  189. and not has_echoed[i]):
  190. # echo the prompt and first token
  191. delta_text = res.prompt + output.text
  192. delta_token_ids = (res.prompt_token_ids +
  193. output.token_ids)
  194. top_logprobs = res.prompt_logprobs + (output.logprobs
  195. or [])
  196. has_echoed[i] = True
  197. else:
  198. # return just the delta
  199. delta_text = output.text[len(previous_texts[i]):]
  200. delta_token_ids = output.token_ids[
  201. previous_num_tokens[i]:]
  202. top_logprobs = output.logprobs[previous_num_tokens[
  203. i]:] if output.logprobs else None
  204. if request.logprobs is not None:
  205. logprobs = self._create_logprobs(
  206. token_ids=delta_token_ids,
  207. top_logprobs=top_logprobs,
  208. num_output_top_logprobs=request.logprobs,
  209. initial_text_offset=len(previous_texts[i]),
  210. )
  211. else:
  212. logprobs = None
  213. previous_texts[i] = output.text
  214. previous_num_tokens[i] = len(output.token_ids)
  215. finish_reason = output.finish_reason
  216. stop_reason = output.stop_reason
  217. if output.finish_reason is not None: # return final usage
  218. prompt_tokens = len(res.prompt_token_ids)
  219. completion_tokens = len(output.token_ids)
  220. final_usage = UsageInfo(
  221. prompt_tokens=prompt_tokens,
  222. completion_tokens=completion_tokens,
  223. total_tokens=prompt_tokens + completion_tokens,
  224. )
  225. else:
  226. final_usage = None
  227. response_json = CompletionStreamResponse(
  228. id=request_id,
  229. created=created_time,
  230. model=model_name,
  231. choices=[
  232. CompletionResponseStreamChoice(
  233. index=i,
  234. text=delta_text,
  235. logprobs=logprobs,
  236. finish_reason=finish_reason,
  237. stop_reason=stop_reason,
  238. )
  239. ],
  240. usage=final_usage,
  241. ).model_dump_json(exclude_unset=True)
  242. yield f"data: {response_json}\n\n"
  243. except ValueError as e:
  244. # TODO: Use an aphrodite-specific Validation Error
  245. data = self.create_streaming_error_response(str(e))
  246. yield f"data: {data}\n\n"
  247. yield "data: [DONE]\n\n"
  248. def request_output_to_completion_response(
  249. self,
  250. final_res_batch: List[RequestOutput],
  251. request: CompletionRequest,
  252. request_id: str,
  253. created_time: int,
  254. model_name: str,
  255. ) -> CompletionResponse:
  256. choices = []
  257. num_prompt_tokens = 0
  258. num_generated_tokens = 0
  259. for final_res in final_res_batch:
  260. assert final_res is not None
  261. prompt_token_ids = final_res.prompt_token_ids
  262. prompt_logprobs = final_res.prompt_logprobs
  263. prompt_text = final_res.prompt
  264. for output in final_res.outputs:
  265. if request.echo and request.max_tokens == 0:
  266. token_ids = prompt_token_ids
  267. top_logprobs = prompt_logprobs
  268. output_text = prompt_text
  269. elif request.echo and request.max_tokens > 0:
  270. token_ids = prompt_token_ids + output.token_ids
  271. top_logprobs = (prompt_logprobs + output.logprobs
  272. if request.logprobs else None)
  273. output_text = prompt_text + output.text
  274. else:
  275. token_ids = output.token_ids
  276. top_logprobs = output.logprobs
  277. output_text = output.text
  278. if request.logprobs is not None:
  279. assert top_logprobs is not None, (
  280. "top_logprobs must be provided when logprobs "
  281. "is requested")
  282. logprobs = self._create_logprobs(
  283. token_ids=token_ids,
  284. top_logprobs=top_logprobs,
  285. num_output_top_logprobs=request.logprobs,
  286. )
  287. else:
  288. logprobs = None
  289. choice_data = CompletionResponseChoice(
  290. index=len(choices),
  291. text=output_text,
  292. logprobs=logprobs,
  293. finish_reason=output.finish_reason,
  294. stop_reason=output.stop_reason,
  295. )
  296. choices.append(choice_data)
  297. num_prompt_tokens += len(prompt_token_ids)
  298. num_generated_tokens += sum(
  299. len(output.token_ids) for output in final_res.outputs)
  300. usage = UsageInfo(
  301. prompt_tokens=num_prompt_tokens,
  302. completion_tokens=num_generated_tokens,
  303. total_tokens=num_prompt_tokens + num_generated_tokens,
  304. )
  305. return CompletionResponse(
  306. id=request_id,
  307. created=created_time,
  308. model=model_name,
  309. choices=choices,
  310. usage=usage,
  311. )