Explorar o código

fix: better async request cancellation (#641)

* wip

* wip

* rpc client async generator

* serving chat

* serving completions

* serving embeddings
AlpinDale hai 6 meses
pai
achega
77c4fbd5c9

+ 61 - 50
aphrodite/common/utils.py

@@ -1,5 +1,6 @@
 import argparse
 import argparse
 import asyncio
 import asyncio
+import contextlib
 import datetime
 import datetime
 import enum
 import enum
 import gc
 import gc
@@ -11,10 +12,11 @@ import tempfile
 import threading
 import threading
 import uuid
 import uuid
 import warnings
 import warnings
+from asyncio import FIRST_COMPLETED, ensure_future
 from collections import defaultdict
 from collections import defaultdict
 from functools import lru_cache, partial, wraps
 from functools import lru_cache, partial, wraps
 from platform import uname
 from platform import uname
-from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
+from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
                     Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar,
                     Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar,
                     Union, overload)
                     Union, overload)
 
 
@@ -372,63 +374,72 @@ def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]:
     return _async_wrapper
     return _async_wrapper
 
 
 
 
-class ProducerFinished:
-    pass
+async def iterate_with_cancellation(
+    iterator: AsyncGenerator[T, None],
+    is_cancelled: Callable[[], Awaitable[bool]],
+) -> AsyncGenerator[T, None]:
+    """Convert async iterator into one that polls the provided function
+    at least once per second to check for client cancellation.
+    """
 
 
+    # Can use anext() in python >= 3.10
+    awaits = [ensure_future(iterator.__anext__())]
+    while True:
+        done, pending = await asyncio.wait(awaits, timeout=1)
+        if await is_cancelled():
+            with contextlib.suppress(BaseException):
+                awaits[0].cancel()
+                await iterator.aclose()
+            raise asyncio.CancelledError("client cancelled")
+        if done:
+            try:
+                item = await awaits[0]
+                awaits[0] = ensure_future(iterator.__anext__())
+                yield item
+            except StopAsyncIteration:
+                # we are done
+                return
 
 
-def merge_async_iterators(
-        *iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]:
-    """Merge multiple asynchronous iterators into a single iterator.
 
 
+async def merge_async_iterators(
+    *iterators: AsyncGenerator[T, None],
+    is_cancelled: Callable[[], Awaitable[bool]],
+) -> AsyncGenerator[Tuple[int, T], None]:
+    """Merge multiple asynchronous iterators into a single iterator.
     This method handle the case where some iterators finish before others.
     This method handle the case where some iterators finish before others.
     When it yields, it yields a tuple (i, item) where i is the index of the
     When it yields, it yields a tuple (i, item) where i is the index of the
     iterator that yields the item.
     iterator that yields the item.
+    It also polls the provided function at least once per second to check
+    for client cancellation.
     """
     """
-    queue: asyncio.Queue[Union[Tuple[int, T], ProducerFinished,
-                               Exception]] = asyncio.Queue()
-
-    producers = len(iterators)
-
-    async def producer(i: int, iterator: AsyncIterator[T]):
-        try:
-            async for item in iterator:
-                await queue.put((i, item))
-        except Exception as e:
-            await queue.put(e)
-        # Signal to the consumer that we've finished
-        await queue.put(ProducerFinished())
-
-    _tasks = [
-        asyncio.create_task(producer(i, iterator))
-        for i, iterator in enumerate(iterators)
-    ]
 
 
-    async def consumer():
-        remaining = producers
-        try:
-            while remaining or not queue.empty():
-                # we think there is a race condition here
-                item = await queue.get()
-
-                if isinstance(item, ProducerFinished):
-                    # Signal that a producer finished- not a real item
-                    remaining -= 1
-                    continue
-
-                if isinstance(item, Exception):
-                    raise item
-                yield item
-        except (Exception, asyncio.CancelledError) as e:
-            for task in _tasks:
-                if sys.version_info >= (3, 9):
-                    # msg parameter only supported in Python 3.9+
-                    task.cancel(e)
-                else:
-                    task.cancel()
-            raise e
-        await asyncio.gather(*_tasks)
-
-    return consumer()
+    # Can use anext() in python >= 3.10
+    awaits = {
+        ensure_future(pair[1].__anext__()): pair
+        for pair in enumerate(iterators)
+    }
+    try:
+        while awaits:
+            done, pending = await asyncio.wait(awaits.keys(),
+                                               return_when=FIRST_COMPLETED,
+                                               timeout=1)
+            if await is_cancelled():
+                raise asyncio.CancelledError("client cancelled")
+            for d in done:
+                pair = awaits.pop(d)
+                try:
+                    item = await d
+                    i, it = pair
+                    awaits[ensure_future(it.__anext__())] = pair
+                    yield i, item
+                except StopAsyncIteration:
+                    pass
+    finally:
+        # Cancel any remaining iterators
+        for f, (_, it) in awaits.items():
+            with contextlib.suppress(BaseException):
+                f.cancel()
+                await it.aclose()
 
 
 
 
 def get_ip() -> str:
 def get_ip() -> str:

+ 31 - 29
aphrodite/endpoints/openai/rpc/client.py

@@ -1,5 +1,5 @@
 from contextlib import contextmanager
 from contextlib import contextmanager
-from typing import Any, AsyncIterator, Optional
+from typing import Any, AsyncGenerator, Optional
 
 
 import cloudpickle
 import cloudpickle
 import zmq
 import zmq
@@ -179,35 +179,37 @@ class AsyncEngineRPCClient:
         request_id: str,
         request_id: str,
         lora_request: Optional[LoRARequest] = None,
         lora_request: Optional[LoRARequest] = None,
         prompt_adapter_request: Optional[PromptAdapterRequest] = None
         prompt_adapter_request: Optional[PromptAdapterRequest] = None
-    ) -> AsyncIterator[RequestOutput]:
+    ) -> AsyncGenerator[RequestOutput, None]:
         """Send an RPCGenerateRequest to the RPCServer and stream responses."""
         """Send an RPCGenerateRequest to the RPCServer and stream responses."""
 
 
-        with self.socket() as socket:
-
-            # Send RPCGenerateRequest to the RPCServer.
-            await socket.send_multipart([
-                cloudpickle.dumps(
-                    RPCGenerateRequest(
-                        inputs=inputs,
-                        sampling_params=sampling_params,
-                        request_id=request_id,
-                        lora_request=lora_request,
-                        prompt_adapter_request=prompt_adapter_request))
-            ])
-
-            # Stream back the results from the RPC Server.
-            while True:
-                message = await socket.recv()
-                request_output = cloudpickle.loads(message)
-
-                if isinstance(request_output, Exception):
-                    raise request_output
-
-                if request_output.finished:
-                    break
-                yield request_output
-
-            yield request_output
+        finished = False
+        try:
+            with self.socket() as socket:
+
+                # Send RPCGenerateRequest to the RPCServer.
+                await socket.send_multipart([
+                    cloudpickle.dumps(
+                        RPCGenerateRequest(
+                            inputs=inputs,
+                            sampling_params=sampling_params,
+                            request_id=request_id,
+                            lora_request=lora_request,
+                            prompt_adapter_request=prompt_adapter_request))
+                ])
+
+                # Stream back the results from the RPC Server.
+                while not finished:
+                    message = await socket.recv()
+                    request_output = cloudpickle.loads(message)
+
+                    if isinstance(request_output, Exception):
+                        raise request_output
+
+                    finished = request_output.finished
+                    yield request_output
+        finally:
+            if not finished:
+                await self.abort(request_id)
 
 
     async def check_health(self) -> None:
     async def check_health(self) -> None:
         """Raise if unhealthy"""
         """Raise if unhealthy"""
@@ -231,6 +233,6 @@ class AsyncEngineRPCClient:
                              f"{health_message}")
                              f"{health_message}")
 
 
     async def encode(self, *args,
     async def encode(self, *args,
-                     **kwargs) -> AsyncIterator[EmbeddingRequestOutput]:
+                     **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]:
         raise NotImplementedError(
         raise NotImplementedError(
             "Embeddings not supported with multiprocessing backend")
             "Embeddings not supported with multiprocessing backend")

+ 18 - 16
aphrodite/endpoints/openai/serving_chat.py

@@ -1,3 +1,4 @@
+import asyncio
 import time
 import time
 from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional
 from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional
 from typing import Sequence as GenericSequence
 from typing import Sequence as GenericSequence
@@ -10,7 +11,7 @@ from transformers import PreTrainedTokenizer
 from aphrodite.common.config import ModelConfig
 from aphrodite.common.config import ModelConfig
 from aphrodite.common.outputs import RequestOutput
 from aphrodite.common.outputs import RequestOutput
 from aphrodite.common.sequence import Logprob
 from aphrodite.common.sequence import Logprob
-from aphrodite.common.utils import random_uuid
+from aphrodite.common.utils import iterate_with_cancellation, random_uuid
 from aphrodite.endpoints.chat_utils import (ConversationMessage,
 from aphrodite.endpoints.chat_utils import (ConversationMessage,
                                             load_chat_template,
                                             load_chat_template,
                                             parse_chat_messages)
                                             parse_chat_messages)
@@ -160,18 +161,20 @@ class OpenAIServingChat(OpenAIServing):
             # TODO: Use an aphrodite-specific Validation Error
             # TODO: Use an aphrodite-specific Validation Error
             return self.create_error_response(str(e))
             return self.create_error_response(str(e))
 
 
+        if raw_request:
+            result_generator = iterate_with_cancellation(
+                result_generator, raw_request.is_disconnected)
+
         # Streaming response
         # Streaming response
         if request.stream:
         if request.stream:
             return self.chat_completion_stream_generator(
             return self.chat_completion_stream_generator(
                 request, result_generator, request_id, conversation, tokenizer)
                 request, result_generator, request_id, conversation, tokenizer)
-        else:
-            try:
-                return await self.chat_completion_full_generator(
-                    request, raw_request, result_generator, request_id,
-                    conversation, tokenizer)
-            except ValueError as e:
-                # TODO: Use an aphrodite-specific Validation Error
-                return self.create_error_response(str(e))
+        try:
+            return await self.chat_completion_full_generator(
+                request, result_generator, request_id, conversation, tokenizer)
+        except ValueError as e:
+            # TODO: Use an aphrodite-specific Validation Error
+            return self.create_error_response(str(e))
 
 
     def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
     def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
         if request.add_generation_prompt:
         if request.add_generation_prompt:
@@ -402,7 +405,6 @@ class OpenAIServingChat(OpenAIServing):
     async def chat_completion_full_generator(
     async def chat_completion_full_generator(
         self,
         self,
         request: ChatCompletionRequest,
         request: ChatCompletionRequest,
-        raw_request: Optional[Request],
         result_generator: AsyncIterator[RequestOutput],
         result_generator: AsyncIterator[RequestOutput],
         request_id: str,
         request_id: str,
         conversation: List[ConversationMessage],
         conversation: List[ConversationMessage],
@@ -413,12 +415,12 @@ class OpenAIServingChat(OpenAIServing):
         created_time = int(time.time())
         created_time = int(time.time())
         final_res: Optional[RequestOutput] = None
         final_res: Optional[RequestOutput] = None
 
 
-        async for res in result_generator:
-            if raw_request is not None and await raw_request.is_disconnected():
-                # Abort the request if the client disconnects.
-                await self.async_engine_client.abort(request_id)
-                return self.create_error_response("Client disconnected")
-            final_res = res
+        try:
+            async for res in result_generator:
+                final_res = res
+        except asyncio.CancelledError:
+            return self.create_error_response("Client disconnected")
+
         assert final_res is not None
         assert final_res is not None
 
 
         choices: List[ChatCompletionResponseChoice] = []
         choices: List[ChatCompletionResponseChoice] = []

+ 6 - 14
aphrodite/endpoints/openai/serving_completions.py

@@ -1,3 +1,4 @@
+import asyncio
 import time
 import time
 from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
 from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
                     Optional)
                     Optional)
@@ -76,7 +77,7 @@ class OpenAIServingCompletion(OpenAIServing):
         created_time = int(time.time())
         created_time = int(time.time())
 
 
         # Schedule the request and get the result generator.
         # Schedule the request and get the result generator.
-        generators: List[AsyncIterator[RequestOutput]] = []
+        generators: List[AsyncGenerator[RequestOutput, None]] = []
         try:
         try:
             (
             (
                 lora_request,
                 lora_request,
@@ -126,7 +127,8 @@ class OpenAIServingCompletion(OpenAIServing):
             return self.create_error_response(str(e))
             return self.create_error_response(str(e))
 
 
         result_generator: AsyncIterator[Tuple[
         result_generator: AsyncIterator[Tuple[
-            int, RequestOutput]] = merge_async_iterators(*generators)
+            int, RequestOutput]] = merge_async_iterators(
+                *generators, is_cancelled=raw_request.is_disconnected)
 
 
         # Similar to the OpenAI API, when n != best_of, we do not stream the
         # Similar to the OpenAI API, when n != best_of, we do not stream the
         # results. In addition, we do not stream the results when use
         # results. In addition, we do not stream the results when use
@@ -138,7 +140,6 @@ class OpenAIServingCompletion(OpenAIServing):
         # Streaming response
         # Streaming response
         if stream:
         if stream:
             return self.completion_stream_generator(request,
             return self.completion_stream_generator(request,
-                                                    raw_request,
                                                     result_generator,
                                                     result_generator,
                                                     request_id,
                                                     request_id,
                                                     created_time,
                                                     created_time,
@@ -150,10 +151,6 @@ class OpenAIServingCompletion(OpenAIServing):
         final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
         final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
         try:
         try:
             async for i, res in result_generator:
             async for i, res in result_generator:
-                if await raw_request.is_disconnected():
-                    # Abort the request if the client disconnects.
-                    await self.async_engine_client.abort(f"{request_id}-{i}")
-                    return self.create_error_response("Client disconnected")
                 final_res_batch[i] = res
                 final_res_batch[i] = res
 
 
             for i, final_res in enumerate(final_res_batch):
             for i, final_res in enumerate(final_res_batch):
@@ -175,6 +172,8 @@ class OpenAIServingCompletion(OpenAIServing):
                 model_name,
                 model_name,
                 tokenizer,
                 tokenizer,
             )
             )
+        except asyncio.CancelledError:
+            return self.create_error_response("Client disconnected")
         except ValueError as e:
         except ValueError as e:
             # TODO: Use an aphrodite-specific Validation Error
             # TODO: Use an aphrodite-specific Validation Error
             return self.create_error_response(str(e))
             return self.create_error_response(str(e))
@@ -195,7 +194,6 @@ class OpenAIServingCompletion(OpenAIServing):
     async def completion_stream_generator(
     async def completion_stream_generator(
         self,
         self,
         request: CompletionRequest,
         request: CompletionRequest,
-        raw_request: Request,
         result_generator: AsyncIterator[Tuple[int, RequestOutput]],
         result_generator: AsyncIterator[Tuple[int, RequestOutput]],
         request_id: str,
         request_id: str,
         created_time: int,
         created_time: int,
@@ -211,12 +209,6 @@ class OpenAIServingCompletion(OpenAIServing):
         try:
         try:
             async for prompt_idx, res in result_generator:
             async for prompt_idx, res in result_generator:
 
 
-                # Abort the request if the client disconnects.
-                if await raw_request.is_disconnected():
-                    await self.async_engine_client.abort(
-                        f"{request_id}-{prompt_idx}")
-                    raise StopAsyncIteration()
-
                 for output in res.outputs:
                 for output in res.outputs:
                     i = output.index + prompt_idx * num_choices
                     i = output.index + prompt_idx * num_choices
                     # TODO: optimize the performance by avoiding full
                     # TODO: optimize the performance by avoiding full

+ 7 - 7
aphrodite/endpoints/openai/serving_embedding.py

@@ -1,6 +1,7 @@
+import asyncio
 import base64
 import base64
 import time
 import time
-from typing import AsyncIterator, List, Optional, Tuple, cast
+from typing import AsyncGenerator, AsyncIterator, List, Optional, Tuple, cast
 
 
 import numpy as np
 import numpy as np
 from fastapi import Request
 from fastapi import Request
@@ -91,7 +92,7 @@ class OpenAIServingEmbedding(OpenAIServing):
         created_time = int(time.monotonic())
         created_time = int(time.monotonic())
 
 
         # Schedule the request and get the result generator.
         # Schedule the request and get the result generator.
-        generators: List[AsyncIterator[EmbeddingRequestOutput]] = []
+        generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = []
         try:
         try:
             (
             (
                 lora_request,
                 lora_request,
@@ -136,17 +137,14 @@ class OpenAIServingEmbedding(OpenAIServing):
             return self.create_error_response(str(e))
             return self.create_error_response(str(e))
 
 
         result_generator: AsyncIterator[Tuple[
         result_generator: AsyncIterator[Tuple[
-            int, EmbeddingRequestOutput]] = merge_async_iterators(*generators)
+            int, EmbeddingRequestOutput]] = merge_async_iterators(
+                *generators, is_cancelled=raw_request.is_disconnected)
 
 
         # Non-streaming response
         # Non-streaming response
         final_res_batch: List[Optional[EmbeddingRequestOutput]]
         final_res_batch: List[Optional[EmbeddingRequestOutput]]
         final_res_batch = [None] * len(prompts)
         final_res_batch = [None] * len(prompts)
         try:
         try:
             async for i, res in result_generator:
             async for i, res in result_generator:
-                if await raw_request.is_disconnected():
-                    # Abort the request if the client disconnects.
-                    await self.async_engine_client.abort(f"{request_id}-{i}")
-                    return self.create_error_response("Client disconnected")
                 final_res_batch[i] = res
                 final_res_batch[i] = res
 
 
             for final_res in final_res_batch:
             for final_res in final_res_batch:
@@ -157,6 +155,8 @@ class OpenAIServingEmbedding(OpenAIServing):
             response = request_output_to_embedding_response(
             response = request_output_to_embedding_response(
                 final_res_batch_checked, request_id, created_time, model_name,
                 final_res_batch_checked, request_id, created_time, model_name,
                 encoding_format)
                 encoding_format)
+        except asyncio.CancelledError:
+            return self.create_error_response("Client disconnected")
         except ValueError as e:
         except ValueError as e:
             # TODO: Use an aphrodite-specific Validation Error
             # TODO: Use an aphrodite-specific Validation Error
             return self.create_error_response(str(e))
             return self.create_error_response(str(e))

+ 67 - 76
aphrodite/engine/async_aphrodite.py

@@ -2,7 +2,7 @@ import asyncio
 import os
 import os
 import time
 import time
 from functools import partial
 from functools import partial
-from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional,
+from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Optional,
                     Set, Tuple, Type, Union)
                     Set, Tuple, Type, Union)
 
 
 from loguru import logger
 from loguru import logger
@@ -61,12 +61,16 @@ def _log_task_completion(task: asyncio.Task,
             "actual cause.") from e
             "actual cause.") from e
 
 
 
 
+STOP_ITERATION = Exception() # Sentinel
+
+
 class AsyncStream:
 class AsyncStream:
     """A stream of RequestOutputs or EmbeddingRequestOutputs for a request
     """A stream of RequestOutputs or EmbeddingRequestOutputs for a request
-    that can be iterated over asynchronously."""
+    that can be iterated over asynchronously via an async generator."""
 
 
-    def __init__(self, request_id: str) -> None:
+    def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
         self.request_id = request_id
         self.request_id = request_id
+        self._cancel = cancel
         self._queue: asyncio.Queue = asyncio.Queue()
         self._queue: asyncio.Queue = asyncio.Queue()
         self._finished = False
         self._finished = False
 
 
@@ -76,22 +80,30 @@ class AsyncStream:
             return
             return
         self._queue.put_nowait(item)
         self._queue.put_nowait(item)
 
 
-    def finish(self) -> None:
-        self._queue.put_nowait(StopAsyncIteration())
-        self._finished = True
+    def finish(self, cancelled: bool = False) -> None:
+        if not self._finished:
+            self._finished = True
+            self._queue.put_nowait(
+                asyncio.CancelledError if cancelled else STOP_ITERATION)
 
 
     @property
     @property
     def finished(self) -> bool:
     def finished(self) -> bool:
         return self._finished
         return self._finished
 
 
-    def __aiter__(self):
-        return self
-
-    async def __anext__(self) -> Union[RequestOutput, EmbeddingRequestOutput]:
-        result = await self._queue.get()
-        if isinstance(result, Exception):
-            raise result
-        return result
+    async def generator(
+        self
+    ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
+        try:
+            while not self._finished:
+                result = await self._queue.get()
+                if isinstance(result, Exception):
+                    if result == STOP_ITERATION:
+                        return
+                    raise result
+                yield result
+        except GeneratorExit:
+            self._cancel(self.request_id)
+            raise asyncio.CancelledError from None
 
 
 
 
 class RequestTracker:
 class RequestTracker:
@@ -99,7 +111,7 @@ class RequestTracker:
 
 
     def __init__(self) -> None:
     def __init__(self) -> None:
         self._request_streams: Dict[str, AsyncStream] = {}
         self._request_streams: Dict[str, AsyncStream] = {}
-        self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
+        self._aborted_requests: asyncio.Queue[str] = asyncio.Queue()
         self._new_requests: asyncio.Queue[Tuple[AsyncStream,
         self._new_requests: asyncio.Queue[Tuple[AsyncStream,
                                                 dict]] = asyncio.Queue()
                                                 dict]] = asyncio.Queue()
         self.new_requests_event = asyncio.Event()
         self.new_requests_event = asyncio.Event()
@@ -130,15 +142,21 @@ class RequestTracker:
                                verbose: bool = False) -> None:
                                verbose: bool = False) -> None:
         """Process a request output from the engine."""
         """Process a request output from the engine."""
         request_id = request_output.request_id
         request_id = request_output.request_id
+        finished = request_output.finished
 
 
+        if finished:
+            stream = self._request_streams.pop(request_id, None)
+        else:
+            stream = self._request_streams.get(request_id)
         # Guard against a KeyError which can occur if the request was aborted
         # Guard against a KeyError which can occur if the request was aborted
         # while the output was generated
         # while the output was generated
-        if (stream := self._request_streams.get(request_id)) is not None:
+        if stream is not None:
             stream.put(request_output)
             stream.put(request_output)
-        if request_output.finished:
-            if verbose:
-                logger.info(f"Finished request {request_id}.")
-            self.abort_request(request_id)
+            if finished:
+                stream.finish()
+
+        if verbose and finished:
+            logger.info(f"Finished request {request_id}.")
 
 
     def process_exception(self,
     def process_exception(self,
                           request_id: str,
                           request_id: str,
@@ -161,7 +179,8 @@ class RequestTracker:
         if request_id in self._request_streams:
         if request_id in self._request_streams:
             raise KeyError(f"Request {request_id} already exists.")
             raise KeyError(f"Request {request_id} already exists.")
 
 
-        stream = AsyncStream(request_id)
+        abort_request = partial(self.abort_request, verbose=verbose)
+        stream = AsyncStream(request_id, abort_request)
         self._new_requests.put_nowait((stream, {
         self._new_requests.put_nowait((stream, {
             "request_id": request_id,
             "request_id": request_id,
             **engine_add_request_kwargs
             **engine_add_request_kwargs
@@ -174,36 +193,36 @@ class RequestTracker:
 
 
         return stream
         return stream
 
 
-    def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
+    def abort_request(self,
+                      request_id: str,
+                      *,
+                      cancelled: bool = False,
+                      verbose: bool = False) -> None:
         """Abort a request during next background loop iteration."""
         """Abort a request during next background loop iteration."""
         if verbose:
         if verbose:
             logger.info(f"Aborted request {request_id}.")
             logger.info(f"Aborted request {request_id}.")
 
 
-        self._finished_requests.put_nowait(request_id)
-
-        if request_id not in self._request_streams or self._request_streams[
-                request_id].finished:
-            # The request has already finished or been aborted.
-            return
+        self._aborted_requests.put_nowait(request_id)
 
 
-        self._request_streams[request_id].finish()
+        stream = self._request_streams.pop(request_id, None)
+        if stream is not None:
+            stream.finish(cancelled=cancelled)
 
 
-    def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]:
+    def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
         """Get the new requests and finished requests to be
         """Get the new requests and finished requests to be
         sent to the engine."""
         sent to the engine."""
         new_requests: List[Dict] = []
         new_requests: List[Dict] = []
         finished_requests: Set[str] = set()
         finished_requests: Set[str] = set()
 
 
-        while not self._finished_requests.empty():
-            request_id = self._finished_requests.get_nowait()
+        while not self._aborted_requests.empty():
+            request_id = self._aborted_requests.get_nowait()
             finished_requests.add(request_id)
             finished_requests.add(request_id)
-            self._request_streams.pop(request_id, None)
 
 
         while not self._new_requests.empty():
         while not self._new_requests.empty():
             stream, new_request = self._new_requests.get_nowait()
             stream, new_request = self._new_requests.get_nowait()
             if stream.request_id in finished_requests:
             if stream.request_id in finished_requests:
                 # The request has already been aborted.
                 # The request has already been aborted.
-                stream.finish()
+                stream.finish(cancelled=True)
                 continue
                 continue
             self._request_streams[stream.request_id] = stream
             self._request_streams[stream.request_id] = stream
             new_requests.append(new_request)
             new_requests.append(new_request)
@@ -554,8 +573,8 @@ class AsyncAphrodite:
 
 
         Returns True if there are in-progress requests."""
         Returns True if there are in-progress requests."""
 
 
-        new_requests, finished_requests = (
-            self._request_tracker.get_new_and_finished_requests())
+        new_requests, aborted_requests = (
+            self._request_tracker.get_new_and_aborted_requests())
 
 
         for new_request in new_requests:
         for new_request in new_requests:
             # Add the request into the Aphrodite engine's waiting queue.
             # Add the request into the Aphrodite engine's waiting queue.
@@ -574,8 +593,8 @@ class AsyncAphrodite:
                     verbose=self.log_requests,
                     verbose=self.log_requests,
                 )
                 )
 
 
-        if finished_requests:
-            await self._engine_abort(finished_requests)
+        if aborted_requests:
+            await self._engine_abort(aborted_requests)
 
 
         if self.engine_use_ray:
         if self.engine_use_ray:
             request_outputs = await self.engine.step.remote()  # type: ignore
             request_outputs = await self.engine.step.remote()  # type: ignore
@@ -664,6 +683,8 @@ class AsyncAphrodite:
                 raise
                 raise
             await asyncio.sleep(0)
             await asyncio.sleep(0)
 
 
+    # This method does not need to be async, but kept that way
+    # for backwards compatibility.
     async def add_request(
     async def add_request(
         self,
         self,
         request_id: str,
         request_id: str,
@@ -672,7 +693,7 @@ class AsyncAphrodite:
         arrival_time: Optional[float] = None,
         arrival_time: Optional[float] = None,
         lora_request: Optional[LoRARequest] = None,
         lora_request: Optional[LoRARequest] = None,
         prompt_adapter_request: Optional[PromptAdapterRequest] = None,
         prompt_adapter_request: Optional[PromptAdapterRequest] = None,
-    ) -> AsyncStream:
+    ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
 
 
         if not self.is_running:
         if not self.is_running:
             if self.start_engine_loop:
             if self.start_engine_loop:
@@ -684,19 +705,16 @@ class AsyncAphrodite:
                     "error that caused the background loop to stop "
                     "error that caused the background loop to stop "
                     "(AsyncEngineDeadError).")
                     "(AsyncEngineDeadError).")
 
 
-        if arrival_time is None:
-            arrival_time = time.time()
-
         stream = self._request_tracker.add_request(
         stream = self._request_tracker.add_request(
             request_id,
             request_id,
             verbose=self.log_requests,
             verbose=self.log_requests,
             inputs=inputs,
             inputs=inputs,
             params=params,
             params=params,
-            arrival_time=arrival_time,
+            arrival_time=arrival_time or time.time(),
             lora_request=lora_request,
             lora_request=lora_request,
             prompt_adapter_request=prompt_adapter_request)
             prompt_adapter_request=prompt_adapter_request)
 
 
-        return stream
+        return stream.generator()
 
 
     async def generate(
     async def generate(
         self,
         self,
@@ -705,7 +723,7 @@ class AsyncAphrodite:
         request_id: str,
         request_id: str,
         lora_request: Optional[LoRARequest] = None,
         lora_request: Optional[LoRARequest] = None,
         prompt_adapter_request: Optional[PromptAdapterRequest] = None,
         prompt_adapter_request: Optional[PromptAdapterRequest] = None,
-    ) -> AsyncIterator[RequestOutput]:
+    ) -> AsyncGenerator[RequestOutput, None]:
         """Generate outputs for a request.
         """Generate outputs for a request.
 
 
         Generate outputs for a request. This method is a coroutine. It adds the
         Generate outputs for a request. This method is a coroutine. It adds the
@@ -771,7 +789,7 @@ class AsyncAphrodite:
             >>> # Process and return the final output
             >>> # Process and return the final output
             >>> ...
             >>> ...
         """
         """
-        async for output in self._process_request(
+        async for output in await self.add_request(
                 request_id,
                 request_id,
                 inputs,
                 inputs,
                 sampling_params,
                 sampling_params,
@@ -786,7 +804,7 @@ class AsyncAphrodite:
         pooling_params: PoolingParams,
         pooling_params: PoolingParams,
         request_id: str,
         request_id: str,
         lora_request: Optional[LoRARequest] = None,
         lora_request: Optional[LoRARequest] = None,
-    ) -> AsyncIterator[EmbeddingRequestOutput]:
+    ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
         """Generate outputs for a request from an embedding model.
         """Generate outputs for a request from an embedding model.
         Generate outputs for a request. This method is a coroutine. It adds the
         Generate outputs for a request. This method is a coroutine. It adds the
         request into the waiting queue of the AphroditeEngine and streams the
         request into the waiting queue of the AphroditeEngine and streams the
@@ -840,7 +858,7 @@ class AsyncAphrodite:
             >>> # Process and return the final output
             >>> # Process and return the final output
             >>> ...
             >>> ...
         """
         """
-        async for output in self._process_request(
+        async for output in await self.add_request(
                 request_id,
                 request_id,
                 inputs,
                 inputs,
                 pooling_params,
                 pooling_params,
@@ -849,34 +867,6 @@ class AsyncAphrodite:
             yield AphroditeEngine.validate_output(output,
             yield AphroditeEngine.validate_output(output,
                                                   EmbeddingRequestOutput)
                                                   EmbeddingRequestOutput)
 
 
-    async def _process_request(
-        self,
-        request_id: str,
-        inputs: PromptInputs,
-        params: Union[SamplingParams, PoolingParams],
-        *,
-        lora_request: Optional[LoRARequest] = None,
-        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
-    ) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
-        """Common logic to process requests with SamplingParams or
-        PoolingParams."""
-        arrival_time = time.time()
-
-        stream = await self.add_request(
-            request_id,
-            inputs,
-            params,
-            arrival_time=arrival_time,
-            lora_request=lora_request,
-            prompt_adapter_request=prompt_adapter_request)
-
-        try:
-            async for request_output in stream:
-                yield request_output
-        except (Exception, asyncio.CancelledError) as e:
-            self._abort(request_id)
-            raise e
-
     async def abort(self, request_id: str) -> None:
     async def abort(self, request_id: str) -> None:
         """Abort a request.
         """Abort a request.
 
 
@@ -905,6 +895,7 @@ class AsyncAphrodite:
             request_id: The unique id of the request.
             request_id: The unique id of the request.
         """
         """
         self._request_tracker.abort_request(request_id,
         self._request_tracker.abort_request(request_id,
+                                            cancelled=True,
                                             verbose=self.log_requests)
                                             verbose=self.log_requests)
 
 
     async def get_model_config(self) -> ModelConfig:
     async def get_model_config(self) -> ModelConfig:

+ 5 - 5
aphrodite/engine/protocol.py

@@ -1,4 +1,4 @@
-from typing import AsyncIterator, List, Optional, Protocol, runtime_checkable
+from typing import AsyncGenerator, List, Optional, Protocol, runtime_checkable
 
 
 from transformers import PreTrainedTokenizer
 from transformers import PreTrainedTokenizer
 
 
@@ -29,24 +29,24 @@ class AsyncEngineClient(Protocol):
     def errored(self) -> bool:
     def errored(self) -> bool:
         ...
         ...
 
 
-    async def generate(
+    def generate(
         self,
         self,
         inputs: PromptInputs,
         inputs: PromptInputs,
         sampling_params: SamplingParams,
         sampling_params: SamplingParams,
         request_id: str,
         request_id: str,
         lora_request: Optional[LoRARequest] = None,
         lora_request: Optional[LoRARequest] = None,
         prompt_adapter_request: Optional[PromptAdapterRequest] = None
         prompt_adapter_request: Optional[PromptAdapterRequest] = None
-    ) -> AsyncIterator[RequestOutput]:
+    ) -> AsyncGenerator[RequestOutput, None]:
         """Generates outputs for a request"""
         """Generates outputs for a request"""
         ...
         ...
 
 
-    async def encode(
+    def encode(
         self,
         self,
         inputs: PromptInputs,
         inputs: PromptInputs,
         pooling_params: PoolingParams,
         pooling_params: PoolingParams,
         request_id: str,
         request_id: str,
         lora_request: Optional[LoRARequest] = None,
         lora_request: Optional[LoRARequest] = None,
-    ) -> AsyncIterator[EmbeddingRequestOutput]:
+    ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
         """Generate outputs for a request from an embedding model."""
         """Generate outputs for a request from an embedding model."""
         ...
         ...