async_aphrodite.py 50 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215
  1. import asyncio
  2. import time
  3. from functools import partial
  4. from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
  5. Optional, Set, Tuple, Type, Union)
  6. from loguru import logger
  7. from typing_extensions import assert_never
  8. import aphrodite.common.envs as envs
  9. from aphrodite.common.config import (DecodingConfig, EngineConfig, LoRAConfig,
  10. ModelConfig, ParallelConfig,
  11. SchedulerConfig)
  12. from aphrodite.common.outputs import EmbeddingRequestOutput, RequestOutput
  13. from aphrodite.common.pooling_params import PoolingParams
  14. from aphrodite.common.sampling_params import SamplingParams
  15. from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
  16. from aphrodite.common.utils import print_warning_once
  17. from aphrodite.engine.aphrodite_engine import (AphroditeEngine,
  18. DecoderPromptComponents,
  19. PromptComponents,
  20. SchedulerOutputState)
  21. from aphrodite.engine.args_tools import AsyncEngineArgs
  22. from aphrodite.engine.async_timeout import asyncio_timeout
  23. from aphrodite.engine.metrics_types import StatLoggerBase
  24. from aphrodite.executor.executor_base import ExecutorAsyncBase
  25. from aphrodite.executor.ray_utils import initialize_ray_cluster, ray
  26. from aphrodite.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
  27. SingletonPromptInputs)
  28. from aphrodite.inputs.parse import is_explicit_encoder_decoder_prompt
  29. from aphrodite.lora.request import LoRARequest
  30. from aphrodite.processing.scheduler import SchedulerOutputs
  31. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  32. from aphrodite.transformers_utils.tokenizer import AnyTokenizer
  33. ENGINE_ITERATION_TIMEOUT_S = envs.APHRODITE_ENGINE_ITERATION_TIMEOUT_S
  34. class AsyncEngineDeadError(RuntimeError):
  35. pass
  36. def _log_task_completion(task: asyncio.Task,
  37. error_callback: Callable[[Exception], None]) -> None:
  38. """This function is only intended for the `engine.run_engine_loop()` task.
  39. In particular, that task runs a `while True` loop that can only exit if
  40. there is an exception.
  41. """
  42. exception = None
  43. try:
  44. return_value = task.result()
  45. raise AssertionError(
  46. f"The engine background task should never finish without an "
  47. f"exception. {return_value}")
  48. except asyncio.exceptions.CancelledError:
  49. # We assume that if the task is cancelled, we are gracefully shutting
  50. # down. This should only happen on program exit.
  51. logger.info("Engine is gracefully shutting down.")
  52. except Exception as e:
  53. exception = e
  54. logger.error("Engine background task failed", exc_info=e)
  55. error_callback(exception)
  56. raise AsyncEngineDeadError(
  57. "Task finished unexpectedly. This should never happen! "
  58. "Please open an issue on Github. See stack trace above for the "
  59. "actual cause.") from e
  60. STOP_ITERATION = Exception() # Sentinel
  61. class AsyncStream:
  62. """A stream of RequestOutputs or EmbeddingRequestOutputs for a request
  63. that can be iterated over asynchronously via an async generator."""
  64. def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
  65. self.request_id = request_id
  66. self._cancel = cancel
  67. self._queue: asyncio.Queue = asyncio.Queue()
  68. self._finished = False
  69. def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
  70. Exception]) -> None:
  71. if not self._finished:
  72. self._queue.put_nowait(item)
  73. def finish(
  74. self,
  75. exception: Optional[Union[BaseException, Type[BaseException]]] = None,
  76. ) -> None:
  77. if not self._finished:
  78. self._finished = True
  79. self._queue.put_nowait(
  80. exception if self._is_raisable(exception) else STOP_ITERATION)
  81. @property
  82. def finished(self) -> bool:
  83. return self._finished
  84. async def generator(
  85. self
  86. ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
  87. try:
  88. while True:
  89. result = await self._queue.get()
  90. if self._is_raisable(result):
  91. if result == STOP_ITERATION:
  92. return
  93. raise result
  94. yield result
  95. except GeneratorExit:
  96. self._cancel(self.request_id)
  97. raise asyncio.CancelledError from None
  98. @staticmethod
  99. def _is_raisable(value: Any):
  100. return isinstance(value, BaseException) or \
  101. (isinstance(value, type) and \
  102. issubclass(value, BaseException))
  103. class RequestTracker:
  104. """Synchronous abstraction for tracking requests."""
  105. def __init__(self) -> None:
  106. self._request_streams: Dict[str, AsyncStream] = {}
  107. self._aborted_requests: asyncio.Queue[str] = asyncio.Queue()
  108. self._new_requests: asyncio.Queue[Tuple[AsyncStream,
  109. dict]] = asyncio.Queue()
  110. self.new_requests_event = asyncio.Event()
  111. def __contains__(self, item):
  112. return item in self._request_streams
  113. def __len__(self) -> int:
  114. return len(self._request_streams)
  115. def propagate_exception(self,
  116. exc: Exception,
  117. request_id: Optional[str] = None) -> None:
  118. """Propagate an exception to request streams
  119. (all if request_id is None)."""
  120. if request_id is not None:
  121. self.abort_request(request_id, exception=exc)
  122. else:
  123. # NB: tuple() used here because self.abort_request pops the stream
  124. # out of self._request_streams, so we can't iterate on it directly
  125. for rid in tuple(self._request_streams.keys()):
  126. self.abort_request(rid, exception=exc)
  127. def process_request_output(self,
  128. request_output: Union[RequestOutput,
  129. EmbeddingRequestOutput],
  130. *,
  131. verbose: bool = False) -> None:
  132. """Process a request output from the engine."""
  133. request_id = request_output.request_id
  134. finished = request_output.finished
  135. if finished:
  136. stream = self._request_streams.pop(request_id, None)
  137. else:
  138. stream = self._request_streams.get(request_id)
  139. # Guard against a KeyError which can occur if the request was aborted
  140. # while the output was generated
  141. if stream is not None:
  142. stream.put(request_output)
  143. if finished:
  144. stream.finish()
  145. if verbose and finished:
  146. logger.info(f"Finished request {request_id}.")
  147. def process_exception(self,
  148. request_id: str,
  149. exception: BaseException,
  150. *,
  151. verbose: bool = False) -> None:
  152. """Propagate an exception from the engine."""
  153. if verbose:
  154. logger.info(f"Finished request {request_id}.")
  155. self.abort_request(request_id, exception=exception)
  156. def add_request(self,
  157. request_id: str,
  158. *,
  159. verbose: bool = False,
  160. **engine_add_request_kwargs) -> AsyncStream:
  161. """Add a request to be sent to the engine on the next background
  162. loop iteration."""
  163. if request_id in self._request_streams:
  164. raise KeyError(f"Request {request_id} already exists.")
  165. abort_request = partial(self.abort_request, verbose=verbose)
  166. stream = AsyncStream(request_id, abort_request)
  167. self._new_requests.put_nowait((stream, {
  168. "request_id": request_id,
  169. **engine_add_request_kwargs
  170. }))
  171. self.new_requests_event.set()
  172. if verbose:
  173. logger.info(f"Added request {request_id}.")
  174. return stream
  175. def abort_request(self,
  176. request_id: str,
  177. *,
  178. exception: Optional[Union[BaseException,
  179. Type[BaseException]]] = None,
  180. verbose: bool = False) -> None:
  181. """Abort a request during next background loop iteration."""
  182. if verbose:
  183. logger.info(f"Aborted request {request_id}.")
  184. self._aborted_requests.put_nowait(request_id)
  185. stream = self._request_streams.pop(request_id, None)
  186. if stream is not None:
  187. stream.finish(exception=exception)
  188. def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
  189. """Get the new requests and finished requests to be
  190. sent to the engine."""
  191. new_requests: List[Dict] = []
  192. finished_requests: Set[str] = set()
  193. while not self._aborted_requests.empty():
  194. request_id = self._aborted_requests.get_nowait()
  195. finished_requests.add(request_id)
  196. while not self._new_requests.empty():
  197. stream, new_request = self._new_requests.get_nowait()
  198. request_id = stream.request_id
  199. if request_id in finished_requests:
  200. # The request has already been aborted.
  201. stream.finish(asyncio.CancelledError)
  202. finished_requests.discard(request_id)
  203. else:
  204. self._request_streams[request_id] = stream
  205. new_requests.append(new_request)
  206. return new_requests, finished_requests
  207. async def wait_for_new_requests(self):
  208. if not self.has_new_requests():
  209. await self.new_requests_event.wait()
  210. self.new_requests_event.clear()
  211. def has_new_requests(self):
  212. return not self._new_requests.empty()
  213. class _AsyncAphrodite(AphroditeEngine):
  214. """Extension of AphroditeEngine to add async methods."""
  215. def __init__(self, *args, **kwargs):
  216. super().__init__(*args, **kwargs)
  217. async def step_async(
  218. self, virtual_engine: int
  219. ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
  220. """Performs one decoding iteration and returns newly generated results.
  221. The workers are ran asynchronously if possible.
  222. This function performs one decoding iteration of the engine. It first
  223. schedules the sequences to be executed in the next iteration and the
  224. token blocks to be swapped in/out/copy. Then, it executes the model
  225. and updates the scheduler with the model outputs. Finally, it decodes
  226. the sequences and returns the newly generated results.
  227. """
  228. # these are cached outputs from previous iterations. None if on first
  229. # iteration
  230. cached_outputs = self.cached_scheduler_outputs[virtual_engine]
  231. seq_group_metadata_list = cached_outputs.seq_group_metadata_list
  232. scheduler_outputs = cached_outputs.scheduler_outputs
  233. allow_async_output_proc = cached_outputs.allow_async_output_proc
  234. # skip the scheduler if there are any remaining steps in the seq groups.
  235. # This ensures that the scheduler is only called again when the current
  236. # batch has completed.
  237. if not self._has_remaining_steps(seq_group_metadata_list):
  238. (seq_group_metadata_list, scheduler_outputs,
  239. allow_async_output_proc
  240. ) = self.scheduler[virtual_engine].schedule()
  241. # If current scheduler iteration has no async postprocessor,
  242. # then we need first to drain the pending async postprocessor
  243. # before moving forward
  244. if not allow_async_output_proc and len(self.output_queue) > 0:
  245. self._process_model_outputs(is_async=True)
  246. if (self.scheduler_config.is_multi_step
  247. and scheduler_outputs.num_lookahead_slots > 0):
  248. # cache the scheduler outputs for the next iteration if we have
  249. # lookahead slots
  250. self._cache_scheduler_outputs_for_multi_step(
  251. virtual_engine, seq_group_metadata_list, scheduler_outputs,
  252. allow_async_output_proc)
  253. assert seq_group_metadata_list is not None
  254. assert scheduler_outputs is not None
  255. assert not (self.scheduler_config.is_multi_step and \
  256. allow_async_output_proc)
  257. if not scheduler_outputs.is_empty():
  258. finished_requests_ids = self.scheduler[
  259. virtual_engine].get_and_reset_finished_requests_ids()
  260. # Check if we have a cached last_output from the previous iteration.
  261. # For supporting PP this is probably the best way to pass the
  262. # sampled_token_ids, as a separate broadcast over all the PP stages
  263. # will cause one virtual engine's microbatch to block the pipeline.
  264. last_sampled_token_ids = \
  265. self._get_last_sampled_token_ids(virtual_engine)
  266. execute_model_req = ExecuteModelRequest(
  267. seq_group_metadata_list=seq_group_metadata_list,
  268. blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
  269. blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
  270. blocks_to_copy=scheduler_outputs.blocks_to_copy,
  271. virtual_engine=virtual_engine,
  272. num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
  273. running_queue_size=scheduler_outputs.running_queue_size,
  274. finished_requests_ids=finished_requests_ids,
  275. # We use ExecuteModelRequest to pass the last sampled_token_ids
  276. # to each of the non-last PP stages for in-place prepare_input.
  277. last_sampled_token_ids=last_sampled_token_ids)
  278. if allow_async_output_proc:
  279. execute_model_req.output_proc_callback_fn = \
  280. self._process_model_outputs
  281. # Execute the model.
  282. output = await self.model_executor.execute_model_async(
  283. execute_model_req)
  284. # we need to do this here so that last step's sampled_token_ids can
  285. # be passed to the next iteration for PP.
  286. if self.scheduler_config.is_multi_step:
  287. self._update_cached_scheduler_output(virtual_engine, output)
  288. else:
  289. if len(self.output_queue) > 0:
  290. assert not self.scheduler_config.is_multi_step
  291. self._process_model_outputs(is_async=True)
  292. output = []
  293. # Finish the current step for all the sequence groups.
  294. if self.scheduler_config.is_multi_step:
  295. for seq_group in seq_group_metadata_list:
  296. seq_group.finish_step()
  297. if not self._has_remaining_steps(seq_group_metadata_list):
  298. # clear the cache if we have finished all the steps
  299. if self.scheduler_config.is_multi_step:
  300. self.cached_scheduler_outputs[
  301. virtual_engine] = SchedulerOutputState()
  302. # Cache results in engine
  303. self.output_queue.append(
  304. (output, seq_group_metadata_list, scheduler_outputs))
  305. if output and allow_async_output_proc:
  306. assert len(
  307. output
  308. ) == 1, "Multi step decoding does not work with async output processing." # noqa: E501
  309. self._advance_to_next_step(
  310. output[0], seq_group_metadata_list,
  311. scheduler_outputs.scheduled_seq_groups)
  312. if not allow_async_output_proc:
  313. self._process_model_outputs(is_async=False)
  314. # Log stats.
  315. self.do_log_stats(scheduler_outputs, output)
  316. else:
  317. self.request_outputs = []
  318. return self.request_outputs
  319. async def stop_remote_worker_execution_loop_async(self) -> None:
  320. """Stop the remote worker execution loop."""
  321. await self.model_executor.stop_remote_worker_execution_loop_async()
  322. async def _tokenize_prompt_async(
  323. self,
  324. prompt: str,
  325. request_id: str,
  326. lora_request: Optional[LoRARequest],
  327. ) -> List[int]:
  328. """Async version of :meth:`_tokenize_prompt`."""
  329. tokenizer = self.get_tokenizer_group(
  330. missing_msg="prompts must be None if skip_tokenizer_init is True")
  331. return await tokenizer.encode_async(request_id=request_id,
  332. prompt=prompt,
  333. lora_request=lora_request)
  334. async def _extract_prompt_components_async(
  335. self,
  336. inputs: SingletonPromptInputs,
  337. request_id: str,
  338. lora_request: Optional[LoRARequest] = None,
  339. ) -> PromptComponents:
  340. """Async version of :meth:`_extract_prompt_components`."""
  341. if isinstance(inputs, str):
  342. prompt = inputs
  343. prompt_token_ids = await self._tokenize_prompt_async(
  344. prompt,
  345. request_id=request_id,
  346. lora_request=lora_request,
  347. )
  348. multi_modal_data = None
  349. elif isinstance(inputs, dict):
  350. if "prompt_token_ids" in inputs:
  351. prompt = None
  352. prompt_token_ids = inputs["prompt_token_ids"]
  353. else:
  354. # NOTE: This extra assignment is required to pass mypy
  355. prompt = parsed_prompt = inputs["prompt"]
  356. prompt_token_ids = await self._tokenize_prompt_async(
  357. parsed_prompt,
  358. request_id=request_id,
  359. lora_request=lora_request,
  360. )
  361. multi_modal_data = inputs.get("multi_modal_data")
  362. else:
  363. assert_never(inputs)
  364. return prompt, prompt_token_ids, multi_modal_data
  365. async def _process_encoder_decoder_prompt_async(
  366. self,
  367. inputs: PromptInputs,
  368. request_id: str,
  369. ) -> EncoderDecoderLLMInputs:
  370. """Async version of :meth:`_process_encoder_decoder_prompt`."""
  371. encoder_comps: PromptComponents
  372. decoder_comps: DecoderPromptComponents
  373. if is_explicit_encoder_decoder_prompt(inputs):
  374. encoder_task = self._extract_prompt_components_async(
  375. inputs["encoder_prompt"],
  376. request_id=request_id,
  377. )
  378. if (decoder_input := inputs["decoder_prompt"]) is None:
  379. encoder_comps = await encoder_task
  380. decoder_comps = None, None, None
  381. else:
  382. decoder_task = self._extract_prompt_components_async(
  383. decoder_input,
  384. request_id=request_id,
  385. )
  386. encoder_comps, decoder_comps = await asyncio.gather(
  387. encoder_task, decoder_task)
  388. else:
  389. encoder_comps = await self._extract_prompt_components_async(
  390. inputs,
  391. request_id=request_id,
  392. )
  393. decoder_comps = None, None, None
  394. return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
  395. async def _process_decoder_only_prompt_async(
  396. self,
  397. inputs: SingletonPromptInputs,
  398. request_id: str,
  399. lora_request: Optional[LoRARequest] = None,
  400. prompt_adapter_request: Optional[PromptAdapterRequest] = None,
  401. ) -> LLMInputs:
  402. """Async version of :meth:`_process_decoder_only_prompt`."""
  403. prompt_comps = await self._extract_prompt_components_async(
  404. inputs,
  405. request_id=request_id,
  406. lora_request=lora_request,
  407. )
  408. return self._build_decoder_only_llm_inputs(
  409. prompt_comps,
  410. prompt_adapter_request=prompt_adapter_request,
  411. )
  412. async def process_model_inputs_async(
  413. self,
  414. inputs: PromptInputs,
  415. request_id: str,
  416. lora_request: Optional[LoRARequest] = None,
  417. prompt_adapter_request: Optional[PromptAdapterRequest] = None,
  418. ) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
  419. """Async version of :meth:`process_model_inputs`."""
  420. if self.is_encoder_decoder_model():
  421. # Encoder-decoder model requires special mapping of
  422. # input prompts to encoder & decoder
  423. model_inputs = await self._process_encoder_decoder_prompt_async(
  424. inputs,
  425. request_id=request_id,
  426. )
  427. else:
  428. if is_explicit_encoder_decoder_prompt(inputs):
  429. raise ValueError("Cannot pass encoder-decoder prompt "
  430. "to decoder-only models")
  431. # Decoder-only operation
  432. model_inputs = await self._process_decoder_only_prompt_async(
  433. inputs,
  434. request_id=request_id,
  435. lora_request=lora_request,
  436. prompt_adapter_request=prompt_adapter_request,
  437. )
  438. return self.input_processor(model_inputs)
  439. async def add_request_async(
  440. self,
  441. request_id: str,
  442. inputs: PromptInputs,
  443. params: Union[SamplingParams, PoolingParams],
  444. arrival_time: Optional[float] = None,
  445. lora_request: Optional[LoRARequest] = None,
  446. prompt_adapter_request: Optional[PromptAdapterRequest] = None,
  447. ) -> None:
  448. """Async version of :meth:`add_request`."""
  449. if lora_request is not None and not self.lora_config:
  450. raise ValueError(f"Got lora_request {lora_request} but LoRA is "
  451. "not enabled!")
  452. if arrival_time is None:
  453. arrival_time = time.time()
  454. processed_inputs = await self.process_model_inputs_async(
  455. inputs,
  456. request_id=request_id,
  457. lora_request=lora_request,
  458. prompt_adapter_request=prompt_adapter_request,
  459. )
  460. self._add_processed_request(
  461. request_id=request_id,
  462. processed_inputs=processed_inputs,
  463. params=params,
  464. arrival_time=arrival_time,
  465. lora_request=lora_request,
  466. prompt_adapter_request=prompt_adapter_request,
  467. )
  468. async def check_health_async(self) -> None:
  469. if self.tokenizer:
  470. self.tokenizer.check_health()
  471. self.model_executor.check_health()
  472. class AsyncAphrodite:
  473. """An asynchronous wrapper for :class:`AphroditeEngine`.
  474. This class is used to wrap the :class:`AphroditeEngine` class to make it
  475. asynchronous. It uses asyncio to create a background loop that keeps
  476. processing incoming requests. The :class:`AphroditeEngine` is kicked by the
  477. generate method when there are requests in the waiting queue. The generate
  478. method yields the outputs from the :class:`AphroditeEngine` to the caller.
  479. Args:
  480. worker_use_ray: Whether to use Ray for model workers. Required for
  481. distributed execution. Should be the same as
  482. `parallel_config.worker_use_ray`.
  483. engine_use_ray: Whether to make AphroditeEngine a Ray actor. If so, the
  484. async frontend will be executed in a separate process as the
  485. model workers.
  486. log_requests: Whether to log the requests.
  487. start_engine_loop: If True, the background task to run the engine
  488. will be automatically started in the generate call.
  489. *args: Arguments for :class:`AphroditeEngine`.
  490. **kwargs: Arguments for :class:`AphroditeEngine`.
  491. """
  492. _engine_class: Type[_AsyncAphrodite] = _AsyncAphrodite
  493. def __init__(self,
  494. worker_use_ray: bool,
  495. engine_use_ray: bool,
  496. *args,
  497. log_requests: bool = True,
  498. start_engine_loop: bool = True,
  499. **kwargs) -> None:
  500. self.worker_use_ray = worker_use_ray
  501. self.engine_use_ray = engine_use_ray
  502. self.log_requests = log_requests
  503. self.engine = self._init_engine(*args, **kwargs)
  504. if self.engine_use_ray:
  505. print_warning_once(
  506. "DEPRECATED. `--engine-use-ray` is deprecated and will "
  507. "be removed in a future update.")
  508. if envs.APHRODITE_ALLOW_ENGINE_USE_RAY:
  509. print_warning_once(
  510. "APHRODITE_ALLOW_ENGINE_USE_RAY is set, "
  511. "force engine use Ray")
  512. else:
  513. raise ValueError("`--engine-use-ray` is deprecated. "
  514. "Set `APHRODITE_ALLOW_ENGINE_USE_RAY=1` "
  515. "to force use it")
  516. self.background_loop: Optional[asyncio.Future] = None
  517. # We need to keep a reference to unshielded
  518. # task as well to prevent it from being garbage
  519. # collected
  520. self._background_loop_unshielded: Optional[asyncio.Task] = None
  521. self.start_engine_loop = start_engine_loop
  522. self._errored_with: Optional[BaseException] = None
  523. # Lazy initialized fields
  524. self._request_tracker: RequestTracker
  525. @classmethod
  526. def _get_executor_cls(
  527. cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]:
  528. distributed_executor_backend = (
  529. engine_config.parallel_config.distributed_executor_backend)
  530. if isinstance(distributed_executor_backend, type):
  531. if not issubclass(distributed_executor_backend, ExecutorAsyncBase):
  532. raise TypeError(
  533. "distributed_executor_backend must be a subclass of "
  534. f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
  535. if distributed_executor_backend.uses_ray: # type: ignore
  536. initialize_ray_cluster(engine_config.parallel_config)
  537. executor_class = distributed_executor_backend
  538. elif engine_config.device_config.device_type == "neuron":
  539. from aphrodite.executor.neuron_executor import NeuronExecutorAsync
  540. executor_class = NeuronExecutorAsync
  541. elif engine_config.device_config.device_type == "tpu":
  542. if distributed_executor_backend == "ray":
  543. initialize_ray_cluster(engine_config.parallel_config)
  544. from aphrodite.executor.ray_tpu_executor import (
  545. RayTPUExecutorAsync)
  546. executor_class = RayTPUExecutorAsync
  547. else:
  548. assert distributed_executor_backend is None
  549. from aphrodite.executor.tpu_executor import TPUExecutorAsync
  550. executor_class = TPUExecutorAsync
  551. elif engine_config.device_config.device_type == "cpu":
  552. from aphrodite.executor.cpu_executor import CPUExecutorAsync
  553. executor_class = CPUExecutorAsync
  554. elif engine_config.device_config.device_type == "openvino":
  555. assert distributed_executor_backend is None, (
  556. "Distributed execution is not supported with "
  557. "the OpenVINO backend.")
  558. from aphrodite.executor.openvino_executor import (
  559. OpenVINOExecutorAsync)
  560. executor_class = OpenVINOExecutorAsync
  561. elif engine_config.device_config.device_type == "xpu":
  562. if distributed_executor_backend is None:
  563. from aphrodite.executor.xpu_executor import XPUExecutorAsync
  564. executor_class = XPUExecutorAsync
  565. elif distributed_executor_backend == "ray":
  566. initialize_ray_cluster(engine_config.parallel_config)
  567. from aphrodite.executor.ray_xpu_executor import (
  568. RayXPUExecutorAsync)
  569. executor_class = RayXPUExecutorAsync
  570. else:
  571. raise RuntimeError(
  572. "Not supported distributed execution model on XPU device.")
  573. elif distributed_executor_backend == "ray":
  574. initialize_ray_cluster(engine_config.parallel_config)
  575. from aphrodite.executor.ray_gpu_executor import RayGPUExecutorAsync
  576. executor_class = RayGPUExecutorAsync
  577. elif distributed_executor_backend == "mp":
  578. from aphrodite.executor.multiproc_gpu_executor import (
  579. MultiprocessingGPUExecutorAsync)
  580. executor_class = MultiprocessingGPUExecutorAsync
  581. else:
  582. from aphrodite.executor.gpu_executor import GPUExecutorAsync
  583. executor_class = GPUExecutorAsync
  584. return executor_class
  585. @classmethod
  586. def from_engine_args(
  587. cls,
  588. engine_args: AsyncEngineArgs,
  589. start_engine_loop: bool = True,
  590. stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
  591. ) -> "AsyncAphrodite":
  592. """Creates an async LLM engine from the engine arguments."""
  593. # Create the engine configs.
  594. engine_config = engine_args.create_engine_config()
  595. if engine_args.engine_use_ray:
  596. from aphrodite.executor import ray_utils
  597. ray_utils.assert_ray_available()
  598. executor_class = cls._get_executor_cls(engine_config)
  599. # Create the async LLM engine.
  600. engine = cls(
  601. executor_class.uses_ray,
  602. engine_args.engine_use_ray,
  603. **engine_config.to_dict(),
  604. executor_class=executor_class,
  605. log_requests=not engine_args.disable_log_requests,
  606. log_stats=not engine_args.disable_log_stats,
  607. start_engine_loop=start_engine_loop,
  608. stat_loggers=stat_loggers,
  609. )
  610. return engine
  611. @property
  612. def is_running(self) -> bool:
  613. return (self.background_loop is not None
  614. and self._background_loop_unshielded is not None
  615. and not self._background_loop_unshielded.done())
  616. @property
  617. def is_stopped(self) -> bool:
  618. return self.errored or (self.background_loop is not None and
  619. self._background_loop_unshielded is not None
  620. and self._background_loop_unshielded.done())
  621. @property
  622. def errored(self) -> bool:
  623. return self._errored_with is not None
  624. @property
  625. def limit_concurrency(self) -> Optional[int]:
  626. """Maximum number of concurrently running requests."""
  627. return None
  628. def set_errored(self, exc: Exception) -> None:
  629. self._errored_with = exc
  630. def _error_callback(self, exc: Exception) -> None:
  631. self.set_errored(exc)
  632. self._request_tracker.propagate_exception(exc)
  633. async def get_tokenizer(
  634. self,
  635. lora_request: Optional[LoRARequest] = None,
  636. ) -> AnyTokenizer:
  637. if self.engine_use_ray:
  638. return await self.engine.get_tokenizer.remote( # type: ignore
  639. lora_request)
  640. return await (self.engine.get_tokenizer_group().
  641. get_lora_tokenizer_async(lora_request))
  642. def start_background_loop(self) -> None:
  643. """Start the background loop."""
  644. if self.errored:
  645. raise AsyncEngineDeadError(
  646. "Background loop has errored already.") from self._errored_with
  647. if self.is_running:
  648. raise RuntimeError("Background loop is already running.")
  649. # Initialize the RequestTracker here so it uses the right event loop.
  650. self._request_tracker = RequestTracker()
  651. self._background_loop_unshielded = asyncio.get_event_loop(
  652. ).create_task(self.run_engine_loop())
  653. self._background_loop_unshielded.add_done_callback(
  654. partial(_log_task_completion, error_callback=self._error_callback))
  655. self.background_loop = asyncio.shield(self._background_loop_unshielded)
  656. def shutdown_background_loop(self) -> None:
  657. """
  658. Shut down the background loop.
  659. This method needs to be called during cleanup to remove
  660. references to `self` and properly GC the resources held
  661. by the async LLM engine (e.g., the executors as well as
  662. their resources).
  663. """
  664. if self._background_loop_unshielded is not None:
  665. self._background_loop_unshielded.cancel()
  666. self._background_loop_unshielded = None
  667. self.background_loop = None
  668. def _init_engine(self, *args,
  669. **kwargs) -> Union[_AsyncAphrodite, "ray.ObjectRef"]:
  670. if not self.engine_use_ray:
  671. engine_class = self._engine_class
  672. elif self.worker_use_ray:
  673. engine_class = ray.remote(num_cpus=0)(self._engine_class).remote
  674. else:
  675. # FIXME: This is a bit hacky. Be careful when changing the
  676. # order of the arguments.
  677. cache_config = kwargs["cache_config"]
  678. parallel_config = kwargs["parallel_config"]
  679. if (parallel_config.tensor_parallel_size == 1
  680. and parallel_config.pipeline_parallel_size == 1):
  681. num_gpus = cache_config.gpu_memory_utilization
  682. else:
  683. num_gpus = 1
  684. engine_class = ray.remote(num_gpus=num_gpus)(
  685. self._engine_class).remote
  686. return engine_class(*args, **kwargs)
  687. async def engine_step(self, virtual_engine: int) -> bool:
  688. """Kick the engine to process the waiting requests.
  689. Returns True if there are in-progress requests."""
  690. new_requests, aborted_requests = (
  691. self._request_tracker.get_new_and_aborted_requests())
  692. for new_request in new_requests:
  693. # Add the request into the Aphrodite engine's waiting queue.
  694. # TODO: Maybe add add_request_batch to reduce Ray overhead
  695. try:
  696. if self.engine_use_ray:
  697. await self.engine.add_request.remote( # type: ignore
  698. **new_request)
  699. else:
  700. await self.engine.add_request_async(**new_request)
  701. except ValueError as e:
  702. # TODO: use an Aphrodite specific error for failed validation
  703. self._request_tracker.process_exception(
  704. new_request["request_id"],
  705. e,
  706. verbose=self.log_requests,
  707. )
  708. if aborted_requests:
  709. await self._engine_abort(aborted_requests)
  710. if self.engine_use_ray:
  711. request_outputs = await self.engine.step.remote() # type: ignore
  712. else:
  713. request_outputs = await self.engine.step_async(virtual_engine)
  714. # Put the outputs into the corresponding streams.
  715. finished = True
  716. for request_output in request_outputs:
  717. self._request_tracker.process_request_output(
  718. request_output, verbose=self.log_requests)
  719. finished = finished and request_output.finished
  720. return not finished
  721. async def _engine_abort(self, request_ids: Iterable[str]):
  722. if self.engine_use_ray:
  723. await self.engine.abort_request.remote(request_ids) # type: ignore
  724. else:
  725. self.engine.abort_request(request_ids)
  726. async def run_engine_loop(self):
  727. if self.engine_use_ray:
  728. pipeline_parallel_size = 1 # type: ignore
  729. else:
  730. pipeline_parallel_size = \
  731. self.engine.parallel_config.pipeline_parallel_size
  732. has_requests_in_progress = [False] * pipeline_parallel_size
  733. while True:
  734. if not any(has_requests_in_progress):
  735. logger.debug("Waiting for new requests...")
  736. # Stop the execute model loop in parallel workers until there
  737. # are more requests to process. This avoids waiting
  738. # indefinitely in torch.distributed ops which may otherwise
  739. # timeout, and unblocks the RPC thread in the workers so that
  740. # they can process any other queued control plane messages,
  741. # such as add/remove lora adapters.
  742. if self.engine_use_ray:
  743. await (self.engine.stop_remote_worker_execution_loop.
  744. remote() # type: ignore
  745. )
  746. else:
  747. await self.engine.stop_remote_worker_execution_loop_async()
  748. await self._request_tracker.wait_for_new_requests()
  749. logger.debug("Got new requests!")
  750. requests_in_progress = [
  751. asyncio.create_task(self.engine_step(ve))
  752. for ve in range(pipeline_parallel_size)
  753. ]
  754. has_requests_in_progress = [True] * pipeline_parallel_size
  755. # Abort if iteration takes too long due to unrecoverable errors
  756. # (eg. NCCL timeouts).
  757. try:
  758. async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
  759. done, _ = await asyncio.wait(
  760. requests_in_progress,
  761. return_when=asyncio.FIRST_COMPLETED)
  762. for _ in range(pipeline_parallel_size):
  763. await asyncio.sleep(0)
  764. for task in done:
  765. result = task.result()
  766. virtual_engine = requests_in_progress.index(task)
  767. if self.engine_use_ray:
  768. has_unfinished_requests = (
  769. await (self.engine.
  770. has_unfinished_requests_for_virtual_engine.
  771. remote( # type: ignore
  772. virtual_engine)))
  773. else:
  774. has_unfinished_requests = (
  775. self.engine.
  776. has_unfinished_requests_for_virtual_engine(
  777. virtual_engine))
  778. if result or has_unfinished_requests:
  779. requests_in_progress[virtual_engine] = (
  780. asyncio.create_task(
  781. self.engine_step(virtual_engine)))
  782. has_requests_in_progress[virtual_engine] = True
  783. else:
  784. has_requests_in_progress[virtual_engine] = False
  785. except asyncio.TimeoutError as exc:
  786. logger.error(
  787. "Engine iteration timed out. This should never happen!")
  788. self.set_errored(exc)
  789. raise
  790. await asyncio.sleep(0)
  791. # This method does not need to be async, but kept that way
  792. # for backwards compatibility.
  793. async def add_request(
  794. self,
  795. request_id: str,
  796. inputs: PromptInputs,
  797. params: Union[SamplingParams, PoolingParams],
  798. arrival_time: Optional[float] = None,
  799. lora_request: Optional[LoRARequest] = None,
  800. prompt_adapter_request: Optional[PromptAdapterRequest] = None
  801. ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
  802. if not self.is_running:
  803. if self.start_engine_loop:
  804. self.start_background_loop()
  805. else:
  806. raise AsyncEngineDeadError(
  807. "Background loop is not running. If it was running, "
  808. "inspect the output to find the stacktrace of the "
  809. "error that caused the background loop to stop "
  810. "(AsyncEngineDeadError).")
  811. stream = self._request_tracker.add_request(
  812. request_id,
  813. verbose=self.log_requests,
  814. inputs=inputs,
  815. params=params,
  816. arrival_time=arrival_time or time.time(),
  817. lora_request=lora_request,
  818. prompt_adapter_request=prompt_adapter_request)
  819. return stream.generator()
  820. async def generate(
  821. self,
  822. inputs: PromptInputs,
  823. sampling_params: SamplingParams,
  824. request_id: str,
  825. lora_request: Optional[LoRARequest] = None,
  826. prompt_adapter_request: Optional[PromptAdapterRequest] = None
  827. ) -> AsyncGenerator[RequestOutput, None]:
  828. """Generate outputs for a request.
  829. Generate outputs for a request. This method is a coroutine. It adds the
  830. request into the waiting queue of the AphroditeEngine and streams the
  831. outputs from the AphroditeEngine to the caller.
  832. Args:
  833. inputs: The inputs to the LLM. See
  834. :class:`~aphrodite.inputs.PromptInputs`
  835. for more details about the format of each input.
  836. sampling_params: The sampling parameters of the request.
  837. request_id: The unique id of the request.
  838. lora_request: LoRA request to use for generation, if any.
  839. prompt_adapter_request: Prompt Adapter request to use
  840. for generation, if any.
  841. Yields:
  842. The output `RequestOutput` objects from the AphroditeEngine
  843. for the request.
  844. Details:
  845. - If the engine is not running, start the background loop,
  846. which iteratively invokes
  847. # pylint: disable=line-too-long
  848. :meth:`~aphrodite.engine.async_aphrodite.AsyncAphrodite.engine_step`
  849. to process the waiting requests.
  850. - Add the request to the engine's `RequestTracker`.
  851. On the next background loop, this request will be sent to
  852. the underlying engine.
  853. Also, a corresponding `AsyncStream` will be created.
  854. - Wait for the request outputs from `AsyncStream` and yield them.
  855. Example:
  856. >>> # Please refer to entrypoints/api_server.py for
  857. >>> # the complete example.
  858. >>>
  859. >>> # initialize the engine and the example input
  860. >>> engine = AsyncAphrodite.from_engine_args(engine_args)
  861. >>> example_input = {
  862. >>> "prompt": "What is LLM?",
  863. >>> "stream": False, # assume the non-streaming case
  864. >>> "temperature": 0.0,
  865. >>> "request_id": 0,
  866. >>> }
  867. >>>
  868. >>> # start the generation
  869. >>> results_generator = engine.generate(
  870. >>> example_input["prompt"],
  871. >>> SamplingParams(temperature=example_input["temperature"]),
  872. >>> example_input["request_id"])
  873. >>>
  874. >>> # get the results
  875. >>> final_output = None
  876. >>> async for request_output in results_generator:
  877. >>> if await request.is_disconnected():
  878. >>> # Abort the request if the client disconnects.
  879. >>> await engine.abort(request_id)
  880. >>> # Return or raise an error
  881. >>> ...
  882. >>> final_output = request_output
  883. >>>
  884. >>> # Process and return the final output
  885. >>> ...
  886. """
  887. async for output in await self.add_request(
  888. request_id,
  889. inputs,
  890. sampling_params,
  891. lora_request=lora_request,
  892. prompt_adapter_request=prompt_adapter_request,
  893. ):
  894. yield AphroditeEngine.validate_output(output, RequestOutput)
  895. async def encode(
  896. self,
  897. inputs: PromptInputs,
  898. pooling_params: PoolingParams,
  899. request_id: str,
  900. lora_request: Optional[LoRARequest] = None,
  901. ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
  902. """Generate outputs for a request from an embedding model.
  903. Generate outputs for a request. This method is a coroutine. It adds the
  904. request into the waiting queue of the AphroditeEngine and streams the
  905. outputs from the AphroditeEngine to the caller.
  906. Args:
  907. inputs: The inputs to the LLM. See
  908. :class:`~aphrodite.inputs.PromptInputs`
  909. for more details about the format of each input.
  910. pooling_params: The pooling parameters of the request.
  911. request_id: The unique id of the request.
  912. lora_request: LoRA request to use for generation, if any.
  913. Yields:
  914. The output `EmbeddingRequestOutput` objects from the AphroditeEngine
  915. for the request.
  916. Details:
  917. - If the engine is not running, start the background loop,
  918. which iteratively invokes
  919. :meth:`~aphrodite.engine.async_aphrodite.AsyncAphrodite.engine_step`
  920. to process the waiting requests.
  921. - Add the request to the engine's `RequestTracker`.
  922. On the next background loop, this request will be sent to
  923. the underlying engine.
  924. Also, a corresponding `AsyncStream` will be created.
  925. - Wait for the request outputs from `AsyncStream` and yield them.
  926. Example:
  927. >>> # Please refer to endpoints/api_server.py for
  928. >>> # the complete example.
  929. >>>
  930. >>> # initialize the engine and the example input
  931. >>> engine = AsyncAphrodite.from_engine_args(engine_args)
  932. >>> example_input = {
  933. >>> "input": "What is LLM?",
  934. >>> "request_id": 0,
  935. >>> }
  936. >>>
  937. >>> # start the generation
  938. >>> results_generator = engine.encode(
  939. >>> example_input["input"],
  940. >>> PoolingParams(),
  941. >>> example_input["request_id"])
  942. >>>
  943. >>> # get the results
  944. >>> final_output = None
  945. >>> async for request_output in results_generator:
  946. >>> if await request.is_disconnected():
  947. >>> # Abort the request if the client disconnects.
  948. >>> await engine.abort(request_id)
  949. >>> # Return or raise an error
  950. >>> ...
  951. >>> final_output = request_output
  952. >>>
  953. >>> # Process and return the final output
  954. >>> ...
  955. """
  956. async for output in await self.add_request(
  957. request_id,
  958. inputs,
  959. pooling_params,
  960. lora_request=lora_request,
  961. ):
  962. yield AphroditeEngine.validate_output(output,
  963. EmbeddingRequestOutput)
  964. async def abort(self, request_id: str) -> None:
  965. """Abort a request.
  966. Abort a submitted request. If the request is finished or not found,
  967. this method will be a no-op.
  968. Args:
  969. request_id: The unique id of the request.
  970. """
  971. if not self.is_running:
  972. raise AsyncEngineDeadError(
  973. "Background loop is not running. If it was running, "
  974. "inspect the output to find the stacktrace of the "
  975. "error that caused the background loop to stop "
  976. "(AsyncEngineDeadError).")
  977. return self._abort(request_id)
  978. def _abort(self, request_id: str) -> None:
  979. """Abort a request.
  980. Abort a submitted request. If the request is finished or not found,
  981. this method will be a no-op.
  982. Args:
  983. request_id: The unique id of the request.
  984. """
  985. self._request_tracker.abort_request(request_id,
  986. exception=asyncio.CancelledError,
  987. verbose=self.log_requests)
  988. async def get_model_config(self) -> ModelConfig:
  989. """Get the model configuration of the Aphrodite engine."""
  990. if self.engine_use_ray:
  991. return await self.engine.get_model_config.remote() # type: ignore
  992. else:
  993. return self.engine.get_model_config()
  994. async def get_parallel_config(self) -> ParallelConfig:
  995. """Get the parallel configuration of the Aphrodite engine."""
  996. if self.engine_use_ray:
  997. return await self.engine.get_parallel_config.remote( # type: ignore
  998. )
  999. else:
  1000. return self.engine.get_parallel_config()
  1001. async def get_decoding_config(self) -> DecodingConfig:
  1002. """Get the decoding configuration of the Aphrodite engine."""
  1003. if self.engine_use_ray:
  1004. return await self.engine.get_decoding_config.remote( # type: ignore
  1005. )
  1006. else:
  1007. return self.engine.get_decoding_config()
  1008. async def get_scheduler_config(self) -> SchedulerConfig:
  1009. """Get the scheduling configuration of the Aphrodite engine."""
  1010. if self.engine_use_ray:
  1011. return await self.engine.get_scheduler_config.remote( # type: ignore
  1012. )
  1013. else:
  1014. return self.engine.get_scheduler_config()
  1015. async def get_lora_config(self) -> LoRAConfig:
  1016. """Get the lora configuration of the Aphrodite engine."""
  1017. if self.engine_use_ray:
  1018. return await self.engine.get_lora_config.remote( # type: ignore
  1019. )
  1020. else:
  1021. return self.engine.get_lora_config()
  1022. async def do_log_stats(
  1023. self,
  1024. scheduler_outputs: Optional[SchedulerOutputs] = None,
  1025. model_output: Optional[List[SamplerOutput]] = None) -> None:
  1026. if self.engine_use_ray:
  1027. await self.engine.do_log_stats.remote( # type: ignore
  1028. scheduler_outputs, model_output)
  1029. else:
  1030. self.engine.do_log_stats()
  1031. async def check_health(self) -> None:
  1032. """Raises an error if engine is unhealthy."""
  1033. t = time.perf_counter()
  1034. logger.debug("Starting health check...")
  1035. if self.is_stopped:
  1036. raise AsyncEngineDeadError("Background loop is stopped.")
  1037. if self.engine_use_ray:
  1038. try:
  1039. await self.engine.check_health.remote() # type: ignore
  1040. except ray.exceptions.RayActorError as e:
  1041. raise RuntimeError("Engine is dead.") from e
  1042. else:
  1043. await self.engine.check_health_async()
  1044. logger.debug(f"Health check took {time.perf_counter() - t}s")
  1045. def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
  1046. if self.engine_use_ray:
  1047. ray.get(
  1048. self.engine.add_logger.remote( # type: ignore
  1049. logger_name=logger_name, logger=logger))
  1050. else:
  1051. self.engine.add_logger(logger_name=logger_name, logger=logger)
  1052. def remove_logger(self, logger_name: str) -> None:
  1053. if self.engine_use_ray:
  1054. ray.get(
  1055. self.engine.remove_logger.remote( # type: ignore
  1056. logger_name=logger_name))
  1057. else:
  1058. self.engine.remove_logger(logger_name=logger_name)