123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182 |
- import asyncio
- from io import StringIO
- from typing import Awaitable, Callable, List
- import aiohttp
- from loguru import logger
- from aphrodite.common.utils import FlexibleArgumentParser, random_uuid
- from aphrodite.endpoints.logger import RequestLogger
- from aphrodite.endpoints.openai.protocol import (BatchRequestInput,
- BatchRequestOutput,
- BatchResponseData,
- ChatCompletionResponse,
- EmbeddingResponse,
- ErrorResponse)
- from aphrodite.endpoints.openai.serving_chat import OpenAIServingChat
- from aphrodite.endpoints.openai.serving_embedding import OpenAIServingEmbedding
- from aphrodite.engine.args_tools import AsyncEngineArgs
- from aphrodite.engine.async_aphrodite import AsyncAphrodite
- from aphrodite.version import __version__ as APHRODITE_VERSION
- def parse_args():
- parser = FlexibleArgumentParser(
- description="Aphrodite OpenAI-Compatible batch runner.")
- parser.add_argument(
- "-i",
- "--input-file",
- required=True,
- type=str,
- help=
- "The path or url to a single input file. Currently supports local file "
- "paths, or the http protocol (http or https). If a URL is specified, "
- "the file should be available via HTTP GET.")
- parser.add_argument(
- "-o",
- "--output-file",
- required=True,
- type=str,
- help="The path or url to a single output file. Currently supports "
- "local file paths, or web (http or https) urls. If a URL is specified,"
- " the file should be available via HTTP PUT.")
- parser.add_argument("--response-role",
- type=str,
- default="assistant",
- help="The role name to return if "
- "`request.add_generation_prompt=true`.")
- parser.add_argument("--max-log-len",
- type=int,
- default=0,
- help="Max number of prompt characters or prompt "
- "ID numbers being printed in log."
- "\n\nDefault: 0")
- parser = AsyncEngineArgs.add_cli_args(parser)
- return parser.parse_args()
- async def read_file(path_or_url: str) -> str:
- if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
- async with aiohttp.ClientSession() as session, \
- session.get(path_or_url) as resp:
- return await resp.text()
- else:
- with open(path_or_url, "r", encoding="utf-8") as f:
- return f.read()
- async def write_file(path_or_url: str, data: str) -> None:
- if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
- async with aiohttp.ClientSession() as session, \
- session.put(path_or_url, data=data.encode("utf-8")):
- pass
- else:
- # We should make this async, but as long as this is always run as a
- # standalone program, blocking the event loop won't effect performance
- # in this particular case.
- with open(path_or_url, "w", encoding="utf-8") as f:
- f.write(data)
- async def run_request(serving_engine_func: Callable,
- request: BatchRequestInput) -> BatchRequestOutput:
- response = await serving_engine_func(request.body)
- if isinstance(response, (ChatCompletionResponse, EmbeddingResponse)):
- batch_output = BatchRequestOutput(
- id=f"aphrodite-{random_uuid()}",
- custom_id=request.custom_id,
- response=BatchResponseData(
- body=response, request_id=f"aphrodite-batch-{random_uuid()}"),
- error=None,
- )
- elif isinstance(response, ErrorResponse):
- batch_output = BatchRequestOutput(
- id=f"aphrodite-{random_uuid()}",
- custom_id=request.custom_id,
- response=BatchResponseData(
- status_code=response.code,
- request_id=f"aphrodite-batch-{random_uuid()}"),
- error=response,
- )
- else:
- raise ValueError("Request must not be sent in stream mode")
- return batch_output
- async def main(args):
- if args.served_model_name is not None:
- served_model_names = args.served_model_name
- else:
- served_model_names = [args.model]
- engine_args = AsyncEngineArgs.from_cli_args(args)
- engine = AsyncAphrodite.from_engine_args(engine_args)
- # When using single Aphrodite without engine_use_ray
- model_config = await engine.get_model_config()
- if args.disable_log_requests:
- request_logger = None
- else:
- request_logger = RequestLogger(max_log_len=args.max_log_len)
- # Create the OpenAI serving objects.
- openai_serving_chat = OpenAIServingChat(
- engine,
- model_config,
- served_model_names,
- args.response_role,
- lora_modules=None,
- prompt_adapters=None,
- request_logger=request_logger,
- chat_template=None,
- )
- openai_serving_embedding = OpenAIServingEmbedding(
- engine,
- model_config,
- served_model_names,
- request_logger=request_logger,
- )
- # Submit all requests in the file to the engine "concurrently".
- response_futures: List[Awaitable[BatchRequestOutput]] = []
- for request_json in (await read_file(args.input_file)).strip().split("\n"):
- # Skip empty lines.
- request_json = request_json.strip()
- if not request_json:
- continue
- request = BatchRequestInput.model_validate_json(request_json)
- # Determine the type of request and run it.
- if request.url == "/v1/chat/completions":
- response_futures.append(
- run_request(openai_serving_chat.create_chat_completion,
- request))
- elif request.url == "/v1/embeddings":
- response_futures.append(
- run_request(openai_serving_embedding.create_embedding,
- request))
- else:
- raise ValueError("Only /v1/chat/completions and /v1/embeddings are"
- "supported in the batch endpoint.")
- responses = await asyncio.gather(*response_futures)
- output_buffer = StringIO()
- for response in responses:
- print(response.model_dump_json(), file=output_buffer)
- output_buffer.seek(0)
- await write_file(args.output_file, output_buffer.read().strip())
- if __name__ == "__main__":
- args = parse_args()
- logger.info(f"Aphrodite API server version {APHRODITE_VERSION}")
- logger.info(f"args: {args}")
- asyncio.run(main(args))
|