api_server_async_aphrodite.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. """aphrodite.endpoints.ooba.api_server with some extra logging for testing."""
  2. import argparse
  3. from typing import Any, Dict
  4. import uvicorn
  5. from fastapi.responses import JSONResponse, Response
  6. import aphrodite.endpoints.ooba.api_server
  7. from aphrodite.engine.args_tools import AsyncEngineArgs
  8. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  9. app = aphrodite.endpoints.ooba.api_server.app
  10. class AsyncAphroditeWithStats(AsyncAphrodite):
  11. def __init__(self, *args, **kwargs):
  12. super().__init__(*args, **kwargs)
  13. self._num_aborts = 0
  14. async def abort(self, request_id: str) -> None:
  15. await super().abort(request_id)
  16. self._num_aborts += 1
  17. def testing_stats(self) -> Dict[str, Any]:
  18. return {"num_aborted_requests": self._num_aborts}
  19. @app.get("/stats")
  20. def stats() -> Response:
  21. """Get the statistics of the engine."""
  22. return JSONResponse(engine.testing_stats())
  23. if __name__ == "__main__":
  24. parser = argparse.ArgumentParser()
  25. parser.add_argument("--host", type=str, default="localhost")
  26. parser.add_argument("--port", type=int, default=2242)
  27. parser = AsyncEngineArgs.add_cli_args(parser)
  28. args = parser.parse_args()
  29. engine_args = AsyncEngineArgs.from_cli_args(args)
  30. engine = AsyncAphroditeWithStats.from_engine_args(engine_args)
  31. aphrodite.endpoints.ooba.api_server.engine = engine
  32. uvicorn.run(app,
  33. host=args.host,
  34. port=args.port,
  35. log_level="debug",
  36. timeout_keep_alive=aphrodite.endpoints.ooba.api_server.
  37. TIMEOUT_KEEP_ALIVE)