Преглед изворни кода

fix: better async request cancellation (#641)

* wip

* wip

* rpc client async generator

* serving chat

* serving completions

* serving embeddings
AlpinDale пре 6 месеци
родитељ
комит
77c4fbd5c9

+ 61 - 50
aphrodite/common/utils.py

@@ -1,5 +1,6 @@
 import argparse
 import asyncio
+import contextlib
 import datetime
 import enum
 import gc
@@ -11,10 +12,11 @@ import tempfile
 import threading
 import uuid
 import warnings
+from asyncio import FIRST_COMPLETED, ensure_future
 from collections import defaultdict
 from functools import lru_cache, partial, wraps
 from platform import uname
-from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
+from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
                     Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar,
                     Union, overload)
 
@@ -372,63 +374,72 @@ def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]:
     return _async_wrapper
 
 
-class ProducerFinished:
-    pass
+async def iterate_with_cancellation(
+    iterator: AsyncGenerator[T, None],
+    is_cancelled: Callable[[], Awaitable[bool]],
+) -> AsyncGenerator[T, None]:
+    """Convert async iterator into one that polls the provided function
+    at least once per second to check for client cancellation.
+    """
 
+    # Can use anext() in python >= 3.10
+    awaits = [ensure_future(iterator.__anext__())]
+    while True:
+        done, pending = await asyncio.wait(awaits, timeout=1)
+        if await is_cancelled():
+            with contextlib.suppress(BaseException):
+                awaits[0].cancel()
+                await iterator.aclose()
+            raise asyncio.CancelledError("client cancelled")
+        if done:
+            try:
+                item = await awaits[0]
+                awaits[0] = ensure_future(iterator.__anext__())
+                yield item
+            except StopAsyncIteration:
+                # we are done
+                return
 
-def merge_async_iterators(
-        *iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]:
-    """Merge multiple asynchronous iterators into a single iterator.
 
+async def merge_async_iterators(
+    *iterators: AsyncGenerator[T, None],
+    is_cancelled: Callable[[], Awaitable[bool]],
+) -> AsyncGenerator[Tuple[int, T], None]:
+    """Merge multiple asynchronous iterators into a single iterator.
     This method handle the case where some iterators finish before others.
     When it yields, it yields a tuple (i, item) where i is the index of the
     iterator that yields the item.
+    It also polls the provided function at least once per second to check
+    for client cancellation.
     """
-    queue: asyncio.Queue[Union[Tuple[int, T], ProducerFinished,
-                               Exception]] = asyncio.Queue()
-
-    producers = len(iterators)
-
-    async def producer(i: int, iterator: AsyncIterator[T]):
-        try:
-            async for item in iterator:
-                await queue.put((i, item))
-        except Exception as e:
-            await queue.put(e)
-        # Signal to the consumer that we've finished
-        await queue.put(ProducerFinished())
-
-    _tasks = [
-        asyncio.create_task(producer(i, iterator))
-        for i, iterator in enumerate(iterators)
-    ]
 
-    async def consumer():
-        remaining = producers
-        try:
-            while remaining or not queue.empty():
-                # we think there is a race condition here
-                item = await queue.get()
-
-                if isinstance(item, ProducerFinished):
-                    # Signal that a producer finished- not a real item
-                    remaining -= 1
-                    continue
-
-                if isinstance(item, Exception):
-                    raise item
-                yield item
-        except (Exception, asyncio.CancelledError) as e:
-            for task in _tasks:
-                if sys.version_info >= (3, 9):
-                    # msg parameter only supported in Python 3.9+
-                    task.cancel(e)
-                else:
-                    task.cancel()
-            raise e
-        await asyncio.gather(*_tasks)
-
-    return consumer()
+    # Can use anext() in python >= 3.10
+    awaits = {
+        ensure_future(pair[1].__anext__()): pair
+        for pair in enumerate(iterators)
+    }
+    try:
+        while awaits:
+            done, pending = await asyncio.wait(awaits.keys(),
+                                               return_when=FIRST_COMPLETED,
+                                               timeout=1)
+            if await is_cancelled():
+                raise asyncio.CancelledError("client cancelled")
+            for d in done:
+                pair = awaits.pop(d)
+                try:
+                    item = await d
+                    i, it = pair
+                    awaits[ensure_future(it.__anext__())] = pair
+                    yield i, item
+                except StopAsyncIteration:
+                    pass
+    finally:
+        # Cancel any remaining iterators
+        for f, (_, it) in awaits.items():
+            with contextlib.suppress(BaseException):
+                f.cancel()
+                await it.aclose()
 
 
 def get_ip() -> str:

+ 31 - 29
aphrodite/endpoints/openai/rpc/client.py

@@ -1,5 +1,5 @@
 from contextlib import contextmanager
-from typing import Any, AsyncIterator, Optional
+from typing import Any, AsyncGenerator, Optional
 
 import cloudpickle
 import zmq
@@ -179,35 +179,37 @@ class AsyncEngineRPCClient:
         request_id: str,
         lora_request: Optional[LoRARequest] = None,
         prompt_adapter_request: Optional[PromptAdapterRequest] = None
-    ) -> AsyncIterator[RequestOutput]:
+    ) -> AsyncGenerator[RequestOutput, None]:
         """Send an RPCGenerateRequest to the RPCServer and stream responses."""
 
-        with self.socket() as socket:
-
-            # Send RPCGenerateRequest to the RPCServer.
-            await socket.send_multipart([
-                cloudpickle.dumps(
-                    RPCGenerateRequest(
-                        inputs=inputs,
-                        sampling_params=sampling_params,
-                        request_id=request_id,
-                        lora_request=lora_request,
-                        prompt_adapter_request=prompt_adapter_request))
-            ])
-
-            # Stream back the results from the RPC Server.
-            while True:
-                message = await socket.recv()
-                request_output = cloudpickle.loads(message)
-
-                if isinstance(request_output, Exception):
-                    raise request_output
-
-                if request_output.finished:
-                    break
-                yield request_output
-
-            yield request_output
+        finished = False
+        try:
+            with self.socket() as socket:
+
+                # Send RPCGenerateRequest to the RPCServer.
+                await socket.send_multipart([
+                    cloudpickle.dumps(
+                        RPCGenerateRequest(
+                            inputs=inputs,
+                            sampling_params=sampling_params,
+                            request_id=request_id,
+                            lora_request=lora_request,
+                            prompt_adapter_request=prompt_adapter_request))
+                ])
+
+                # Stream back the results from the RPC Server.
+                while not finished:
+                    message = await socket.recv()
+                    request_output = cloudpickle.loads(message)
+
+                    if isinstance(request_output, Exception):
+                        raise request_output
+
+                    finished = request_output.finished
+                    yield request_output
+        finally:
+            if not finished:
+                await self.abort(request_id)
 
     async def check_health(self) -> None:
         """Raise if unhealthy"""
@@ -231,6 +233,6 @@ class AsyncEngineRPCClient:
                              f"{health_message}")
 
     async def encode(self, *args,
-                     **kwargs) -> AsyncIterator[EmbeddingRequestOutput]:
+                     **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]:
         raise NotImplementedError(
             "Embeddings not supported with multiprocessing backend")

+ 18 - 16
aphrodite/endpoints/openai/serving_chat.py

@@ -1,3 +1,4 @@
+import asyncio
 import time
 from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional
 from typing import Sequence as GenericSequence
@@ -10,7 +11,7 @@ from transformers import PreTrainedTokenizer
 from aphrodite.common.config import ModelConfig
 from aphrodite.common.outputs import RequestOutput
 from aphrodite.common.sequence import Logprob
-from aphrodite.common.utils import random_uuid
+from aphrodite.common.utils import iterate_with_cancellation, random_uuid
 from aphrodite.endpoints.chat_utils import (ConversationMessage,
                                             load_chat_template,
                                             parse_chat_messages)
@@ -160,18 +161,20 @@ class OpenAIServingChat(OpenAIServing):
             # TODO: Use an aphrodite-specific Validation Error
             return self.create_error_response(str(e))
 
+        if raw_request:
+            result_generator = iterate_with_cancellation(
+                result_generator, raw_request.is_disconnected)
+
         # Streaming response
         if request.stream:
             return self.chat_completion_stream_generator(
                 request, result_generator, request_id, conversation, tokenizer)
-        else:
-            try:
-                return await self.chat_completion_full_generator(
-                    request, raw_request, result_generator, request_id,
-                    conversation, tokenizer)
-            except ValueError as e:
-                # TODO: Use an aphrodite-specific Validation Error
-                return self.create_error_response(str(e))
+        try:
+            return await self.chat_completion_full_generator(
+                request, result_generator, request_id, conversation, tokenizer)
+        except ValueError as e:
+            # TODO: Use an aphrodite-specific Validation Error
+            return self.create_error_response(str(e))
 
     def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
         if request.add_generation_prompt:
@@ -402,7 +405,6 @@ class OpenAIServingChat(OpenAIServing):
     async def chat_completion_full_generator(
         self,
         request: ChatCompletionRequest,
-        raw_request: Optional[Request],
         result_generator: AsyncIterator[RequestOutput],
         request_id: str,
         conversation: List[ConversationMessage],
@@ -413,12 +415,12 @@ class OpenAIServingChat(OpenAIServing):
         created_time = int(time.time())
         final_res: Optional[RequestOutput] = None
 
-        async for res in result_generator:
-            if raw_request is not None and await raw_request.is_disconnected():
-                # Abort the request if the client disconnects.
-                await self.async_engine_client.abort(request_id)
-                return self.create_error_response("Client disconnected")
-            final_res = res
+        try:
+            async for res in result_generator:
+                final_res = res
+        except asyncio.CancelledError:
+            return self.create_error_response("Client disconnected")
+
         assert final_res is not None
 
         choices: List[ChatCompletionResponseChoice] = []

+ 6 - 14
aphrodite/endpoints/openai/serving_completions.py

@@ -1,3 +1,4 @@
+import asyncio
 import time
 from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
                     Optional)
@@ -76,7 +77,7 @@ class OpenAIServingCompletion(OpenAIServing):
         created_time = int(time.time())
 
         # Schedule the request and get the result generator.
-        generators: List[AsyncIterator[RequestOutput]] = []
+        generators: List[AsyncGenerator[RequestOutput, None]] = []
         try:
             (
                 lora_request,
@@ -126,7 +127,8 @@ class OpenAIServingCompletion(OpenAIServing):
             return self.create_error_response(str(e))
 
         result_generator: AsyncIterator[Tuple[
-            int, RequestOutput]] = merge_async_iterators(*generators)
+            int, RequestOutput]] = merge_async_iterators(
+                *generators, is_cancelled=raw_request.is_disconnected)
 
         # Similar to the OpenAI API, when n != best_of, we do not stream the
         # results. In addition, we do not stream the results when use
@@ -138,7 +140,6 @@ class OpenAIServingCompletion(OpenAIServing):
         # Streaming response
         if stream:
             return self.completion_stream_generator(request,
-                                                    raw_request,
                                                     result_generator,
                                                     request_id,
                                                     created_time,
@@ -150,10 +151,6 @@ class OpenAIServingCompletion(OpenAIServing):
         final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
         try:
             async for i, res in result_generator:
-                if await raw_request.is_disconnected():
-                    # Abort the request if the client disconnects.
-                    await self.async_engine_client.abort(f"{request_id}-{i}")
-                    return self.create_error_response("Client disconnected")
                 final_res_batch[i] = res
 
             for i, final_res in enumerate(final_res_batch):
@@ -175,6 +172,8 @@ class OpenAIServingCompletion(OpenAIServing):
                 model_name,
                 tokenizer,
             )
+        except asyncio.CancelledError:
+            return self.create_error_response("Client disconnected")
         except ValueError as e:
             # TODO: Use an aphrodite-specific Validation Error
             return self.create_error_response(str(e))
@@ -195,7 +194,6 @@ class OpenAIServingCompletion(OpenAIServing):
     async def completion_stream_generator(
         self,
         request: CompletionRequest,
-        raw_request: Request,
         result_generator: AsyncIterator[Tuple[int, RequestOutput]],
         request_id: str,
         created_time: int,
@@ -211,12 +209,6 @@ class OpenAIServingCompletion(OpenAIServing):
         try:
             async for prompt_idx, res in result_generator:
 
-                # Abort the request if the client disconnects.
-                if await raw_request.is_disconnected():
-                    await self.async_engine_client.abort(
-                        f"{request_id}-{prompt_idx}")
-                    raise StopAsyncIteration()
-
                 for output in res.outputs:
                     i = output.index + prompt_idx * num_choices
                     # TODO: optimize the performance by avoiding full

+ 7 - 7
aphrodite/endpoints/openai/serving_embedding.py

@@ -1,6 +1,7 @@
+import asyncio
 import base64
 import time
-from typing import AsyncIterator, List, Optional, Tuple, cast
+from typing import AsyncGenerator, AsyncIterator, List, Optional, Tuple, cast
 
 import numpy as np
 from fastapi import Request
@@ -91,7 +92,7 @@ class OpenAIServingEmbedding(OpenAIServing):
         created_time = int(time.monotonic())
 
         # Schedule the request and get the result generator.
-        generators: List[AsyncIterator[EmbeddingRequestOutput]] = []
+        generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = []
         try:
             (
                 lora_request,
@@ -136,17 +137,14 @@ class OpenAIServingEmbedding(OpenAIServing):
             return self.create_error_response(str(e))
 
         result_generator: AsyncIterator[Tuple[
-            int, EmbeddingRequestOutput]] = merge_async_iterators(*generators)
+            int, EmbeddingRequestOutput]] = merge_async_iterators(
+                *generators, is_cancelled=raw_request.is_disconnected)
 
         # Non-streaming response
         final_res_batch: List[Optional[EmbeddingRequestOutput]]
         final_res_batch = [None] * len(prompts)
         try:
             async for i, res in result_generator:
-                if await raw_request.is_disconnected():
-                    # Abort the request if the client disconnects.
-                    await self.async_engine_client.abort(f"{request_id}-{i}")
-                    return self.create_error_response("Client disconnected")
                 final_res_batch[i] = res
 
             for final_res in final_res_batch:
@@ -157,6 +155,8 @@ class OpenAIServingEmbedding(OpenAIServing):
             response = request_output_to_embedding_response(
                 final_res_batch_checked, request_id, created_time, model_name,
                 encoding_format)
+        except asyncio.CancelledError:
+            return self.create_error_response("Client disconnected")
         except ValueError as e:
             # TODO: Use an aphrodite-specific Validation Error
             return self.create_error_response(str(e))

+ 67 - 76
aphrodite/engine/async_aphrodite.py

@@ -2,7 +2,7 @@ import asyncio
 import os
 import time
 from functools import partial
-from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional,
+from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Optional,
                     Set, Tuple, Type, Union)
 
 from loguru import logger
@@ -61,12 +61,16 @@ def _log_task_completion(task: asyncio.Task,
             "actual cause.") from e
 
 
+STOP_ITERATION = Exception() # Sentinel
+
+
 class AsyncStream:
     """A stream of RequestOutputs or EmbeddingRequestOutputs for a request
-    that can be iterated over asynchronously."""
+    that can be iterated over asynchronously via an async generator."""
 
-    def __init__(self, request_id: str) -> None:
+    def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
         self.request_id = request_id
+        self._cancel = cancel
         self._queue: asyncio.Queue = asyncio.Queue()
         self._finished = False
 
@@ -76,22 +80,30 @@ class AsyncStream:
             return
         self._queue.put_nowait(item)
 
-    def finish(self) -> None:
-        self._queue.put_nowait(StopAsyncIteration())
-        self._finished = True
+    def finish(self, cancelled: bool = False) -> None:
+        if not self._finished:
+            self._finished = True
+            self._queue.put_nowait(
+                asyncio.CancelledError if cancelled else STOP_ITERATION)
 
     @property
     def finished(self) -> bool:
         return self._finished
 
-    def __aiter__(self):
-        return self
-
-    async def __anext__(self) -> Union[RequestOutput, EmbeddingRequestOutput]:
-        result = await self._queue.get()
-        if isinstance(result, Exception):
-            raise result
-        return result
+    async def generator(
+        self
+    ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
+        try:
+            while not self._finished:
+                result = await self._queue.get()
+                if isinstance(result, Exception):
+                    if result == STOP_ITERATION:
+                        return
+                    raise result
+                yield result
+        except GeneratorExit:
+            self._cancel(self.request_id)
+            raise asyncio.CancelledError from None
 
 
 class RequestTracker:
@@ -99,7 +111,7 @@ class RequestTracker:
 
     def __init__(self) -> None:
         self._request_streams: Dict[str, AsyncStream] = {}
-        self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
+        self._aborted_requests: asyncio.Queue[str] = asyncio.Queue()
         self._new_requests: asyncio.Queue[Tuple[AsyncStream,
                                                 dict]] = asyncio.Queue()
         self.new_requests_event = asyncio.Event()
@@ -130,15 +142,21 @@ class RequestTracker:
                                verbose: bool = False) -> None:
         """Process a request output from the engine."""
         request_id = request_output.request_id
+        finished = request_output.finished
 
+        if finished:
+            stream = self._request_streams.pop(request_id, None)
+        else:
+            stream = self._request_streams.get(request_id)
         # Guard against a KeyError which can occur if the request was aborted
         # while the output was generated
-        if (stream := self._request_streams.get(request_id)) is not None:
+        if stream is not None:
             stream.put(request_output)
-        if request_output.finished:
-            if verbose:
-                logger.info(f"Finished request {request_id}.")
-            self.abort_request(request_id)
+            if finished:
+                stream.finish()
+
+        if verbose and finished:
+            logger.info(f"Finished request {request_id}.")
 
     def process_exception(self,
                           request_id: str,
@@ -161,7 +179,8 @@ class RequestTracker:
         if request_id in self._request_streams:
             raise KeyError(f"Request {request_id} already exists.")
 
-        stream = AsyncStream(request_id)
+        abort_request = partial(self.abort_request, verbose=verbose)
+        stream = AsyncStream(request_id, abort_request)
         self._new_requests.put_nowait((stream, {
             "request_id": request_id,
             **engine_add_request_kwargs
@@ -174,36 +193,36 @@ class RequestTracker:
 
         return stream
 
-    def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
+    def abort_request(self,
+                      request_id: str,
+                      *,
+                      cancelled: bool = False,
+                      verbose: bool = False) -> None:
         """Abort a request during next background loop iteration."""
         if verbose:
             logger.info(f"Aborted request {request_id}.")
 
-        self._finished_requests.put_nowait(request_id)
-
-        if request_id not in self._request_streams or self._request_streams[
-                request_id].finished:
-            # The request has already finished or been aborted.
-            return
+        self._aborted_requests.put_nowait(request_id)
 
-        self._request_streams[request_id].finish()
+        stream = self._request_streams.pop(request_id, None)
+        if stream is not None:
+            stream.finish(cancelled=cancelled)
 
-    def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]:
+    def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
         """Get the new requests and finished requests to be
         sent to the engine."""
         new_requests: List[Dict] = []
         finished_requests: Set[str] = set()
 
-        while not self._finished_requests.empty():
-            request_id = self._finished_requests.get_nowait()
+        while not self._aborted_requests.empty():
+            request_id = self._aborted_requests.get_nowait()
             finished_requests.add(request_id)
-            self._request_streams.pop(request_id, None)
 
         while not self._new_requests.empty():
             stream, new_request = self._new_requests.get_nowait()
             if stream.request_id in finished_requests:
                 # The request has already been aborted.
-                stream.finish()
+                stream.finish(cancelled=True)
                 continue
             self._request_streams[stream.request_id] = stream
             new_requests.append(new_request)
@@ -554,8 +573,8 @@ class AsyncAphrodite:
 
         Returns True if there are in-progress requests."""
 
-        new_requests, finished_requests = (
-            self._request_tracker.get_new_and_finished_requests())
+        new_requests, aborted_requests = (
+            self._request_tracker.get_new_and_aborted_requests())
 
         for new_request in new_requests:
             # Add the request into the Aphrodite engine's waiting queue.
@@ -574,8 +593,8 @@ class AsyncAphrodite:
                     verbose=self.log_requests,
                 )
 
-        if finished_requests:
-            await self._engine_abort(finished_requests)
+        if aborted_requests:
+            await self._engine_abort(aborted_requests)
 
         if self.engine_use_ray:
             request_outputs = await self.engine.step.remote()  # type: ignore
@@ -664,6 +683,8 @@ class AsyncAphrodite:
                 raise
             await asyncio.sleep(0)
 
+    # This method does not need to be async, but kept that way
+    # for backwards compatibility.
     async def add_request(
         self,
         request_id: str,
@@ -672,7 +693,7 @@ class AsyncAphrodite:
         arrival_time: Optional[float] = None,
         lora_request: Optional[LoRARequest] = None,
         prompt_adapter_request: Optional[PromptAdapterRequest] = None,
-    ) -> AsyncStream:
+    ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
 
         if not self.is_running:
             if self.start_engine_loop:
@@ -684,19 +705,16 @@ class AsyncAphrodite:
                     "error that caused the background loop to stop "
                     "(AsyncEngineDeadError).")
 
-        if arrival_time is None:
-            arrival_time = time.time()
-
         stream = self._request_tracker.add_request(
             request_id,
             verbose=self.log_requests,
             inputs=inputs,
             params=params,
-            arrival_time=arrival_time,
+            arrival_time=arrival_time or time.time(),
             lora_request=lora_request,
             prompt_adapter_request=prompt_adapter_request)
 
-        return stream
+        return stream.generator()
 
     async def generate(
         self,
@@ -705,7 +723,7 @@ class AsyncAphrodite:
         request_id: str,
         lora_request: Optional[LoRARequest] = None,
         prompt_adapter_request: Optional[PromptAdapterRequest] = None,
-    ) -> AsyncIterator[RequestOutput]:
+    ) -> AsyncGenerator[RequestOutput, None]:
         """Generate outputs for a request.
 
         Generate outputs for a request. This method is a coroutine. It adds the
@@ -771,7 +789,7 @@ class AsyncAphrodite:
             >>> # Process and return the final output
             >>> ...
         """
-        async for output in self._process_request(
+        async for output in await self.add_request(
                 request_id,
                 inputs,
                 sampling_params,
@@ -786,7 +804,7 @@ class AsyncAphrodite:
         pooling_params: PoolingParams,
         request_id: str,
         lora_request: Optional[LoRARequest] = None,
-    ) -> AsyncIterator[EmbeddingRequestOutput]:
+    ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
         """Generate outputs for a request from an embedding model.
         Generate outputs for a request. This method is a coroutine. It adds the
         request into the waiting queue of the AphroditeEngine and streams the
@@ -840,7 +858,7 @@ class AsyncAphrodite:
             >>> # Process and return the final output
             >>> ...
         """
-        async for output in self._process_request(
+        async for output in await self.add_request(
                 request_id,
                 inputs,
                 pooling_params,
@@ -849,34 +867,6 @@ class AsyncAphrodite:
             yield AphroditeEngine.validate_output(output,
                                                   EmbeddingRequestOutput)
 
-    async def _process_request(
-        self,
-        request_id: str,
-        inputs: PromptInputs,
-        params: Union[SamplingParams, PoolingParams],
-        *,
-        lora_request: Optional[LoRARequest] = None,
-        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
-    ) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
-        """Common logic to process requests with SamplingParams or
-        PoolingParams."""
-        arrival_time = time.time()
-
-        stream = await self.add_request(
-            request_id,
-            inputs,
-            params,
-            arrival_time=arrival_time,
-            lora_request=lora_request,
-            prompt_adapter_request=prompt_adapter_request)
-
-        try:
-            async for request_output in stream:
-                yield request_output
-        except (Exception, asyncio.CancelledError) as e:
-            self._abort(request_id)
-            raise e
-
     async def abort(self, request_id: str) -> None:
         """Abort a request.
 
@@ -905,6 +895,7 @@ class AsyncAphrodite:
             request_id: The unique id of the request.
         """
         self._request_tracker.abort_request(request_id,
+                                            cancelled=True,
                                             verbose=self.log_requests)
 
     async def get_model_config(self) -> ModelConfig:

+ 5 - 5
aphrodite/engine/protocol.py

@@ -1,4 +1,4 @@
-from typing import AsyncIterator, List, Optional, Protocol, runtime_checkable
+from typing import AsyncGenerator, List, Optional, Protocol, runtime_checkable
 
 from transformers import PreTrainedTokenizer
 
@@ -29,24 +29,24 @@ class AsyncEngineClient(Protocol):
     def errored(self) -> bool:
         ...
 
-    async def generate(
+    def generate(
         self,
         inputs: PromptInputs,
         sampling_params: SamplingParams,
         request_id: str,
         lora_request: Optional[LoRARequest] = None,
         prompt_adapter_request: Optional[PromptAdapterRequest] = None
-    ) -> AsyncIterator[RequestOutput]:
+    ) -> AsyncGenerator[RequestOutput, None]:
         """Generates outputs for a request"""
         ...
 
-    async def encode(
+    def encode(
         self,
         inputs: PromptInputs,
         pooling_params: PoolingParams,
         request_id: str,
         lora_request: Optional[LoRARequest] = None,
-    ) -> AsyncIterator[EmbeddingRequestOutput]:
+    ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
         """Generate outputs for a request from an embedding model."""
         ...