api_server.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import argparse
  2. import json
  3. import os
  4. from typing import AsyncGenerator, Dict
  5. from fastapi import BackgroundTasks, Depends, Header, FastAPI, HTTPException, Request
  6. from fastapi.responses import JSONResponse, Response, StreamingResponse
  7. import uvicorn
  8. from aphrodite.engine.args_tools import AsyncEngineArgs, EngineArgs
  9. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  10. from aphrodite.common.sampling_params import SamplingParams
  11. from aphrodite.common.utils import random_uuid
  12. TIMEOUT_KEEP_ALIVE = 5 # seconds.
  13. TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
  14. user_tokens: Dict[str, str] = {}
  15. SECRET_TOKEN = "EMPTY"
  16. app = FastAPI()
  17. engine = None
  18. def get_token(authorization: str = Header(None)):
  19. if authorization is None or not authorization.startswith("Bearer "):
  20. raise HTTPException(status_code=401, detail="Unauthorized access.")
  21. token = authorization.replace("Bearer ", "")
  22. if token != SECRET_TOKEN:
  23. raise HTTPException(status_code=401, detail="Unauthorized access.")
  24. return True
  25. @app.post("/api/v1/generate")
  26. async def generate(request: Request, token: bool = Depends(get_token)) -> Response:
  27. """Generate completion for the request.
  28. The request should be a JSON object with the following fields:
  29. - prompt: the prompt to use for the generation.
  30. - stream: whether to stream the results or not.
  31. - other fields: the sampling parameters (See `SamplingParams` for details).
  32. """
  33. request_dict = await request.json()
  34. prompt = request_dict.pop("prompt")
  35. stream = request_dict.pop("stream", False)
  36. sampling_params = SamplingParams(**request_dict)
  37. request_id = random_uuid()
  38. results_generator = engine.generate(prompt, sampling_params, request_id)
  39. # Streaming case
  40. async def stream_results() -> AsyncGenerator[bytes, None]:
  41. async for request_output in results_generator:
  42. prompt = request_output.prompt
  43. text_outputs = [
  44. prompt + output.text for output in request_output.outputs
  45. ]
  46. ret = {"text": text_outputs}
  47. yield (json.dumps(ret) + "\0").encode("utf-8")
  48. async def abort_request() -> None:
  49. await engine.abort(request_id)
  50. if stream:
  51. background_tasks = BackgroundTasks()
  52. # Abort the request if the client disconnects.
  53. background_tasks.add_task(abort_request)
  54. return StreamingResponse(stream_results(), background=background_tasks)
  55. # Non-streaming case
  56. final_output = None
  57. async for request_output in results_generator:
  58. if await request.is_disconnected():
  59. # Abort the request if the client disconnects.
  60. await engine.abort(request_id)
  61. return Response(status_code=499)
  62. final_output = request_output
  63. assert final_output is not None
  64. prompt = final_output.prompt
  65. text_outputs = [prompt + output.text for output in final_output.outputs]
  66. ret = {"text": text_outputs}
  67. return JSONResponse(ret)
  68. @app.get("/api/v1/model")
  69. async def get_model_name(token: bool = Depends(get_token)) -> JSONResponse:
  70. """Return the model name based on the EngineArgs configuration."""
  71. if engine is not None:
  72. model_name = engine_args.model
  73. result = {"result": model_name}
  74. return JSONResponse(content=result)
  75. else:
  76. return JSONResponse(content={"result": "Read Only"}, status_code=500)
  77. if __name__ == "__main__":
  78. parser = argparse.ArgumentParser()
  79. parser.add_argument("--host", type=str, default="localhost")
  80. parser.add_argument("--port", type=int, default=8000)
  81. parser = AsyncEngineArgs.add_cli_args(parser)
  82. args = parser.parse_args()
  83. engine_args = AsyncEngineArgs.from_cli_args(args)
  84. engine = AsyncAphrodite.from_engine_args(engine_args)
  85. uvicorn.run(app,
  86. host=args.host,
  87. port=args.port,
  88. log_level="debug",
  89. timeout_keep_alive=TIMEOUT_KEEP_ALIVE)