api_server_async_aphrodite.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. """aphrodite.endpoints.api_server with some extra logging for testing."""
  2. from typing import Any, Dict, Iterable
  3. import uvicorn
  4. from fastapi.responses import JSONResponse, Response
  5. import aphrodite.endpoints.api_server
  6. from aphrodite.common.utils import FlexibleArgumentParser
  7. from aphrodite.engine.args_tools import AsyncEngineArgs
  8. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  9. app = aphrodite.endpoints.api_server.app
  10. class AsyncAphroditeWithStats(AsyncAphrodite):
  11. def __init__(self, *args, **kwargs):
  12. super().__init__(*args, **kwargs)
  13. self._num_aborts = 0
  14. async def _engine_abort(self, request_ids: Iterable[str]):
  15. ids = list(request_ids)
  16. self._num_aborts += len(ids)
  17. await super()._engine_abort(ids)
  18. def testing_stats(self) -> Dict[str, Any]:
  19. return {"num_aborted_requests": self._num_aborts}
  20. @app.get("/stats")
  21. def stats() -> Response:
  22. """Get the statistics of the engine."""
  23. return JSONResponse(engine.testing_stats())
  24. if __name__ == "__main__":
  25. parser = FlexibleArgumentParser()
  26. parser.add_argument("--host", type=str, default="localhost")
  27. parser.add_argument("--port", type=int, default=8000)
  28. parser = AsyncEngineArgs.add_cli_args(parser)
  29. args = parser.parse_args()
  30. engine_args = AsyncEngineArgs.from_cli_args(args)
  31. engine = AsyncAphroditeWithStats.from_engine_args(engine_args)
  32. aphrodite.endpoints.api_server.engine = engine
  33. uvicorn.run(
  34. app,
  35. host=args.host,
  36. port=args.port,
  37. log_level="debug",
  38. timeout_keep_alive=aphrodite.endpoints.api_server.TIMEOUT_KEEP_ALIVE)