""" NOTE: This API server is used only for demonstrating usage of AsyncAphrodite and simple performance benchmarks. It is not intended for production use. For production use, we recommend using our OpenAI compatible server. We are also not going to accept PRs modifying this file, please change `aphrodite/endpoints/openai/api_server.py` instead. """ import asyncio import json import ssl from argparse import Namespace from typing import Any, AsyncGenerator, Optional from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, Response, StreamingResponse from aphrodite.common.sampling_params import SamplingParams from aphrodite.common.utils import (FlexibleArgumentParser, iterate_with_cancellation, random_uuid) from aphrodite.engine.args_tools import AsyncEngineArgs from aphrodite.engine.async_aphrodite import AsyncAphrodite from aphrodite.server.launch import serve_http TIMEOUT_KEEP_ALIVE = 5 # seconds. app = FastAPI() engine = None @app.get("/health") async def health() -> Response: """Health check.""" return Response(status_code=200) @app.post("/generate") async def generate(request: Request) -> Response: """Generate completion for the request. The request should be a JSON object with the following fields: - prompt: the prompt to use for the generation. - stream: whether to stream the results or not. - other fields: the sampling parameters (See `SamplingParams` for details). """ request_dict = await request.json() prompt = request_dict.pop("prompt") stream = request_dict.pop("stream", False) sampling_params = SamplingParams(**request_dict) request_id = random_uuid() assert engine is not None results_generator = engine.generate(prompt, sampling_params, request_id) results_generator = iterate_with_cancellation( results_generator, is_cancelled=request.is_disconnected) # Streaming case async def stream_results() -> AsyncGenerator[bytes, None]: async for request_output in results_generator: prompt = request_output.prompt text_outputs = [ prompt + output.text for output in request_output.outputs ] ret = {"text": text_outputs} yield (json.dumps(ret) + "\0").encode("utf-8") if stream: return StreamingResponse(stream_results()) # Non-streaming case final_output = None try: async for request_output in results_generator: final_output = request_output except asyncio.CancelledError: return Response(status_code=499) assert final_output is not None prompt = final_output.prompt text_outputs = [prompt + output.text for output in final_output.outputs] ret = {"text": text_outputs} return JSONResponse(ret) def build_app(args: Namespace) -> FastAPI: global app app.root_path = args.root_path return app async def init_app( args: Namespace, llm_engine: Optional[AsyncAphrodite] = None, ) -> FastAPI: app = build_app(args) global engine engine_args = AsyncEngineArgs.from_cli_args(args) engine = (llm_engine if llm_engine is not None else AsyncAphrodite.from_engine_args( engine_args)) return app async def run_server(args: Namespace, llm_engine: Optional[AsyncAphrodite] = None, **uvicorn_kwargs: Any) -> None: app = await init_app(args, llm_engine) shutdown_task = await serve_http( app, engine=engine, host=args.host, port=args.port, log_level=args.log_level, timeout_keep_alive=TIMEOUT_KEEP_ALIVE, ssl_keyfile=args.ssl_keyfile, ssl_certfile=args.ssl_certfile, ssl_ca_certs=args.ssl_ca_certs, ssl_cert_reqs=args.ssl_cert_reqs, **uvicorn_kwargs, ) await shutdown_task if __name__ == "__main__": parser = FlexibleArgumentParser() parser.add_argument("--host", type=str, default=None) parser.add_argument("--port", type=int, default=2242) parser.add_argument("--ssl-keyfile", type=str, default=None) parser.add_argument("--ssl-certfile", type=str, default=None) parser.add_argument("--ssl-ca-certs", type=str, default=None, help="The CA certificates file") parser.add_argument( "--ssl-cert-reqs", type=int, default=int(ssl.CERT_NONE), help="Whether client certificate is required (see stdlib ssl module's)" ) parser.add_argument( "--root-path", type=str, default=None, help="FastAPI root_path when app is behind a path based routing proxy") parser.add_argument("--log-level", type=str, default="debug") parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() asyncio.run(run_server(args))