api_server.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. """
  2. NOTE: This API server is used only for demonstrating usage of AsyncAphrodite
  3. and simple performance benchmarks. It is not intended for production use.
  4. For production use, we recommend using our OpenAI compatible server.
  5. We are also not going to accept PRs modifying this file, please
  6. change `aphrodite/endpoints/openai/api_server.py` instead.
  7. """
  8. import asyncio
  9. import json
  10. import ssl
  11. from argparse import Namespace
  12. from typing import Any, AsyncGenerator, Optional
  13. from fastapi import FastAPI, Request
  14. from fastapi.responses import JSONResponse, Response, StreamingResponse
  15. from aphrodite.common.sampling_params import SamplingParams
  16. from aphrodite.common.utils import (FlexibleArgumentParser,
  17. iterate_with_cancellation, random_uuid)
  18. from aphrodite.engine.args_tools import AsyncEngineArgs
  19. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  20. from aphrodite.server.launch import serve_http
  21. TIMEOUT_KEEP_ALIVE = 5 # seconds.
  22. app = FastAPI()
  23. engine = None
  24. @app.get("/health")
  25. async def health() -> Response:
  26. """Health check."""
  27. return Response(status_code=200)
  28. @app.post("/generate")
  29. async def generate(request: Request) -> Response:
  30. """Generate completion for the request.
  31. The request should be a JSON object with the following fields:
  32. - prompt: the prompt to use for the generation.
  33. - stream: whether to stream the results or not.
  34. - other fields: the sampling parameters (See `SamplingParams` for details).
  35. """
  36. request_dict = await request.json()
  37. prompt = request_dict.pop("prompt")
  38. stream = request_dict.pop("stream", False)
  39. sampling_params = SamplingParams(**request_dict)
  40. request_id = random_uuid()
  41. assert engine is not None
  42. results_generator = engine.generate(prompt, sampling_params, request_id)
  43. results_generator = iterate_with_cancellation(
  44. results_generator, is_cancelled=request.is_disconnected)
  45. # Streaming case
  46. async def stream_results() -> AsyncGenerator[bytes, None]:
  47. async for request_output in results_generator:
  48. prompt = request_output.prompt
  49. text_outputs = [
  50. prompt + output.text for output in request_output.outputs
  51. ]
  52. ret = {"text": text_outputs}
  53. yield (json.dumps(ret) + "\0").encode("utf-8")
  54. if stream:
  55. return StreamingResponse(stream_results())
  56. # Non-streaming case
  57. final_output = None
  58. try:
  59. async for request_output in results_generator:
  60. final_output = request_output
  61. except asyncio.CancelledError:
  62. return Response(status_code=499)
  63. assert final_output is not None
  64. prompt = final_output.prompt
  65. text_outputs = [prompt + output.text for output in final_output.outputs]
  66. ret = {"text": text_outputs}
  67. return JSONResponse(ret)
  68. def build_app(args: Namespace) -> FastAPI:
  69. global app
  70. app.root_path = args.root_path
  71. return app
  72. async def init_app(
  73. args: Namespace,
  74. llm_engine: Optional[AsyncAphrodite] = None,
  75. ) -> FastAPI:
  76. app = build_app(args)
  77. global engine
  78. engine_args = AsyncEngineArgs.from_cli_args(args)
  79. engine = (llm_engine
  80. if llm_engine is not None else AsyncAphrodite.from_engine_args(
  81. engine_args))
  82. return app
  83. async def run_server(args: Namespace,
  84. llm_engine: Optional[AsyncAphrodite] = None,
  85. **uvicorn_kwargs: Any) -> None:
  86. app = await init_app(args, llm_engine)
  87. shutdown_task = await serve_http(
  88. app,
  89. engine=engine,
  90. host=args.host,
  91. port=args.port,
  92. log_level=args.log_level,
  93. timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
  94. ssl_keyfile=args.ssl_keyfile,
  95. ssl_certfile=args.ssl_certfile,
  96. ssl_ca_certs=args.ssl_ca_certs,
  97. ssl_cert_reqs=args.ssl_cert_reqs,
  98. **uvicorn_kwargs,
  99. )
  100. await shutdown_task
  101. if __name__ == "__main__":
  102. parser = FlexibleArgumentParser()
  103. parser.add_argument("--host", type=str, default=None)
  104. parser.add_argument("--port", type=int, default=2242)
  105. parser.add_argument("--ssl-keyfile", type=str, default=None)
  106. parser.add_argument("--ssl-certfile", type=str, default=None)
  107. parser.add_argument("--ssl-ca-certs",
  108. type=str,
  109. default=None,
  110. help="The CA certificates file")
  111. parser.add_argument(
  112. "--ssl-cert-reqs",
  113. type=int,
  114. default=int(ssl.CERT_NONE),
  115. help="Whether client certificate is required (see stdlib ssl module's)"
  116. )
  117. parser.add_argument(
  118. "--root-path",
  119. type=str,
  120. default=None,
  121. help="FastAPI root_path when app is behind a path based routing proxy")
  122. parser.add_argument("--log-level", type=str, default="debug")
  123. parser = AsyncEngineArgs.add_cli_args(parser)
  124. args = parser.parse_args()
  125. asyncio.run(run_server(args))