Browse Source

chore: add health check for ray workers (#290)

* add health check for ray workers

Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>

* add tests

Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>

* formatting

---------

Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
AlpinDale 1 year ago
parent
commit
2d3d44b3e9

+ 20 - 0
aphrodite/engine/aphrodite_engine.py

@@ -1026,3 +1026,23 @@ class AphroditeEngine:
             ray_worker_outputs = ray.get(ray_worker_outputs)
 
         return [driver_worker_output] + ray_worker_outputs
+
+    def check_health(self) -> None:
+        """Raises an error if engine is unhealthy."""
+        self._check_if_any_actor_is_dead()
+
+    def _check_if_any_actor_is_dead(self):
+        if not self.parallel_config.worker_use_ray:
+            return
+
+        if not self.workers:
+            return
+
+        dead_actors = []
+        for actor in self.workers:
+            actor_state = ray.state.actors(actor._ray_actor_id.hex())
+            if actor_state["State"] == "DEAD":
+                dead_actors.append(actor)
+        if dead_actors:
+            raise RuntimeError("At least one Worker is dead. "
+                               f"Dead workers: {dead_actors}")

+ 87 - 26
aphrodite/engine/async_aphrodite.py

@@ -1,8 +1,9 @@
 import asyncio
+import os
 import time
 from functools import partial
 from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
-                    Union, AsyncIterator)
+                    Union, AsyncIterator, Callable)
 
 from aphrodite.lora.request import LoRARequest
 from aphrodite.common.config import ModelConfig
@@ -14,28 +15,31 @@ from aphrodite.common.outputs import RequestOutput
 from aphrodite.common.sampling_params import SamplingParams
 
 logger = init_logger(__name__)
+ENGINE_ITERATION_TIMEOUT_S = int(
+    os.environ.get("APHRODITE_ENGINE_ITERATION_TIMEOUT_S", 60))
 
 
 class AsyncEngineDeadError(RuntimeError):
     pass
 
 
-def _raise_exception_on_finish(task: asyncio.Task,
-                               request_tracker: "RequestTracker") -> None:
+def _raise_exception_on_finish(
+        task: asyncio.Task, error_callback: Callable[[Exception],
+                                                     None]) -> None:
     msg = ("Task finished unexpectedly. This should never happen! "
            "Please open an issue on Github.")
+
+    exception = None
     try:
-        try:
-            task.result()
-        except asyncio.CancelledError:
-            return
-        except Exception as exc:
-            raise AsyncEngineDeadError(
-                msg + " See stack trace above for the actual cause.") from exc
+        task.result()
+        # NOTE: This will be thrown if task exits normally (which it should not)
         raise AsyncEngineDeadError(msg)
-    except Exception as exc:
-        request_tracker.propagate_exception(exc)
-        raise exc
+    except Exception as e:
+        exception = e
+        logger.error("Engine background task failed", exc_info=e)
+        error_callback(exception)
+        raise AsyncEngineDeadError(
+            msg + " See stack trace above for the actual cause.") from e
 
 
 class AsyncStream:
@@ -78,13 +82,13 @@ class RequestTracker:
         self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
         self._new_requests: asyncio.Queue[Tuple[AsyncStream,
                                                 dict]] = asyncio.Queue()
-        self.new_requests_event = None
+        self.new_requests_event = asyncio.Event()
 
     def __contains__(self, item):
         return item in self._request_streams
 
-    def init_event(self):
-        self.new_requests_event = asyncio.Event()
+    def __len__(self) -> int:
+        return len(self._request_streams)
 
     def propagate_exception(self,
                             exc: Exception,
@@ -93,9 +97,11 @@ class RequestTracker:
         (all if request_id is None)."""
         if request_id is not None:
             self._request_streams[request_id].put(exc)
+            self.abort_request(request_id)
         else:
-            for stream in self._request_streams.values():
+            for rid, stream in self._request_streams.items():
                 stream.put(exc)
+                self.abort_request(rid)
 
     def process_request_output(self,
                                request_output: RequestOutput,
@@ -172,12 +178,15 @@ class RequestTracker:
             self._request_streams[stream.request_id] = stream
             new_requests.append(new_request)
 
-        self.new_requests_event.clear()
-
         return new_requests, finished_requests
 
     async def wait_for_new_requests(self):
-        await self.new_requests_event.wait()
+        if not self.has_new_requests():
+            await self.new_requests_event.wait()
+        self.new_requests_event.clear()
+
+    def has_new_requests(self):
+        return not self._new_requests.empty()
 
 
 class _AsyncAphrodite(AphroditeEngine):
@@ -285,6 +294,10 @@ class _AsyncAphrodite(AphroditeEngine):
         all_outputs = await asyncio.gather(*coros)
         return all_outputs
 
+    async def check_health_async(self):
+        """Raises an error if engine is unhealthy."""
+        self._check_if_any_actor_is_dead()
+
 
 class AsyncAphrodite:
     """An asynchronous wrapper for AphroditeEngine.
@@ -334,27 +347,48 @@ class AsyncAphrodite:
         # collected
         self._background_loop_unshielded = None
         self.start_engine_loop = start_engine_loop
-        self._request_tracker = RequestTracker()
+        self._request_tracker: Optional[RequestTracker] = None
+        self._errored_with: Optional[BaseException] = None
 
     @property
     def is_running(self) -> bool:
         return (self.background_loop is not None
-                and not self.background_loop.done())
+                and not self._background_loop_unshielded.done())
+
+    @property
+    def is_stopped(self) -> bool:
+        return self.errored or (self.background_loop is not None
+                                and self._background_loop_unshielded.done())
+
+    @property
+    def errored(self) -> bool:
+        return self._errored_with is not None
+
+    def set_errored(self, exc: Exception) -> None:
+        self._errored_with = exc
+
+    def _error_callback(self, exc: Exception) -> None:
+        self.set_errored(exc)
+        self._request_tracker.propagate_exception(exc)
 
     def get_tokenizer(self):
         return self.engine.tokenizer.tokenizer
 
     def start_background_loop(self) -> None:
         """Start the background loop."""
+        if self.errored:
+            raise AsyncEngineDeadError(
+                "Background loop has errored already.") from self._errored_with
         if self.is_running:
             raise RuntimeError("Background loop is already running.")
-        self._request_tracker.init_event()
+        # Initialize the RequestTracker here so it uses the right event loop.
+        self._request_tracker = RequestTracker()
 
         self._background_loop_unshielded = asyncio.get_event_loop(
         ).create_task(self.run_engine_loop())
         self._background_loop_unshielded.add_done_callback(
             partial(_raise_exception_on_finish,
-                    request_tracker=self._request_tracker))
+                    error_callback=self._error_callback))
         self.background_loop = asyncio.shield(self._background_loop_unshielded)
 
     def _init_engine(self, *args,
@@ -422,12 +456,23 @@ class AsyncAphrodite:
             self.engine.abort_request(request_ids)
 
     async def run_engine_loop(self):
-        # Initialize the RequestTracker here so it uses the right event loop.
         has_requests_in_progress = False
         while True:
             if not has_requests_in_progress:
+                logger.debug("Waiting for new requests...")
                 await self._request_tracker.wait_for_new_requests()
-            has_requests_in_progress = await self.engine_step()
+                logger.debug("Got new requests!")
+
+            # Abort if iteration takes too long due to unrecoverable errors
+            # (eg. NCCL timeouts).
+            try:
+                has_requests_in_progress = await asyncio.wait_for(
+                    self.engine_step(), ENGINE_ITERATION_TIMEOUT_S)
+            except asyncio.TimeoutError as exc:
+                logger.error(
+                    "Engine iteration timed out. This should never happen!")
+                self.set_errored(exc)
+                raise
             await asyncio.sleep(0)
 
     async def add_request(
@@ -644,3 +689,19 @@ class AsyncAphrodite:
             await self.engine.do_log_stats.remote()
         else:
             self.engine.do_log_stats()
+
+    async def check_health(self):
+        """Raises an error if engine is unhealthy."""
+        t = time.perf_counter()
+        logger.debug("Starting health check...")
+        if self.is_stopped:
+            raise AsyncEngineDeadError("Background loop is stopped.")
+
+        if self.engine_use_ray:
+            try:
+                await self.engine.check_health.remote()
+            except ray.exceptions.RayActorError as e:
+                raise RuntimeError("Engine is dead.") from e
+        else:
+            await self.engine.check_health_async()
+        logger.debug(f"Health check took {time.perf_counter()-t}s")

+ 20 - 9
tests/async_engine/test_async_aphrodite.py

@@ -25,6 +25,9 @@ class MockEngine:
         return [RequestOutput(
             request_id=self.request_id)] if self.request_id else []
 
+    async def encode_request_async(self, *args, **kwargs):
+        pass
+
     def generate(self, request_id):
         self.request_id = request_id
 
@@ -35,10 +38,17 @@ class MockEngine:
         del kwargs  # Unused
         self.add_request_calls += 1
 
+    async def add_request_async(self, **kwargs):
+        self.add_request_calls += 1
+        return
+
     def abort_request(self, request_id):
         del request_id  # Unused
         self.abort_request_calls += 1
 
+    def has_unfinished_requests(self):
+        return self.request_id is not None
+
 
 class MockAsyncAphrodite(AsyncAphrodite):
 
@@ -61,20 +71,21 @@ async def test_new_requests_event():
     await engine.add_request("2", "", None)
     engine.engine.generate("2")
     await asyncio.sleep(0)
-    assert engine.engine.add_request_calls == 2
-    assert engine.engine.step_calls == 2
     await asyncio.sleep(0)
-    assert engine.engine.step_calls == 3
+    assert engine.engine.add_request_calls == 2
+    assert engine.engine.step_calls >= 2
+    await asyncio.sleep(0.001)
+    assert engine.engine.step_calls >= 3
     engine.engine.stop_generating()
-    await asyncio.sleep(0)
-    assert engine.engine.step_calls == 4
-    await asyncio.sleep(0)
-    assert engine.engine.step_calls == 4
+    await asyncio.sleep(0.001)
+    old_step_calls = engine.engine.step_calls
+    await asyncio.sleep(0.001)
+    assert engine.engine.step_calls == old_step_calls
 
     await engine.add_request("3", "", None)
     await asyncio.sleep(0.01)
     assert engine.engine.add_request_calls == 3
-    assert engine.engine.step_calls == 5
+    assert engine.engine.step_calls == old_step_calls + 1
     await asyncio.sleep(0.01)
     assert engine.engine.add_request_calls == 3
-    assert engine.engine.step_calls == 5
+    assert engine.engine.step_calls == old_step_calls + 1

+ 12 - 22
tests/async_engine/test_request_tracker.py

@@ -4,25 +4,12 @@ from aphrodite.engine.async_aphrodite import RequestTracker
 from aphrodite.common.outputs import RequestOutput
 
 
-class DummyEvent:
-
-    def __init__(self):
-        self.flag = False
-
-    def set(self):
-        self.flag = True
-
-    def clear(self):
-        self.flag = False
-
-
-def test_request_tracker():
+@pytest.mark.asyncio
+async def test_request_tracker():
     tracker = RequestTracker()
-    tracker.new_requests_event = DummyEvent()
     stream_1 = tracker.add_request("1")
-    assert tracker.new_requests_event.flag
     new, finished = tracker.get_new_and_finished_requests()
-    assert not tracker.new_requests_event.flag
+    assert not tracker.new_requests_event.is_set()
     assert len(new) == 1
     assert new[0]["request_id"] == "1"
     assert not finished
@@ -30,9 +17,10 @@ def test_request_tracker():
 
     stream_2 = tracker.add_request("2")
     stream_3 = tracker.add_request("3")
-    assert tracker.new_requests_event.flag
+    assert tracker.new_requests_event.is_set()
+    await tracker.wait_for_new_requests()
     new, finished = tracker.get_new_and_finished_requests()
-    assert not tracker.new_requests_event.flag
+    assert not tracker.new_requests_event.is_set()
     assert len(new) == 2
     assert new[0]["request_id"] == "2"
     assert new[1]["request_id"] == "3"
@@ -43,7 +31,7 @@ def test_request_tracker():
     # request_ids must be unique
     with pytest.raises(KeyError):
         tracker.add_request("1")
-    assert not tracker.new_requests_event.flag
+    assert not tracker.new_requests_event.is_set()
 
     tracker.abort_request("1")
     new, finished = tracker.get_new_and_finished_requests()
@@ -54,7 +42,8 @@ def test_request_tracker():
 
     stream_4 = tracker.add_request("4")
     tracker.abort_request("4")
-    assert tracker.new_requests_event.flag
+    assert tracker.new_requests_event.is_set()
+    await tracker.wait_for_new_requests()
     new, finished = tracker.get_new_and_finished_requests()
     assert len(finished) == 1
     assert "4" in finished
@@ -62,11 +51,12 @@ def test_request_tracker():
     assert stream_4.finished
 
     stream_5 = tracker.add_request("5")
-    assert tracker.new_requests_event.flag
+    assert tracker.new_requests_event.is_set()
     tracker.process_request_output(
         RequestOutput("2", "output", [], [], [], finished=True))
+    await tracker.wait_for_new_requests()
     new, finished = tracker.get_new_and_finished_requests()
-    assert not tracker.new_requests_event.flag
+    assert not tracker.new_requests_event.is_set()
     assert len(finished) == 1
     assert "2" in finished
     assert len(new) == 1