api_server_ooba.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. import argparse
  2. import json
  3. import os
  4. from typing import AsyncGenerator, Dict
  5. from fastapi import BackgroundTasks, Depends, Header, FastAPI, HTTPException, Request
  6. from fastapi.responses import JSONResponse, Response, StreamingResponse
  7. import uvicorn
  8. from pydantic import parse_obj_as
  9. from aphrodite.engine.args_tools import AsyncEngineArgs
  10. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  11. from aphrodite.common.sampling_params import SamplingParams
  12. from aphrodite.common.utils import random_uuid
  13. TIMEOUT_KEEP_ALIVE = 5 # seconds.
  14. TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
  15. app = FastAPI()
  16. engine = None
  17. valid_api_key = os.getenv("X_API_KEY")
  18. @app.post("/api/v1/generate")
  19. async def generate(request: Request, x_api_key: str = Header(None)) -> Response:
  20. """Generate completion for the request.
  21. The request should be a JSON object with the following fields:
  22. - prompt: the prompt to use for the generation.
  23. - stream: whether to stream the results or not.
  24. - other fields: the sampling parameters (See `SamplingParams` for details).
  25. """
  26. if x_api_key is None or x_api_key != valid_api_key:
  27. raise HTTPException(status_code=401, detail="Unauthorized. Please acquire an API key.")
  28. request_dict = await request.json()
  29. prompt = request_dict.pop("prompt")
  30. stream = request_dict.pop("stream", False)
  31. sampling_params = SamplingParams()
  32. if 'stopping_strings' in request_dict:
  33. request_dict['stop'] = request_dict.pop('stopping_strings')
  34. if 'max_new_tokens' in request_dict:
  35. request_dict['max_tokens'] = request_dict.pop('max_new_tokens')
  36. if 'repetition_penalty' in request_dict:
  37. request_dict['frequency_penalty'] = request_dict.pop('repetition_penalty')
  38. if 'ban_eos_token' in request_dict:
  39. request_dict['ignore_eos'] = request_dict.pop('ban_eos_token')
  40. if 'top_k' in request_dict and request_dict['top_k'] == 0:
  41. request_dict['top_k'] = -1
  42. for key, value in request_dict.items():
  43. if hasattr(sampling_params, key):
  44. setattr(sampling_params, key, value)
  45. request_id = random_uuid()
  46. results_generator = engine.generate(prompt, sampling_params, request_id)
  47. # Streaming case
  48. async def stream_results() -> AsyncGenerator[bytes, None]:
  49. async for request_output in results_generator:
  50. prompt = request_output.prompt
  51. text_outputs = [
  52. prompt + output.text for output in request_output.outputs
  53. ]
  54. ret = {"text": text_outputs}
  55. yield (json.dumps(ret) + "\0").encode("utf-8")
  56. async def abort_request() -> None:
  57. await engine.abort(request_id)
  58. if stream:
  59. background_tasks = BackgroundTasks()
  60. background_tasks.add_task(abort_request)
  61. return StreamingResponse(stream_results(), background=background_tasks)
  62. # Non-streaming case
  63. final_output = None
  64. async for request_output in results_generator:
  65. if await request.is_disconnected():
  66. # Abort the request if the client disconnects.
  67. await engine.abort(request_id)
  68. return Response(status_code=499)
  69. final_output = request_output
  70. assert final_output is not None
  71. prompt = final_output.prompt
  72. text_outputs = [{"text": output.text} for output in final_output.outputs]
  73. response_data = {"results": text_outputs}
  74. return JSONResponse(response_data)
  75. @app.get("/api/v1/model")
  76. async def get_model_name() -> JSONResponse:
  77. """Return the model name based on the EngineArgs configuration."""
  78. if engine is not None:
  79. model_name = engine_args.model
  80. result = {"result": model_name}
  81. return JSONResponse(content=result)
  82. else:
  83. return JSONResponse(content={"result": "Read Only"}, status_code=500)
  84. if __name__ == "__main__":
  85. parser = argparse.ArgumentParser()
  86. parser.add_argument("--host", type=str, default="localhost")
  87. parser.add_argument("--port", type=int, default=8000)
  88. parser = AsyncEngineArgs.add_cli_args(parser)
  89. args = parser.parse_args()
  90. engine_args = AsyncEngineArgs.from_cli_args(args)
  91. engine = AsyncAphrodite.from_engine_args(engine_args)
  92. uvicorn.run(app,
  93. host=args.host,
  94. port=args.port,
  95. log_level="debug",
  96. timeout_keep_alive=TIMEOUT_KEEP_ALIVE)