import asyncio from http import HTTPStatus from io import StringIO from typing import Awaitable, Callable, List, Optional import aiohttp import torch from prometheus_client import start_http_server from tqdm import tqdm from aphrodite.common.utils import FlexibleArgumentParser, random_uuid from aphrodite.endpoints.logger import RequestLogger, logger # yapf: disable from aphrodite.endpoints.openai.protocol import (BatchRequestInput, BatchRequestOutput, BatchResponseData, ChatCompletionResponse, EmbeddingResponse, ErrorResponse) # yapf: enable from aphrodite.endpoints.openai.serving_chat import OpenAIServingChat from aphrodite.endpoints.openai.serving_embedding import OpenAIServingEmbedding from aphrodite.endpoints.openai.serving_engine import BaseModelPath from aphrodite.engine.args_tools import AsyncEngineArgs from aphrodite.engine.async_aphrodite import AsyncAphrodite from aphrodite.version import __version__ as APHRODITE_VERSION def parse_args(): parser = FlexibleArgumentParser( description="Aphrodite OpenAI-Compatible batch runner.") parser.add_argument( "-i", "--input-file", required=True, type=str, help= "The path or url to a single input file. Currently supports local file " "paths, or the http protocol (http or https). If a URL is specified, " "the file should be available via HTTP GET.") parser.add_argument( "-o", "--output-file", required=True, type=str, help="The path or url to a single output file. Currently supports " "local file paths, or web (http or https) urls. If a URL is specified," " the file should be available via HTTP PUT.") parser.add_argument("--response-role", type=str, default="assistant", help="The role name to return if " "`request.add_generation_prompt=True`.") parser = AsyncEngineArgs.add_cli_args(parser) parser.add_argument('--max-log-len', type=int, default=None, help='Max number of prompt characters or prompt ' 'ID numbers being printed in log.' '\n\nDefault: Unlimited') parser.add_argument("--enable-metrics", action="store_true", help="Enable Prometheus metrics") parser.add_argument( "--url", type=str, default="0.0.0.0", help="URL to the Prometheus metrics server " "(only needed if enable-metrics is set).", ) parser.add_argument( "--port", type=int, default=8000, help="Port number for the Prometheus metrics server " "(only needed if enable-metrics is set).", ) return parser.parse_args() # explicitly use pure text format, with a newline at the end # this makes it impossible to see the animation in the progress bar # but will avoid messing up with ray or multiprocessing, which wraps # each line of output with some prefix. _BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501 class BatchProgressTracker: def __init__(self): self._total = 0 self._pbar: Optional[tqdm] = None def submitted(self): self._total += 1 def completed(self): if self._pbar: self._pbar.update() def pbar(self) -> tqdm: enable_tqdm = not torch.distributed.is_initialized( ) or torch.distributed.get_rank() == 0 self._pbar = tqdm(total=self._total, unit="req", desc="Running batch", mininterval=5, disable=not enable_tqdm, bar_format=_BAR_FORMAT) return self._pbar async def read_file(path_or_url: str) -> str: if path_or_url.startswith("http://") or path_or_url.startswith("https://"): async with aiohttp.ClientSession() as session, \ session.get(path_or_url) as resp: return await resp.text() else: with open(path_or_url, "r", encoding="utf-8") as f: return f.read() async def write_file(path_or_url: str, data: str) -> None: if path_or_url.startswith("http://") or path_or_url.startswith("https://"): async with aiohttp.ClientSession() as session, \ session.put(path_or_url, data=data.encode("utf-8")): pass else: # We should make this async, but as long as this is always run as a # standalone program, blocking the event loop won't effect performance # in this particular case. with open(path_or_url, "w", encoding="utf-8") as f: f.write(data) def make_error_request_output(request: BatchRequestInput, error_msg: str) -> BatchRequestOutput: batch_output = BatchRequestOutput( id=f"aphrodite-{random_uuid()}", custom_id=request.custom_id, response=BatchResponseData( status_code=HTTPStatus.BAD_REQUEST, request_id=f"aphrodite-batch-{random_uuid()}", ), error=error_msg, ) return batch_output async def make_async_error_request_output( request: BatchRequestInput, error_msg: str) -> BatchRequestOutput: return make_error_request_output(request, error_msg) async def run_request(serving_engine_func: Callable, request: BatchRequestInput, tracker: BatchProgressTracker) -> BatchRequestOutput: response = await serving_engine_func(request.body) if isinstance(response, (ChatCompletionResponse, EmbeddingResponse)): batch_output = BatchRequestOutput( id=f"aphrodite-{random_uuid()}", custom_id=request.custom_id, response=BatchResponseData( body=response, request_id=f"aphrodite-batch-{random_uuid()}"), error=None, ) elif isinstance(response, ErrorResponse): batch_output = BatchRequestOutput( id=f"aphrodite-{random_uuid()}", custom_id=request.custom_id, response=BatchResponseData( status_code=response.code, request_id=f"aphrodite-batch-{random_uuid()}"), error=response, ) else: batch_output = make_error_request_output( request, error_msg="Request must not be sent in stream mode") tracker.completed() return batch_output async def main(args): if args.served_model_name is not None: served_model_names = args.served_model_name else: served_model_names = [args.model] engine_args = AsyncEngineArgs.from_cli_args(args) engine = AsyncAphrodite.from_engine_args(engine_args) model_config = await engine.get_model_config() base_model_paths = [ BaseModelPath(name=name, model_path=args.model) for name in served_model_names ] if args.disable_log_requests: request_logger = None else: request_logger = RequestLogger(max_log_len=args.max_log_len) # Create the openai serving objects. openai_serving_chat = OpenAIServingChat( engine, model_config, base_model_paths, args.response_role, lora_modules=None, prompt_adapters=None, request_logger=request_logger, chat_template=None, ) openai_serving_embedding = OpenAIServingEmbedding( engine, model_config, base_model_paths, request_logger=request_logger, ) tracker = BatchProgressTracker() logger.info(f"Reading batch from {args.input_file}...") # Submit all requests in the file to the engine "concurrently". response_futures: List[Awaitable[BatchRequestOutput]] = [] for request_json in (await read_file(args.input_file)).strip().split("\n"): # Skip empty lines. request_json = request_json.strip() if not request_json: continue request = BatchRequestInput.model_validate_json(request_json) # Determine the type of request and run it. if request.url == "/v1/chat/completions": response_futures.append( run_request(openai_serving_chat.create_chat_completion, request, tracker)) tracker.submitted() elif request.url == "/v1/embeddings": response_futures.append( run_request(openai_serving_embedding.create_embedding, request, tracker)) tracker.submitted() else: response_futures.append( make_async_error_request_output( request, error_msg="Only /v1/chat/completions and " "/v1/embeddings are supported in the batch endpoint.", )) with tracker.pbar(): responses = await asyncio.gather(*response_futures) output_buffer = StringIO() for response in responses: print(response.model_dump_json(), file=output_buffer) output_buffer.seek(0) await write_file(args.output_file, output_buffer.read().strip()) if __name__ == "__main__": args = parse_args() logger.info(f"Aphrodite batch processing API version {APHRODITE_VERSION}") logger.debug(f"args: {args}") # Start the Prometheus metrics server. LLMEngine uses the Prometheus client # to publish metrics at the /metrics endpoint. if args.enable_metrics: logger.info("Prometheus metrics enabled") start_http_server(port=args.port, addr=args.url) else: logger.info("Prometheus metrics disabled") asyncio.run(main(args))