|
@@ -1,8 +1,9 @@
|
|
|
import asyncio
|
|
|
+import os
|
|
|
import time
|
|
|
from functools import partial
|
|
|
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
|
|
|
- Union, AsyncIterator)
|
|
|
+ Union, AsyncIterator, Callable)
|
|
|
|
|
|
from aphrodite.lora.request import LoRARequest
|
|
|
from aphrodite.common.config import ModelConfig
|
|
@@ -14,28 +15,31 @@ from aphrodite.common.outputs import RequestOutput
|
|
|
from aphrodite.common.sampling_params import SamplingParams
|
|
|
|
|
|
logger = init_logger(__name__)
|
|
|
+ENGINE_ITERATION_TIMEOUT_S = int(
|
|
|
+ os.environ.get("APHRODITE_ENGINE_ITERATION_TIMEOUT_S", 60))
|
|
|
|
|
|
|
|
|
class AsyncEngineDeadError(RuntimeError):
|
|
|
pass
|
|
|
|
|
|
|
|
|
-def _raise_exception_on_finish(task: asyncio.Task,
|
|
|
- request_tracker: "RequestTracker") -> None:
|
|
|
+def _raise_exception_on_finish(
|
|
|
+ task: asyncio.Task, error_callback: Callable[[Exception],
|
|
|
+ None]) -> None:
|
|
|
msg = ("Task finished unexpectedly. This should never happen! "
|
|
|
"Please open an issue on Github.")
|
|
|
+
|
|
|
+ exception = None
|
|
|
try:
|
|
|
- try:
|
|
|
- task.result()
|
|
|
- except asyncio.CancelledError:
|
|
|
- return
|
|
|
- except Exception as exc:
|
|
|
- raise AsyncEngineDeadError(
|
|
|
- msg + " See stack trace above for the actual cause.") from exc
|
|
|
+ task.result()
|
|
|
+ # NOTE: This will be thrown if task exits normally (which it should not)
|
|
|
raise AsyncEngineDeadError(msg)
|
|
|
- except Exception as exc:
|
|
|
- request_tracker.propagate_exception(exc)
|
|
|
- raise exc
|
|
|
+ except Exception as e:
|
|
|
+ exception = e
|
|
|
+ logger.error("Engine background task failed", exc_info=e)
|
|
|
+ error_callback(exception)
|
|
|
+ raise AsyncEngineDeadError(
|
|
|
+ msg + " See stack trace above for the actual cause.") from e
|
|
|
|
|
|
|
|
|
class AsyncStream:
|
|
@@ -78,13 +82,13 @@ class RequestTracker:
|
|
|
self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
|
|
|
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
|
|
|
dict]] = asyncio.Queue()
|
|
|
- self.new_requests_event = None
|
|
|
+ self.new_requests_event = asyncio.Event()
|
|
|
|
|
|
def __contains__(self, item):
|
|
|
return item in self._request_streams
|
|
|
|
|
|
- def init_event(self):
|
|
|
- self.new_requests_event = asyncio.Event()
|
|
|
+ def __len__(self) -> int:
|
|
|
+ return len(self._request_streams)
|
|
|
|
|
|
def propagate_exception(self,
|
|
|
exc: Exception,
|
|
@@ -93,9 +97,11 @@ class RequestTracker:
|
|
|
(all if request_id is None)."""
|
|
|
if request_id is not None:
|
|
|
self._request_streams[request_id].put(exc)
|
|
|
+ self.abort_request(request_id)
|
|
|
else:
|
|
|
- for stream in self._request_streams.values():
|
|
|
+ for rid, stream in self._request_streams.items():
|
|
|
stream.put(exc)
|
|
|
+ self.abort_request(rid)
|
|
|
|
|
|
def process_request_output(self,
|
|
|
request_output: RequestOutput,
|
|
@@ -172,12 +178,15 @@ class RequestTracker:
|
|
|
self._request_streams[stream.request_id] = stream
|
|
|
new_requests.append(new_request)
|
|
|
|
|
|
- self.new_requests_event.clear()
|
|
|
-
|
|
|
return new_requests, finished_requests
|
|
|
|
|
|
async def wait_for_new_requests(self):
|
|
|
- await self.new_requests_event.wait()
|
|
|
+ if not self.has_new_requests():
|
|
|
+ await self.new_requests_event.wait()
|
|
|
+ self.new_requests_event.clear()
|
|
|
+
|
|
|
+ def has_new_requests(self):
|
|
|
+ return not self._new_requests.empty()
|
|
|
|
|
|
|
|
|
class _AsyncAphrodite(AphroditeEngine):
|
|
@@ -285,6 +294,10 @@ class _AsyncAphrodite(AphroditeEngine):
|
|
|
all_outputs = await asyncio.gather(*coros)
|
|
|
return all_outputs
|
|
|
|
|
|
+ async def check_health_async(self):
|
|
|
+ """Raises an error if engine is unhealthy."""
|
|
|
+ self._check_if_any_actor_is_dead()
|
|
|
+
|
|
|
|
|
|
class AsyncAphrodite:
|
|
|
"""An asynchronous wrapper for AphroditeEngine.
|
|
@@ -334,27 +347,48 @@ class AsyncAphrodite:
|
|
|
# collected
|
|
|
self._background_loop_unshielded = None
|
|
|
self.start_engine_loop = start_engine_loop
|
|
|
- self._request_tracker = RequestTracker()
|
|
|
+ self._request_tracker: Optional[RequestTracker] = None
|
|
|
+ self._errored_with: Optional[BaseException] = None
|
|
|
|
|
|
@property
|
|
|
def is_running(self) -> bool:
|
|
|
return (self.background_loop is not None
|
|
|
- and not self.background_loop.done())
|
|
|
+ and not self._background_loop_unshielded.done())
|
|
|
+
|
|
|
+ @property
|
|
|
+ def is_stopped(self) -> bool:
|
|
|
+ return self.errored or (self.background_loop is not None
|
|
|
+ and self._background_loop_unshielded.done())
|
|
|
+
|
|
|
+ @property
|
|
|
+ def errored(self) -> bool:
|
|
|
+ return self._errored_with is not None
|
|
|
+
|
|
|
+ def set_errored(self, exc: Exception) -> None:
|
|
|
+ self._errored_with = exc
|
|
|
+
|
|
|
+ def _error_callback(self, exc: Exception) -> None:
|
|
|
+ self.set_errored(exc)
|
|
|
+ self._request_tracker.propagate_exception(exc)
|
|
|
|
|
|
def get_tokenizer(self):
|
|
|
return self.engine.tokenizer.tokenizer
|
|
|
|
|
|
def start_background_loop(self) -> None:
|
|
|
"""Start the background loop."""
|
|
|
+ if self.errored:
|
|
|
+ raise AsyncEngineDeadError(
|
|
|
+ "Background loop has errored already.") from self._errored_with
|
|
|
if self.is_running:
|
|
|
raise RuntimeError("Background loop is already running.")
|
|
|
- self._request_tracker.init_event()
|
|
|
+ # Initialize the RequestTracker here so it uses the right event loop.
|
|
|
+ self._request_tracker = RequestTracker()
|
|
|
|
|
|
self._background_loop_unshielded = asyncio.get_event_loop(
|
|
|
).create_task(self.run_engine_loop())
|
|
|
self._background_loop_unshielded.add_done_callback(
|
|
|
partial(_raise_exception_on_finish,
|
|
|
- request_tracker=self._request_tracker))
|
|
|
+ error_callback=self._error_callback))
|
|
|
self.background_loop = asyncio.shield(self._background_loop_unshielded)
|
|
|
|
|
|
def _init_engine(self, *args,
|
|
@@ -422,12 +456,23 @@ class AsyncAphrodite:
|
|
|
self.engine.abort_request(request_ids)
|
|
|
|
|
|
async def run_engine_loop(self):
|
|
|
- # Initialize the RequestTracker here so it uses the right event loop.
|
|
|
has_requests_in_progress = False
|
|
|
while True:
|
|
|
if not has_requests_in_progress:
|
|
|
+ logger.debug("Waiting for new requests...")
|
|
|
await self._request_tracker.wait_for_new_requests()
|
|
|
- has_requests_in_progress = await self.engine_step()
|
|
|
+ logger.debug("Got new requests!")
|
|
|
+
|
|
|
+ # Abort if iteration takes too long due to unrecoverable errors
|
|
|
+ # (eg. NCCL timeouts).
|
|
|
+ try:
|
|
|
+ has_requests_in_progress = await asyncio.wait_for(
|
|
|
+ self.engine_step(), ENGINE_ITERATION_TIMEOUT_S)
|
|
|
+ except asyncio.TimeoutError as exc:
|
|
|
+ logger.error(
|
|
|
+ "Engine iteration timed out. This should never happen!")
|
|
|
+ self.set_errored(exc)
|
|
|
+ raise
|
|
|
await asyncio.sleep(0)
|
|
|
|
|
|
async def add_request(
|
|
@@ -644,3 +689,19 @@ class AsyncAphrodite:
|
|
|
await self.engine.do_log_stats.remote()
|
|
|
else:
|
|
|
self.engine.do_log_stats()
|
|
|
+
|
|
|
+ async def check_health(self):
|
|
|
+ """Raises an error if engine is unhealthy."""
|
|
|
+ t = time.perf_counter()
|
|
|
+ logger.debug("Starting health check...")
|
|
|
+ if self.is_stopped:
|
|
|
+ raise AsyncEngineDeadError("Background loop is stopped.")
|
|
|
+
|
|
|
+ if self.engine_use_ray:
|
|
|
+ try:
|
|
|
+ await self.engine.check_health.remote()
|
|
|
+ except ray.exceptions.RayActorError as e:
|
|
|
+ raise RuntimeError("Engine is dead.") from e
|
|
|
+ else:
|
|
|
+ await self.engine.check_health_async()
|
|
|
+ logger.debug(f"Health check took {time.perf_counter()-t}s")
|