|
@@ -2,7 +2,7 @@ import asyncio
|
|
import os
|
|
import os
|
|
import time
|
|
import time
|
|
from functools import partial
|
|
from functools import partial
|
|
-from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional,
|
|
|
|
|
|
+from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Optional,
|
|
Set, Tuple, Type, Union)
|
|
Set, Tuple, Type, Union)
|
|
|
|
|
|
from loguru import logger
|
|
from loguru import logger
|
|
@@ -61,12 +61,16 @@ def _log_task_completion(task: asyncio.Task,
|
|
"actual cause.") from e
|
|
"actual cause.") from e
|
|
|
|
|
|
|
|
|
|
|
|
+STOP_ITERATION = Exception() # Sentinel
|
|
|
|
+
|
|
|
|
+
|
|
class AsyncStream:
|
|
class AsyncStream:
|
|
"""A stream of RequestOutputs or EmbeddingRequestOutputs for a request
|
|
"""A stream of RequestOutputs or EmbeddingRequestOutputs for a request
|
|
- that can be iterated over asynchronously."""
|
|
|
|
|
|
+ that can be iterated over asynchronously via an async generator."""
|
|
|
|
|
|
- def __init__(self, request_id: str) -> None:
|
|
|
|
|
|
+ def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
|
|
self.request_id = request_id
|
|
self.request_id = request_id
|
|
|
|
+ self._cancel = cancel
|
|
self._queue: asyncio.Queue = asyncio.Queue()
|
|
self._queue: asyncio.Queue = asyncio.Queue()
|
|
self._finished = False
|
|
self._finished = False
|
|
|
|
|
|
@@ -76,22 +80,30 @@ class AsyncStream:
|
|
return
|
|
return
|
|
self._queue.put_nowait(item)
|
|
self._queue.put_nowait(item)
|
|
|
|
|
|
- def finish(self) -> None:
|
|
|
|
- self._queue.put_nowait(StopAsyncIteration())
|
|
|
|
- self._finished = True
|
|
|
|
|
|
+ def finish(self, cancelled: bool = False) -> None:
|
|
|
|
+ if not self._finished:
|
|
|
|
+ self._finished = True
|
|
|
|
+ self._queue.put_nowait(
|
|
|
|
+ asyncio.CancelledError if cancelled else STOP_ITERATION)
|
|
|
|
|
|
@property
|
|
@property
|
|
def finished(self) -> bool:
|
|
def finished(self) -> bool:
|
|
return self._finished
|
|
return self._finished
|
|
|
|
|
|
- def __aiter__(self):
|
|
|
|
- return self
|
|
|
|
-
|
|
|
|
- async def __anext__(self) -> Union[RequestOutput, EmbeddingRequestOutput]:
|
|
|
|
- result = await self._queue.get()
|
|
|
|
- if isinstance(result, Exception):
|
|
|
|
- raise result
|
|
|
|
- return result
|
|
|
|
|
|
+ async def generator(
|
|
|
|
+ self
|
|
|
|
+ ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
|
|
|
|
+ try:
|
|
|
|
+ while not self._finished:
|
|
|
|
+ result = await self._queue.get()
|
|
|
|
+ if isinstance(result, Exception):
|
|
|
|
+ if result == STOP_ITERATION:
|
|
|
|
+ return
|
|
|
|
+ raise result
|
|
|
|
+ yield result
|
|
|
|
+ except GeneratorExit:
|
|
|
|
+ self._cancel(self.request_id)
|
|
|
|
+ raise asyncio.CancelledError from None
|
|
|
|
|
|
|
|
|
|
class RequestTracker:
|
|
class RequestTracker:
|
|
@@ -99,7 +111,7 @@ class RequestTracker:
|
|
|
|
|
|
def __init__(self) -> None:
|
|
def __init__(self) -> None:
|
|
self._request_streams: Dict[str, AsyncStream] = {}
|
|
self._request_streams: Dict[str, AsyncStream] = {}
|
|
- self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
|
|
|
|
|
|
+ self._aborted_requests: asyncio.Queue[str] = asyncio.Queue()
|
|
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
|
|
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
|
|
dict]] = asyncio.Queue()
|
|
dict]] = asyncio.Queue()
|
|
self.new_requests_event = asyncio.Event()
|
|
self.new_requests_event = asyncio.Event()
|
|
@@ -130,15 +142,21 @@ class RequestTracker:
|
|
verbose: bool = False) -> None:
|
|
verbose: bool = False) -> None:
|
|
"""Process a request output from the engine."""
|
|
"""Process a request output from the engine."""
|
|
request_id = request_output.request_id
|
|
request_id = request_output.request_id
|
|
|
|
+ finished = request_output.finished
|
|
|
|
|
|
|
|
+ if finished:
|
|
|
|
+ stream = self._request_streams.pop(request_id, None)
|
|
|
|
+ else:
|
|
|
|
+ stream = self._request_streams.get(request_id)
|
|
# Guard against a KeyError which can occur if the request was aborted
|
|
# Guard against a KeyError which can occur if the request was aborted
|
|
# while the output was generated
|
|
# while the output was generated
|
|
- if (stream := self._request_streams.get(request_id)) is not None:
|
|
|
|
|
|
+ if stream is not None:
|
|
stream.put(request_output)
|
|
stream.put(request_output)
|
|
- if request_output.finished:
|
|
|
|
- if verbose:
|
|
|
|
- logger.info(f"Finished request {request_id}.")
|
|
|
|
- self.abort_request(request_id)
|
|
|
|
|
|
+ if finished:
|
|
|
|
+ stream.finish()
|
|
|
|
+
|
|
|
|
+ if verbose and finished:
|
|
|
|
+ logger.info(f"Finished request {request_id}.")
|
|
|
|
|
|
def process_exception(self,
|
|
def process_exception(self,
|
|
request_id: str,
|
|
request_id: str,
|
|
@@ -161,7 +179,8 @@ class RequestTracker:
|
|
if request_id in self._request_streams:
|
|
if request_id in self._request_streams:
|
|
raise KeyError(f"Request {request_id} already exists.")
|
|
raise KeyError(f"Request {request_id} already exists.")
|
|
|
|
|
|
- stream = AsyncStream(request_id)
|
|
|
|
|
|
+ abort_request = partial(self.abort_request, verbose=verbose)
|
|
|
|
+ stream = AsyncStream(request_id, abort_request)
|
|
self._new_requests.put_nowait((stream, {
|
|
self._new_requests.put_nowait((stream, {
|
|
"request_id": request_id,
|
|
"request_id": request_id,
|
|
**engine_add_request_kwargs
|
|
**engine_add_request_kwargs
|
|
@@ -174,36 +193,36 @@ class RequestTracker:
|
|
|
|
|
|
return stream
|
|
return stream
|
|
|
|
|
|
- def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
|
|
|
|
|
|
+ def abort_request(self,
|
|
|
|
+ request_id: str,
|
|
|
|
+ *,
|
|
|
|
+ cancelled: bool = False,
|
|
|
|
+ verbose: bool = False) -> None:
|
|
"""Abort a request during next background loop iteration."""
|
|
"""Abort a request during next background loop iteration."""
|
|
if verbose:
|
|
if verbose:
|
|
logger.info(f"Aborted request {request_id}.")
|
|
logger.info(f"Aborted request {request_id}.")
|
|
|
|
|
|
- self._finished_requests.put_nowait(request_id)
|
|
|
|
-
|
|
|
|
- if request_id not in self._request_streams or self._request_streams[
|
|
|
|
- request_id].finished:
|
|
|
|
- # The request has already finished or been aborted.
|
|
|
|
- return
|
|
|
|
|
|
+ self._aborted_requests.put_nowait(request_id)
|
|
|
|
|
|
- self._request_streams[request_id].finish()
|
|
|
|
|
|
+ stream = self._request_streams.pop(request_id, None)
|
|
|
|
+ if stream is not None:
|
|
|
|
+ stream.finish(cancelled=cancelled)
|
|
|
|
|
|
- def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]:
|
|
|
|
|
|
+ def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
|
|
"""Get the new requests and finished requests to be
|
|
"""Get the new requests and finished requests to be
|
|
sent to the engine."""
|
|
sent to the engine."""
|
|
new_requests: List[Dict] = []
|
|
new_requests: List[Dict] = []
|
|
finished_requests: Set[str] = set()
|
|
finished_requests: Set[str] = set()
|
|
|
|
|
|
- while not self._finished_requests.empty():
|
|
|
|
- request_id = self._finished_requests.get_nowait()
|
|
|
|
|
|
+ while not self._aborted_requests.empty():
|
|
|
|
+ request_id = self._aborted_requests.get_nowait()
|
|
finished_requests.add(request_id)
|
|
finished_requests.add(request_id)
|
|
- self._request_streams.pop(request_id, None)
|
|
|
|
|
|
|
|
while not self._new_requests.empty():
|
|
while not self._new_requests.empty():
|
|
stream, new_request = self._new_requests.get_nowait()
|
|
stream, new_request = self._new_requests.get_nowait()
|
|
if stream.request_id in finished_requests:
|
|
if stream.request_id in finished_requests:
|
|
# The request has already been aborted.
|
|
# The request has already been aborted.
|
|
- stream.finish()
|
|
|
|
|
|
+ stream.finish(cancelled=True)
|
|
continue
|
|
continue
|
|
self._request_streams[stream.request_id] = stream
|
|
self._request_streams[stream.request_id] = stream
|
|
new_requests.append(new_request)
|
|
new_requests.append(new_request)
|
|
@@ -554,8 +573,8 @@ class AsyncAphrodite:
|
|
|
|
|
|
Returns True if there are in-progress requests."""
|
|
Returns True if there are in-progress requests."""
|
|
|
|
|
|
- new_requests, finished_requests = (
|
|
|
|
- self._request_tracker.get_new_and_finished_requests())
|
|
|
|
|
|
+ new_requests, aborted_requests = (
|
|
|
|
+ self._request_tracker.get_new_and_aborted_requests())
|
|
|
|
|
|
for new_request in new_requests:
|
|
for new_request in new_requests:
|
|
# Add the request into the Aphrodite engine's waiting queue.
|
|
# Add the request into the Aphrodite engine's waiting queue.
|
|
@@ -574,8 +593,8 @@ class AsyncAphrodite:
|
|
verbose=self.log_requests,
|
|
verbose=self.log_requests,
|
|
)
|
|
)
|
|
|
|
|
|
- if finished_requests:
|
|
|
|
- await self._engine_abort(finished_requests)
|
|
|
|
|
|
+ if aborted_requests:
|
|
|
|
+ await self._engine_abort(aborted_requests)
|
|
|
|
|
|
if self.engine_use_ray:
|
|
if self.engine_use_ray:
|
|
request_outputs = await self.engine.step.remote() # type: ignore
|
|
request_outputs = await self.engine.step.remote() # type: ignore
|
|
@@ -664,6 +683,8 @@ class AsyncAphrodite:
|
|
raise
|
|
raise
|
|
await asyncio.sleep(0)
|
|
await asyncio.sleep(0)
|
|
|
|
|
|
|
|
+ # This method does not need to be async, but kept that way
|
|
|
|
+ # for backwards compatibility.
|
|
async def add_request(
|
|
async def add_request(
|
|
self,
|
|
self,
|
|
request_id: str,
|
|
request_id: str,
|
|
@@ -672,7 +693,7 @@ class AsyncAphrodite:
|
|
arrival_time: Optional[float] = None,
|
|
arrival_time: Optional[float] = None,
|
|
lora_request: Optional[LoRARequest] = None,
|
|
lora_request: Optional[LoRARequest] = None,
|
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
|
- ) -> AsyncStream:
|
|
|
|
|
|
+ ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
|
|
|
|
|
|
if not self.is_running:
|
|
if not self.is_running:
|
|
if self.start_engine_loop:
|
|
if self.start_engine_loop:
|
|
@@ -684,19 +705,16 @@ class AsyncAphrodite:
|
|
"error that caused the background loop to stop "
|
|
"error that caused the background loop to stop "
|
|
"(AsyncEngineDeadError).")
|
|
"(AsyncEngineDeadError).")
|
|
|
|
|
|
- if arrival_time is None:
|
|
|
|
- arrival_time = time.time()
|
|
|
|
-
|
|
|
|
stream = self._request_tracker.add_request(
|
|
stream = self._request_tracker.add_request(
|
|
request_id,
|
|
request_id,
|
|
verbose=self.log_requests,
|
|
verbose=self.log_requests,
|
|
inputs=inputs,
|
|
inputs=inputs,
|
|
params=params,
|
|
params=params,
|
|
- arrival_time=arrival_time,
|
|
|
|
|
|
+ arrival_time=arrival_time or time.time(),
|
|
lora_request=lora_request,
|
|
lora_request=lora_request,
|
|
prompt_adapter_request=prompt_adapter_request)
|
|
prompt_adapter_request=prompt_adapter_request)
|
|
|
|
|
|
- return stream
|
|
|
|
|
|
+ return stream.generator()
|
|
|
|
|
|
async def generate(
|
|
async def generate(
|
|
self,
|
|
self,
|
|
@@ -705,7 +723,7 @@ class AsyncAphrodite:
|
|
request_id: str,
|
|
request_id: str,
|
|
lora_request: Optional[LoRARequest] = None,
|
|
lora_request: Optional[LoRARequest] = None,
|
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
|
- ) -> AsyncIterator[RequestOutput]:
|
|
|
|
|
|
+ ) -> AsyncGenerator[RequestOutput, None]:
|
|
"""Generate outputs for a request.
|
|
"""Generate outputs for a request.
|
|
|
|
|
|
Generate outputs for a request. This method is a coroutine. It adds the
|
|
Generate outputs for a request. This method is a coroutine. It adds the
|
|
@@ -771,7 +789,7 @@ class AsyncAphrodite:
|
|
>>> # Process and return the final output
|
|
>>> # Process and return the final output
|
|
>>> ...
|
|
>>> ...
|
|
"""
|
|
"""
|
|
- async for output in self._process_request(
|
|
|
|
|
|
+ async for output in await self.add_request(
|
|
request_id,
|
|
request_id,
|
|
inputs,
|
|
inputs,
|
|
sampling_params,
|
|
sampling_params,
|
|
@@ -786,7 +804,7 @@ class AsyncAphrodite:
|
|
pooling_params: PoolingParams,
|
|
pooling_params: PoolingParams,
|
|
request_id: str,
|
|
request_id: str,
|
|
lora_request: Optional[LoRARequest] = None,
|
|
lora_request: Optional[LoRARequest] = None,
|
|
- ) -> AsyncIterator[EmbeddingRequestOutput]:
|
|
|
|
|
|
+ ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
|
"""Generate outputs for a request from an embedding model.
|
|
"""Generate outputs for a request from an embedding model.
|
|
Generate outputs for a request. This method is a coroutine. It adds the
|
|
Generate outputs for a request. This method is a coroutine. It adds the
|
|
request into the waiting queue of the AphroditeEngine and streams the
|
|
request into the waiting queue of the AphroditeEngine and streams the
|
|
@@ -840,7 +858,7 @@ class AsyncAphrodite:
|
|
>>> # Process and return the final output
|
|
>>> # Process and return the final output
|
|
>>> ...
|
|
>>> ...
|
|
"""
|
|
"""
|
|
- async for output in self._process_request(
|
|
|
|
|
|
+ async for output in await self.add_request(
|
|
request_id,
|
|
request_id,
|
|
inputs,
|
|
inputs,
|
|
pooling_params,
|
|
pooling_params,
|
|
@@ -849,34 +867,6 @@ class AsyncAphrodite:
|
|
yield AphroditeEngine.validate_output(output,
|
|
yield AphroditeEngine.validate_output(output,
|
|
EmbeddingRequestOutput)
|
|
EmbeddingRequestOutput)
|
|
|
|
|
|
- async def _process_request(
|
|
|
|
- self,
|
|
|
|
- request_id: str,
|
|
|
|
- inputs: PromptInputs,
|
|
|
|
- params: Union[SamplingParams, PoolingParams],
|
|
|
|
- *,
|
|
|
|
- lora_request: Optional[LoRARequest] = None,
|
|
|
|
- prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
|
|
|
- ) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
|
|
|
|
- """Common logic to process requests with SamplingParams or
|
|
|
|
- PoolingParams."""
|
|
|
|
- arrival_time = time.time()
|
|
|
|
-
|
|
|
|
- stream = await self.add_request(
|
|
|
|
- request_id,
|
|
|
|
- inputs,
|
|
|
|
- params,
|
|
|
|
- arrival_time=arrival_time,
|
|
|
|
- lora_request=lora_request,
|
|
|
|
- prompt_adapter_request=prompt_adapter_request)
|
|
|
|
-
|
|
|
|
- try:
|
|
|
|
- async for request_output in stream:
|
|
|
|
- yield request_output
|
|
|
|
- except (Exception, asyncio.CancelledError) as e:
|
|
|
|
- self._abort(request_id)
|
|
|
|
- raise e
|
|
|
|
-
|
|
|
|
async def abort(self, request_id: str) -> None:
|
|
async def abort(self, request_id: str) -> None:
|
|
"""Abort a request.
|
|
"""Abort a request.
|
|
|
|
|
|
@@ -905,6 +895,7 @@ class AsyncAphrodite:
|
|
request_id: The unique id of the request.
|
|
request_id: The unique id of the request.
|
|
"""
|
|
"""
|
|
self._request_tracker.abort_request(request_id,
|
|
self._request_tracker.abort_request(request_id,
|
|
|
|
+ cancelled=True,
|
|
verbose=self.log_requests)
|
|
verbose=self.log_requests)
|
|
|
|
|
|
async def get_model_config(self) -> ModelConfig:
|
|
async def get_model_config(self) -> ModelConfig:
|