api_server.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802
  1. import asyncio
  2. import importlib
  3. import inspect
  4. import json
  5. import multiprocessing
  6. import os
  7. import re
  8. import tempfile
  9. from argparse import Namespace
  10. from contextlib import asynccontextmanager
  11. from distutils.util import strtobool
  12. from http import HTTPStatus
  13. from typing import AsyncGenerator, AsyncIterator, List, Set, Tuple
  14. from fastapi import APIRouter, FastAPI, Request
  15. from fastapi.exceptions import RequestValidationError
  16. from fastapi.middleware.cors import CORSMiddleware
  17. from fastapi.responses import (HTMLResponse, JSONResponse, Response,
  18. StreamingResponse)
  19. from loguru import logger
  20. from starlette.routing import Mount
  21. from aphrodite.common.config import ModelConfig
  22. from aphrodite.common.outputs import RequestOutput
  23. from aphrodite.common.sampling_params import _SAMPLING_EPS, SamplingParams
  24. from aphrodite.common.utils import (FlexibleArgumentParser,
  25. get_open_zmq_ipc_path, in_windows,
  26. random_uuid)
  27. from aphrodite.endpoints.logger import RequestLogger
  28. from aphrodite.endpoints.openai.args import make_arg_parser
  29. # yapf: disable
  30. from aphrodite.endpoints.openai.protocol import (ChatCompletionRequest,
  31. ChatCompletionResponse,
  32. CompletionRequest,
  33. DetokenizeRequest,
  34. DetokenizeResponse,
  35. EmbeddingRequest,
  36. ErrorResponse,
  37. KAIGenerationInputSchema,
  38. TokenizeRequest,
  39. TokenizeResponse)
  40. from aphrodite.endpoints.openai.rpc.client import AsyncEngineRPCClient
  41. from aphrodite.endpoints.openai.rpc.server import run_rpc_server
  42. # yapf: enable
  43. from aphrodite.endpoints.openai.serving_chat import OpenAIServingChat
  44. from aphrodite.endpoints.openai.serving_completions import (
  45. OpenAIServingCompletion)
  46. from aphrodite.endpoints.openai.serving_embedding import OpenAIServingEmbedding
  47. from aphrodite.endpoints.openai.serving_engine import (LoRAModulePath,
  48. PromptAdapterPath)
  49. from aphrodite.endpoints.openai.serving_tokenization import (
  50. OpenAIServingTokenization)
  51. from aphrodite.engine.args_tools import AsyncEngineArgs
  52. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  53. from aphrodite.engine.protocol import AsyncEngineClient
  54. from aphrodite.server import serve_http
  55. from aphrodite.transformers_utils.tokenizer import get_tokenizer
  56. from aphrodite.version import __version__ as APHRODITE_VERSION
  57. if in_windows():
  58. import winloop as uvloop
  59. else:
  60. import uvloop
  61. TIMEOUT_KEEP_ALIVE = 5 # seconds
  62. SERVE_KOBOLD_LITE_UI = strtobool(os.getenv("SERVE_KOBOLD_LITE_UI", "1"))
  63. async_engine_client: AsyncEngineClient
  64. engine_args: AsyncEngineArgs
  65. openai_serving_chat: OpenAIServingChat
  66. openai_serving_completion: OpenAIServingCompletion
  67. openai_serving_embedding: OpenAIServingEmbedding
  68. openai_serving_tokenization: OpenAIServingTokenization
  69. router = APIRouter()
  70. kai_api = APIRouter()
  71. extra_api = APIRouter()
  72. kobold_lite_ui = ""
  73. sampler_json = ""
  74. gen_cache: dict = {}
  75. prometheus_multiproc_dir: tempfile.TemporaryDirectory
  76. _running_tasks: Set[asyncio.Task] = set()
  77. def model_is_embedding(model_name: str, trust_remote_code: bool) -> bool:
  78. return ModelConfig(model=model_name,
  79. tokenizer=model_name,
  80. tokenizer_mode="auto",
  81. trust_remote_code=trust_remote_code,
  82. seed=0,
  83. dtype="auto").embedding_mode
  84. @asynccontextmanager
  85. async def lifespan(app: FastAPI):
  86. async def _force_log():
  87. while True:
  88. await asyncio.sleep(10)
  89. await async_engine_client.do_log_stats()
  90. if not engine_args.disable_log_stats:
  91. task = asyncio.create_task(_force_log())
  92. _running_tasks.add(task)
  93. task.add_done_callback(_running_tasks.remove)
  94. yield
  95. @asynccontextmanager
  96. async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
  97. # Context manager to handle async_engine_client lifecycle
  98. # Ensures everything is shutdown and cleaned up on error/exit
  99. global engine_args
  100. engine_args = AsyncEngineArgs.from_cli_args(args)
  101. # Backend itself still global for the silly lil' health handler
  102. global async_engine_client
  103. # If manually triggered or embedding model, use AsyncAphrodite in process.
  104. # TODO: support embedding model via RPC.
  105. if (model_is_embedding(args.model, args.trust_remote_code)
  106. or args.disable_frontend_multiprocessing):
  107. async_engine_client = AsyncAphrodite.from_engine_args(engine_args)
  108. yield async_engine_client
  109. return
  110. # Otherwise, use the multiprocessing AsyncAphrodite.
  111. else:
  112. if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
  113. # Make TemporaryDirectory for prometheus multiprocessing
  114. # Note: global TemporaryDirectory will be automatically
  115. # cleaned up upon exit.
  116. global prometheus_multiproc_dir
  117. prometheus_multiproc_dir = tempfile.TemporaryDirectory()
  118. os.environ[
  119. "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
  120. else:
  121. logger.warning(
  122. "Found PROMETHEUS_MULTIPROC_DIR was set by user. "
  123. "This directory must be wiped between Aphrodite runs or "
  124. "you will find inaccurate metrics. Unset the variable "
  125. "and Aphrodite will properly handle cleanup.")
  126. # Select random path for IPC.
  127. rpc_path = get_open_zmq_ipc_path()
  128. logger.info(f"Multiprocessing frontend to use {rpc_path} for RPC Path."
  129. )
  130. # Start RPCServer in separate process (holds the AsyncAphrodite).
  131. context = multiprocessing.get_context("spawn")
  132. # the current process might have CUDA context,
  133. # so we need to spawn a new process
  134. rpc_server_process = context.Process(
  135. target=run_rpc_server,
  136. args=(engine_args, rpc_path))
  137. rpc_server_process.start()
  138. logger.info(
  139. f"Started engine process with PID {rpc_server_process.pid}")
  140. # Build RPCClient, which conforms to AsyncEngineClient Protocol.
  141. async_engine_client = AsyncEngineRPCClient(rpc_path)
  142. try:
  143. while True:
  144. try:
  145. await async_engine_client.setup()
  146. break
  147. except TimeoutError as e:
  148. if not rpc_server_process.is_alive():
  149. raise RuntimeError(
  150. "The server process died before "
  151. "responding to the readiness probe") from e
  152. yield async_engine_client
  153. finally:
  154. # Ensure rpc server process was terminated
  155. rpc_server_process.terminate()
  156. # Close all open connections to the backend
  157. async_engine_client.close()
  158. # Wait for server process to join
  159. rpc_server_process.join()
  160. # Lazy import for prometheus multiprocessing.
  161. # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
  162. # before prometheus_client is imported.
  163. # See https://prometheus.github.io/client_python/multiprocess/
  164. from prometheus_client import multiprocess
  165. multiprocess.mark_process_dead(rpc_server_process.pid)
  166. def mount_metrics(app: FastAPI):
  167. # Lazy import for prometheus multiprocessing.
  168. # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
  169. # before prometheus_client is imported.
  170. # See https://prometheus.github.io/client_python/multiprocess/
  171. from prometheus_client import (CollectorRegistry, make_asgi_app,
  172. multiprocess)
  173. prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
  174. if prometheus_multiproc_dir_path is not None:
  175. logger.info(f"Aphrodite to use {prometheus_multiproc_dir_path} "
  176. "as PROMETHEUS_MULTIPROC_DIR")
  177. registry = CollectorRegistry()
  178. multiprocess.MultiProcessCollector(registry)
  179. # Add prometheus asgi middleware to route /metrics requests
  180. metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
  181. else:
  182. # Add prometheus asgi middleware to route /metrics requests
  183. metrics_route = Mount("/metrics", make_asgi_app())
  184. # Workaround for 307 Redirect for /metrics
  185. metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')
  186. app.routes.append(metrics_route)
  187. @router.get("/health")
  188. async def health() -> Response:
  189. """Health check."""
  190. await async_engine_client.check_health()
  191. return Response(status_code=200)
  192. @router.post("/v1/tokenize")
  193. async def tokenize(request: TokenizeRequest):
  194. generator = await openai_serving_tokenization.create_tokenize(request)
  195. if isinstance(generator, ErrorResponse):
  196. return JSONResponse(content=generator.model_dump(),
  197. status_code=generator.code)
  198. else:
  199. assert isinstance(generator, TokenizeResponse)
  200. return JSONResponse(content=generator.model_dump())
  201. @router.post("/v1/detokenize")
  202. async def detokenize(request: DetokenizeRequest):
  203. generator = await openai_serving_tokenization.create_detokenize(request)
  204. if isinstance(generator, ErrorResponse):
  205. return JSONResponse(content=generator.model_dump(),
  206. status_code=generator.code)
  207. else:
  208. assert isinstance(generator, DetokenizeResponse)
  209. return JSONResponse(content=generator.model_dump())
  210. @router.get("/v1/models")
  211. async def show_available_models():
  212. models = await openai_serving_completion.show_available_models()
  213. return JSONResponse(content=models.model_dump())
  214. @router.get("/version")
  215. async def show_version():
  216. ver = {"version": APHRODITE_VERSION}
  217. return JSONResponse(content=ver)
  218. @router.get("/.well-known/serviceinfo")
  219. async def serviceinfo():
  220. """Return service information including version, API endpoints,
  221. and documentation URLs."""
  222. return JSONResponse(content={
  223. "version": 0.2,
  224. "software": {
  225. "name": "Aphrodite Engine",
  226. "version": APHRODITE_VERSION,
  227. "repository": "https://github.com/PygmalionAI/aphrodite-engine",
  228. "homepage": "https://aphrodite.pygmalion.chat",
  229. "logo": "https://pygmalion.chat/icons/favicon.ico",
  230. },
  231. "api": {
  232. "openai": {
  233. "name": "OpenAI API",
  234. "rel_url": "/v1",
  235. "documentation": "/redoc",
  236. "version": 1,
  237. },
  238. "koboldai": {
  239. "name": "KoboldAI API",
  240. "rel_url": "/api",
  241. "documentation": "/redoc",
  242. "version": 1,
  243. }
  244. }
  245. })
  246. @router.post("/v1/chat/completions")
  247. async def create_chat_completion(request: ChatCompletionRequest,
  248. raw_request: Request):
  249. generator = await openai_serving_chat.create_chat_completion(
  250. request, raw_request)
  251. if isinstance(generator, ErrorResponse):
  252. return JSONResponse(content=generator.model_dump(),
  253. status_code=generator.code)
  254. if request.stream:
  255. return StreamingResponse(content=generator,
  256. media_type="text/event-stream")
  257. else:
  258. assert isinstance(generator, ChatCompletionResponse)
  259. return JSONResponse(content=generator.model_dump())
  260. @router.post("/v1/completions")
  261. async def create_completion(request: CompletionRequest, raw_request: Request):
  262. generator = await openai_serving_completion.create_completion(
  263. request, raw_request)
  264. if isinstance(generator, ErrorResponse):
  265. return JSONResponse(content=generator.model_dump(),
  266. status_code=generator.code)
  267. if request.stream:
  268. return StreamingResponse(content=generator,
  269. media_type="text/event-stream")
  270. else:
  271. return JSONResponse(content=generator.model_dump())
  272. @router.post("/v1/embeddings")
  273. async def create_embedding(request: EmbeddingRequest, raw_request: Request):
  274. generator = await openai_serving_embedding.create_embedding(
  275. request, raw_request)
  276. if isinstance(generator, ErrorResponse):
  277. return JSONResponse(content=generator.model_dump(),
  278. status_code=generator.code)
  279. else:
  280. return JSONResponse(content=generator.model_dump())
  281. @router.post("/v1/lora/load")
  282. async def load_lora(lora: LoRAModulePath):
  283. openai_serving_completion.add_lora(lora)
  284. if engine_args.enable_lora is False:
  285. logger.error("LoRA is not enabled in the engine. "
  286. "Please start the server with the "
  287. "--enable-lora flag!")
  288. return JSONResponse(content={"status": "success"})
  289. @router.delete("/v1/lora/unload")
  290. async def unload_lora(lora_name: str):
  291. openai_serving_completion.remove_lora(lora_name)
  292. return JSONResponse(content={"status": "success"})
  293. @router.post("/v1/soft_prompt/load")
  294. async def load_soft_prompt(soft_prompt: PromptAdapterPath):
  295. openai_serving_completion.add_prompt_adapter(soft_prompt)
  296. if engine_args.enable_prompt_adapter is False:
  297. logger.error("Prompt Adapter is not enabled in the engine. "
  298. "Please start the server with the "
  299. "--enable-prompt-adapter flag!")
  300. return JSONResponse(content={"status": "success"})
  301. @router.delete("/v1/soft_prompt/unload")
  302. async def unload_soft_prompt(soft_prompt_name: str):
  303. openai_serving_completion.remove_prompt_adapter(soft_prompt_name)
  304. return JSONResponse(content={"status": "success"})
  305. # ============ KoboldAI API ============ #
  306. badwordsids: List[int] = []
  307. def _set_badwords(tokenizer, hf_config): # pylint: disable=redefined-outer-name
  308. # pylint: disable=global-variable-undefined
  309. global badwordsids
  310. if hf_config.bad_words_ids is not None:
  311. badwordsids = hf_config.bad_words_ids
  312. return
  313. badwordsids = [
  314. v for k, v in tokenizer.get_vocab().items()
  315. if any(c in str(k) for c in "[]")
  316. ]
  317. if tokenizer.pad_token_id in badwordsids:
  318. badwordsids.remove(tokenizer.pad_token_id)
  319. badwordsids.append(tokenizer.eos_token_id)
  320. def prepare_engine_payload(
  321. kai_payload: KAIGenerationInputSchema
  322. ) -> Tuple[SamplingParams, List[int]]:
  323. """Create SamplingParams and truncated input tokens for AsyncEngine"""
  324. if not kai_payload.genkey:
  325. kai_payload.genkey = f"kai-{random_uuid()}"
  326. kai_payload.top_k = kai_payload.top_k if kai_payload.top_k != 0.0 else -1
  327. kai_payload.tfs = max(_SAMPLING_EPS, kai_payload.tfs)
  328. if kai_payload.temperature < _SAMPLING_EPS:
  329. kai_payload.n = 1
  330. kai_payload.top_p = 1.0
  331. kai_payload.top_k = -1
  332. sampling_params = SamplingParams(
  333. n=kai_payload.n,
  334. best_of=kai_payload.n,
  335. repetition_penalty=kai_payload.rep_pen,
  336. temperature=kai_payload.temperature,
  337. smoothing_factor=kai_payload.smoothing_factor,
  338. smoothing_curve=kai_payload.smoothing_curve,
  339. tfs=kai_payload.tfs,
  340. top_p=kai_payload.top_p,
  341. top_k=kai_payload.top_k,
  342. top_a=kai_payload.top_a,
  343. min_p=kai_payload.min_p,
  344. typical_p=kai_payload.typical,
  345. eta_cutoff=kai_payload.eta_cutoff,
  346. epsilon_cutoff=kai_payload.eps_cutoff,
  347. stop=kai_payload.stop_sequence,
  348. include_stop_str_in_output=kai_payload.include_stop_str_in_output,
  349. custom_token_bans=badwordsids
  350. if kai_payload.use_default_badwordsids else [],
  351. max_tokens=kai_payload.max_length,
  352. seed=kai_payload.sampler_seed,
  353. xtc_probability=kai_payload.xtc_probability,
  354. xtc_threshold=kai_payload.xtc_threshold,
  355. )
  356. max_input_tokens = max(
  357. 1, kai_payload.max_context_length - kai_payload.max_length)
  358. input_tokens = tokenizer(kai_payload.prompt).input_ids[-max_input_tokens:]
  359. return sampling_params, input_tokens
  360. @kai_api.post("/generate")
  361. async def generate(kai_payload: KAIGenerationInputSchema) -> JSONResponse:
  362. sampling_params, input_tokens = prepare_engine_payload(kai_payload)
  363. result_generator = async_engine_client.generate(
  364. {
  365. "prompt": kai_payload.prompt,
  366. "prompt_token_ids": input_tokens,
  367. },
  368. sampling_params,
  369. kai_payload.genkey,
  370. )
  371. final_res: RequestOutput = None
  372. previous_output = ""
  373. async for res in result_generator:
  374. final_res = res
  375. new_chunk = res.outputs[0].text[len(previous_output):]
  376. previous_output += new_chunk
  377. gen_cache[kai_payload.genkey] = previous_output
  378. assert final_res is not None
  379. del gen_cache[kai_payload.genkey]
  380. return JSONResponse(
  381. {"results": [{
  382. "text": output.text
  383. } for output in final_res.outputs]})
  384. @extra_api.post("/generate/stream")
  385. async def generate_stream(
  386. kai_payload: KAIGenerationInputSchema) -> StreamingResponse:
  387. sampling_params, input_tokens = prepare_engine_payload(kai_payload)
  388. results_generator = async_engine_client.generate(
  389. {
  390. "prompt": kai_payload.prompt,
  391. "prompt_token_ids": input_tokens,
  392. },
  393. sampling_params,
  394. kai_payload.genkey,
  395. )
  396. async def stream_kobold() -> AsyncGenerator[bytes, None]:
  397. previous_output = ""
  398. async for res in results_generator:
  399. new_chunk = res.outputs[0].text[len(previous_output):]
  400. previous_output += new_chunk
  401. yield b"event: message\n"
  402. yield f"data: {json.dumps({'token': new_chunk})}\n\n".encode()
  403. return StreamingResponse(stream_kobold(),
  404. headers={
  405. "Cache-Control": "no-cache",
  406. "Connection": "keep-alive",
  407. },
  408. media_type="text/event-stream")
  409. @extra_api.post("/generate/check")
  410. @extra_api.get("/generate/check")
  411. async def check_generation(request: Request):
  412. text = ""
  413. try:
  414. request_dict = await request.json()
  415. if "genkey" in request_dict and request_dict["genkey"] in gen_cache:
  416. text = gen_cache[request_dict["genkey"]]
  417. except json.JSONDecodeError:
  418. pass
  419. return JSONResponse({"results": [{"text": text}]})
  420. @extra_api.post("/abort")
  421. async def abort_generation(request: Request):
  422. try:
  423. request_dict = await request.json()
  424. if "genkey" in request_dict:
  425. await async_engine_client.abort(request_dict["genkey"])
  426. except json.JSONDecodeError:
  427. pass
  428. return JSONResponse({})
  429. @extra_api.post("/tokencount")
  430. async def count_tokens(request: TokenizeRequest):
  431. """Tokenize string and return token count"""
  432. generator = await openai_serving_tokenization.create_tokenize(request)
  433. return JSONResponse({"value": generator.model_dump()["tokens"]})
  434. @kai_api.get("/info/version")
  435. async def get_version():
  436. """Impersonate KAI"""
  437. return JSONResponse({"result": "1.2.4"})
  438. @kai_api.get("/model")
  439. async def get_model():
  440. return JSONResponse({"result": f"aphrodite/{served_model_names[0]}"})
  441. @kai_api.get("/config/soft_prompts_list")
  442. async def get_available_softprompts():
  443. """Stub for compatibility"""
  444. return JSONResponse({"values": []})
  445. @kai_api.get("/config/soft_prompt")
  446. async def get_current_softprompt():
  447. """Stub for compatibility"""
  448. return JSONResponse({"value": ""})
  449. @kai_api.put("/config/soft_prompt")
  450. async def set_current_softprompt():
  451. """Stub for compatibility"""
  452. return JSONResponse({})
  453. @kai_api.get("/config/max_length")
  454. async def get_max_length() -> JSONResponse:
  455. max_length = args.max_length
  456. return JSONResponse({"value": max_length})
  457. @kai_api.get("/config/max_context_length")
  458. @extra_api.get("/true_max_context_length")
  459. async def get_max_context_length() -> JSONResponse:
  460. max_context_length = engine_args.max_model_len
  461. return JSONResponse({"value": max_context_length})
  462. @extra_api.get("/preloadstory")
  463. async def get_preloaded_story() -> JSONResponse:
  464. """Stub for compatibility"""
  465. return JSONResponse({})
  466. @extra_api.get("/version")
  467. async def get_extra_version():
  468. """Impersonate KoboldCpp"""
  469. return JSONResponse({"result": "KoboldCpp", "version": "1.63"})
  470. @router.get("/")
  471. async def get_kobold_lite_ui():
  472. """Serves a cached copy of the Kobold Lite UI, loading it from disk
  473. on demand if needed. Can be disabled with SERVE_KOBOLD_LITE_UI=0."""
  474. if not SERVE_KOBOLD_LITE_UI:
  475. return JSONResponse(content={"error": "Kobold Lite UI is disabled"},
  476. status_code=404)
  477. global kobold_lite_ui
  478. if kobold_lite_ui == "":
  479. scriptpath = os.path.dirname(os.path.abspath(__file__))
  480. klitepath = os.path.join(scriptpath, "../kobold/klite.embd")
  481. klitepath = os.path.normpath(klitepath) # Normalize the path
  482. if os.path.exists(klitepath):
  483. with open(klitepath, "r", encoding="utf-8") as f:
  484. kobold_lite_ui = f.read()
  485. else:
  486. logger.error("Kobold Lite UI not found at " + klitepath)
  487. return HTMLResponse(content=kobold_lite_ui)
  488. # ============ KoboldAI API ============ #
  489. def build_app(args: Namespace) -> FastAPI:
  490. app = FastAPI(lifespan=lifespan)
  491. app.include_router(router)
  492. app.root_path = args.root_path
  493. if args.launch_kobold_api:
  494. logger.warning("Kobold API is now enabled by default. "
  495. "This flag will be removed in the future.")
  496. app.include_router(kai_api, prefix="/api/v1")
  497. app.include_router(kai_api,
  498. prefix="/api/latest",
  499. include_in_schema=False)
  500. app.include_router(extra_api, prefix="/api/extra")
  501. mount_metrics(app)
  502. app.add_middleware(
  503. CORSMiddleware,
  504. allow_origins=args.allowed_origins,
  505. allow_credentials=args.allow_credentials,
  506. allow_methods=args.allowed_methods,
  507. allow_headers=args.allowed_headers,
  508. )
  509. @app.exception_handler(RequestValidationError)
  510. async def validation_exception_handler(_, exc):
  511. err = openai_serving_completion.create_error_response(message=str(exc))
  512. return JSONResponse(err.model_dump(),
  513. status_code=HTTPStatus.BAD_REQUEST)
  514. if token := os.environ.get("APHRODITE_API_KEY") or args.api_keys:
  515. admin_key = os.environ.get("APHRODITE_ADMIN_KEY") or args.admin_key
  516. if admin_key is None:
  517. logger.warning("Admin key not provided. Admin operations will "
  518. "be disabled.")
  519. @app.middleware("http")
  520. async def authentication(request: Request, call_next):
  521. if not request.url.path.startswith(("/v1", "/api")):
  522. return await call_next(request)
  523. # Browsers may send OPTIONS requests to check CORS headers
  524. # before sending the actual request. We should allow these
  525. # requests to pass through without authentication.
  526. # See https://github.com/PygmalionAI/aphrodite-engine/issues/434
  527. if request.method == "OPTIONS":
  528. return await call_next(request)
  529. auth_header = request.headers.get("Authorization")
  530. api_key_header = request.headers.get("x-api-key")
  531. if request.url.path.startswith(("/v1/lora", "/v1/soft_prompt")):
  532. if admin_key is not None and api_key_header == admin_key:
  533. return await call_next(request)
  534. return JSONResponse(content={"error": "Unauthorized"},
  535. status_code=401)
  536. if auth_header != "Bearer " + token and api_key_header != token:
  537. return JSONResponse(content={"error": "Unauthorized"},
  538. status_code=401)
  539. return await call_next(request)
  540. for middleware in args.middleware:
  541. module_path, object_name = middleware.rsplit(".", 1)
  542. imported = getattr(importlib.import_module(module_path), object_name)
  543. if inspect.isclass(imported):
  544. app.add_middleware(imported)
  545. elif inspect.iscoroutinefunction(imported):
  546. app.middleware("http")(imported)
  547. else:
  548. raise ValueError(f"Invalid middleware {middleware}. "
  549. f"Must be a function or a class.")
  550. return app
  551. async def init_app(
  552. async_engine_client: AsyncEngineClient,
  553. args: Namespace,
  554. ) -> FastAPI:
  555. global api_server_args
  556. api_server_args = args
  557. app = build_app(args)
  558. logger.debug(f"args: {args}")
  559. global served_model_names
  560. if args.served_model_name is not None:
  561. served_model_names = args.served_model_name
  562. else:
  563. served_model_names = [args.model]
  564. if args.uvloop:
  565. uvloop.install()
  566. global tokenizer
  567. model_config = await async_engine_client.get_model_config()
  568. if args.disable_log_requests:
  569. request_logger = None
  570. else:
  571. request_logger = RequestLogger(max_log_len=args.max_log_len)
  572. global openai_serving_chat
  573. global openai_serving_completion
  574. global openai_serving_embedding
  575. global openai_serving_tokenization
  576. openai_serving_chat = OpenAIServingChat(
  577. async_engine_client,
  578. model_config,
  579. served_model_names,
  580. args.response_role,
  581. lora_modules=args.lora_modules,
  582. prompt_adapters=args.prompt_adapters,
  583. request_logger=request_logger,
  584. chat_template=args.chat_template,
  585. return_tokens_as_token_ids=args.return_tokens_as_token_ids,
  586. )
  587. openai_serving_completion = OpenAIServingCompletion(
  588. async_engine_client,
  589. model_config,
  590. served_model_names,
  591. lora_modules=args.lora_modules,
  592. prompt_adapters=args.prompt_adapters,
  593. request_logger=request_logger,
  594. return_tokens_as_token_ids=args.return_tokens_as_token_ids,
  595. )
  596. openai_serving_embedding = OpenAIServingEmbedding(
  597. async_engine_client,
  598. model_config,
  599. served_model_names,
  600. request_logger=request_logger,
  601. )
  602. openai_serving_tokenization = OpenAIServingTokenization(
  603. async_engine_client,
  604. model_config,
  605. served_model_names,
  606. lora_modules=args.lora_modules,
  607. request_logger=request_logger,
  608. chat_template=args.chat_template,
  609. )
  610. app.root_path = args.root_path
  611. tokenizer = get_tokenizer(
  612. tokenizer_name=engine_args.tokenizer,
  613. tokenizer_mode=engine_args.tokenizer_mode,
  614. trust_remote_code=engine_args.trust_remote_code,
  615. revision=engine_args.revision,
  616. )
  617. if args.launch_kobold_api:
  618. _set_badwords(tokenizer, model_config.hf_config)
  619. return app
  620. async def run_server(args, **uvicorn_kwargs) -> None:
  621. async with build_async_engine_client(args) as async_engine_client:
  622. app = await init_app(async_engine_client, args)
  623. protocol = "https" if args.ssl_certfile else "http"
  624. root_path = args.root_path.rstrip("/") if args.root_path else ""
  625. host_name = args.host if args.host else "localhost"
  626. port_str = str(args.port)
  627. if SERVE_KOBOLD_LITE_UI:
  628. ui_url = f"{protocol}://{host_name}:{port_str}{root_path}/"
  629. logger.info(f"Kobold Lite UI: {ui_url}")
  630. logger.info(f"Documentation: {protocol}://{host_name}:{port_str}{root_path}/redoc") # noqa: E501
  631. logger.info(f"Completions API: {protocol}://{host_name}:{port_str}{root_path}/v1/completions") # noqa: E501
  632. logger.info(f"Chat API: {protocol}://{host_name}:{port_str}{root_path}/v1/chat/completions") # noqa: E501
  633. logger.info(f"Embeddings API: {protocol}://{host_name}:{port_str}{root_path}/v1/embeddings") # noqa: E501
  634. logger.info(f"Tokenization API: {protocol}://{host_name}:{port_str}{root_path}/v1/tokenize") # noqa: E501
  635. shutdown_task = await serve_http(
  636. app,
  637. engine=async_engine_client,
  638. host=args.host,
  639. port=args.port,
  640. log_level=args.uvicorn_log_level,
  641. timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
  642. ssl_keyfile=args.ssl_keyfile,
  643. ssl_certfile=args.ssl_certfile,
  644. ssl_ca_certs=args.ssl_ca_certs,
  645. ssl_cert_reqs=args.ssl_cert_reqs,
  646. **uvicorn_kwargs,
  647. )
  648. # NB: Await server shutdown only after the backend context is exited
  649. await shutdown_task
  650. if __name__ == "__main__":
  651. # NOTE:
  652. # This section should be in sync with aphrodite/endpoints/cli.py
  653. # for CLI entrypoints.
  654. parser = FlexibleArgumentParser(
  655. description="Aphrodite OpenAI-Compatible RESTful API Server")
  656. parser = make_arg_parser(parser)
  657. args = parser.parse_args()
  658. asyncio.run(run_server(args))