api_server_alt.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. import argparse
  2. import json
  3. import os
  4. from typing import AsyncGenerator, Dict
  5. from fastapi import BackgroundTasks, Depends, Header, FastAPI, HTTPException, Request
  6. from fastapi.responses import JSONResponse, Response, StreamingResponse
  7. import uvicorn
  8. from pydantic import parse_obj_as
  9. from aphrodite.engine.args_tools import AsyncEngineArgs
  10. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  11. from aphrodite.common.sampling_params import SamplingParams
  12. from aphrodite.common.utils import random_uuid
  13. TIMEOUT_KEEP_ALIVE = 5 # seconds.
  14. TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
  15. app = FastAPI()
  16. engine = None
  17. # user_tokens: Dict[str, str] = {}
  18. # def get_token(authorization: str = Header(None)):
  19. # if authorization is None or not authorization.startswith("Bearer "):
  20. # raise HTTPException(status_code=401, detail="Unauthorized access.")
  21. # token = authorization.replace("Bearer ", "")
  22. # # Check if the token exists in the user_tokens dictionary
  23. # if token not in user_tokens:
  24. # raise HTTPException(status_code=401, detail="Unauthorized access.")
  25. # return True
  26. # def generate_user_token(user_id: str) -> str:
  27. # token = random_uuid()
  28. # user_tokens[token] = user_id
  29. # return token
  30. @app.post("/api/v1/generate")
  31. # async def generate(request: Request, token: bool = Depends(get_token), params: SamplingParams) -> Response:
  32. async def generate(request: Request) -> Response:
  33. """Generate completion for the request.
  34. The request should be a JSON object with the following fields:
  35. - prompt: the prompt to use for the generation.
  36. - stream: whether to stream the results or not.
  37. - other fields: the sampling parameters (See `SamplingParams` for details).
  38. """
  39. request_dict = await request.json()
  40. prompt = request_dict.pop("prompt")
  41. stream = request_dict.pop("stream", False)
  42. sampling_params = SamplingParams()
  43. if 'stop_sequence' in request_dict:
  44. request_dict['stop'] = request_dict.pop('stop_sequence')
  45. if 'max_length' in request_dict:
  46. request_dict['max_tokens'] = request_dict.pop('max_length')
  47. if 'rep_pen' in request_dict:
  48. request_dict['frequency_penalty'] = request_dict.pop('rep_pen')
  49. for key, value in request_dict.items():
  50. if hasattr(sampling_params, key):
  51. setattr(sampling_params, key, value)
  52. # sampling_params = SamplingParams(**sampling_params_data)
  53. # param_aliases = {
  54. # 'stop_sequence': 'stop',
  55. # 'max_length': 'max_tokens',
  56. # 'rep_pen': 'frequency_penalty',
  57. # 'use_story': None,
  58. # 'use_memory': None,
  59. # 'use_authors_note': None,
  60. # 'use_world_info': None,
  61. # 'max_context_length': None,
  62. # 'rep_pen_range': None,
  63. # 'rep_pen_slope': None,
  64. # 'tfs': None,
  65. # 'top_a': None,
  66. # 'typical': None,
  67. # 'sampler_order': None,
  68. # 'singleline': None,
  69. # 'use_default_badwordsids': None,
  70. # 'mirostat': None,
  71. # 'mirostat_eta': None,
  72. # 'mirostat_tau': None,
  73. # }
  74. # sampling_params = SamplingParams(**request_dict)
  75. request_id = random_uuid()
  76. results_generator = engine.generate(prompt, sampling_params, request_id)
  77. # Streaming case
  78. async def stream_results() -> AsyncGenerator[bytes, None]:
  79. async for request_output in results_generator:
  80. prompt = request_output.prompt
  81. text_outputs = [
  82. prompt + output.text for output in request_output.outputs
  83. ]
  84. ret = {"text": text_outputs}
  85. yield (json.dumps(ret) + "\0").encode("utf-8")
  86. async def abort_request() -> None:
  87. await engine.abort(request_id)
  88. if stream:
  89. background_tasks = BackgroundTasks()
  90. # Abort the request if the client disconnects.
  91. background_tasks.add_task(abort_request)
  92. return StreamingResponse(stream_results(), background=background_tasks)
  93. # Non-streaming case
  94. final_output = None
  95. async for request_output in results_generator:
  96. if await request.is_disconnected():
  97. # Abort the request if the client disconnects.
  98. await engine.abort(request_id)
  99. return Response(status_code=499)
  100. final_output = request_output
  101. assert final_output is not None
  102. prompt = final_output.prompt
  103. text_outputs = [prompt + output.text for output in final_output.outputs]
  104. ret = {"text": text_outputs}
  105. return JSONResponse(ret)
  106. @app.get("/api/v1/model")
  107. # async def get_model_name(token: bool = Depends(get_token)) -> JSONResponse:
  108. async def get_model_name() -> JSONResponse:
  109. """Return the model name based on the EngineArgs configuration."""
  110. if engine is not None:
  111. model_name = engine_args.model
  112. result = {"result": model_name}
  113. return JSONResponse(content=result)
  114. else:
  115. return JSONResponse(content={"result": "Read Only"}, status_code=500)
  116. # @app.post("/api/v1/get-token")
  117. # async def get_user_token(user_id: str) -> JSONResponse:
  118. # token = generate_user_token(user_id)
  119. # return JSONResponse(content={"token": token})
  120. if __name__ == "__main__":
  121. parser = argparse.ArgumentParser()
  122. parser.add_argument("--host", type=str, default="localhost")
  123. parser.add_argument("--port", type=int, default=8000)
  124. parser = AsyncEngineArgs.add_cli_args(parser)
  125. args = parser.parse_args()
  126. engine_args = AsyncEngineArgs.from_cli_args(args)
  127. engine = AsyncAphrodite.from_engine_args(engine_args)
  128. uvicorn.run(app,
  129. host=args.host,
  130. port=args.port,
  131. log_level="debug",
  132. timeout_keep_alive=TIMEOUT_KEEP_ALIVE)