api_server.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820
  1. import asyncio
  2. import importlib
  3. import inspect
  4. import json
  5. import multiprocessing
  6. import os
  7. import re
  8. import tempfile
  9. from argparse import Namespace
  10. from contextlib import asynccontextmanager, suppress
  11. from distutils.util import strtobool
  12. from http import HTTPStatus
  13. from typing import AsyncGenerator, AsyncIterator, List, Optional, Set, Tuple
  14. from fastapi import APIRouter, FastAPI, Request
  15. from fastapi.exceptions import RequestValidationError
  16. from fastapi.middleware.cors import CORSMiddleware
  17. from fastapi.responses import (HTMLResponse, JSONResponse, Response,
  18. StreamingResponse)
  19. from loguru import logger
  20. from starlette.routing import Mount
  21. import aphrodite.common.envs as envs
  22. from aphrodite.common.config import ModelConfig
  23. from aphrodite.common.outputs import RequestOutput
  24. from aphrodite.common.sampling_params import _SAMPLING_EPS, SamplingParams
  25. from aphrodite.common.utils import (FlexibleArgumentParser,
  26. get_open_zmq_ipc_path, in_windows,
  27. random_uuid)
  28. from aphrodite.endpoints.logger import RequestLogger
  29. from aphrodite.endpoints.openai.args import make_arg_parser
  30. # yapf: disable
  31. from aphrodite.endpoints.openai.protocol import (ChatCompletionRequest,
  32. ChatCompletionResponse,
  33. CompletionRequest,
  34. DetokenizeRequest,
  35. DetokenizeResponse,
  36. EmbeddingRequest,
  37. ErrorResponse,
  38. KAIGenerationInputSchema,
  39. TokenizeRequest,
  40. TokenizeResponse)
  41. from aphrodite.endpoints.openai.rpc.client import AsyncEngineRPCClient
  42. from aphrodite.endpoints.openai.rpc.server import run_rpc_server
  43. # yapf: enable
  44. from aphrodite.endpoints.openai.serving_chat import OpenAIServingChat
  45. from aphrodite.endpoints.openai.serving_completions import (
  46. OpenAIServingCompletion)
  47. from aphrodite.endpoints.openai.serving_embedding import OpenAIServingEmbedding
  48. from aphrodite.endpoints.openai.serving_engine import (LoRAModulePath,
  49. PromptAdapterPath)
  50. from aphrodite.endpoints.openai.serving_tokenization import (
  51. OpenAIServingTokenization)
  52. from aphrodite.engine.args_tools import AsyncEngineArgs
  53. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  54. from aphrodite.engine.protocol import AsyncEngineClient
  55. from aphrodite.server import serve_http
  56. from aphrodite.transformers_utils.tokenizer import get_tokenizer
  57. from aphrodite.version import __version__ as APHRODITE_VERSION
  58. if in_windows():
  59. import winloop as uvloop
  60. else:
  61. import uvloop
  62. TIMEOUT_KEEP_ALIVE = 5 # seconds
  63. SERVE_KOBOLD_LITE_UI = strtobool(os.getenv("SERVE_KOBOLD_LITE_UI", "1"))
  64. async_engine_client: AsyncEngineClient
  65. engine_args: AsyncEngineArgs
  66. openai_serving_chat: OpenAIServingChat
  67. openai_serving_completion: OpenAIServingCompletion
  68. openai_serving_embedding: OpenAIServingEmbedding
  69. openai_serving_tokenization: OpenAIServingTokenization
  70. router = APIRouter()
  71. kai_api = APIRouter()
  72. extra_api = APIRouter()
  73. kobold_lite_ui = ""
  74. sampler_json = ""
  75. gen_cache: dict = {}
  76. prometheus_multiproc_dir: tempfile.TemporaryDirectory
  77. _running_tasks: Set[asyncio.Task] = set()
  78. def model_is_embedding(model_name: str, trust_remote_code: bool) -> bool:
  79. return ModelConfig(model=model_name,
  80. tokenizer=model_name,
  81. tokenizer_mode="auto",
  82. trust_remote_code=trust_remote_code,
  83. seed=0,
  84. dtype="auto").embedding_mode
  85. @asynccontextmanager
  86. async def lifespan(app: FastAPI):
  87. async def _force_log():
  88. while True:
  89. await asyncio.sleep(10)
  90. with suppress(Exception):
  91. await async_engine_client.do_log_stats()
  92. if not engine_args.disable_log_stats:
  93. task = asyncio.create_task(_force_log())
  94. _running_tasks.add(task)
  95. task.add_done_callback(_running_tasks.remove)
  96. yield
  97. @asynccontextmanager
  98. async def build_async_engine_client(
  99. args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]:
  100. """
  101. Create AsyncEngineClient, either:
  102. - in-process using the AsyncAphrodite Directly
  103. - multiprocess using AsyncAphrodite RPC
  104. Returns the Client or None if the creation failed.
  105. """
  106. # Context manager to handle async_engine_client lifecycle
  107. # Ensures everything is shutdown and cleaned up on error/exit
  108. global engine_args
  109. engine_args = AsyncEngineArgs.from_cli_args(args)
  110. # Backend itself still global for the silly lil' health handler
  111. global async_engine_client
  112. # If manually triggered or embedding model, use AsyncAphrodite in process.
  113. # TODO: support embedding model via RPC.
  114. if (model_is_embedding(args.model, args.trust_remote_code)
  115. or args.disable_frontend_multiprocessing):
  116. async_engine_client = AsyncAphrodite.from_engine_args(engine_args)
  117. yield async_engine_client
  118. return
  119. # Otherwise, use the multiprocessing AsyncAphrodite.
  120. else:
  121. if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
  122. # Make TemporaryDirectory for prometheus multiprocessing
  123. # Note: global TemporaryDirectory will be automatically
  124. # cleaned up upon exit.
  125. global prometheus_multiproc_dir
  126. prometheus_multiproc_dir = tempfile.TemporaryDirectory()
  127. os.environ[
  128. "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
  129. else:
  130. logger.warning(
  131. "Found PROMETHEUS_MULTIPROC_DIR was set by user. "
  132. "This directory must be wiped between Aphrodite runs or "
  133. "you will find inaccurate metrics. Unset the variable "
  134. "and Aphrodite will properly handle cleanup.")
  135. # Select random path for IPC.
  136. rpc_path = get_open_zmq_ipc_path()
  137. logger.info(f"Multiprocessing frontend to use {rpc_path} for RPC Path."
  138. )
  139. # Build RPCClient, which conforms to AsyncEngineClient Protocol.
  140. # NOTE: Actually, this is not true yet. We still need to support
  141. # embedding models via RPC (see TODO above)
  142. rpc_client = AsyncEngineRPCClient(rpc_path)
  143. async_engine_client = rpc_client # type: ignore
  144. # Start RPCServer in separate process (holds the AsyncAphrodite).
  145. context = multiprocessing.get_context("spawn")
  146. # the current process might have CUDA context,
  147. # so we need to spawn a new process
  148. rpc_server_process = context.Process(
  149. target=run_rpc_server,
  150. args=(engine_args, rpc_path))
  151. rpc_server_process.start()
  152. logger.info(
  153. f"Started engine process with PID {rpc_server_process.pid}")
  154. try:
  155. while True:
  156. try:
  157. await async_engine_client.setup()
  158. break
  159. except TimeoutError:
  160. if not rpc_server_process.is_alive():
  161. logger.error(
  162. "RPCServer process died before responding "
  163. "to readiness probe")
  164. yield None
  165. return
  166. yield async_engine_client
  167. finally:
  168. # Ensure rpc server process was terminated
  169. rpc_server_process.terminate()
  170. # Close all open connections to the backend
  171. async_engine_client.close()
  172. # Wait for server process to join
  173. rpc_server_process.join()
  174. # Lazy import for prometheus multiprocessing.
  175. # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
  176. # before prometheus_client is imported.
  177. # See https://prometheus.github.io/client_python/multiprocess/
  178. from prometheus_client import multiprocess
  179. multiprocess.mark_process_dead(rpc_server_process.pid)
  180. def mount_metrics(app: FastAPI):
  181. # Lazy import for prometheus multiprocessing.
  182. # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
  183. # before prometheus_client is imported.
  184. # See https://prometheus.github.io/client_python/multiprocess/
  185. from prometheus_client import (CollectorRegistry, make_asgi_app,
  186. multiprocess)
  187. prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
  188. if prometheus_multiproc_dir_path is not None:
  189. logger.info(f"Aphrodite to use {prometheus_multiproc_dir_path} "
  190. "as PROMETHEUS_MULTIPROC_DIR")
  191. registry = CollectorRegistry()
  192. multiprocess.MultiProcessCollector(registry)
  193. # Add prometheus asgi middleware to route /metrics requests
  194. metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
  195. else:
  196. # Add prometheus asgi middleware to route /metrics requests
  197. metrics_route = Mount("/metrics", make_asgi_app())
  198. # Workaround for 307 Redirect for /metrics
  199. metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')
  200. app.routes.append(metrics_route)
  201. @router.get("/health")
  202. async def health() -> Response:
  203. """Health check."""
  204. await async_engine_client.check_health()
  205. return Response(status_code=200)
  206. @router.post("/v1/tokenize")
  207. async def tokenize(request: TokenizeRequest):
  208. generator = await openai_serving_tokenization.create_tokenize(request)
  209. if isinstance(generator, ErrorResponse):
  210. return JSONResponse(content=generator.model_dump(),
  211. status_code=generator.code)
  212. else:
  213. assert isinstance(generator, TokenizeResponse)
  214. return JSONResponse(content=generator.model_dump())
  215. @router.post("/v1/detokenize")
  216. async def detokenize(request: DetokenizeRequest):
  217. generator = await openai_serving_tokenization.create_detokenize(request)
  218. if isinstance(generator, ErrorResponse):
  219. return JSONResponse(content=generator.model_dump(),
  220. status_code=generator.code)
  221. else:
  222. assert isinstance(generator, DetokenizeResponse)
  223. return JSONResponse(content=generator.model_dump())
  224. @router.get("/v1/models")
  225. async def show_available_models():
  226. models = await openai_serving_completion.show_available_models()
  227. return JSONResponse(content=models.model_dump())
  228. @router.get("/version")
  229. async def show_version():
  230. ver = {"version": APHRODITE_VERSION}
  231. return JSONResponse(content=ver)
  232. @router.get("/.well-known/serviceinfo")
  233. async def serviceinfo():
  234. """Return service information including version, API endpoints,
  235. and documentation URLs."""
  236. return JSONResponse(content={
  237. "version": 0.2,
  238. "software": {
  239. "name": "Aphrodite Engine",
  240. "version": APHRODITE_VERSION,
  241. "repository": "https://github.com/PygmalionAI/aphrodite-engine",
  242. "homepage": "https://aphrodite.pygmalion.chat",
  243. "logo": "https://pygmalion.chat/icons/favicon.ico",
  244. },
  245. "api": {
  246. "openai": {
  247. "name": "OpenAI API",
  248. "rel_url": "/v1",
  249. "documentation": "/redoc",
  250. "version": 1,
  251. },
  252. "koboldai": {
  253. "name": "KoboldAI API",
  254. "rel_url": "/api",
  255. "documentation": "/redoc",
  256. "version": 1,
  257. }
  258. }
  259. })
  260. @router.post("/v1/chat/completions")
  261. async def create_chat_completion(request: ChatCompletionRequest,
  262. raw_request: Request):
  263. generator = await openai_serving_chat.create_chat_completion(
  264. request, raw_request)
  265. if isinstance(generator, ErrorResponse):
  266. return JSONResponse(content=generator.model_dump(),
  267. status_code=generator.code)
  268. if request.stream:
  269. return StreamingResponse(content=generator,
  270. media_type="text/event-stream")
  271. else:
  272. assert isinstance(generator, ChatCompletionResponse)
  273. return JSONResponse(content=generator.model_dump())
  274. @router.post("/v1/completions")
  275. async def create_completion(request: CompletionRequest, raw_request: Request):
  276. generator = await openai_serving_completion.create_completion(
  277. request, raw_request)
  278. if isinstance(generator, ErrorResponse):
  279. return JSONResponse(content=generator.model_dump(),
  280. status_code=generator.code)
  281. if request.stream:
  282. return StreamingResponse(content=generator,
  283. media_type="text/event-stream")
  284. else:
  285. return JSONResponse(content=generator.model_dump())
  286. @router.post("/v1/embeddings")
  287. async def create_embedding(request: EmbeddingRequest, raw_request: Request):
  288. generator = await openai_serving_embedding.create_embedding(
  289. request, raw_request)
  290. if isinstance(generator, ErrorResponse):
  291. return JSONResponse(content=generator.model_dump(),
  292. status_code=generator.code)
  293. else:
  294. return JSONResponse(content=generator.model_dump())
  295. @router.post("/v1/lora/load")
  296. async def load_lora(lora: LoRAModulePath):
  297. openai_serving_completion.add_lora(lora)
  298. if engine_args.enable_lora is False:
  299. logger.error("LoRA is not enabled in the engine. "
  300. "Please start the server with the "
  301. "--enable-lora flag!")
  302. return JSONResponse(content={"status": "success"})
  303. @router.delete("/v1/lora/unload")
  304. async def unload_lora(lora_name: str):
  305. openai_serving_completion.remove_lora(lora_name)
  306. return JSONResponse(content={"status": "success"})
  307. @router.post("/v1/soft_prompt/load")
  308. async def load_soft_prompt(soft_prompt: PromptAdapterPath):
  309. openai_serving_completion.add_prompt_adapter(soft_prompt)
  310. if engine_args.enable_prompt_adapter is False:
  311. logger.error("Prompt Adapter is not enabled in the engine. "
  312. "Please start the server with the "
  313. "--enable-prompt-adapter flag!")
  314. return JSONResponse(content={"status": "success"})
  315. @router.delete("/v1/soft_prompt/unload")
  316. async def unload_soft_prompt(soft_prompt_name: str):
  317. openai_serving_completion.remove_prompt_adapter(soft_prompt_name)
  318. return JSONResponse(content={"status": "success"})
  319. # ============ KoboldAI API ============ #
  320. badwordsids: List[int] = []
  321. def _set_badwords(tokenizer, hf_config): # pylint: disable=redefined-outer-name
  322. # pylint: disable=global-variable-undefined
  323. global badwordsids
  324. if hf_config.bad_words_ids is not None:
  325. badwordsids = hf_config.bad_words_ids
  326. return
  327. badwordsids = [
  328. v for k, v in tokenizer.get_vocab().items()
  329. if any(c in str(k) for c in "[]")
  330. ]
  331. if tokenizer.pad_token_id in badwordsids:
  332. badwordsids.remove(tokenizer.pad_token_id)
  333. badwordsids.append(tokenizer.eos_token_id)
  334. def prepare_engine_payload(
  335. kai_payload: KAIGenerationInputSchema
  336. ) -> Tuple[SamplingParams, List[int]]:
  337. """Create SamplingParams and truncated input tokens for AsyncEngine"""
  338. if not kai_payload.genkey:
  339. kai_payload.genkey = f"kai-{random_uuid()}"
  340. kai_payload.top_k = kai_payload.top_k if kai_payload.top_k != 0.0 else -1
  341. kai_payload.tfs = max(_SAMPLING_EPS, kai_payload.tfs)
  342. if kai_payload.temperature < _SAMPLING_EPS:
  343. kai_payload.n = 1
  344. kai_payload.top_p = 1.0
  345. kai_payload.top_k = -1
  346. sampling_params = SamplingParams(
  347. n=kai_payload.n,
  348. best_of=kai_payload.n,
  349. repetition_penalty=kai_payload.rep_pen,
  350. temperature=kai_payload.temperature,
  351. smoothing_factor=kai_payload.smoothing_factor,
  352. smoothing_curve=kai_payload.smoothing_curve,
  353. tfs=kai_payload.tfs,
  354. top_p=kai_payload.top_p,
  355. top_k=kai_payload.top_k,
  356. top_a=kai_payload.top_a,
  357. min_p=kai_payload.min_p,
  358. typical_p=kai_payload.typical,
  359. eta_cutoff=kai_payload.eta_cutoff,
  360. epsilon_cutoff=kai_payload.eps_cutoff,
  361. stop=kai_payload.stop_sequence,
  362. include_stop_str_in_output=kai_payload.include_stop_str_in_output,
  363. custom_token_bans=badwordsids
  364. if kai_payload.use_default_badwordsids else [],
  365. max_tokens=kai_payload.max_length,
  366. seed=kai_payload.sampler_seed,
  367. xtc_probability=kai_payload.xtc_probability,
  368. xtc_threshold=kai_payload.xtc_threshold,
  369. )
  370. max_input_tokens = max(
  371. 1, kai_payload.max_context_length - kai_payload.max_length)
  372. input_tokens = tokenizer(kai_payload.prompt).input_ids[-max_input_tokens:]
  373. return sampling_params, input_tokens
  374. @kai_api.post("/generate")
  375. async def generate(kai_payload: KAIGenerationInputSchema) -> JSONResponse:
  376. sampling_params, input_tokens = prepare_engine_payload(kai_payload)
  377. result_generator = async_engine_client.generate(
  378. {
  379. "prompt": kai_payload.prompt,
  380. "prompt_token_ids": input_tokens,
  381. },
  382. sampling_params,
  383. kai_payload.genkey,
  384. )
  385. final_res: RequestOutput = None
  386. previous_output = ""
  387. async for res in result_generator:
  388. final_res = res
  389. new_chunk = res.outputs[0].text[len(previous_output):]
  390. previous_output += new_chunk
  391. gen_cache[kai_payload.genkey] = previous_output
  392. assert final_res is not None
  393. del gen_cache[kai_payload.genkey]
  394. return JSONResponse(
  395. {"results": [{
  396. "text": output.text
  397. } for output in final_res.outputs]})
  398. @extra_api.post("/generate/stream")
  399. async def generate_stream(
  400. kai_payload: KAIGenerationInputSchema) -> StreamingResponse:
  401. sampling_params, input_tokens = prepare_engine_payload(kai_payload)
  402. results_generator = async_engine_client.generate(
  403. {
  404. "prompt": kai_payload.prompt,
  405. "prompt_token_ids": input_tokens,
  406. },
  407. sampling_params,
  408. kai_payload.genkey,
  409. )
  410. async def stream_kobold() -> AsyncGenerator[bytes, None]:
  411. previous_output = ""
  412. async for res in results_generator:
  413. new_chunk = res.outputs[0].text[len(previous_output):]
  414. previous_output += new_chunk
  415. yield b"event: message\n"
  416. yield f"data: {json.dumps({'token': new_chunk})}\n\n".encode()
  417. return StreamingResponse(stream_kobold(),
  418. headers={
  419. "Cache-Control": "no-cache",
  420. "Connection": "keep-alive",
  421. },
  422. media_type="text/event-stream")
  423. @extra_api.post("/generate/check")
  424. @extra_api.get("/generate/check")
  425. async def check_generation(request: Request):
  426. text = ""
  427. try:
  428. request_dict = await request.json()
  429. if "genkey" in request_dict and request_dict["genkey"] in gen_cache:
  430. text = gen_cache[request_dict["genkey"]]
  431. except json.JSONDecodeError:
  432. pass
  433. return JSONResponse({"results": [{"text": text}]})
  434. @extra_api.post("/abort")
  435. async def abort_generation(request: Request):
  436. try:
  437. request_dict = await request.json()
  438. if "genkey" in request_dict:
  439. await async_engine_client.abort(request_dict["genkey"])
  440. except json.JSONDecodeError:
  441. pass
  442. return JSONResponse({})
  443. @extra_api.post("/tokencount")
  444. async def count_tokens(request: TokenizeRequest):
  445. """Tokenize string and return token count"""
  446. generator = await openai_serving_tokenization.create_tokenize(request)
  447. return JSONResponse({"value": generator.model_dump()["tokens"]})
  448. @kai_api.get("/info/version")
  449. async def get_version():
  450. """Impersonate KAI"""
  451. return JSONResponse({"result": "1.2.4"})
  452. @kai_api.get("/model")
  453. async def get_model():
  454. return JSONResponse({"result": f"aphrodite/{served_model_names[0]}"})
  455. @kai_api.get("/config/soft_prompts_list")
  456. async def get_available_softprompts():
  457. """Stub for compatibility"""
  458. return JSONResponse({"values": []})
  459. @kai_api.get("/config/soft_prompt")
  460. async def get_current_softprompt():
  461. """Stub for compatibility"""
  462. return JSONResponse({"value": ""})
  463. @kai_api.put("/config/soft_prompt")
  464. async def set_current_softprompt():
  465. """Stub for compatibility"""
  466. return JSONResponse({})
  467. @kai_api.get("/config/max_length")
  468. async def get_max_length() -> JSONResponse:
  469. max_length = args.max_length
  470. return JSONResponse({"value": max_length})
  471. @kai_api.get("/config/max_context_length")
  472. @extra_api.get("/true_max_context_length")
  473. async def get_max_context_length() -> JSONResponse:
  474. max_context_length = engine_args.max_model_len
  475. return JSONResponse({"value": max_context_length})
  476. @extra_api.get("/preloadstory")
  477. async def get_preloaded_story() -> JSONResponse:
  478. """Stub for compatibility"""
  479. return JSONResponse({})
  480. @extra_api.get("/version")
  481. async def get_extra_version():
  482. """Impersonate KoboldCpp"""
  483. return JSONResponse({"result": "KoboldCpp", "version": "1.63"})
  484. @router.get("/")
  485. async def get_kobold_lite_ui():
  486. """Serves a cached copy of the Kobold Lite UI, loading it from disk
  487. on demand if needed. Can be disabled with SERVE_KOBOLD_LITE_UI=0."""
  488. if not SERVE_KOBOLD_LITE_UI:
  489. return JSONResponse(content={"error": "Kobold Lite UI is disabled"},
  490. status_code=404)
  491. global kobold_lite_ui
  492. if kobold_lite_ui == "":
  493. scriptpath = os.path.dirname(os.path.abspath(__file__))
  494. klitepath = os.path.join(scriptpath, "../kobold/klite.embd")
  495. klitepath = os.path.normpath(klitepath) # Normalize the path
  496. if os.path.exists(klitepath):
  497. with open(klitepath, "r", encoding="utf-8") as f:
  498. kobold_lite_ui = f.read()
  499. else:
  500. logger.error("Kobold Lite UI not found at " + klitepath)
  501. return HTMLResponse(content=kobold_lite_ui)
  502. # ============ KoboldAI API ============ #
  503. def build_app(args: Namespace) -> FastAPI:
  504. app = FastAPI(lifespan=lifespan)
  505. app.include_router(router)
  506. app.root_path = args.root_path
  507. if args.launch_kobold_api:
  508. logger.warning("Kobold API is now enabled by default. "
  509. "This flag will be removed in the future.")
  510. app.include_router(kai_api, prefix="/api/v1")
  511. app.include_router(kai_api,
  512. prefix="/api/latest",
  513. include_in_schema=False)
  514. app.include_router(extra_api, prefix="/api/extra")
  515. mount_metrics(app)
  516. app.add_middleware(
  517. CORSMiddleware,
  518. allow_origins=args.allowed_origins,
  519. allow_credentials=args.allow_credentials,
  520. allow_methods=args.allowed_methods,
  521. allow_headers=args.allowed_headers,
  522. )
  523. @app.exception_handler(RequestValidationError)
  524. async def validation_exception_handler(_, exc):
  525. err = openai_serving_completion.create_error_response(message=str(exc))
  526. return JSONResponse(err.model_dump(),
  527. status_code=HTTPStatus.BAD_REQUEST)
  528. if token := envs.APHRODITE_API_KEY or args.api_keys:
  529. admin_key = os.environ.get("APHRODITE_ADMIN_KEY") or args.admin_key
  530. if admin_key is None:
  531. logger.warning("Admin key not provided. Admin operations will "
  532. "be disabled.")
  533. @app.middleware("http")
  534. async def authentication(request: Request, call_next):
  535. if not request.url.path.startswith(("/v1", "/api")):
  536. return await call_next(request)
  537. # Browsers may send OPTIONS requests to check CORS headers
  538. # before sending the actual request. We should allow these
  539. # requests to pass through without authentication.
  540. # See https://github.com/PygmalionAI/aphrodite-engine/issues/434
  541. if request.method == "OPTIONS":
  542. return await call_next(request)
  543. auth_header = request.headers.get("Authorization")
  544. api_key_header = request.headers.get("x-api-key")
  545. if request.url.path.startswith(("/v1/lora", "/v1/soft_prompt")):
  546. if admin_key is not None and api_key_header == admin_key:
  547. return await call_next(request)
  548. return JSONResponse(content={"error": "Unauthorized"},
  549. status_code=401)
  550. if auth_header != "Bearer " + token and api_key_header != token:
  551. return JSONResponse(content={"error": "Unauthorized"},
  552. status_code=401)
  553. return await call_next(request)
  554. for middleware in args.middleware:
  555. module_path, object_name = middleware.rsplit(".", 1)
  556. imported = getattr(importlib.import_module(module_path), object_name)
  557. if inspect.isclass(imported):
  558. app.add_middleware(imported)
  559. elif inspect.iscoroutinefunction(imported):
  560. app.middleware("http")(imported)
  561. else:
  562. raise ValueError(f"Invalid middleware {middleware}. "
  563. f"Must be a function or a class.")
  564. return app
  565. async def init_app(
  566. async_engine_client: AsyncEngineClient,
  567. args: Namespace,
  568. ) -> FastAPI:
  569. global api_server_args
  570. api_server_args = args
  571. app = build_app(args)
  572. logger.debug(f"args: {args}")
  573. global served_model_names
  574. if args.served_model_name is not None:
  575. served_model_names = args.served_model_name
  576. else:
  577. served_model_names = [args.model]
  578. if args.uvloop:
  579. uvloop.install()
  580. global tokenizer
  581. model_config = await async_engine_client.get_model_config()
  582. if args.disable_log_requests:
  583. request_logger = None
  584. else:
  585. request_logger = RequestLogger(max_log_len=args.max_log_len)
  586. global openai_serving_chat
  587. global openai_serving_completion
  588. global openai_serving_embedding
  589. global openai_serving_tokenization
  590. openai_serving_chat = OpenAIServingChat(
  591. async_engine_client,
  592. model_config,
  593. served_model_names,
  594. args.response_role,
  595. lora_modules=args.lora_modules,
  596. prompt_adapters=args.prompt_adapters,
  597. request_logger=request_logger,
  598. chat_template=args.chat_template,
  599. return_tokens_as_token_ids=args.return_tokens_as_token_ids,
  600. )
  601. openai_serving_completion = OpenAIServingCompletion(
  602. async_engine_client,
  603. model_config,
  604. served_model_names,
  605. lora_modules=args.lora_modules,
  606. prompt_adapters=args.prompt_adapters,
  607. request_logger=request_logger,
  608. return_tokens_as_token_ids=args.return_tokens_as_token_ids,
  609. )
  610. openai_serving_embedding = OpenAIServingEmbedding(
  611. async_engine_client,
  612. model_config,
  613. served_model_names,
  614. request_logger=request_logger,
  615. )
  616. openai_serving_tokenization = OpenAIServingTokenization(
  617. async_engine_client,
  618. model_config,
  619. served_model_names,
  620. lora_modules=args.lora_modules,
  621. request_logger=request_logger,
  622. chat_template=args.chat_template,
  623. )
  624. app.root_path = args.root_path
  625. tokenizer = get_tokenizer(
  626. tokenizer_name=engine_args.tokenizer,
  627. tokenizer_mode=engine_args.tokenizer_mode,
  628. trust_remote_code=engine_args.trust_remote_code,
  629. revision=engine_args.revision,
  630. )
  631. if args.launch_kobold_api:
  632. _set_badwords(tokenizer, model_config.hf_config)
  633. return app
  634. async def run_server(args, **uvicorn_kwargs) -> None:
  635. async with build_async_engine_client(args) as async_engine_client:
  636. # If None, creation of the client failed and we exit.
  637. if async_engine_client is None:
  638. return
  639. app = await init_app(async_engine_client, args)
  640. protocol = "https" if args.ssl_certfile else "http"
  641. root_path = args.root_path.rstrip("/") if args.root_path else ""
  642. host_name = args.host if args.host else "localhost"
  643. port_str = str(args.port)
  644. if SERVE_KOBOLD_LITE_UI:
  645. ui_url = f"{protocol}://{host_name}:{port_str}{root_path}/"
  646. logger.info(f"Kobold Lite UI: {ui_url}")
  647. logger.info(f"Documentation: {protocol}://{host_name}:{port_str}{root_path}/redoc") # noqa: E501
  648. logger.info(f"Completions API: {protocol}://{host_name}:{port_str}{root_path}/v1/completions") # noqa: E501
  649. logger.info(f"Chat API: {protocol}://{host_name}:{port_str}{root_path}/v1/chat/completions") # noqa: E501
  650. logger.info(f"Embeddings API: {protocol}://{host_name}:{port_str}{root_path}/v1/embeddings") # noqa: E501
  651. logger.info(f"Tokenization API: {protocol}://{host_name}:{port_str}{root_path}/v1/tokenize") # noqa: E501
  652. shutdown_task = await serve_http(
  653. app,
  654. engine=async_engine_client,
  655. host=args.host,
  656. port=args.port,
  657. log_level=args.uvicorn_log_level,
  658. timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
  659. ssl_keyfile=args.ssl_keyfile,
  660. ssl_certfile=args.ssl_certfile,
  661. ssl_ca_certs=args.ssl_ca_certs,
  662. ssl_cert_reqs=args.ssl_cert_reqs,
  663. **uvicorn_kwargs,
  664. )
  665. # NB: Await server shutdown only after the backend context is exited
  666. await shutdown_task
  667. if __name__ == "__main__":
  668. # NOTE:
  669. # This section should be in sync with aphrodite/endpoints/cli.py
  670. # for CLI entrypoints.
  671. parser = FlexibleArgumentParser(
  672. description="Aphrodite OpenAI-Compatible RESTful API Server")
  673. parser = make_arg_parser(parser)
  674. args = parser.parse_args()
  675. asyncio.run(run_server(args))