api_server_ooba.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import argparse
  2. import json
  3. import os
  4. from typing import AsyncGenerator, Dict
  5. from fastapi import BackgroundTasks, Depends, Header, FastAPI, HTTPException, Request, APIRouter
  6. from fastapi.middleware.cors import CORSMiddleware
  7. from fastapi.responses import JSONResponse, Response, StreamingResponse
  8. import uvicorn
  9. from pydantic import parse_obj_as
  10. from aphrodite.engine.args_tools import AsyncEngineArgs
  11. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  12. from aphrodite.common.sampling_params import SamplingParams
  13. from aphrodite.common.utils import random_uuid
  14. TIMEOUT_KEEP_ALIVE = 5 # seconds.
  15. TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
  16. app = FastAPI()
  17. engine = None
  18. valid_api_key = 'EMPTY'
  19. app.add_middleware(
  20. CORSMiddleware,
  21. allow_origins=["*"],
  22. allow_credentials=True,
  23. allow_methods=["*"],
  24. allow_headers=["*"],
  25. )
  26. @app.post("/api/v1/generate")
  27. async def generate(request: Request, x_api_key: str = Header(None)) -> Response:
  28. """Generate completion for the request.
  29. The request should be a JSON object with the following fields:
  30. - prompt: the prompt to use for the generation.
  31. - stream: whether to stream the results or not.
  32. - other fields: the sampling parameters (See `SamplingParams` for details).
  33. """
  34. if x_api_key is None or x_api_key != valid_api_key:
  35. raise HTTPException(status_code=401, detail="Unauthorized. Please acquire an API key.")
  36. request_dict = await request.json()
  37. prompt = request_dict.pop("prompt")
  38. stream = request_dict.pop("stream", False)
  39. sampling_params = SamplingParams()
  40. if 'stopping_strings' in request_dict:
  41. request_dict['stop'] = request_dict.pop('stopping_strings')
  42. if 'max_new_tokens' in request_dict:
  43. request_dict['max_tokens'] = request_dict.pop('max_new_tokens')
  44. if 'ban_eos_token' in request_dict:
  45. request_dict['ignore_eos'] = request_dict.pop('ban_eos_token')
  46. if 'top_k' in request_dict and request_dict['top_k'] == 0:
  47. request_dict['top_k'] = -1
  48. for key, value in request_dict.items():
  49. if hasattr(sampling_params, key):
  50. setattr(sampling_params, key, value)
  51. request_id = random_uuid()
  52. results_generator = engine.generate(prompt, sampling_params, request_id)
  53. # Streaming case
  54. async def stream_results() -> AsyncGenerator[bytes, None]:
  55. async for request_output in results_generator:
  56. prompt = request_output.prompt
  57. text_outputs = [
  58. {"text": output.text} for output in request_output.outputs
  59. ]
  60. ret = {"results": text_outputs}
  61. yield (json.dumps(ret) + "\n\n").encode("utf-8")
  62. async def abort_request() -> None:
  63. await engine.abort(request_id)
  64. if stream:
  65. background_tasks = BackgroundTasks()
  66. background_tasks.add_task(abort_request)
  67. return StreamingResponse(stream_results(), background=background_tasks)
  68. # Non-streaming case
  69. final_output = None
  70. async for request_output in results_generator:
  71. if await request.is_disconnected():
  72. # Abort the request if the client disconnects.
  73. await engine.abort(request_id)
  74. return Response(status_code=499)
  75. final_output = request_output
  76. assert final_output is not None
  77. prompt = final_output.prompt
  78. text_outputs = [{"text": output.text} for output in final_output.outputs]
  79. response_data = {"results": text_outputs}
  80. return JSONResponse(response_data)
  81. @app.get("/api/v1/model")
  82. async def get_model_name() -> JSONResponse:
  83. """Return the model name based on the EngineArgs configuration."""
  84. if engine is not None:
  85. model_name = engine_args.model
  86. result = {"result": model_name}
  87. return JSONResponse(content=result)
  88. else:
  89. return JSONResponse(content={"result": "Read Only"}, status_code=500)
  90. if __name__ == "__main__":
  91. parser = argparse.ArgumentParser()
  92. parser.add_argument("--host", type=str, default="localhost")
  93. parser.add_argument("--port", type=int, default=2242)
  94. parser = AsyncEngineArgs.add_cli_args(parser)
  95. args = parser.parse_args()
  96. engine_args = AsyncEngineArgs.from_cli_args(args)
  97. engine = AsyncAphrodite.from_engine_args(engine_args)
  98. uvicorn.run(app,
  99. host=args.host,
  100. port=args.port,
  101. log_level="debug",
  102. timeout_keep_alive=TIMEOUT_KEEP_ALIVE)