api_server.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819
  1. import asyncio
  2. import importlib
  3. import inspect
  4. import json
  5. import multiprocessing
  6. import os
  7. import re
  8. import tempfile
  9. from argparse import Namespace
  10. from contextlib import asynccontextmanager
  11. from distutils.util import strtobool
  12. from http import HTTPStatus
  13. from typing import AsyncGenerator, AsyncIterator, List, Optional, Set, Tuple
  14. from fastapi import APIRouter, FastAPI, Request
  15. from fastapi.exceptions import RequestValidationError
  16. from fastapi.middleware.cors import CORSMiddleware
  17. from fastapi.responses import (HTMLResponse, JSONResponse, Response,
  18. StreamingResponse)
  19. from loguru import logger
  20. from starlette.routing import Mount
  21. import aphrodite.common.envs as envs
  22. from aphrodite.common.config import ModelConfig
  23. from aphrodite.common.outputs import RequestOutput
  24. from aphrodite.common.sampling_params import _SAMPLING_EPS, SamplingParams
  25. from aphrodite.common.utils import (FlexibleArgumentParser,
  26. get_open_zmq_ipc_path, in_windows,
  27. random_uuid)
  28. from aphrodite.endpoints.logger import RequestLogger
  29. from aphrodite.endpoints.openai.args import make_arg_parser
  30. # yapf: disable
  31. from aphrodite.endpoints.openai.protocol import (ChatCompletionRequest,
  32. ChatCompletionResponse,
  33. CompletionRequest,
  34. DetokenizeRequest,
  35. DetokenizeResponse,
  36. EmbeddingRequest,
  37. ErrorResponse,
  38. KAIGenerationInputSchema,
  39. TokenizeRequest,
  40. TokenizeResponse)
  41. from aphrodite.endpoints.openai.rpc.client import AsyncEngineRPCClient
  42. from aphrodite.endpoints.openai.rpc.server import run_rpc_server
  43. # yapf: enable
  44. from aphrodite.endpoints.openai.serving_chat import OpenAIServingChat
  45. from aphrodite.endpoints.openai.serving_completions import (
  46. OpenAIServingCompletion)
  47. from aphrodite.endpoints.openai.serving_embedding import OpenAIServingEmbedding
  48. from aphrodite.endpoints.openai.serving_engine import (LoRAModulePath,
  49. PromptAdapterPath)
  50. from aphrodite.endpoints.openai.serving_tokenization import (
  51. OpenAIServingTokenization)
  52. from aphrodite.engine.args_tools import AsyncEngineArgs
  53. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  54. from aphrodite.engine.protocol import AsyncEngineClient
  55. from aphrodite.server import serve_http
  56. from aphrodite.transformers_utils.tokenizer import get_tokenizer
  57. from aphrodite.version import __version__ as APHRODITE_VERSION
  58. if in_windows():
  59. import winloop as uvloop
  60. else:
  61. import uvloop
  62. TIMEOUT_KEEP_ALIVE = 5 # seconds
  63. SERVE_KOBOLD_LITE_UI = strtobool(os.getenv("SERVE_KOBOLD_LITE_UI", "1"))
  64. async_engine_client: AsyncEngineClient
  65. engine_args: AsyncEngineArgs
  66. openai_serving_chat: OpenAIServingChat
  67. openai_serving_completion: OpenAIServingCompletion
  68. openai_serving_embedding: OpenAIServingEmbedding
  69. openai_serving_tokenization: OpenAIServingTokenization
  70. router = APIRouter()
  71. kai_api = APIRouter()
  72. extra_api = APIRouter()
  73. kobold_lite_ui = ""
  74. sampler_json = ""
  75. gen_cache: dict = {}
  76. prometheus_multiproc_dir: tempfile.TemporaryDirectory
  77. _running_tasks: Set[asyncio.Task] = set()
  78. def model_is_embedding(model_name: str, trust_remote_code: bool) -> bool:
  79. return ModelConfig(model=model_name,
  80. tokenizer=model_name,
  81. tokenizer_mode="auto",
  82. trust_remote_code=trust_remote_code,
  83. seed=0,
  84. dtype="auto").embedding_mode
  85. @asynccontextmanager
  86. async def lifespan(app: FastAPI):
  87. async def _force_log():
  88. while True:
  89. await asyncio.sleep(10)
  90. await async_engine_client.do_log_stats()
  91. if not engine_args.disable_log_stats:
  92. task = asyncio.create_task(_force_log())
  93. _running_tasks.add(task)
  94. task.add_done_callback(_running_tasks.remove)
  95. yield
  96. @asynccontextmanager
  97. async def build_async_engine_client(
  98. args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]:
  99. """
  100. Create AsyncEngineClient, either:
  101. - in-process using the AsyncAphrodite Directly
  102. - multiprocess using AsyncAphrodite RPC
  103. Returns the Client or None if the creation failed.
  104. """
  105. # Context manager to handle async_engine_client lifecycle
  106. # Ensures everything is shutdown and cleaned up on error/exit
  107. global engine_args
  108. engine_args = AsyncEngineArgs.from_cli_args(args)
  109. # Backend itself still global for the silly lil' health handler
  110. global async_engine_client
  111. # If manually triggered or embedding model, use AsyncAphrodite in process.
  112. # TODO: support embedding model via RPC.
  113. if (model_is_embedding(args.model, args.trust_remote_code)
  114. or args.disable_frontend_multiprocessing):
  115. async_engine_client = AsyncAphrodite.from_engine_args(engine_args)
  116. yield async_engine_client
  117. return
  118. # Otherwise, use the multiprocessing AsyncAphrodite.
  119. else:
  120. if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
  121. # Make TemporaryDirectory for prometheus multiprocessing
  122. # Note: global TemporaryDirectory will be automatically
  123. # cleaned up upon exit.
  124. global prometheus_multiproc_dir
  125. prometheus_multiproc_dir = tempfile.TemporaryDirectory()
  126. os.environ[
  127. "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
  128. else:
  129. logger.warning(
  130. "Found PROMETHEUS_MULTIPROC_DIR was set by user. "
  131. "This directory must be wiped between Aphrodite runs or "
  132. "you will find inaccurate metrics. Unset the variable "
  133. "and Aphrodite will properly handle cleanup.")
  134. # Select random path for IPC.
  135. rpc_path = get_open_zmq_ipc_path()
  136. logger.info(f"Multiprocessing frontend to use {rpc_path} for RPC Path."
  137. )
  138. # Build RPCClient, which conforms to AsyncEngineClient Protocol.
  139. # NOTE: Actually, this is not true yet. We still need to support
  140. # embedding models via RPC (see TODO above)
  141. rpc_client = AsyncEngineRPCClient(rpc_path)
  142. async_engine_client = rpc_client # type: ignore
  143. # Start RPCServer in separate process (holds the AsyncAphrodite).
  144. context = multiprocessing.get_context("spawn")
  145. # the current process might have CUDA context,
  146. # so we need to spawn a new process
  147. rpc_server_process = context.Process(
  148. target=run_rpc_server,
  149. args=(engine_args, rpc_path))
  150. rpc_server_process.start()
  151. logger.info(
  152. f"Started engine process with PID {rpc_server_process.pid}")
  153. try:
  154. while True:
  155. try:
  156. await async_engine_client.setup()
  157. break
  158. except TimeoutError:
  159. if not rpc_server_process.is_alive():
  160. logger.error(
  161. "RPCServer process died before responding "
  162. "to readiness probe")
  163. yield None
  164. return
  165. yield async_engine_client
  166. finally:
  167. # Ensure rpc server process was terminated
  168. rpc_server_process.terminate()
  169. # Close all open connections to the backend
  170. async_engine_client.close()
  171. # Wait for server process to join
  172. rpc_server_process.join()
  173. # Lazy import for prometheus multiprocessing.
  174. # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
  175. # before prometheus_client is imported.
  176. # See https://prometheus.github.io/client_python/multiprocess/
  177. from prometheus_client import multiprocess
  178. multiprocess.mark_process_dead(rpc_server_process.pid)
  179. def mount_metrics(app: FastAPI):
  180. # Lazy import for prometheus multiprocessing.
  181. # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
  182. # before prometheus_client is imported.
  183. # See https://prometheus.github.io/client_python/multiprocess/
  184. from prometheus_client import (CollectorRegistry, make_asgi_app,
  185. multiprocess)
  186. prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
  187. if prometheus_multiproc_dir_path is not None:
  188. logger.info(f"Aphrodite to use {prometheus_multiproc_dir_path} "
  189. "as PROMETHEUS_MULTIPROC_DIR")
  190. registry = CollectorRegistry()
  191. multiprocess.MultiProcessCollector(registry)
  192. # Add prometheus asgi middleware to route /metrics requests
  193. metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
  194. else:
  195. # Add prometheus asgi middleware to route /metrics requests
  196. metrics_route = Mount("/metrics", make_asgi_app())
  197. # Workaround for 307 Redirect for /metrics
  198. metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')
  199. app.routes.append(metrics_route)
  200. @router.get("/health")
  201. async def health() -> Response:
  202. """Health check."""
  203. await async_engine_client.check_health()
  204. return Response(status_code=200)
  205. @router.post("/v1/tokenize")
  206. async def tokenize(request: TokenizeRequest):
  207. generator = await openai_serving_tokenization.create_tokenize(request)
  208. if isinstance(generator, ErrorResponse):
  209. return JSONResponse(content=generator.model_dump(),
  210. status_code=generator.code)
  211. else:
  212. assert isinstance(generator, TokenizeResponse)
  213. return JSONResponse(content=generator.model_dump())
  214. @router.post("/v1/detokenize")
  215. async def detokenize(request: DetokenizeRequest):
  216. generator = await openai_serving_tokenization.create_detokenize(request)
  217. if isinstance(generator, ErrorResponse):
  218. return JSONResponse(content=generator.model_dump(),
  219. status_code=generator.code)
  220. else:
  221. assert isinstance(generator, DetokenizeResponse)
  222. return JSONResponse(content=generator.model_dump())
  223. @router.get("/v1/models")
  224. async def show_available_models():
  225. models = await openai_serving_completion.show_available_models()
  226. return JSONResponse(content=models.model_dump())
  227. @router.get("/version")
  228. async def show_version():
  229. ver = {"version": APHRODITE_VERSION}
  230. return JSONResponse(content=ver)
  231. @router.get("/.well-known/serviceinfo")
  232. async def serviceinfo():
  233. """Return service information including version, API endpoints,
  234. and documentation URLs."""
  235. return JSONResponse(content={
  236. "version": 0.2,
  237. "software": {
  238. "name": "Aphrodite Engine",
  239. "version": APHRODITE_VERSION,
  240. "repository": "https://github.com/PygmalionAI/aphrodite-engine",
  241. "homepage": "https://aphrodite.pygmalion.chat",
  242. "logo": "https://pygmalion.chat/icons/favicon.ico",
  243. },
  244. "api": {
  245. "openai": {
  246. "name": "OpenAI API",
  247. "rel_url": "/v1",
  248. "documentation": "/redoc",
  249. "version": 1,
  250. },
  251. "koboldai": {
  252. "name": "KoboldAI API",
  253. "rel_url": "/api",
  254. "documentation": "/redoc",
  255. "version": 1,
  256. }
  257. }
  258. })
  259. @router.post("/v1/chat/completions")
  260. async def create_chat_completion(request: ChatCompletionRequest,
  261. raw_request: Request):
  262. generator = await openai_serving_chat.create_chat_completion(
  263. request, raw_request)
  264. if isinstance(generator, ErrorResponse):
  265. return JSONResponse(content=generator.model_dump(),
  266. status_code=generator.code)
  267. if request.stream:
  268. return StreamingResponse(content=generator,
  269. media_type="text/event-stream")
  270. else:
  271. assert isinstance(generator, ChatCompletionResponse)
  272. return JSONResponse(content=generator.model_dump())
  273. @router.post("/v1/completions")
  274. async def create_completion(request: CompletionRequest, raw_request: Request):
  275. generator = await openai_serving_completion.create_completion(
  276. request, raw_request)
  277. if isinstance(generator, ErrorResponse):
  278. return JSONResponse(content=generator.model_dump(),
  279. status_code=generator.code)
  280. if request.stream:
  281. return StreamingResponse(content=generator,
  282. media_type="text/event-stream")
  283. else:
  284. return JSONResponse(content=generator.model_dump())
  285. @router.post("/v1/embeddings")
  286. async def create_embedding(request: EmbeddingRequest, raw_request: Request):
  287. generator = await openai_serving_embedding.create_embedding(
  288. request, raw_request)
  289. if isinstance(generator, ErrorResponse):
  290. return JSONResponse(content=generator.model_dump(),
  291. status_code=generator.code)
  292. else:
  293. return JSONResponse(content=generator.model_dump())
  294. @router.post("/v1/lora/load")
  295. async def load_lora(lora: LoRAModulePath):
  296. openai_serving_completion.add_lora(lora)
  297. if engine_args.enable_lora is False:
  298. logger.error("LoRA is not enabled in the engine. "
  299. "Please start the server with the "
  300. "--enable-lora flag!")
  301. return JSONResponse(content={"status": "success"})
  302. @router.delete("/v1/lora/unload")
  303. async def unload_lora(lora_name: str):
  304. openai_serving_completion.remove_lora(lora_name)
  305. return JSONResponse(content={"status": "success"})
  306. @router.post("/v1/soft_prompt/load")
  307. async def load_soft_prompt(soft_prompt: PromptAdapterPath):
  308. openai_serving_completion.add_prompt_adapter(soft_prompt)
  309. if engine_args.enable_prompt_adapter is False:
  310. logger.error("Prompt Adapter is not enabled in the engine. "
  311. "Please start the server with the "
  312. "--enable-prompt-adapter flag!")
  313. return JSONResponse(content={"status": "success"})
  314. @router.delete("/v1/soft_prompt/unload")
  315. async def unload_soft_prompt(soft_prompt_name: str):
  316. openai_serving_completion.remove_prompt_adapter(soft_prompt_name)
  317. return JSONResponse(content={"status": "success"})
  318. # ============ KoboldAI API ============ #
  319. badwordsids: List[int] = []
  320. def _set_badwords(tokenizer, hf_config): # pylint: disable=redefined-outer-name
  321. # pylint: disable=global-variable-undefined
  322. global badwordsids
  323. if hf_config.bad_words_ids is not None:
  324. badwordsids = hf_config.bad_words_ids
  325. return
  326. badwordsids = [
  327. v for k, v in tokenizer.get_vocab().items()
  328. if any(c in str(k) for c in "[]")
  329. ]
  330. if tokenizer.pad_token_id in badwordsids:
  331. badwordsids.remove(tokenizer.pad_token_id)
  332. badwordsids.append(tokenizer.eos_token_id)
  333. def prepare_engine_payload(
  334. kai_payload: KAIGenerationInputSchema
  335. ) -> Tuple[SamplingParams, List[int]]:
  336. """Create SamplingParams and truncated input tokens for AsyncEngine"""
  337. if not kai_payload.genkey:
  338. kai_payload.genkey = f"kai-{random_uuid()}"
  339. kai_payload.top_k = kai_payload.top_k if kai_payload.top_k != 0.0 else -1
  340. kai_payload.tfs = max(_SAMPLING_EPS, kai_payload.tfs)
  341. if kai_payload.temperature < _SAMPLING_EPS:
  342. kai_payload.n = 1
  343. kai_payload.top_p = 1.0
  344. kai_payload.top_k = -1
  345. sampling_params = SamplingParams(
  346. n=kai_payload.n,
  347. best_of=kai_payload.n,
  348. repetition_penalty=kai_payload.rep_pen,
  349. temperature=kai_payload.temperature,
  350. smoothing_factor=kai_payload.smoothing_factor,
  351. smoothing_curve=kai_payload.smoothing_curve,
  352. tfs=kai_payload.tfs,
  353. top_p=kai_payload.top_p,
  354. top_k=kai_payload.top_k,
  355. top_a=kai_payload.top_a,
  356. min_p=kai_payload.min_p,
  357. typical_p=kai_payload.typical,
  358. eta_cutoff=kai_payload.eta_cutoff,
  359. epsilon_cutoff=kai_payload.eps_cutoff,
  360. stop=kai_payload.stop_sequence,
  361. include_stop_str_in_output=kai_payload.include_stop_str_in_output,
  362. custom_token_bans=badwordsids
  363. if kai_payload.use_default_badwordsids else [],
  364. max_tokens=kai_payload.max_length,
  365. seed=kai_payload.sampler_seed,
  366. xtc_probability=kai_payload.xtc_probability,
  367. xtc_threshold=kai_payload.xtc_threshold,
  368. )
  369. max_input_tokens = max(
  370. 1, kai_payload.max_context_length - kai_payload.max_length)
  371. input_tokens = tokenizer(kai_payload.prompt).input_ids[-max_input_tokens:]
  372. return sampling_params, input_tokens
  373. @kai_api.post("/generate")
  374. async def generate(kai_payload: KAIGenerationInputSchema) -> JSONResponse:
  375. sampling_params, input_tokens = prepare_engine_payload(kai_payload)
  376. result_generator = async_engine_client.generate(
  377. {
  378. "prompt": kai_payload.prompt,
  379. "prompt_token_ids": input_tokens,
  380. },
  381. sampling_params,
  382. kai_payload.genkey,
  383. )
  384. final_res: RequestOutput = None
  385. previous_output = ""
  386. async for res in result_generator:
  387. final_res = res
  388. new_chunk = res.outputs[0].text[len(previous_output):]
  389. previous_output += new_chunk
  390. gen_cache[kai_payload.genkey] = previous_output
  391. assert final_res is not None
  392. del gen_cache[kai_payload.genkey]
  393. return JSONResponse(
  394. {"results": [{
  395. "text": output.text
  396. } for output in final_res.outputs]})
  397. @extra_api.post("/generate/stream")
  398. async def generate_stream(
  399. kai_payload: KAIGenerationInputSchema) -> StreamingResponse:
  400. sampling_params, input_tokens = prepare_engine_payload(kai_payload)
  401. results_generator = async_engine_client.generate(
  402. {
  403. "prompt": kai_payload.prompt,
  404. "prompt_token_ids": input_tokens,
  405. },
  406. sampling_params,
  407. kai_payload.genkey,
  408. )
  409. async def stream_kobold() -> AsyncGenerator[bytes, None]:
  410. previous_output = ""
  411. async for res in results_generator:
  412. new_chunk = res.outputs[0].text[len(previous_output):]
  413. previous_output += new_chunk
  414. yield b"event: message\n"
  415. yield f"data: {json.dumps({'token': new_chunk})}\n\n".encode()
  416. return StreamingResponse(stream_kobold(),
  417. headers={
  418. "Cache-Control": "no-cache",
  419. "Connection": "keep-alive",
  420. },
  421. media_type="text/event-stream")
  422. @extra_api.post("/generate/check")
  423. @extra_api.get("/generate/check")
  424. async def check_generation(request: Request):
  425. text = ""
  426. try:
  427. request_dict = await request.json()
  428. if "genkey" in request_dict and request_dict["genkey"] in gen_cache:
  429. text = gen_cache[request_dict["genkey"]]
  430. except json.JSONDecodeError:
  431. pass
  432. return JSONResponse({"results": [{"text": text}]})
  433. @extra_api.post("/abort")
  434. async def abort_generation(request: Request):
  435. try:
  436. request_dict = await request.json()
  437. if "genkey" in request_dict:
  438. await async_engine_client.abort(request_dict["genkey"])
  439. except json.JSONDecodeError:
  440. pass
  441. return JSONResponse({})
  442. @extra_api.post("/tokencount")
  443. async def count_tokens(request: TokenizeRequest):
  444. """Tokenize string and return token count"""
  445. generator = await openai_serving_tokenization.create_tokenize(request)
  446. return JSONResponse({"value": generator.model_dump()["tokens"]})
  447. @kai_api.get("/info/version")
  448. async def get_version():
  449. """Impersonate KAI"""
  450. return JSONResponse({"result": "1.2.4"})
  451. @kai_api.get("/model")
  452. async def get_model():
  453. return JSONResponse({"result": f"aphrodite/{served_model_names[0]}"})
  454. @kai_api.get("/config/soft_prompts_list")
  455. async def get_available_softprompts():
  456. """Stub for compatibility"""
  457. return JSONResponse({"values": []})
  458. @kai_api.get("/config/soft_prompt")
  459. async def get_current_softprompt():
  460. """Stub for compatibility"""
  461. return JSONResponse({"value": ""})
  462. @kai_api.put("/config/soft_prompt")
  463. async def set_current_softprompt():
  464. """Stub for compatibility"""
  465. return JSONResponse({})
  466. @kai_api.get("/config/max_length")
  467. async def get_max_length() -> JSONResponse:
  468. max_length = args.max_length
  469. return JSONResponse({"value": max_length})
  470. @kai_api.get("/config/max_context_length")
  471. @extra_api.get("/true_max_context_length")
  472. async def get_max_context_length() -> JSONResponse:
  473. max_context_length = engine_args.max_model_len
  474. return JSONResponse({"value": max_context_length})
  475. @extra_api.get("/preloadstory")
  476. async def get_preloaded_story() -> JSONResponse:
  477. """Stub for compatibility"""
  478. return JSONResponse({})
  479. @extra_api.get("/version")
  480. async def get_extra_version():
  481. """Impersonate KoboldCpp"""
  482. return JSONResponse({"result": "KoboldCpp", "version": "1.63"})
  483. @router.get("/")
  484. async def get_kobold_lite_ui():
  485. """Serves a cached copy of the Kobold Lite UI, loading it from disk
  486. on demand if needed. Can be disabled with SERVE_KOBOLD_LITE_UI=0."""
  487. if not SERVE_KOBOLD_LITE_UI:
  488. return JSONResponse(content={"error": "Kobold Lite UI is disabled"},
  489. status_code=404)
  490. global kobold_lite_ui
  491. if kobold_lite_ui == "":
  492. scriptpath = os.path.dirname(os.path.abspath(__file__))
  493. klitepath = os.path.join(scriptpath, "../kobold/klite.embd")
  494. klitepath = os.path.normpath(klitepath) # Normalize the path
  495. if os.path.exists(klitepath):
  496. with open(klitepath, "r", encoding="utf-8") as f:
  497. kobold_lite_ui = f.read()
  498. else:
  499. logger.error("Kobold Lite UI not found at " + klitepath)
  500. return HTMLResponse(content=kobold_lite_ui)
  501. # ============ KoboldAI API ============ #
  502. def build_app(args: Namespace) -> FastAPI:
  503. app = FastAPI(lifespan=lifespan)
  504. app.include_router(router)
  505. app.root_path = args.root_path
  506. if args.launch_kobold_api:
  507. logger.warning("Kobold API is now enabled by default. "
  508. "This flag will be removed in the future.")
  509. app.include_router(kai_api, prefix="/api/v1")
  510. app.include_router(kai_api,
  511. prefix="/api/latest",
  512. include_in_schema=False)
  513. app.include_router(extra_api, prefix="/api/extra")
  514. mount_metrics(app)
  515. app.add_middleware(
  516. CORSMiddleware,
  517. allow_origins=args.allowed_origins,
  518. allow_credentials=args.allow_credentials,
  519. allow_methods=args.allowed_methods,
  520. allow_headers=args.allowed_headers,
  521. )
  522. @app.exception_handler(RequestValidationError)
  523. async def validation_exception_handler(_, exc):
  524. err = openai_serving_completion.create_error_response(message=str(exc))
  525. return JSONResponse(err.model_dump(),
  526. status_code=HTTPStatus.BAD_REQUEST)
  527. if token := envs.APHRODITE_API_KEY or args.api_keys:
  528. admin_key = os.environ.get("APHRODITE_ADMIN_KEY") or args.admin_key
  529. if admin_key is None:
  530. logger.warning("Admin key not provided. Admin operations will "
  531. "be disabled.")
  532. @app.middleware("http")
  533. async def authentication(request: Request, call_next):
  534. if not request.url.path.startswith(("/v1", "/api")):
  535. return await call_next(request)
  536. # Browsers may send OPTIONS requests to check CORS headers
  537. # before sending the actual request. We should allow these
  538. # requests to pass through without authentication.
  539. # See https://github.com/PygmalionAI/aphrodite-engine/issues/434
  540. if request.method == "OPTIONS":
  541. return await call_next(request)
  542. auth_header = request.headers.get("Authorization")
  543. api_key_header = request.headers.get("x-api-key")
  544. if request.url.path.startswith(("/v1/lora", "/v1/soft_prompt")):
  545. if admin_key is not None and api_key_header == admin_key:
  546. return await call_next(request)
  547. return JSONResponse(content={"error": "Unauthorized"},
  548. status_code=401)
  549. if auth_header != "Bearer " + token and api_key_header != token:
  550. return JSONResponse(content={"error": "Unauthorized"},
  551. status_code=401)
  552. return await call_next(request)
  553. for middleware in args.middleware:
  554. module_path, object_name = middleware.rsplit(".", 1)
  555. imported = getattr(importlib.import_module(module_path), object_name)
  556. if inspect.isclass(imported):
  557. app.add_middleware(imported)
  558. elif inspect.iscoroutinefunction(imported):
  559. app.middleware("http")(imported)
  560. else:
  561. raise ValueError(f"Invalid middleware {middleware}. "
  562. f"Must be a function or a class.")
  563. return app
  564. async def init_app(
  565. async_engine_client: AsyncEngineClient,
  566. args: Namespace,
  567. ) -> FastAPI:
  568. global api_server_args
  569. api_server_args = args
  570. app = build_app(args)
  571. logger.debug(f"args: {args}")
  572. global served_model_names
  573. if args.served_model_name is not None:
  574. served_model_names = args.served_model_name
  575. else:
  576. served_model_names = [args.model]
  577. if args.uvloop:
  578. uvloop.install()
  579. global tokenizer
  580. model_config = await async_engine_client.get_model_config()
  581. if args.disable_log_requests:
  582. request_logger = None
  583. else:
  584. request_logger = RequestLogger(max_log_len=args.max_log_len)
  585. global openai_serving_chat
  586. global openai_serving_completion
  587. global openai_serving_embedding
  588. global openai_serving_tokenization
  589. openai_serving_chat = OpenAIServingChat(
  590. async_engine_client,
  591. model_config,
  592. served_model_names,
  593. args.response_role,
  594. lora_modules=args.lora_modules,
  595. prompt_adapters=args.prompt_adapters,
  596. request_logger=request_logger,
  597. chat_template=args.chat_template,
  598. return_tokens_as_token_ids=args.return_tokens_as_token_ids,
  599. )
  600. openai_serving_completion = OpenAIServingCompletion(
  601. async_engine_client,
  602. model_config,
  603. served_model_names,
  604. lora_modules=args.lora_modules,
  605. prompt_adapters=args.prompt_adapters,
  606. request_logger=request_logger,
  607. return_tokens_as_token_ids=args.return_tokens_as_token_ids,
  608. )
  609. openai_serving_embedding = OpenAIServingEmbedding(
  610. async_engine_client,
  611. model_config,
  612. served_model_names,
  613. request_logger=request_logger,
  614. )
  615. openai_serving_tokenization = OpenAIServingTokenization(
  616. async_engine_client,
  617. model_config,
  618. served_model_names,
  619. lora_modules=args.lora_modules,
  620. request_logger=request_logger,
  621. chat_template=args.chat_template,
  622. )
  623. app.root_path = args.root_path
  624. tokenizer = get_tokenizer(
  625. tokenizer_name=engine_args.tokenizer,
  626. tokenizer_mode=engine_args.tokenizer_mode,
  627. trust_remote_code=engine_args.trust_remote_code,
  628. revision=engine_args.revision,
  629. )
  630. if args.launch_kobold_api:
  631. _set_badwords(tokenizer, model_config.hf_config)
  632. return app
  633. async def run_server(args, **uvicorn_kwargs) -> None:
  634. async with build_async_engine_client(args) as async_engine_client:
  635. # If None, creation of the client failed and we exit.
  636. if async_engine_client is None:
  637. return
  638. app = await init_app(async_engine_client, args)
  639. protocol = "https" if args.ssl_certfile else "http"
  640. root_path = args.root_path.rstrip("/") if args.root_path else ""
  641. host_name = args.host if args.host else "localhost"
  642. port_str = str(args.port)
  643. if SERVE_KOBOLD_LITE_UI:
  644. ui_url = f"{protocol}://{host_name}:{port_str}{root_path}/"
  645. logger.info(f"Kobold Lite UI: {ui_url}")
  646. logger.info(f"Documentation: {protocol}://{host_name}:{port_str}{root_path}/redoc") # noqa: E501
  647. logger.info(f"Completions API: {protocol}://{host_name}:{port_str}{root_path}/v1/completions") # noqa: E501
  648. logger.info(f"Chat API: {protocol}://{host_name}:{port_str}{root_path}/v1/chat/completions") # noqa: E501
  649. logger.info(f"Embeddings API: {protocol}://{host_name}:{port_str}{root_path}/v1/embeddings") # noqa: E501
  650. logger.info(f"Tokenization API: {protocol}://{host_name}:{port_str}{root_path}/v1/tokenize") # noqa: E501
  651. shutdown_task = await serve_http(
  652. app,
  653. engine=async_engine_client,
  654. host=args.host,
  655. port=args.port,
  656. log_level=args.uvicorn_log_level,
  657. timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
  658. ssl_keyfile=args.ssl_keyfile,
  659. ssl_certfile=args.ssl_certfile,
  660. ssl_ca_certs=args.ssl_ca_certs,
  661. ssl_cert_reqs=args.ssl_cert_reqs,
  662. **uvicorn_kwargs,
  663. )
  664. # NB: Await server shutdown only after the backend context is exited
  665. await shutdown_task
  666. if __name__ == "__main__":
  667. # NOTE:
  668. # This section should be in sync with aphrodite/endpoints/cli.py
  669. # for CLI entrypoints.
  670. parser = FlexibleArgumentParser(
  671. description="Aphrodite OpenAI-Compatible RESTful API Server")
  672. parser = make_arg_parser(parser)
  673. args = parser.parse_args()
  674. asyncio.run(run_server(args))