serving_completions.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. import asyncio
  2. import time
  3. from fastapi import Request
  4. from typing import (AsyncGenerator, AsyncIterator, Callable, List, Optional,
  5. Dict, Tuple)
  6. from aphrodite.common.utils import random_uuid
  7. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  8. from aphrodite.endpoints.openai.protocol import (
  9. CompletionRequest,
  10. CompletionResponse,
  11. CompletionResponseChoice,
  12. CompletionResponseStreamChoice,
  13. CompletionStreamResponse,
  14. LogProbs,
  15. UsageInfo,
  16. )
  17. from aphrodite.common.outputs import RequestOutput
  18. from aphrodite.endpoints.openai.serving_engine import OpenAIServing, LoRA
  19. from aphrodite.modeling.outlines_decoding import (
  20. get_guided_decoding_logits_processor)
  21. TypeTokenIDs = List[int]
  22. TypeTopLogProbs = List[Optional[Dict[int, float]]]
  23. TypeCreateLogProbsFn = Callable[
  24. [TypeTokenIDs, TypeTopLogProbs, Optional[int], int], LogProbs]
  25. def parse_prompt_format(prompt) -> Tuple[bool, list]:
  26. # get the prompt, openai supports the following
  27. # "a string, array of strings, array of tokens, or array of token arrays."
  28. prompt_is_tokens = False
  29. prompts = [prompt] # case 1: a string
  30. if isinstance(prompt, list):
  31. if len(prompt) == 0:
  32. raise ValueError("please provide at least one prompt")
  33. elif isinstance(prompt[0], str):
  34. prompt_is_tokens = False
  35. prompts = prompt # case 2: array of strings
  36. elif isinstance(prompt[0], int):
  37. prompt_is_tokens = True
  38. prompts = [prompt] # case 3: array of tokens
  39. elif isinstance(prompt[0], list) and isinstance(prompt[0][0], int):
  40. prompt_is_tokens = True
  41. prompts = prompt # case 4: array of token arrays
  42. else:
  43. raise ValueError(
  44. "prompt must be a string, array of strings, array of tokens, "
  45. "or array of token arrays")
  46. return prompt_is_tokens, prompts
  47. def merge_async_iterators(*iterators):
  48. """Merge multiple asynchronous iterators into a single iterator.
  49. This method handle the case where some iterators finish before others.
  50. When it yields, it yields a tuple (i, item) where i is the index of the
  51. iterator that yields the item.
  52. """
  53. queue = asyncio.Queue()
  54. finished = [False] * len(iterators)
  55. async def producer(i, iterator):
  56. try:
  57. async for item in iterator:
  58. await queue.put((i, item))
  59. except Exception as e:
  60. await queue.put(e)
  61. finished[i] = True
  62. _tasks = [
  63. asyncio.create_task(producer(i, iterator))
  64. for i, iterator in enumerate(iterators)
  65. ]
  66. async def consumer():
  67. while not all(finished) or not queue.empty():
  68. item = await queue.get()
  69. if isinstance(item, Exception):
  70. raise item
  71. yield item
  72. await asyncio.gather(*_tasks)
  73. return consumer()
  74. class OpenAIServingCompletion(OpenAIServing):
  75. def __init__(self,
  76. engine: AsyncAphrodite,
  77. served_model: str,
  78. lora_modules: Optional[List[LoRA]] = None):
  79. super().__init__(engine=engine,
  80. served_model=served_model,
  81. lora_modules=lora_modules)
  82. async def create_completion(self, request: CompletionRequest,
  83. raw_request: Request):
  84. """Completion API similar to OpenAI's API.
  85. See https://platform.openai.com/docs/api-reference/completions/create
  86. for the API specification. This API mimics the OpenAI Completion API.
  87. NOTE: Currently we do not support the following feature:
  88. - suffix (the language models we currently support do not support
  89. suffix)
  90. """
  91. error_check_ret = await self._check_model(request)
  92. if error_check_ret is not None:
  93. return error_check_ret
  94. # Return error for unsupported features.
  95. if request.suffix is not None:
  96. return self.create_error_response(
  97. "suffix is not currently supported")
  98. model_name = request.model
  99. request_id = f"cmpl-{random_uuid()}"
  100. created_time = int(time.monotonic())
  101. # Schedule the request and get the result generator.
  102. generators = []
  103. try:
  104. sampling_params = request.to_sampling_params(
  105. self.tokenizer.vocab_size)
  106. lora_request = self._maybe_get_lora(request)
  107. guided_decode_logit_processor = (
  108. await get_guided_decoding_logits_processor(
  109. request, self.engine.get_tokenizer()))
  110. if guided_decode_logit_processor is not None:
  111. if sampling_params.logits_processors is None:
  112. sampling_params.logits_processors = []
  113. sampling_params.logits_processors.append(
  114. guided_decode_logit_processor)
  115. prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
  116. for i, prompt in enumerate(prompts):
  117. if prompt_is_tokens:
  118. input_ids = self._validate_prompt_and_tokenize(
  119. request, prompt_ids=prompt)
  120. else:
  121. input_ids = self._validate_prompt_and_tokenize(
  122. request, prompt=prompt)
  123. generators.append(
  124. self.engine.generate(prompt,
  125. sampling_params,
  126. f"{request_id}-{i}",
  127. prompt_token_ids=input_ids,
  128. lora_request=lora_request))
  129. except ValueError as e:
  130. return self.create_error_response(str(e))
  131. result_generator: AsyncIterator[Tuple[
  132. int, RequestOutput]] = merge_async_iterators(*generators)
  133. # Similar to the OpenAI API, when n != best_of, we do not stream the
  134. # results. In addition, we do not stream the results when use beam
  135. # search.
  136. stream = (request.stream
  137. and (request.best_of is None or request.n == request.best_of)
  138. and not request.use_beam_search)
  139. # Streaming response
  140. if stream:
  141. return self.completion_stream_generator(request,
  142. raw_request,
  143. result_generator,
  144. request_id,
  145. created_time,
  146. model_name,
  147. num_prompts=len(prompts))
  148. # Non-streaming response
  149. final_res_batch: RequestOutput = [None] * len(prompts)
  150. try:
  151. async for i, res in result_generator:
  152. if await raw_request.is_disconnected():
  153. # Abort the request if the client disconnects.
  154. await self.engine.abort(f"{request_id}-{i}")
  155. return self.create_error_response("Client disconnected")
  156. final_res_batch[i] = res
  157. response = self.request_output_to_completion_response(
  158. final_res_batch, request, request_id, created_time, model_name)
  159. except ValueError as e:
  160. # TODO: Use a aphrodite-specific Validation Error
  161. return self.create_error_response(str(e))
  162. # When user requests streaming but we don't stream, we still need to
  163. # return a streaming response with a single event.
  164. if request.stream:
  165. response_json = response.model_dump_json()
  166. async def fake_stream_generator() -> AsyncGenerator[str, None]:
  167. yield f"data: {response_json}\n\n"
  168. yield "data: [DONE]\n\n"
  169. return fake_stream_generator()
  170. return response
  171. async def completion_stream_generator(
  172. self,
  173. request: CompletionRequest,
  174. raw_request: Request,
  175. result_generator: AsyncIterator[Tuple[int, RequestOutput]],
  176. request_id: str,
  177. created_time: int,
  178. model_name: str,
  179. num_prompts: int,
  180. ) -> AsyncGenerator[str, None]:
  181. previous_texts = [""] * request.n * num_prompts
  182. previous_num_tokens = [0] * request.n * num_prompts
  183. has_echoed = [False] * request.n * num_prompts
  184. try:
  185. async for prompt_idx, res in result_generator:
  186. # Abort the request if the client disconnects.
  187. if await raw_request.is_disconnected():
  188. await self.engine.abort(f"{request_id}-{prompt_idx}")
  189. raise StopAsyncIteration()
  190. for output in res.outputs:
  191. i = output.index + prompt_idx * request.n
  192. # TODO: optimize the performance by avoiding full text
  193. # O(n^2) sending.
  194. if request.echo and request.max_tokens == 0:
  195. # only return the prompt
  196. delta_text = res.prompt
  197. delta_token_ids = res.prompt_token_ids
  198. top_logprobs = res.prompt_logprobs
  199. has_echoed[i] = True
  200. elif (request.echo and request.max_tokens > 0
  201. and not has_echoed[i]):
  202. # echo the prompt and first token
  203. delta_text = res.prompt + output.text
  204. delta_token_ids = (res.prompt_token_ids +
  205. output.token_ids)
  206. top_logprobs = res.prompt_logprobs + (output.logprobs
  207. or [])
  208. has_echoed[i] = True
  209. else:
  210. # return just the delta
  211. delta_text = output.text[len(previous_texts[i]):]
  212. delta_token_ids = output.token_ids[
  213. previous_num_tokens[i]:]
  214. top_logprobs = output.logprobs[previous_num_tokens[
  215. i]:] if output.logprobs else None
  216. if request.logprobs is not None:
  217. assert top_logprobs is not None, "top_logprobs must " \
  218. "be provided when logprobs is requested"
  219. logprobs = self._create_logprobs(
  220. token_ids=delta_token_ids,
  221. top_logprobs=top_logprobs,
  222. num_output_top_logprobs=request.logprobs,
  223. initial_text_offset=len(previous_texts[i]),
  224. )
  225. else:
  226. logprobs = None
  227. previous_texts[i] = output.text
  228. previous_num_tokens[i] = len(output.token_ids)
  229. finish_reason = output.finish_reason
  230. response_json = CompletionStreamResponse(
  231. id=request_id,
  232. created=created_time,
  233. model=model_name,
  234. choices=[
  235. CompletionResponseStreamChoice(
  236. index=i,
  237. text=delta_text,
  238. logprobs=logprobs,
  239. finish_reason=finish_reason,
  240. )
  241. ]).model_dump_json()
  242. yield f"data: {response_json}\n\n"
  243. if output.finish_reason is not None: # return final usage
  244. logprobs = LogProbs(
  245. ) if request.logprobs is not None else None
  246. prompt_tokens = len(res.prompt_token_ids)
  247. completion_tokens = len(output.token_ids)
  248. final_usage = UsageInfo(
  249. prompt_tokens=prompt_tokens,
  250. completion_tokens=completion_tokens,
  251. total_tokens=prompt_tokens + completion_tokens,
  252. )
  253. response_json = CompletionStreamResponse(
  254. id=request_id,
  255. created=created_time,
  256. model=model_name,
  257. choices=[
  258. CompletionResponseStreamChoice(
  259. index=i,
  260. text="",
  261. logprobs=logprobs,
  262. finish_reason=output.finish_reason,
  263. )
  264. ],
  265. usage=final_usage,
  266. ).model_dump_json()
  267. yield f"data: {response_json}\n\n"
  268. except ValueError as e:
  269. # TODO: Use an aphrodite-specific Validation Error
  270. data = self.create_streaming_error_response(str(e))
  271. yield f"data: {data}\n\n"
  272. yield "data: [DONE]\n\n"
  273. def request_output_to_completion_response(
  274. self,
  275. final_res_batch: List[RequestOutput],
  276. request: CompletionRequest,
  277. request_id: str,
  278. created_time: int,
  279. model_name: str,
  280. ) -> CompletionResponse:
  281. choices = []
  282. num_prompt_tokens = 0
  283. num_generated_tokens = 0
  284. for final_res in final_res_batch:
  285. assert final_res is not None
  286. prompt_token_ids = final_res.prompt_token_ids
  287. prompt_logprobs = final_res.prompt_logprobs
  288. prompt_text = final_res.prompt
  289. for output in final_res.outputs:
  290. if request.echo and request.max_tokens == 0:
  291. token_ids = prompt_token_ids
  292. top_logprobs = prompt_logprobs
  293. output_text = prompt_text
  294. elif request.echo and request.max_tokens > 0:
  295. token_ids = prompt_token_ids + output.token_ids
  296. top_logprobs = prompt_logprobs + output.logprobs
  297. output_text = prompt_text + output.text
  298. else:
  299. token_ids = output.token_ids
  300. top_logprobs = output.logprobs
  301. output_text = output.text
  302. if request.logprobs is not None:
  303. logprobs = self._create_logprobs(
  304. token_ids=token_ids,
  305. top_logprobs=top_logprobs,
  306. num_output_top_logprobs=request.logprobs,
  307. )
  308. else:
  309. logprobs = None
  310. choice_data = CompletionResponseChoice(
  311. index=len(choices),
  312. text=output_text,
  313. logprobs=logprobs,
  314. finish_reason=output.finish_reason,
  315. )
  316. choices.append(choice_data)
  317. num_prompt_tokens += len(prompt_token_ids)
  318. num_generated_tokens += sum(
  319. len(output.token_ids) for output in final_res.outputs)
  320. usage = UsageInfo(
  321. prompt_tokens=num_prompt_tokens,
  322. completion_tokens=num_generated_tokens,
  323. total_tokens=num_prompt_tokens + num_generated_tokens,
  324. )
  325. return CompletionResponse(
  326. id=request_id,
  327. created=created_time,
  328. model=model_name,
  329. choices=choices,
  330. usage=usage,
  331. )