api_server_async_aphrodite.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. """API server with some extra logging for testing."""
  2. import argparse
  3. from typing import Any, Dict
  4. import uvicorn
  5. from fastapi.responses import JSONResponse, Response
  6. import aphrodite.endpoints.api_server_ooba
  7. from aphrodite.engine.args_tools import AsyncEngineArgs
  8. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  9. app = aphrodite.endpoints.api_server_ooba.app
  10. class AsyncAphroditeWithStats(AsyncAphrodite):
  11. # pylint: disable=redefined-outer-name
  12. def __init__(self, *args, **kwargs):
  13. super().__init__(*args, **kwargs)
  14. self._num_aborts = 0
  15. async def abort(self, request_id: str) -> None:
  16. await super().abort(request_id)
  17. self._num_aborts += 1
  18. def testing_stats(self) -> Dict[str, Any]:
  19. return {"num_aborted_requests": self._num_aborts}
  20. @app.get("/stats")
  21. def stats() -> Response:
  22. """Get the statistics of the engine."""
  23. return JSONResponse(engine.testing_stats())
  24. if __name__ == "__main__":
  25. parser = argparse.ArgumentParser()
  26. parser.add_argument("--host", type=str, default="localhost")
  27. parser.add_argument("--port", type=int, default=8000)
  28. parser = AsyncEngineArgs.add_cli_args(parser)
  29. args = parser.parse_args()
  30. engine_args = AsyncEngineArgs.from_cli_args(args)
  31. engine = AsyncAphroditeWithStats.from_engine_args(engine_args)
  32. aphrodite.endpoints.api_server_ooba.engine = engine
  33. uvicorn.run(app,
  34. host=args.host,
  35. port=args.port,
  36. log_level="debug",
  37. timeout_keep_alive=aphrodite.endpoints.api_server_ooba.
  38. TIMEOUT_KEEP_ALIVE)