api_server.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661
  1. import asyncio
  2. import importlib
  3. import inspect
  4. import json
  5. import os
  6. import re
  7. from argparse import Namespace
  8. from contextlib import asynccontextmanager
  9. from http import HTTPStatus
  10. from multiprocessing import Process
  11. from typing import AsyncGenerator, AsyncIterator, List, Set, Tuple
  12. import uvloop
  13. from fastapi import APIRouter, FastAPI, Request
  14. from fastapi.exceptions import RequestValidationError
  15. from fastapi.middleware.cors import CORSMiddleware
  16. from fastapi.responses import (HTMLResponse, JSONResponse, Response,
  17. StreamingResponse)
  18. from loguru import logger
  19. from prometheus_client import make_asgi_app
  20. from starlette.routing import Mount
  21. from aphrodite.common.config import ModelConfig
  22. from aphrodite.common.outputs import RequestOutput
  23. from aphrodite.common.sampling_params import _SAMPLING_EPS, SamplingParams
  24. from aphrodite.common.utils import (FlexibleArgumentParser, get_open_port,
  25. random_uuid)
  26. from aphrodite.endpoints.logger import RequestLogger
  27. from aphrodite.endpoints.openai.args import make_arg_parser
  28. # yapf: disable
  29. from aphrodite.endpoints.openai.protocol import (ChatCompletionRequest,
  30. ChatCompletionResponse,
  31. CompletionRequest,
  32. DetokenizeRequest,
  33. DetokenizeResponse,
  34. EmbeddingRequest,
  35. ErrorResponse,
  36. KAIGenerationInputSchema,
  37. TokenizeRequest,
  38. TokenizeResponse)
  39. from aphrodite.endpoints.openai.rpc.client import AsyncEngineRPCClient
  40. from aphrodite.endpoints.openai.rpc.server import run_rpc_server
  41. # yapf: enable
  42. from aphrodite.endpoints.openai.serving_chat import OpenAIServingChat
  43. from aphrodite.endpoints.openai.serving_completions import (
  44. OpenAIServingCompletion)
  45. from aphrodite.endpoints.openai.serving_embedding import OpenAIServingEmbedding
  46. from aphrodite.endpoints.openai.serving_tokenization import (
  47. OpenAIServingTokenization)
  48. from aphrodite.engine.args_tools import AsyncEngineArgs
  49. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  50. from aphrodite.engine.protocol import AsyncEngineClient
  51. from aphrodite.server import serve_http
  52. from aphrodite.transformers_utils.tokenizer import get_tokenizer
  53. from aphrodite.version import __version__ as APHRODITE_VERSION
  54. TIMEOUT_KEEP_ALIVE = 5 # seconds
  55. APHRODITE_RPC_PORT = int(os.getenv("APHRODITE_RPC_PORT", '5570'))
  56. async_engine_client: AsyncEngineClient
  57. engine_args: AsyncEngineArgs
  58. openai_serving_chat: OpenAIServingChat
  59. openai_serving_completion: OpenAIServingCompletion
  60. openai_serving_embedding: OpenAIServingEmbedding
  61. openai_serving_tokenization: OpenAIServingTokenization
  62. router = APIRouter()
  63. kai_api = APIRouter()
  64. extra_api = APIRouter()
  65. kobold_lite_ui = ""
  66. sampler_json = ""
  67. gen_cache: dict = {}
  68. _running_tasks: Set[asyncio.Task] = set()
  69. def model_is_embedding(model_name: str, trust_remote_code: bool) -> bool:
  70. return ModelConfig(model=model_name,
  71. tokenizer=model_name,
  72. tokenizer_mode="auto",
  73. trust_remote_code=trust_remote_code,
  74. seed=0,
  75. dtype="auto").embedding_mode
  76. @asynccontextmanager
  77. async def lifespan(app: FastAPI):
  78. async def _force_log():
  79. while True:
  80. await asyncio.sleep(10)
  81. await async_engine_client.do_log_stats()
  82. if not engine_args.disable_log_stats:
  83. task = asyncio.create_task(_force_log())
  84. _running_tasks.add(task)
  85. task.add_done_callback(_running_tasks.remove)
  86. yield
  87. @asynccontextmanager
  88. async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
  89. # Context manager to handle async_engine_client lifecycle
  90. # Ensures everything is shutdown and cleaned up on error/exit
  91. global engine_args
  92. engine_args = AsyncEngineArgs.from_cli_args(args)
  93. # Backend itself still global for the silly lil' health handler
  94. global async_engine_client
  95. # If manually triggered or embedding model, use AsyncAphrodite in process.
  96. # TODO: support embedding model via RPC.
  97. if (model_is_embedding(args.model, args.trust_remote_code)
  98. or args.disable_frontend_multiprocessing):
  99. async_engine_client = AsyncAphrodite.from_engine_args(engine_args)
  100. yield async_engine_client
  101. return
  102. # Otherwise, use the multiprocessing AsyncAphrodite.
  103. else:
  104. # Start RPCServer in separate process (holds the AsyncAphrodite).
  105. port = get_open_port(APHRODITE_RPC_PORT)
  106. rpc_server_process = Process(target=run_rpc_server,
  107. args=(engine_args, port))
  108. rpc_server_process.start()
  109. # Build RPCClient, which conforms to AsyncEngineClient Protocol.
  110. async_engine_client = AsyncEngineRPCClient(port)
  111. await async_engine_client.setup()
  112. try:
  113. yield async_engine_client
  114. finally:
  115. # Ensure rpc server process was terminated
  116. rpc_server_process.terminate()
  117. # Close all open connections to the backend
  118. async_engine_client.close()
  119. # Wait for server process to join
  120. rpc_server_process.join()
  121. def mount_metrics(app: FastAPI):
  122. # Add prometheus asgi middleware to route /metrics requests
  123. metrics_route = Mount("/metrics", make_asgi_app())
  124. # Workaround for 307 Redirect for /metrics
  125. metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')
  126. app.routes.append(metrics_route)
  127. @router.get("/health")
  128. async def health() -> Response:
  129. """Health check."""
  130. await async_engine_client.check_health()
  131. return Response(status_code=200)
  132. @router.post("/v1/tokenize")
  133. async def tokenize(request: TokenizeRequest):
  134. generator = await openai_serving_tokenization.create_tokenize(request)
  135. if isinstance(generator, ErrorResponse):
  136. return JSONResponse(content=generator.model_dump(),
  137. status_code=generator.code)
  138. else:
  139. assert isinstance(generator, TokenizeResponse)
  140. return JSONResponse(content=generator.model_dump())
  141. @router.post("/v1/detokenize")
  142. async def detokenize(request: DetokenizeRequest):
  143. generator = await openai_serving_tokenization.create_detokenize(request)
  144. if isinstance(generator, ErrorResponse):
  145. return JSONResponse(content=generator.model_dump(),
  146. status_code=generator.code)
  147. else:
  148. assert isinstance(generator, DetokenizeResponse)
  149. return JSONResponse(content=generator.model_dump())
  150. @router.get("/v1/models")
  151. async def show_available_models():
  152. models = await openai_serving_completion.show_available_models()
  153. return JSONResponse(content=models.model_dump())
  154. @router.get("/version")
  155. async def show_version():
  156. ver = {"version": APHRODITE_VERSION}
  157. return JSONResponse(content=ver)
  158. @router.post("/v1/chat/completions")
  159. async def create_chat_completion(request: ChatCompletionRequest,
  160. raw_request: Request):
  161. generator = await openai_serving_chat.create_chat_completion(
  162. request, raw_request)
  163. if isinstance(generator, ErrorResponse):
  164. return JSONResponse(content=generator.model_dump(),
  165. status_code=generator.code)
  166. if request.stream:
  167. return StreamingResponse(content=generator,
  168. media_type="text/event-stream")
  169. else:
  170. assert isinstance(generator, ChatCompletionResponse)
  171. return JSONResponse(content=generator.model_dump())
  172. @router.post("/v1/completions")
  173. async def create_completion(request: CompletionRequest, raw_request: Request):
  174. generator = await openai_serving_completion.create_completion(
  175. request, raw_request)
  176. if isinstance(generator, ErrorResponse):
  177. return JSONResponse(content=generator.model_dump(),
  178. status_code=generator.code)
  179. if request.stream:
  180. return StreamingResponse(content=generator,
  181. media_type="text/event-stream")
  182. else:
  183. return JSONResponse(content=generator.model_dump())
  184. @router.post("/v1/embeddings")
  185. async def create_embedding(request: EmbeddingRequest, raw_request: Request):
  186. generator = await openai_serving_embedding.create_embedding(
  187. request, raw_request)
  188. if isinstance(generator, ErrorResponse):
  189. return JSONResponse(content=generator.model_dump(),
  190. status_code=generator.code)
  191. else:
  192. return JSONResponse(content=generator.model_dump())
  193. # ============ KoboldAI API ============ #
  194. def _set_badwords(tokenizer, hf_config): # pylint: disable=redefined-outer-name
  195. # pylint: disable=global-variable-undefined
  196. global badwordsids
  197. if hf_config.bad_words_ids is not None:
  198. badwordsids = hf_config.bad_words_ids
  199. return
  200. badwordsids = [
  201. v for k, v in tokenizer.get_vocab().items()
  202. if any(c in str(k) for c in "[]")
  203. ]
  204. if tokenizer.pad_token_id in badwordsids:
  205. badwordsids.remove(tokenizer.pad_token_id)
  206. badwordsids.append(tokenizer.eos_token_id)
  207. def prepare_engine_payload(
  208. kai_payload: KAIGenerationInputSchema
  209. ) -> Tuple[SamplingParams, List[int]]:
  210. """Create SamplingParams and truncated input tokens for AsyncEngine"""
  211. if not kai_payload.genkey:
  212. kai_payload.genkey = f"kai-{random_uuid()}"
  213. # if kai_payload.max_context_length > engine_args.max_model_len:
  214. # raise ValueError(
  215. # f"max_context_length ({kai_payload.max_context_length}) "
  216. # "must be less than or equal to "
  217. # f"max_model_len ({engine_args.max_model_len})")
  218. kai_payload.top_k = kai_payload.top_k if kai_payload.top_k != 0.0 else -1
  219. kai_payload.tfs = max(_SAMPLING_EPS, kai_payload.tfs)
  220. if kai_payload.temperature < _SAMPLING_EPS:
  221. kai_payload.n = 1
  222. kai_payload.top_p = 1.0
  223. kai_payload.top_k = -1
  224. sampling_params = SamplingParams(
  225. n=kai_payload.n,
  226. best_of=kai_payload.n,
  227. repetition_penalty=kai_payload.rep_pen,
  228. temperature=kai_payload.temperature,
  229. smoothing_factor=kai_payload.smoothing_factor,
  230. smoothing_curve=kai_payload.smoothing_curve,
  231. tfs=kai_payload.tfs,
  232. top_p=kai_payload.top_p,
  233. top_k=kai_payload.top_k,
  234. top_a=kai_payload.top_a,
  235. min_p=kai_payload.min_p,
  236. typical_p=kai_payload.typical,
  237. eta_cutoff=kai_payload.eta_cutoff,
  238. epsilon_cutoff=kai_payload.eps_cutoff,
  239. stop=kai_payload.stop_sequence,
  240. include_stop_str_in_output=kai_payload.include_stop_str_in_output,
  241. custom_token_bans=badwordsids
  242. if kai_payload.use_default_badwordsids else [],
  243. max_tokens=kai_payload.max_length,
  244. seed=kai_payload.sampler_seed,
  245. )
  246. max_input_tokens = max(
  247. 1, kai_payload.max_context_length - kai_payload.max_length)
  248. input_tokens = tokenizer(kai_payload.prompt).input_ids[-max_input_tokens:]
  249. return sampling_params, input_tokens
  250. @kai_api.post("/generate")
  251. async def generate(kai_payload: KAIGenerationInputSchema) -> JSONResponse:
  252. sampling_params, input_tokens = prepare_engine_payload(kai_payload)
  253. result_generator = async_engine_client.generate(
  254. {
  255. "prompt": kai_payload.prompt,
  256. "prompt_token_ids": input_tokens,
  257. },
  258. sampling_params,
  259. kai_payload.genkey,
  260. )
  261. final_res: RequestOutput = None
  262. previous_output = ""
  263. async for res in result_generator:
  264. final_res = res
  265. new_chunk = res.outputs[0].text[len(previous_output):]
  266. previous_output += new_chunk
  267. gen_cache[kai_payload.genkey] = previous_output
  268. assert final_res is not None
  269. del gen_cache[kai_payload.genkey]
  270. return JSONResponse(
  271. {"results": [{
  272. "text": output.text
  273. } for output in final_res.outputs]})
  274. @extra_api.post("/generate/stream")
  275. async def generate_stream(
  276. kai_payload: KAIGenerationInputSchema) -> StreamingResponse:
  277. sampling_params, input_tokens = prepare_engine_payload(kai_payload)
  278. results_generator = async_engine_client.generate(
  279. {
  280. "prompt": kai_payload.prompt,
  281. "prompt_token_ids": input_tokens,
  282. },
  283. sampling_params,
  284. kai_payload.genkey,
  285. )
  286. async def stream_kobold() -> AsyncGenerator[bytes, None]:
  287. previous_output = ""
  288. async for res in results_generator:
  289. new_chunk = res.outputs[0].text[len(previous_output):]
  290. previous_output += new_chunk
  291. yield b"event: message\n"
  292. yield f"data: {json.dumps({'token': new_chunk})}\n\n".encode()
  293. return StreamingResponse(stream_kobold(),
  294. headers={
  295. "Cache-Control": "no-cache",
  296. "Connection": "keep-alive",
  297. },
  298. media_type="text/event-stream")
  299. @extra_api.post("/generate/check")
  300. @extra_api.get("/generate/check")
  301. async def check_generation(request: Request):
  302. text = ""
  303. try:
  304. request_dict = await request.json()
  305. if "genkey" in request_dict and request_dict["genkey"] in gen_cache:
  306. text = gen_cache[request_dict["genkey"]]
  307. except json.JSONDecodeError:
  308. pass
  309. return JSONResponse({"results": [{"text": text}]})
  310. @extra_api.post("/abort")
  311. async def abort_generation(request: Request):
  312. try:
  313. request_dict = await request.json()
  314. if "genkey" in request_dict:
  315. await async_engine_client.abort(request_dict["genkey"])
  316. except json.JSONDecodeError:
  317. pass
  318. return JSONResponse({})
  319. @extra_api.post("/tokencount")
  320. async def count_tokens(request: TokenizeRequest):
  321. """Tokenize string and return token count"""
  322. generator = await openai_serving_tokenization.create_tokenize(request)
  323. return JSONResponse({"value": generator.model_dump()["tokens"]})
  324. @kai_api.get("/info/version")
  325. async def get_version():
  326. """Impersonate KAI"""
  327. return JSONResponse({"result": "1.2.4"})
  328. @kai_api.get("/model")
  329. async def get_model():
  330. return JSONResponse({"result": f"aphrodite/{served_model_names[0]}"})
  331. @kai_api.get("/config/soft_prompts_list")
  332. async def get_available_softprompts():
  333. """Stub for compatibility"""
  334. return JSONResponse({"values": []})
  335. @kai_api.get("/config/soft_prompt")
  336. async def get_current_softprompt():
  337. """Stub for compatibility"""
  338. return JSONResponse({"value": ""})
  339. @kai_api.put("/config/soft_prompt")
  340. async def set_current_softprompt():
  341. """Stub for compatibility"""
  342. return JSONResponse({})
  343. @kai_api.get("/config/max_length")
  344. async def get_max_length() -> JSONResponse:
  345. max_length = args.max_length
  346. return JSONResponse({"value": max_length})
  347. @kai_api.get("/config/max_context_length")
  348. @extra_api.get("/true_max_context_length")
  349. async def get_max_context_length() -> JSONResponse:
  350. max_context_length = engine_args.max_model_len
  351. return JSONResponse({"value": max_context_length})
  352. @extra_api.get("/preloadstory")
  353. async def get_preloaded_story() -> JSONResponse:
  354. """Stub for compatibility"""
  355. return JSONResponse({})
  356. @extra_api.get("/version")
  357. async def get_extra_version():
  358. """Impersonate KoboldCpp"""
  359. return JSONResponse({"result": "KoboldCpp", "version": "1.63"})
  360. @router.get("/")
  361. async def get_kobold_lite_ui():
  362. """Serves a cached copy of the Kobold Lite UI, loading it from disk
  363. on demand if needed."""
  364. global kobold_lite_ui
  365. if kobold_lite_ui == "":
  366. scriptpath = os.path.dirname(os.path.abspath(__file__))
  367. klitepath = os.path.join(scriptpath, "../kobold/klite.embd")
  368. klitepath = os.path.normpath(klitepath) # Normalize the path
  369. if os.path.exists(klitepath):
  370. with open(klitepath, "r") as f:
  371. kobold_lite_ui = f.read()
  372. else:
  373. logger.error("Kobold Lite UI not found at " + klitepath)
  374. return HTMLResponse(content=kobold_lite_ui)
  375. # ============ KoboldAI API ============ #
  376. def build_app(args: Namespace) -> FastAPI:
  377. app = FastAPI(lifespan=lifespan)
  378. app.include_router(router)
  379. # Add prometheus asgi middleware to route /metrics requests
  380. route = Mount("/metrics", make_asgi_app())
  381. route.path_regex = re.compile('^/metrics(?P<path>.*)$')
  382. app.routes.append(route)
  383. app.root_path = args.root_path
  384. if args.launch_kobold_api:
  385. logger.warning("Launching Kobold API server in addition to OpenAI. "
  386. "Keep in mind that the Kobold API routes are NOT "
  387. "protected via the API key.")
  388. app.include_router(kai_api, prefix="/api/v1")
  389. app.include_router(kai_api,
  390. prefix="/api/latest",
  391. include_in_schema=False)
  392. app.include_router(extra_api, prefix="/api/extra")
  393. mount_metrics(app)
  394. app.add_middleware(
  395. CORSMiddleware,
  396. allow_origins=args.allowed_origins,
  397. allow_credentials=args.allow_credentials,
  398. allow_methods=args.allowed_methods,
  399. allow_headers=args.allowed_headers,
  400. )
  401. @app.exception_handler(RequestValidationError)
  402. async def validation_exception_handler(_, exc):
  403. err = openai_serving_completion.create_error_response(message=str(exc))
  404. return JSONResponse(err.model_dump(),
  405. status_code=HTTPStatus.BAD_REQUEST)
  406. if token := os.environ.get("APHRODITE_API_KEY") or args.api_keys:
  407. admin_key = os.environ.get("APHRODITE_ADMIN_KEY") or args.admin_key
  408. if admin_key is None:
  409. logger.warning("Admin key not provided. Admin operations will "
  410. "be disabled.")
  411. @app.middleware("http")
  412. async def authentication(request: Request, call_next):
  413. excluded_paths = ["/api"]
  414. if any(
  415. request.url.path.startswith(path)
  416. for path in excluded_paths):
  417. return await call_next(request)
  418. if not request.url.path.startswith("/v1"):
  419. return await call_next(request)
  420. # Browsers may send OPTIONS requests to check CORS headers
  421. # before sending the actual request. We should allow these
  422. # requests to pass through without authentication.
  423. # See https://github.com/PygmalionAI/aphrodite-engine/issues/434
  424. if request.method == "OPTIONS":
  425. return await call_next(request)
  426. auth_header = request.headers.get("Authorization")
  427. api_key_header = request.headers.get("x-api-key")
  428. if auth_header != "Bearer " + token and api_key_header != token:
  429. return JSONResponse(content={"error": "Unauthorized"},
  430. status_code=401)
  431. return await call_next(request)
  432. for middleware in args.middleware:
  433. module_path, object_name = middleware.rsplit(".", 1)
  434. imported = getattr(importlib.import_module(module_path), object_name)
  435. if inspect.isclass(imported):
  436. app.add_middleware(imported)
  437. elif inspect.iscoroutinefunction(imported):
  438. app.middleware("http")(imported)
  439. else:
  440. raise ValueError(f"Invalid middleware {middleware}. "
  441. f"Must be a function or a class.")
  442. return app
  443. async def init_app(
  444. async_engine_client: AsyncEngineClient,
  445. args: Namespace,
  446. ) -> FastAPI:
  447. app = build_app(args)
  448. logger.debug(f"args: {args}")
  449. global served_model_names
  450. if args.served_model_name is not None:
  451. served_model_names = args.served_model_name
  452. else:
  453. served_model_names = [args.model]
  454. if args.uvloop:
  455. uvloop.install()
  456. global tokenizer
  457. model_config = await async_engine_client.get_model_config()
  458. if args.disable_log_requests:
  459. request_logger = None
  460. else:
  461. request_logger = RequestLogger(max_log_len=args.max_log_len)
  462. global openai_serving_chat
  463. global openai_serving_completion
  464. global openai_serving_embedding
  465. global openai_serving_tokenization
  466. openai_serving_chat = OpenAIServingChat(
  467. async_engine_client,
  468. model_config,
  469. served_model_names,
  470. args.response_role,
  471. lora_modules=args.lora_modules,
  472. prompt_adapters=args.prompt_adapters,
  473. request_logger=request_logger,
  474. chat_template=args.chat_template,
  475. return_tokens_as_token_ids=args.return_tokens_as_token_ids,
  476. )
  477. openai_serving_completion = OpenAIServingCompletion(
  478. async_engine_client,
  479. model_config,
  480. served_model_names,
  481. lora_modules=args.lora_modules,
  482. prompt_adapters=args.prompt_adapters,
  483. request_logger=request_logger,
  484. return_tokens_as_token_ids=args.return_tokens_as_token_ids,
  485. )
  486. openai_serving_embedding = OpenAIServingEmbedding(
  487. async_engine_client,
  488. model_config,
  489. served_model_names,
  490. request_logger=request_logger,
  491. )
  492. openai_serving_tokenization = OpenAIServingTokenization(
  493. async_engine_client,
  494. model_config,
  495. served_model_names,
  496. lora_modules=args.lora_modules,
  497. request_logger=request_logger,
  498. chat_template=args.chat_template,
  499. )
  500. app.root_path = args.root_path
  501. tokenizer = get_tokenizer(
  502. tokenizer_name=engine_args.tokenizer,
  503. tokenizer_mode=engine_args.tokenizer_mode,
  504. trust_remote_code=engine_args.trust_remote_code,
  505. revision=engine_args.revision,
  506. )
  507. if args.launch_kobold_api:
  508. _set_badwords(tokenizer, model_config.hf_config)
  509. return app
  510. async def run_server(args, **uvicorn_kwargs) -> None:
  511. async with build_async_engine_client(args) as async_engine_client:
  512. app = await init_app(async_engine_client, args)
  513. shutdown_task = await serve_http(
  514. app,
  515. host=args.host,
  516. port=args.port,
  517. log_level=args.uvicorn_log_level,
  518. timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
  519. ssl_keyfile=args.ssl_keyfile,
  520. ssl_certfile=args.ssl_certfile,
  521. ssl_ca_certs=args.ssl_ca_certs,
  522. ssl_cert_reqs=args.ssl_cert_reqs,
  523. **uvicorn_kwargs,
  524. )
  525. # NB: Await server shutdown only after the backend context is exited
  526. await shutdown_task
  527. if __name__ == "__main__":
  528. # NOTE:
  529. # This section should be in sync with aphrodite/endpoints/cli.py
  530. # for CLI entrypoints.
  531. parser = FlexibleArgumentParser(
  532. description="Aphrodite OpenAI-Compatible RESTful API Server")
  533. parser = make_arg_parser(parser)
  534. args = parser.parse_args()
  535. asyncio.run(run_server(args))