Browse Source

chore: better stream termination in async engine (#672)

AlpinDale 6 months ago
parent
commit
19ad952dd4
1 changed files with 22 additions and 19 deletions
  1. 22 19
      aphrodite/engine/async_aphrodite.py

+ 22 - 19
aphrodite/engine/async_aphrodite.py

@@ -85,11 +85,14 @@ class AsyncStream:
             return
         self._queue.put_nowait(item)
 
-    def finish(self, cancelled: bool = False) -> None:
+    def finish(
+        self,
+        exception: Optional[Union[BaseException, Type[BaseException]]] = None,
+    ) -> None:
         if not self._finished:
             self._finished = True
             self._queue.put_nowait(
-                asyncio.CancelledError if cancelled else STOP_ITERATION)
+                exception if exception is not None else STOP_ITERATION)
 
     @property
     def finished(self) -> bool:
@@ -133,14 +136,12 @@ class RequestTracker:
         """Propagate an exception to request streams
         (all if request_id is None)."""
         if request_id is not None:
-            self._request_streams[request_id].put(exc)
-            self.abort_request(request_id)
+            self.abort_request(request_id, exception=exc)
         else:
-            # NB: list() used here because self.abort_request pops the stream
+            # NB: tuple() used here because self.abort_request pops the stream
             # out of self._request_streams, so we can't iterate on it directly
-            for rid, stream in list(self._request_streams.items()):
-                stream.put(exc)
-                self.abort_request(rid)
+            for rid in tuple(self._request_streams.keys()):
+                self.abort_request(rid, exception=exc)
 
     def process_request_output(self,
                                request_output: Union[RequestOutput,
@@ -167,14 +168,13 @@ class RequestTracker:
 
     def process_exception(self,
                           request_id: str,
-                          exception: Exception,
+                          exception: BaseException,
                           *,
                           verbose: bool = False) -> None:
         """Propagate an exception from the engine."""
-        self._request_streams[request_id].put(exception)
         if verbose:
             logger.info(f"Finished request {request_id}.")
-        self.abort_request(request_id)
+        self.abort_request(request_id, exception=exception)
 
     def add_request(self,
                     request_id: str,
@@ -203,7 +203,8 @@ class RequestTracker:
     def abort_request(self,
                       request_id: str,
                       *,
-                      cancelled: bool = False,
+                      exception: Optional[Union[BaseException,
+                                                Type[BaseException]]] = None,
                       verbose: bool = False) -> None:
         """Abort a request during next background loop iteration."""
         if verbose:
@@ -213,7 +214,7 @@ class RequestTracker:
 
         stream = self._request_streams.pop(request_id, None)
         if stream is not None:
-            stream.finish(cancelled=cancelled)
+            stream.finish(exception=exception)
 
     def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
         """Get the new requests and finished requests to be
@@ -227,12 +228,14 @@ class RequestTracker:
 
         while not self._new_requests.empty():
             stream, new_request = self._new_requests.get_nowait()
-            if stream.request_id in finished_requests:
+            request_id = stream.request_id
+            if request_id in finished_requests:
                 # The request has already been aborted.
-                stream.finish(cancelled=True)
-                continue
-            self._request_streams[stream.request_id] = stream
-            new_requests.append(new_request)
+                stream.finish(asyncio.CancelledError)
+                finished_requests.discard(request_id)
+            else:
+                self._request_streams[request_id] = stream
+                new_requests.append(new_request)
 
         return new_requests, finished_requests
 
@@ -1004,7 +1007,7 @@ class AsyncAphrodite:
             request_id: The unique id of the request.
         """
         self._request_tracker.abort_request(request_id,
-                                            cancelled=True,
+                                            exception=asyncio.CancelledError,
                                             verbose=self.log_requests)
 
     async def get_model_config(self) -> ModelConfig: