|
@@ -1,20 +1,26 @@
|
|
|
import asyncio
|
|
|
+from http import HTTPStatus
|
|
|
from io import StringIO
|
|
|
-from typing import Awaitable, Callable, List
|
|
|
+from typing import Awaitable, Callable, List, Optional
|
|
|
|
|
|
import aiohttp
|
|
|
-from loguru import logger
|
|
|
+import torch
|
|
|
+from prometheus_client import start_http_server
|
|
|
+from tqdm import tqdm
|
|
|
|
|
|
from aphrodite.common.utils import FlexibleArgumentParser, random_uuid
|
|
|
-from aphrodite.endpoints.logger import RequestLogger
|
|
|
+from aphrodite.endpoints.logger import RequestLogger, logger
|
|
|
+# yapf: disable
|
|
|
from aphrodite.endpoints.openai.protocol import (BatchRequestInput,
|
|
|
BatchRequestOutput,
|
|
|
BatchResponseData,
|
|
|
ChatCompletionResponse,
|
|
|
EmbeddingResponse,
|
|
|
ErrorResponse)
|
|
|
+# yapf: enable
|
|
|
from aphrodite.endpoints.openai.serving_chat import OpenAIServingChat
|
|
|
from aphrodite.endpoints.openai.serving_embedding import OpenAIServingEmbedding
|
|
|
+from aphrodite.endpoints.openai.serving_engine import BaseModelPath
|
|
|
from aphrodite.engine.args_tools import AsyncEngineArgs
|
|
|
from aphrodite.engine.async_aphrodite import AsyncAphrodite
|
|
|
from aphrodite.version import __version__ as APHRODITE_VERSION
|
|
@@ -44,18 +50,70 @@ def parse_args():
|
|
|
type=str,
|
|
|
default="assistant",
|
|
|
help="The role name to return if "
|
|
|
- "`request.add_generation_prompt=true`.")
|
|
|
- parser.add_argument("--max-log-len",
|
|
|
- type=int,
|
|
|
- default=0,
|
|
|
- help="Max number of prompt characters or prompt "
|
|
|
- "ID numbers being printed in log."
|
|
|
- "\n\nDefault: 0")
|
|
|
+ "`request.add_generation_prompt=True`.")
|
|
|
|
|
|
parser = AsyncEngineArgs.add_cli_args(parser)
|
|
|
+
|
|
|
+ parser.add_argument('--max-log-len',
|
|
|
+ type=int,
|
|
|
+ default=None,
|
|
|
+ help='Max number of prompt characters or prompt '
|
|
|
+ 'ID numbers being printed in log.'
|
|
|
+ '\n\nDefault: Unlimited')
|
|
|
+
|
|
|
+ parser.add_argument("--enable-metrics",
|
|
|
+ action="store_true",
|
|
|
+ help="Enable Prometheus metrics")
|
|
|
+ parser.add_argument(
|
|
|
+ "--url",
|
|
|
+ type=str,
|
|
|
+ default="0.0.0.0",
|
|
|
+ help="URL to the Prometheus metrics server "
|
|
|
+ "(only needed if enable-metrics is set).",
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--port",
|
|
|
+ type=int,
|
|
|
+ default=8000,
|
|
|
+ help="Port number for the Prometheus metrics server "
|
|
|
+ "(only needed if enable-metrics is set).",
|
|
|
+ )
|
|
|
+
|
|
|
return parser.parse_args()
|
|
|
|
|
|
|
|
|
+# explicitly use pure text format, with a newline at the end
|
|
|
+# this makes it impossible to see the animation in the progress bar
|
|
|
+# but will avoid messing up with ray or multiprocessing, which wraps
|
|
|
+# each line of output with some prefix.
|
|
|
+_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501
|
|
|
+
|
|
|
+
|
|
|
+class BatchProgressTracker:
|
|
|
+
|
|
|
+ def __init__(self):
|
|
|
+ self._total = 0
|
|
|
+ self._pbar: Optional[tqdm] = None
|
|
|
+
|
|
|
+ def submitted(self):
|
|
|
+ self._total += 1
|
|
|
+
|
|
|
+ def completed(self):
|
|
|
+ if self._pbar:
|
|
|
+ self._pbar.update()
|
|
|
+
|
|
|
+ def pbar(self) -> tqdm:
|
|
|
+ enable_tqdm = not torch.distributed.is_initialized(
|
|
|
+ ) or torch.distributed.get_rank() == 0
|
|
|
+ self._pbar = tqdm(total=self._total,
|
|
|
+ unit="req",
|
|
|
+ desc="Running batch",
|
|
|
+ mininterval=5,
|
|
|
+ disable=not enable_tqdm,
|
|
|
+ bar_format=_BAR_FORMAT)
|
|
|
+ return self._pbar
|
|
|
+
|
|
|
+
|
|
|
async def read_file(path_or_url: str) -> str:
|
|
|
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
|
|
|
async with aiohttp.ClientSession() as session, \
|
|
@@ -79,8 +137,28 @@ async def write_file(path_or_url: str, data: str) -> None:
|
|
|
f.write(data)
|
|
|
|
|
|
|
|
|
+def make_error_request_output(request: BatchRequestInput,
|
|
|
+ error_msg: str) -> BatchRequestOutput:
|
|
|
+ batch_output = BatchRequestOutput(
|
|
|
+ id=f"aphrodite-{random_uuid()}",
|
|
|
+ custom_id=request.custom_id,
|
|
|
+ response=BatchResponseData(
|
|
|
+ status_code=HTTPStatus.BAD_REQUEST,
|
|
|
+ request_id=f"aphrodite-batch-{random_uuid()}",
|
|
|
+ ),
|
|
|
+ error=error_msg,
|
|
|
+ )
|
|
|
+ return batch_output
|
|
|
+
|
|
|
+
|
|
|
+async def make_async_error_request_output(
|
|
|
+ request: BatchRequestInput, error_msg: str) -> BatchRequestOutput:
|
|
|
+ return make_error_request_output(request, error_msg)
|
|
|
+
|
|
|
+
|
|
|
async def run_request(serving_engine_func: Callable,
|
|
|
- request: BatchRequestInput) -> BatchRequestOutput:
|
|
|
+ request: BatchRequestInput,
|
|
|
+ tracker: BatchProgressTracker) -> BatchRequestOutput:
|
|
|
response = await serving_engine_func(request.body)
|
|
|
|
|
|
if isinstance(response, (ChatCompletionResponse, EmbeddingResponse)):
|
|
@@ -101,8 +179,10 @@ async def run_request(serving_engine_func: Callable,
|
|
|
error=response,
|
|
|
)
|
|
|
else:
|
|
|
- raise ValueError("Request must not be sent in stream mode")
|
|
|
+ batch_output = make_error_request_output(
|
|
|
+ request, error_msg="Request must not be sent in stream mode")
|
|
|
|
|
|
+ tracker.completed()
|
|
|
return batch_output
|
|
|
|
|
|
|
|
@@ -115,19 +195,22 @@ async def main(args):
|
|
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
|
|
engine = AsyncAphrodite.from_engine_args(engine_args)
|
|
|
|
|
|
- # When using single Aphrodite without engine_use_ray
|
|
|
model_config = await engine.get_model_config()
|
|
|
+ base_model_paths = [
|
|
|
+ BaseModelPath(name=name, model_path=args.model)
|
|
|
+ for name in served_model_names
|
|
|
+ ]
|
|
|
|
|
|
if args.disable_log_requests:
|
|
|
request_logger = None
|
|
|
else:
|
|
|
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
|
|
|
|
|
- # Create the OpenAI serving objects.
|
|
|
+ # Create the openai serving objects.
|
|
|
openai_serving_chat = OpenAIServingChat(
|
|
|
engine,
|
|
|
model_config,
|
|
|
- served_model_names,
|
|
|
+ base_model_paths,
|
|
|
args.response_role,
|
|
|
lora_modules=None,
|
|
|
prompt_adapters=None,
|
|
@@ -137,10 +220,13 @@ async def main(args):
|
|
|
openai_serving_embedding = OpenAIServingEmbedding(
|
|
|
engine,
|
|
|
model_config,
|
|
|
- served_model_names,
|
|
|
+ base_model_paths,
|
|
|
request_logger=request_logger,
|
|
|
)
|
|
|
|
|
|
+ tracker = BatchProgressTracker()
|
|
|
+ logger.info(f"Reading batch from {args.input_file}...")
|
|
|
+
|
|
|
# Submit all requests in the file to the engine "concurrently".
|
|
|
response_futures: List[Awaitable[BatchRequestOutput]] = []
|
|
|
for request_json in (await read_file(args.input_file)).strip().split("\n"):
|
|
@@ -148,22 +234,30 @@ async def main(args):
|
|
|
request_json = request_json.strip()
|
|
|
if not request_json:
|
|
|
continue
|
|
|
+
|
|
|
request = BatchRequestInput.model_validate_json(request_json)
|
|
|
|
|
|
# Determine the type of request and run it.
|
|
|
if request.url == "/v1/chat/completions":
|
|
|
response_futures.append(
|
|
|
run_request(openai_serving_chat.create_chat_completion,
|
|
|
- request))
|
|
|
+ request, tracker))
|
|
|
+ tracker.submitted()
|
|
|
elif request.url == "/v1/embeddings":
|
|
|
response_futures.append(
|
|
|
- run_request(openai_serving_embedding.create_embedding,
|
|
|
- request))
|
|
|
+ run_request(openai_serving_embedding.create_embedding, request,
|
|
|
+ tracker))
|
|
|
+ tracker.submitted()
|
|
|
else:
|
|
|
- raise ValueError("Only /v1/chat/completions and /v1/embeddings are"
|
|
|
- "supported in the batch endpoint.")
|
|
|
+ response_futures.append(
|
|
|
+ make_async_error_request_output(
|
|
|
+ request,
|
|
|
+ error_msg="Only /v1/chat/completions and "
|
|
|
+ "/v1/embeddings are supported in the batch endpoint.",
|
|
|
+ ))
|
|
|
|
|
|
- responses = await asyncio.gather(*response_futures)
|
|
|
+ with tracker.pbar():
|
|
|
+ responses = await asyncio.gather(*response_futures)
|
|
|
|
|
|
output_buffer = StringIO()
|
|
|
for response in responses:
|
|
@@ -176,7 +270,15 @@ async def main(args):
|
|
|
if __name__ == "__main__":
|
|
|
args = parse_args()
|
|
|
|
|
|
- logger.info(f"Aphrodite API server version {APHRODITE_VERSION}")
|
|
|
- logger.info(f"args: {args}")
|
|
|
+ logger.info(f"Aphrodite batch processing API version {APHRODITE_VERSION}")
|
|
|
+ logger.debug(f"args: {args}")
|
|
|
+
|
|
|
+ # Start the Prometheus metrics server. LLMEngine uses the Prometheus client
|
|
|
+ # to publish metrics at the /metrics endpoint.
|
|
|
+ if args.enable_metrics:
|
|
|
+ logger.info("Prometheus metrics enabled")
|
|
|
+ start_http_server(port=args.port, addr=args.url)
|
|
|
+ else:
|
|
|
+ logger.info("Prometheus metrics disabled")
|
|
|
|
|
|
asyncio.run(main(args))
|