serving_completions.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. import time
  2. from typing import (
  3. AsyncGenerator,
  4. AsyncIterator,
  5. Callable,
  6. Dict,
  7. List,
  8. Optional,
  9. Tuple,
  10. )
  11. from fastapi import Request
  12. from aphrodite.common.outputs import RequestOutput
  13. from aphrodite.common.utils import merge_async_iterators, random_uuid
  14. from aphrodite.endpoints.openai.protocol import (
  15. CompletionRequest,
  16. CompletionResponse,
  17. CompletionResponseChoice,
  18. CompletionResponseStreamChoice,
  19. CompletionStreamResponse,
  20. LogProbs,
  21. UsageInfo,
  22. )
  23. from aphrodite.endpoints.openai.serving_engine import LoRA, OpenAIServing
  24. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  25. from aphrodite.modeling.guided_decoding import (
  26. get_guided_decoding_logits_processor, )
  27. TypeTokenIDs = List[int]
  28. TypeTopLogProbs = List[Optional[Dict[int, float]]]
  29. TypeCreateLogProbsFn = Callable[
  30. [TypeTokenIDs, TypeTopLogProbs, Optional[int], int], LogProbs]
  31. def parse_prompt_format(prompt) -> Tuple[bool, list]:
  32. # get the prompt, openai supports the following
  33. # "a string, array of strings, array of tokens, or array of token arrays."
  34. prompt_is_tokens = False
  35. prompts = [prompt] # case 1: a string
  36. if isinstance(prompt, list):
  37. if len(prompt) == 0:
  38. raise ValueError("please provide at least one prompt")
  39. elif isinstance(prompt[0], str):
  40. prompt_is_tokens = False
  41. prompts = prompt # case 2: array of strings
  42. elif isinstance(prompt[0], int):
  43. prompt_is_tokens = True
  44. prompts = [prompt] # case 3: array of tokens
  45. elif isinstance(prompt[0], list) and isinstance(prompt[0][0], int):
  46. prompt_is_tokens = True
  47. prompts = prompt # case 4: array of token arrays
  48. else:
  49. raise ValueError("prompt must be a string, array of strings, "
  50. "array of tokens, or array of token arrays")
  51. return prompt_is_tokens, prompts
  52. class OpenAIServingCompletion(OpenAIServing):
  53. def __init__(self,
  54. engine: AsyncAphrodite,
  55. served_model_names: List[str],
  56. lora_modules: Optional[List[LoRA]] = None):
  57. super().__init__(engine=engine,
  58. served_model_names=served_model_names,
  59. lora_modules=lora_modules)
  60. async def create_completion(self, request: CompletionRequest,
  61. raw_request: Request):
  62. """Completion API similar to OpenAI's API.
  63. See https://platform.openai.com/docs/api-reference/completions/create
  64. for the API specification. This API mimics the OpenAI Completion API.
  65. NOTE: Currently we do not support the following feature:
  66. - suffix (the language models we currently support do not support
  67. suffix)
  68. """
  69. error_check_ret = await self._check_model(request)
  70. if error_check_ret is not None:
  71. return error_check_ret
  72. # Return error for unsupported features.
  73. if request.suffix is not None:
  74. return self.create_error_response(
  75. "suffix is not currently supported")
  76. model_name = self.served_model_names[0]
  77. request_id = f"cmpl-{random_uuid()}"
  78. created_time = int(time.time())
  79. # Schedule the request and get the result generator.
  80. generators = []
  81. try:
  82. sampling_params = request.to_sampling_params(
  83. self.tokenizer.vocab_size)
  84. lora_request = self._maybe_get_lora(request)
  85. decoding_config = self.engine.engine.decoding_config
  86. guided_decoding_backend = request.guided_decoding_backend \
  87. or decoding_config.guided_decoding_backend
  88. guided_decode_logit_processor = (
  89. await get_guided_decoding_logits_processor(
  90. guided_decoding_backend, request, await
  91. self.engine.get_tokenizer()))
  92. if guided_decode_logit_processor is not None:
  93. sampling_params.logits_processors.append(
  94. guided_decode_logit_processor)
  95. prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
  96. for i, prompt in enumerate(prompts):
  97. if prompt_is_tokens:
  98. prompt_formats = self._validate_prompt_and_tokenize(
  99. request,
  100. prompt_ids=prompt,
  101. truncate_prompt_tokens=sampling_params.
  102. truncate_prompt_tokens)
  103. else:
  104. prompt_formats = self._validate_prompt_and_tokenize(
  105. request,
  106. prompt=prompt,
  107. truncate_prompt_tokens=sampling_params.
  108. truncate_prompt_tokens)
  109. prompt_ids, prompt_text = prompt_formats
  110. generator = self.engine.generate(
  111. {
  112. "prompt": prompt_text,
  113. "prompt_token_ids": prompt_ids
  114. },
  115. sampling_params,
  116. f"{request_id}-{i}",
  117. lora_request=lora_request,
  118. )
  119. generators.append(generator)
  120. except ValueError as e:
  121. # TODO: Use a specific-specific Validation Error
  122. return self.create_error_response(str(e))
  123. result_generator: AsyncIterator[Tuple[
  124. int, RequestOutput]] = merge_async_iterators(*generators)
  125. # Similar to the OpenAI API, when n != best_of, we do not stream the
  126. # results. In addition, we do not stream the results when use
  127. # beam search.
  128. stream = (request.stream
  129. and (request.best_of is None or request.n == request.best_of)
  130. and not request.use_beam_search)
  131. # Streaming response
  132. if stream:
  133. return self.completion_stream_generator(request,
  134. raw_request,
  135. result_generator,
  136. request_id,
  137. created_time,
  138. model_name,
  139. num_prompts=len(prompts))
  140. # Non-streaming response
  141. final_res_batch: RequestOutput = [None] * len(prompts)
  142. try:
  143. async for i, res in result_generator:
  144. if await raw_request.is_disconnected():
  145. # Abort the request if the client disconnects.
  146. await self.engine.abort(f"{request_id}-{i}")
  147. return self.create_error_response("Client disconnected")
  148. final_res_batch[i] = res
  149. response = self.request_output_to_completion_response(
  150. final_res_batch, request, request_id, created_time, model_name)
  151. except ValueError as e:
  152. # TODO: Use a aphrodite-specific Validation Error
  153. return self.create_error_response(str(e))
  154. # When user requests streaming but we don't stream, we still need to
  155. # return a streaming response with a single event.
  156. if request.stream:
  157. response_json = response.model_dump_json()
  158. async def fake_stream_generator() -> AsyncGenerator[str, None]:
  159. yield f"data: {response_json}\n\n"
  160. yield "data: [DONE]\n\n"
  161. return fake_stream_generator()
  162. return response
  163. async def completion_stream_generator(
  164. self,
  165. request: CompletionRequest,
  166. raw_request: Request,
  167. result_generator: AsyncIterator[Tuple[int, RequestOutput]],
  168. request_id: str,
  169. created_time: int,
  170. model_name: str,
  171. num_prompts: int,
  172. ) -> AsyncGenerator[str, None]:
  173. previous_texts = [""] * request.n * num_prompts
  174. previous_num_tokens = [0] * request.n * num_prompts
  175. has_echoed = [False] * request.n * num_prompts
  176. try:
  177. async for prompt_idx, res in result_generator:
  178. # Abort the request if the client disconnects.
  179. if await raw_request.is_disconnected():
  180. await self.engine.abort(f"{request_id}-{prompt_idx}")
  181. raise StopAsyncIteration()
  182. for output in res.outputs:
  183. i = output.index + prompt_idx * request.n
  184. # TODO: optimize the performance by avoiding full
  185. # text O(n^2) sending.
  186. if request.echo and request.max_tokens == 0:
  187. # only return the prompt
  188. delta_text = res.prompt
  189. delta_token_ids = res.prompt_token_ids
  190. top_logprobs = res.prompt_logprobs
  191. has_echoed[i] = True
  192. elif (request.echo and request.max_tokens > 0
  193. and not has_echoed[i]):
  194. # echo the prompt and first token
  195. delta_text = res.prompt + output.text
  196. delta_token_ids = (res.prompt_token_ids +
  197. output.token_ids)
  198. top_logprobs = res.prompt_logprobs + (output.logprobs
  199. or [])
  200. has_echoed[i] = True
  201. else:
  202. # return just the delta
  203. delta_text = output.text[len(previous_texts[i]):]
  204. delta_token_ids = output.token_ids[
  205. previous_num_tokens[i]:]
  206. top_logprobs = output.logprobs[previous_num_tokens[
  207. i]:] if output.logprobs else None
  208. if request.logprobs is not None:
  209. logprobs = self._create_logprobs(
  210. token_ids=delta_token_ids,
  211. top_logprobs=top_logprobs,
  212. num_output_top_logprobs=request.logprobs,
  213. initial_text_offset=len(previous_texts[i]),
  214. )
  215. else:
  216. logprobs = None
  217. previous_texts[i] = output.text
  218. previous_num_tokens[i] = len(output.token_ids)
  219. finish_reason = output.finish_reason
  220. stop_reason = output.stop_reason
  221. if output.finish_reason is not None: # return final usage
  222. prompt_tokens = len(res.prompt_token_ids)
  223. completion_tokens = len(output.token_ids)
  224. final_usage = UsageInfo(
  225. prompt_tokens=prompt_tokens,
  226. completion_tokens=completion_tokens,
  227. total_tokens=prompt_tokens + completion_tokens,
  228. )
  229. else:
  230. final_usage = None
  231. response_json = CompletionStreamResponse(
  232. id=request_id,
  233. created=created_time,
  234. model=model_name,
  235. choices=[
  236. CompletionResponseStreamChoice(
  237. index=i,
  238. text=delta_text,
  239. logprobs=logprobs,
  240. finish_reason=finish_reason,
  241. stop_reason=stop_reason,
  242. )
  243. ],
  244. usage=final_usage,
  245. ).model_dump_json(exclude_unset=True)
  246. yield f"data: {response_json}\n\n"
  247. except ValueError as e:
  248. # TODO: Use an aphrodite-specific Validation Error
  249. data = self.create_streaming_error_response(str(e))
  250. yield f"data: {data}\n\n"
  251. yield "data: [DONE]\n\n"
  252. def request_output_to_completion_response(
  253. self,
  254. final_res_batch: List[RequestOutput],
  255. request: CompletionRequest,
  256. request_id: str,
  257. created_time: int,
  258. model_name: str,
  259. ) -> CompletionResponse:
  260. choices = []
  261. num_prompt_tokens = 0
  262. num_generated_tokens = 0
  263. for final_res in final_res_batch:
  264. assert final_res is not None
  265. prompt_token_ids = final_res.prompt_token_ids
  266. prompt_logprobs = final_res.prompt_logprobs
  267. prompt_text = final_res.prompt
  268. for output in final_res.outputs:
  269. if request.echo and request.max_tokens == 0:
  270. token_ids = prompt_token_ids
  271. top_logprobs = prompt_logprobs
  272. output_text = prompt_text
  273. elif request.echo and request.max_tokens > 0:
  274. token_ids = prompt_token_ids + output.token_ids
  275. top_logprobs = (prompt_logprobs + output.logprobs
  276. if request.logprobs is not None else None)
  277. output_text = prompt_text + output.text
  278. else:
  279. token_ids = output.token_ids
  280. top_logprobs = output.logprobs
  281. output_text = output.text
  282. if request.logprobs is not None:
  283. assert top_logprobs is not None, (
  284. "top_logprobs must be provided when logprobs "
  285. "is requested")
  286. logprobs = self._create_logprobs(
  287. token_ids=token_ids,
  288. top_logprobs=top_logprobs,
  289. num_output_top_logprobs=request.logprobs,
  290. )
  291. else:
  292. logprobs = None
  293. choice_data = CompletionResponseChoice(
  294. index=len(choices),
  295. text=output_text,
  296. logprobs=logprobs,
  297. finish_reason=output.finish_reason,
  298. stop_reason=output.stop_reason,
  299. )
  300. choices.append(choice_data)
  301. num_prompt_tokens += len(prompt_token_ids)
  302. num_generated_tokens += sum(
  303. len(output.token_ids) for output in final_res.outputs)
  304. usage = UsageInfo(
  305. prompt_tokens=num_prompt_tokens,
  306. completion_tokens=num_generated_tokens,
  307. total_tokens=num_prompt_tokens + num_generated_tokens,
  308. )
  309. return CompletionResponse(
  310. id=request_id,
  311. created=created_time,
  312. model=model_name,
  313. choices=choices,
  314. usage=usage,
  315. )