api_server.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. import argparse
  2. import json
  3. from typing import AsyncGenerator
  4. from fastapi import (BackgroundTasks, FastAPI, HTTPException, Request)
  5. from fastapi.middleware.cors import CORSMiddleware
  6. from fastapi.responses import JSONResponse, Response, StreamingResponse
  7. import uvicorn
  8. from aphrodite.engine.args_tools import AsyncEngineArgs
  9. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  10. from aphrodite.common.sampling_params import SamplingParams
  11. from aphrodite.common.utils import random_uuid
  12. from aphrodite.common.logits_processor import BanEOSUntil
  13. from aphrodite.common.logger import init_logger
  14. TIMEOUT_KEEP_ALIVE = 5 # seconds.
  15. TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
  16. logger = init_logger(__name__)
  17. app = FastAPI()
  18. engine = None
  19. app.add_middleware(
  20. CORSMiddleware,
  21. allow_origins=["*"],
  22. allow_credentials=True,
  23. allow_methods=["*"],
  24. allow_headers=["*"],
  25. )
  26. parser = argparse.ArgumentParser()
  27. parser.add_argument("--host", type=str, default="localhost")
  28. parser.add_argument("--port", type=int, default=2242)
  29. parser.add_argument("--served-model-name", type=str, default=None)
  30. parser = AsyncEngineArgs.add_cli_args(parser)
  31. args = parser.parse_args()
  32. engine_args = AsyncEngineArgs.from_cli_args(args)
  33. if args.served_model_name is not None:
  34. served_model = args.served_model_name
  35. else:
  36. served_model = engine_args.model
  37. @app.post("/api/v1/generate")
  38. async def generate(request: Request) -> Response:
  39. """Generate completion for the request.
  40. The request should be a JSON object with the following fields:
  41. - prompt: the prompt to use for the generation.
  42. - stream: whether to stream the results or not.
  43. - other fields: the sampling parameters (See `SamplingParams` for details).
  44. """
  45. request_dict = await request.json()
  46. prompt = request_dict.pop("prompt")
  47. stream = request_dict.pop("stream", False)
  48. if "stopping_strings" in request_dict:
  49. request_dict["stop"] = request_dict.pop("stopping_strings")
  50. if "max_new_tokens" in request_dict:
  51. request_dict["max_tokens"] = request_dict.pop("max_new_tokens")
  52. if "min_length" in request_dict:
  53. request_dict["min_tokens"] = request_dict.pop("min_length")
  54. if "ban_eos_token" in request_dict:
  55. request_dict["ignore_eos"] = request_dict.pop("ban_eos_token")
  56. if "top_k" in request_dict and request_dict["top_k"] == 0:
  57. request_dict["top_k"] = -1
  58. request_dict["logits_processors"] = []
  59. min_length = request_dict.pop("min_tokens", 0)
  60. if request_dict.get(
  61. "ignore_eos",
  62. False): # ignore_eos/ban_eos_token is functionally equivalent
  63. # to `min_tokens = max_tokens`
  64. min_length = request_dict.get("max_tokens", 16)
  65. if min_length:
  66. request_dict["logits_processors"].append(
  67. BanEOSUntil(min_length, engine.engine.tokenizer.eos_token_id))
  68. sampling_params = SamplingParams()
  69. for key, value in request_dict.items():
  70. if hasattr(sampling_params, key):
  71. setattr(sampling_params, key, value)
  72. try:
  73. sampling_params.verify()
  74. except Exception as err:
  75. raise HTTPException(status_code=422, detail=str(err)) from err
  76. request_id = random_uuid()
  77. results_generator = engine.generate(prompt, sampling_params, request_id)
  78. # Streaming case
  79. async def stream_results() -> AsyncGenerator[bytes, None]:
  80. async for request_output in results_generator:
  81. # prompt = request_output.prompt
  82. text_outputs = [{
  83. "text": output.text
  84. } for output in request_output.outputs]
  85. ret = {"results": text_outputs}
  86. yield (json.dumps(ret) + "\n\n").encode("utf-8")
  87. async def abort_request() -> None:
  88. await engine.abort(request_id)
  89. if stream:
  90. background_tasks = BackgroundTasks()
  91. background_tasks.add_task(abort_request)
  92. return StreamingResponse(stream_results(), background=background_tasks)
  93. # Non-streaming case
  94. final_output = None
  95. async for request_output in results_generator:
  96. if await request.is_disconnected():
  97. # Abort the request if the client disconnects.
  98. await engine.abort(request_id)
  99. return Response(status_code=499)
  100. final_output = request_output
  101. assert final_output is not None
  102. prompt = final_output.prompt
  103. text_outputs = [{"text": output.text} for output in final_output.outputs]
  104. response_data = {"results": text_outputs}
  105. return JSONResponse(response_data)
  106. @app.get("/api/v1/model")
  107. async def get_model_name() -> JSONResponse:
  108. """Return the model name based on the EngineArgs configuration."""
  109. if engine is not None:
  110. result = {"result": f"aphrodite/{served_model}"}
  111. return JSONResponse(content=result)
  112. else:
  113. return JSONResponse(content={"result": "Read Only"}, status_code=500)
  114. @app.get("/health")
  115. async def health() -> Response:
  116. """Health check route for K8s"""
  117. return Response(status_code=200)
  118. if __name__ == "__main__":
  119. engine_args = AsyncEngineArgs.from_cli_args(args)
  120. engine = AsyncAphrodite.from_engine_args(engine_args)
  121. logger.warning("Deprecation warning: The legacy oobabooga API"
  122. " is deprecated and will be removed in a future release.")
  123. uvicorn.run(app,
  124. host=args.host,
  125. port=args.port,
  126. log_level="debug",
  127. timeout_keep_alive=TIMEOUT_KEEP_ALIVE)