api_server.py 42 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180
  1. import asyncio
  2. import copy
  3. import importlib
  4. import inspect
  5. import json
  6. import multiprocessing
  7. import os
  8. import re
  9. import tempfile
  10. from argparse import Namespace
  11. from contextlib import asynccontextmanager
  12. from distutils.util import strtobool
  13. from http import HTTPStatus
  14. from typing import (Any, AsyncGenerator, AsyncIterator, Dict, List, Optional,
  15. Set, Tuple)
  16. import yaml
  17. from fastapi import APIRouter, FastAPI, Request, UploadFile
  18. from fastapi.exceptions import RequestValidationError
  19. from fastapi.middleware.cors import CORSMiddleware
  20. from fastapi.responses import (HTMLResponse, JSONResponse, Response,
  21. StreamingResponse)
  22. from loguru import logger
  23. from starlette.routing import Mount
  24. import aphrodite.common.envs as envs
  25. from aphrodite.common.config import ModelConfig
  26. from aphrodite.common.outputs import RequestOutput
  27. from aphrodite.common.sampling_params import _SAMPLING_EPS, SamplingParams
  28. from aphrodite.common.utils import (FlexibleArgumentParser,
  29. get_open_zmq_ipc_path, in_windows,
  30. random_uuid)
  31. from aphrodite.endpoints.logger import RequestLogger
  32. from aphrodite.endpoints.openai.args import make_arg_parser
  33. # yapf: disable
  34. from aphrodite.endpoints.openai.protocol import (ChatCompletionRequest,
  35. ChatCompletionResponse,
  36. CompletionRequest,
  37. DetokenizeRequest,
  38. DetokenizeResponse,
  39. EmbeddingRequest,
  40. ErrorResponse,
  41. KAIGenerationInputSchema,
  42. TokenizeRequest,
  43. TokenizeResponse)
  44. from aphrodite.endpoints.openai.rpc.client import AsyncEngineRPCClient
  45. from aphrodite.endpoints.openai.rpc.server import run_rpc_server
  46. # yapf: enable
  47. from aphrodite.endpoints.openai.serving_chat import OpenAIServingChat
  48. from aphrodite.endpoints.openai.serving_completions import (
  49. OpenAIServingCompletion)
  50. from aphrodite.endpoints.openai.serving_embedding import OpenAIServingEmbedding
  51. from aphrodite.endpoints.openai.serving_engine import (LoRAModulePath,
  52. PromptAdapterPath)
  53. from aphrodite.endpoints.openai.serving_tokenization import (
  54. OpenAIServingTokenization)
  55. from aphrodite.engine.args_tools import AsyncEngineArgs
  56. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  57. from aphrodite.engine.protocol import AsyncEngineClient
  58. from aphrodite.modeling.model_loader.weight_utils import get_model_config_yaml
  59. from aphrodite.server import serve_http
  60. from aphrodite.transformers_utils.tokenizer import get_tokenizer
  61. from aphrodite.version import __version__ as APHRODITE_VERSION
  62. if in_windows():
  63. import winloop as uvloop
  64. else:
  65. import uvloop
  66. TIMEOUT_KEEP_ALIVE = 5 # seconds
  67. SERVE_KOBOLD_LITE_UI = strtobool(os.getenv("SERVE_KOBOLD_LITE_UI", "1"))
  68. async_engine_client: AsyncEngineClient
  69. engine_args: AsyncEngineArgs
  70. openai_serving_chat: OpenAIServingChat
  71. openai_serving_completion: OpenAIServingCompletion
  72. openai_serving_embedding: OpenAIServingEmbedding
  73. openai_serving_tokenization: OpenAIServingTokenization
  74. router = APIRouter()
  75. kai_api = APIRouter()
  76. extra_api = APIRouter()
  77. kobold_lite_ui = ""
  78. sampler_json = ""
  79. gen_cache: dict = {}
  80. prometheus_multiproc_dir: tempfile.TemporaryDirectory
  81. model_is_loaded = True
  82. _running_tasks: Set[asyncio.Task] = set()
  83. def model_is_embedding(model_name: str, trust_remote_code: bool) -> bool:
  84. return ModelConfig(model=model_name,
  85. tokenizer=model_name,
  86. tokenizer_mode="auto",
  87. trust_remote_code=trust_remote_code,
  88. seed=0,
  89. dtype="auto").embedding_mode
  90. @asynccontextmanager
  91. async def lifespan(app: FastAPI):
  92. async def _force_log():
  93. while True:
  94. await asyncio.sleep(10)
  95. await async_engine_client.do_log_stats()
  96. if not engine_args.disable_log_stats:
  97. task = asyncio.create_task(_force_log())
  98. _running_tasks.add(task)
  99. task.add_done_callback(_running_tasks.remove)
  100. yield
  101. @asynccontextmanager
  102. async def build_async_engine_client(
  103. args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]:
  104. """
  105. Create AsyncEngineClient, either:
  106. - in-process using the AsyncAphrodite Directly
  107. - multiprocess using AsyncAphrodite RPC
  108. Returns the Client or None if the creation failed.
  109. """
  110. # Context manager to handle async_engine_client lifecycle
  111. # Ensures everything is shutdown and cleaned up on error/exit
  112. global engine_args
  113. engine_args = AsyncEngineArgs.from_cli_args(args)
  114. # Backend itself still global for the silly lil' health handler
  115. global async_engine_client
  116. # If manually triggered or embedding model, use AsyncAphrodite in process.
  117. # TODO: support embedding model via RPC.
  118. if (model_is_embedding(args.model, args.trust_remote_code)
  119. or args.disable_frontend_multiprocessing):
  120. async_engine_client = AsyncAphrodite.from_engine_args(engine_args)
  121. yield async_engine_client
  122. return
  123. # Otherwise, use the multiprocessing AsyncAphrodite.
  124. else:
  125. if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
  126. # Make TemporaryDirectory for prometheus multiprocessing
  127. # Note: global TemporaryDirectory will be automatically
  128. # cleaned up upon exit.
  129. global prometheus_multiproc_dir
  130. prometheus_multiproc_dir = tempfile.TemporaryDirectory()
  131. os.environ[
  132. "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
  133. else:
  134. logger.warning(
  135. "Found PROMETHEUS_MULTIPROC_DIR was set by user. "
  136. "This directory must be wiped between Aphrodite runs or "
  137. "you will find inaccurate metrics. Unset the variable "
  138. "and Aphrodite will properly handle cleanup.")
  139. # Select random path for IPC.
  140. rpc_path = get_open_zmq_ipc_path()
  141. logger.info(f"Multiprocessing frontend to use {rpc_path} for RPC Path."
  142. )
  143. # Build RPCClient, which conforms to AsyncEngineClient Protocol.
  144. # NOTE: Actually, this is not true yet. We still need to support
  145. # embedding models via RPC (see TODO above)
  146. rpc_client = AsyncEngineRPCClient(rpc_path)
  147. async_engine_client = rpc_client # type: ignore
  148. # Start RPCServer in separate process (holds the AsyncAphrodite).
  149. context = multiprocessing.get_context("spawn")
  150. # the current process might have CUDA context,
  151. # so we need to spawn a new process
  152. rpc_server_process = context.Process(
  153. target=run_rpc_server,
  154. args=(engine_args, rpc_path))
  155. rpc_server_process.start()
  156. logger.info(
  157. f"Started engine process with PID {rpc_server_process.pid}")
  158. try:
  159. while True:
  160. try:
  161. await async_engine_client.setup()
  162. break
  163. except TimeoutError:
  164. if not rpc_server_process.is_alive():
  165. logger.error(
  166. "RPCServer process died before responding "
  167. "to readiness probe")
  168. yield None
  169. return
  170. yield async_engine_client
  171. finally:
  172. # Ensure rpc server process was terminated
  173. rpc_server_process.terminate()
  174. # Close all open connections to the backend
  175. async_engine_client.close()
  176. # Wait for server process to join
  177. rpc_server_process.join()
  178. # Lazy import for prometheus multiprocessing.
  179. # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
  180. # before prometheus_client is imported.
  181. # See https://prometheus.github.io/client_python/multiprocess/
  182. from prometheus_client import multiprocess
  183. multiprocess.mark_process_dead(rpc_server_process.pid)
  184. async def _maybe_switch_model(
  185. request_model: str, app_state,
  186. raw_request: Request) -> Optional[ErrorResponse]:
  187. """Switch to requested model if different from currently loaded one."""
  188. global model_is_loaded, async_engine_client, engine_args, served_model_names
  189. if not model_is_loaded:
  190. return None
  191. if request_model in served_model_names:
  192. return None
  193. if not app_state.args.allow_inline_model_loading:
  194. return JSONResponse(
  195. content={
  196. "error": {
  197. "message": "Requested model is not currently loaded. "
  198. "Inline model loading is disabled. Enable it with "
  199. "--allow-inline-model-loading.",
  200. "type": "invalid_request_error",
  201. "code": "model_not_loaded"
  202. }
  203. },
  204. status_code=400
  205. ) # type: ignore
  206. api_key = envs.APHRODITE_API_KEY or app_state.args.api_keys
  207. admin_key = envs.APHRODITE_ADMIN_KEY or app_state.args.admin_key
  208. if api_key:
  209. api_key_header = raw_request.headers.get("x-api-key")
  210. auth_header = raw_request.headers.get("Authorization")
  211. if not admin_key:
  212. return JSONResponse(
  213. content={
  214. "error": {
  215. "message": "Admin key not configured. "
  216. "Inline model loading is disabled.",
  217. "type": "invalid_request_error",
  218. "code": "admin_key_required"
  219. }
  220. },
  221. status_code=401
  222. ) # type: ignore
  223. if not (api_key_header == admin_key or
  224. auth_header == f"Bearer {admin_key}"):
  225. return JSONResponse(
  226. content={
  227. "error": {
  228. "message": "Admin privileges required for inline "
  229. "model loading.",
  230. "type": "invalid_request_error",
  231. "code": "unauthorized"
  232. }
  233. },
  234. status_code=401
  235. ) # type: ignore
  236. logger.info(f"Switching from {served_model_names[0]} to {request_model}")
  237. try:
  238. args = app_state.args
  239. if not args.disable_frontend_multiprocessing:
  240. await async_engine_client.kill()
  241. else:
  242. await async_engine_client.shutdown_background_loop()
  243. model_is_loaded = False
  244. yaml_config = get_model_config_yaml(request_model, args.download_dir)
  245. if yaml_config:
  246. parser = FlexibleArgumentParser()
  247. parser = make_arg_parser(parser)
  248. engine_args = parser.parse_args([]) # empty args
  249. for key, value in yaml_config.items():
  250. if hasattr(engine_args, key):
  251. setattr(engine_args, key, value)
  252. engine_args.model = request_model
  253. engine_args = AsyncEngineArgs.from_cli_args(engine_args)
  254. else:
  255. # Fallback to minimal config
  256. engine_args = AsyncEngineArgs(model=request_model)
  257. if (model_is_embedding(engine_args.model, engine_args.trust_remote_code)
  258. or args.disable_frontend_multiprocessing):
  259. async_engine_client = AsyncAphrodite.from_engine_args(engine_args)
  260. await async_engine_client.setup()
  261. else:
  262. if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
  263. global prometheus_multiproc_dir
  264. prometheus_multiproc_dir = tempfile.TemporaryDirectory()
  265. os.environ[
  266. "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
  267. rpc_path = get_open_zmq_ipc_path()
  268. logger.info(
  269. f"Multiprocessing frontend to use {rpc_path} for RPC Path.")
  270. rpc_client = AsyncEngineRPCClient(rpc_path)
  271. async_engine_client = rpc_client
  272. context = multiprocessing.get_context("spawn")
  273. rpc_server_process = context.Process(
  274. target=run_rpc_server,
  275. args=(engine_args, rpc_path))
  276. rpc_server_process.start()
  277. logger.info(
  278. f"Started engine process with PID {rpc_server_process.pid}")
  279. while True:
  280. try:
  281. await async_engine_client.setup()
  282. break
  283. except TimeoutError as e:
  284. if not rpc_server_process.is_alive():
  285. raise RuntimeError(
  286. "RPC Server died before responding to "
  287. "readiness probe") from e
  288. new_args = copy.deepcopy(args)
  289. new_args.model = request_model
  290. app = await init_app(async_engine_client, new_args) # noqa: F841
  291. served_model_names = [request_model]
  292. model_is_loaded = True
  293. return None
  294. except Exception as e:
  295. error_msg = f"Error while switching models: {str(e)}"
  296. logger.error(error_msg)
  297. return JSONResponse(
  298. content={
  299. "error": {
  300. "message": error_msg,
  301. "type": "invalid_request_error",
  302. "code": "model_load_error"
  303. }
  304. },
  305. status_code=500
  306. ) # type: ignore
  307. def mount_metrics(app: FastAPI):
  308. # Lazy import for prometheus multiprocessing.
  309. # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
  310. # before prometheus_client is imported.
  311. # See https://prometheus.github.io/client_python/multiprocess/
  312. from prometheus_client import (CollectorRegistry, make_asgi_app,
  313. multiprocess)
  314. prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
  315. if prometheus_multiproc_dir_path is not None:
  316. logger.info(f"Aphrodite to use {prometheus_multiproc_dir_path} "
  317. "as PROMETHEUS_MULTIPROC_DIR")
  318. registry = CollectorRegistry()
  319. multiprocess.MultiProcessCollector(registry)
  320. # Add prometheus asgi middleware to route /metrics requests
  321. metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
  322. else:
  323. # Add prometheus asgi middleware to route /metrics requests
  324. metrics_route = Mount("/metrics", make_asgi_app())
  325. # Workaround for 307 Redirect for /metrics
  326. metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')
  327. app.routes.append(metrics_route)
  328. @router.delete("/v1/model/unload")
  329. async def unload_model(request: Request):
  330. """Unload the current model and shut down the server."""
  331. logger.info("Received request to unload model.")
  332. try:
  333. args = request.app.state.args
  334. if not args.disable_frontend_multiprocessing:
  335. await async_engine_client.kill()
  336. else:
  337. await async_engine_client.shutdown_background_loop()
  338. global model_is_loaded
  339. model_is_loaded = False
  340. return JSONResponse(
  341. content={
  342. "status": "success",
  343. "message": "Model unloaded successfully"
  344. }
  345. )
  346. except Exception as e:
  347. error_msg = f"Error while unloading model: {str(e)}"
  348. logger.error(error_msg)
  349. return JSONResponse(
  350. content={"status": "error", "message": error_msg},
  351. status_code=500
  352. )
  353. @router.post("/v1/model/load")
  354. async def load_model(config_file: UploadFile):
  355. """Load a model using a YAML configuration file."""
  356. global model_is_loaded, async_engine_client, engine_args
  357. if model_is_loaded:
  358. return JSONResponse(
  359. content={
  360. "error": {
  361. "message": "A model is already loaded. "
  362. "Please unload it first.",
  363. "type": "invalid_request_error",
  364. "code": "model_already_loaded"
  365. }
  366. },
  367. status_code=400
  368. )
  369. try:
  370. # basically the same logic as the one in aphrodite.endpoints.cli
  371. config_text = await config_file.read()
  372. config: Dict[Any, Any] = yaml.safe_load(config_text)
  373. args = []
  374. for key, value in config.items():
  375. key = key.replace('_', '-')
  376. if isinstance(value, bool):
  377. if value:
  378. args.append(f"--{key}")
  379. elif isinstance(value, (list, tuple)):
  380. if key in ['lora-modules', 'prompt-adapters']:
  381. for item in value:
  382. args.append(f"--{key}")
  383. args.append(f"{item['name']}={item['path']}")
  384. else:
  385. for item in value:
  386. args.append(f"--{key}")
  387. args.append(str(item))
  388. else:
  389. args.append(f"--{key}")
  390. args.append(str(value))
  391. parser = FlexibleArgumentParser()
  392. parser = make_arg_parser(parser)
  393. parsed_args = parser.parse_args(args)
  394. if (model_is_embedding(parsed_args.model, parsed_args.trust_remote_code)
  395. or parsed_args.disable_frontend_multiprocessing):
  396. async_engine_client = AsyncAphrodite.from_engine_args(engine_args)
  397. await async_engine_client.setup()
  398. else:
  399. if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
  400. global prometheus_multiproc_dir
  401. prometheus_multiproc_dir = tempfile.TemporaryDirectory()
  402. os.environ[
  403. "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
  404. rpc_path = get_open_zmq_ipc_path()
  405. logger.info(
  406. f"Multiprocessing frontend to use {rpc_path} for RPC Path.")
  407. rpc_client = AsyncEngineRPCClient(rpc_path)
  408. async_engine_client = rpc_client
  409. context = multiprocessing.get_context("spawn")
  410. rpc_server_process = context.Process(
  411. target=run_rpc_server,
  412. args=(engine_args, rpc_path))
  413. rpc_server_process.start()
  414. logger.info(
  415. f"Started engine process with PID {rpc_server_process.pid}")
  416. while True:
  417. try:
  418. await async_engine_client.setup()
  419. break
  420. except TimeoutError as e:
  421. if not rpc_server_process.is_alive():
  422. raise RuntimeError(
  423. "RPC Server died before responding to readiness "
  424. "probe") from e
  425. app = await init_app(async_engine_client, parsed_args) # noqa: F841
  426. model_is_loaded = True
  427. return JSONResponse(
  428. content={
  429. "status": "success",
  430. "message": "Model loaded successfully"
  431. }
  432. )
  433. except Exception as e:
  434. error_msg = f"Error while loading model: {str(e)}"
  435. logger.error(error_msg)
  436. return JSONResponse(
  437. content={
  438. "error": {
  439. "message": error_msg,
  440. "type": "invalid_request_error",
  441. "code": "model_load_error"
  442. }
  443. },
  444. status_code=500
  445. )
  446. @router.get("/health")
  447. async def health() -> Response:
  448. """Health check."""
  449. await async_engine_client.check_health()
  450. return Response(status_code=200)
  451. @router.post("/v1/tokenize")
  452. async def tokenize(request: TokenizeRequest):
  453. if not model_is_loaded:
  454. return JSONResponse(
  455. content={
  456. "status": "error",
  457. "message": "No model loaded."
  458. },
  459. status_code=500
  460. )
  461. generator = await openai_serving_tokenization.create_tokenize(request)
  462. if isinstance(generator, ErrorResponse):
  463. return JSONResponse(content=generator.model_dump(),
  464. status_code=generator.code)
  465. else:
  466. assert isinstance(generator, TokenizeResponse)
  467. return JSONResponse(content=generator.model_dump())
  468. @router.post("/v1/detokenize")
  469. async def detokenize(request: DetokenizeRequest):
  470. if not model_is_loaded:
  471. return JSONResponse(
  472. content={
  473. "status": "error",
  474. "message": "No model loaded."
  475. },
  476. status_code=500
  477. )
  478. generator = await openai_serving_tokenization.create_detokenize(request)
  479. if isinstance(generator, ErrorResponse):
  480. return JSONResponse(content=generator.model_dump(),
  481. status_code=generator.code)
  482. else:
  483. assert isinstance(generator, DetokenizeResponse)
  484. return JSONResponse(content=generator.model_dump())
  485. @router.get("/v1/models")
  486. async def show_available_models():
  487. models = await openai_serving_completion.show_available_models()
  488. return JSONResponse(content=models.model_dump())
  489. @router.get("/version")
  490. async def show_version():
  491. ver = {"version": APHRODITE_VERSION}
  492. return JSONResponse(content=ver)
  493. @router.get("/.well-known/serviceinfo")
  494. async def serviceinfo():
  495. """Return service information including version, API endpoints,
  496. and documentation URLs."""
  497. return JSONResponse(content={
  498. "version": 0.2,
  499. "software": {
  500. "name": "Aphrodite Engine",
  501. "version": APHRODITE_VERSION,
  502. "repository": "https://github.com/PygmalionAI/aphrodite-engine",
  503. "homepage": "https://aphrodite.pygmalion.chat",
  504. "logo": "https://pygmalion.chat/icons/favicon.ico",
  505. },
  506. "api": {
  507. "openai": {
  508. "name": "OpenAI API",
  509. "rel_url": "/v1",
  510. "documentation": "/redoc",
  511. "version": 1,
  512. },
  513. "koboldai": {
  514. "name": "KoboldAI API",
  515. "rel_url": "/api",
  516. "documentation": "/redoc",
  517. "version": 1,
  518. }
  519. }
  520. })
  521. @router.post("/v1/chat/completions")
  522. async def create_chat_completion(request: ChatCompletionRequest,
  523. raw_request: Request):
  524. error_check_ret = await _maybe_switch_model(
  525. request.model, raw_request.app.state, raw_request)
  526. if error_check_ret is not None:
  527. return error_check_ret
  528. generator = await openai_serving_chat.create_chat_completion(
  529. request, raw_request)
  530. if isinstance(generator, ErrorResponse):
  531. return JSONResponse(content=generator.model_dump(),
  532. status_code=generator.code)
  533. if request.stream:
  534. return StreamingResponse(content=generator,
  535. media_type="text/event-stream")
  536. else:
  537. assert isinstance(generator, ChatCompletionResponse)
  538. return JSONResponse(content=generator.model_dump())
  539. @router.post("/v1/completions")
  540. async def create_completion(request: CompletionRequest, raw_request: Request):
  541. error_check_ret = await _maybe_switch_model(
  542. request.model, raw_request.app.state, raw_request)
  543. if error_check_ret is not None:
  544. return error_check_ret
  545. generator = await openai_serving_completion.create_completion(
  546. request, raw_request)
  547. if isinstance(generator, ErrorResponse):
  548. return JSONResponse(content=generator.model_dump(),
  549. status_code=generator.code)
  550. if request.stream:
  551. return StreamingResponse(content=generator,
  552. media_type="text/event-stream")
  553. else:
  554. return JSONResponse(content=generator.model_dump())
  555. @router.post("/v1/embeddings")
  556. async def create_embedding(request: EmbeddingRequest, raw_request: Request):
  557. if not model_is_loaded:
  558. return JSONResponse(
  559. content={
  560. "status": "error",
  561. "message": "No model loaded."
  562. },
  563. status_code=500
  564. )
  565. generator = await openai_serving_embedding.create_embedding(
  566. request, raw_request)
  567. if isinstance(generator, ErrorResponse):
  568. return JSONResponse(content=generator.model_dump(),
  569. status_code=generator.code)
  570. else:
  571. return JSONResponse(content=generator.model_dump())
  572. @router.post("/v1/lora/load")
  573. async def load_lora(lora: LoRAModulePath):
  574. if not model_is_loaded:
  575. return JSONResponse(
  576. content={
  577. "status": "error",
  578. "message": "No model loaded."
  579. },
  580. status_code=500
  581. )
  582. openai_serving_completion.add_lora(lora)
  583. if engine_args.enable_lora is False:
  584. logger.error("LoRA is not enabled in the engine. "
  585. "Please start the server with the "
  586. "--enable-lora flag!")
  587. return JSONResponse(content={"status": "success"})
  588. @router.delete("/v1/lora/unload")
  589. async def unload_lora(lora_name: str):
  590. if not model_is_loaded:
  591. return JSONResponse(
  592. content={
  593. "status": "error",
  594. "message": "No model loaded."
  595. },
  596. status_code=500
  597. )
  598. openai_serving_completion.remove_lora(lora_name)
  599. return JSONResponse(content={"status": "success"})
  600. @router.post("/v1/soft_prompt/load")
  601. async def load_soft_prompt(soft_prompt: PromptAdapterPath):
  602. if not model_is_loaded:
  603. return JSONResponse(
  604. content={
  605. "status": "error",
  606. "message": "No model loaded."
  607. },
  608. status_code=500
  609. )
  610. openai_serving_completion.add_prompt_adapter(soft_prompt)
  611. if engine_args.enable_prompt_adapter is False:
  612. logger.error("Prompt Adapter is not enabled in the engine. "
  613. "Please start the server with the "
  614. "--enable-prompt-adapter flag!")
  615. return JSONResponse(content={"status": "success"})
  616. @router.delete("/v1/soft_prompt/unload")
  617. async def unload_soft_prompt(soft_prompt_name: str):
  618. if not model_is_loaded:
  619. return JSONResponse(
  620. content={
  621. "status": "error",
  622. "message": "No model loaded."
  623. },
  624. status_code=500
  625. )
  626. openai_serving_completion.remove_prompt_adapter(soft_prompt_name)
  627. return JSONResponse(content={"status": "success"})
  628. # ============ KoboldAI API ============ #
  629. badwordsids: List[int] = []
  630. def _set_badwords(tokenizer, hf_config): # pylint: disable=redefined-outer-name
  631. # pylint: disable=global-variable-undefined
  632. global badwordsids
  633. if hf_config.bad_words_ids is not None:
  634. badwordsids = hf_config.bad_words_ids
  635. return
  636. badwordsids = [
  637. v for k, v in tokenizer.get_vocab().items()
  638. if any(c in str(k) for c in "[]")
  639. ]
  640. if tokenizer.pad_token_id in badwordsids:
  641. badwordsids.remove(tokenizer.pad_token_id)
  642. badwordsids.append(tokenizer.eos_token_id)
  643. def prepare_engine_payload(
  644. kai_payload: KAIGenerationInputSchema
  645. ) -> Tuple[SamplingParams, List[int]]:
  646. """Create SamplingParams and truncated input tokens for AsyncEngine"""
  647. if not kai_payload.genkey:
  648. kai_payload.genkey = f"kai-{random_uuid()}"
  649. kai_payload.top_k = kai_payload.top_k if kai_payload.top_k != 0.0 else -1
  650. kai_payload.tfs = max(_SAMPLING_EPS, kai_payload.tfs)
  651. if kai_payload.temperature < _SAMPLING_EPS:
  652. kai_payload.n = 1
  653. kai_payload.top_p = 1.0
  654. kai_payload.top_k = -1
  655. sampling_params = SamplingParams(
  656. n=kai_payload.n,
  657. best_of=kai_payload.n,
  658. repetition_penalty=kai_payload.rep_pen,
  659. temperature=kai_payload.temperature,
  660. smoothing_factor=kai_payload.smoothing_factor,
  661. smoothing_curve=kai_payload.smoothing_curve,
  662. tfs=kai_payload.tfs,
  663. top_p=kai_payload.top_p,
  664. top_k=kai_payload.top_k,
  665. top_a=kai_payload.top_a,
  666. min_p=kai_payload.min_p,
  667. typical_p=kai_payload.typical,
  668. eta_cutoff=kai_payload.eta_cutoff,
  669. epsilon_cutoff=kai_payload.eps_cutoff,
  670. stop=kai_payload.stop_sequence,
  671. include_stop_str_in_output=kai_payload.include_stop_str_in_output,
  672. custom_token_bans=badwordsids
  673. if kai_payload.use_default_badwordsids else [],
  674. max_tokens=kai_payload.max_length,
  675. seed=kai_payload.sampler_seed,
  676. xtc_probability=kai_payload.xtc_probability,
  677. xtc_threshold=kai_payload.xtc_threshold,
  678. )
  679. max_input_tokens = max(
  680. 1, kai_payload.max_context_length - kai_payload.max_length)
  681. input_tokens = tokenizer(kai_payload.prompt).input_ids[-max_input_tokens:]
  682. return sampling_params, input_tokens
  683. @kai_api.post("/generate")
  684. async def generate(kai_payload: KAIGenerationInputSchema) -> JSONResponse:
  685. sampling_params, input_tokens = prepare_engine_payload(kai_payload)
  686. result_generator = async_engine_client.generate(
  687. {
  688. "prompt": kai_payload.prompt,
  689. "prompt_token_ids": input_tokens,
  690. },
  691. sampling_params,
  692. kai_payload.genkey,
  693. )
  694. final_res: RequestOutput = None
  695. previous_output = ""
  696. async for res in result_generator:
  697. final_res = res
  698. new_chunk = res.outputs[0].text[len(previous_output):]
  699. previous_output += new_chunk
  700. gen_cache[kai_payload.genkey] = previous_output
  701. assert final_res is not None
  702. del gen_cache[kai_payload.genkey]
  703. return JSONResponse(
  704. {"results": [{
  705. "text": output.text
  706. } for output in final_res.outputs]})
  707. @extra_api.post("/generate/stream")
  708. async def generate_stream(
  709. kai_payload: KAIGenerationInputSchema) -> StreamingResponse:
  710. sampling_params, input_tokens = prepare_engine_payload(kai_payload)
  711. results_generator = async_engine_client.generate(
  712. {
  713. "prompt": kai_payload.prompt,
  714. "prompt_token_ids": input_tokens,
  715. },
  716. sampling_params,
  717. kai_payload.genkey,
  718. )
  719. async def stream_kobold() -> AsyncGenerator[bytes, None]:
  720. previous_output = ""
  721. async for res in results_generator:
  722. new_chunk = res.outputs[0].text[len(previous_output):]
  723. previous_output += new_chunk
  724. yield b"event: message\n"
  725. yield f"data: {json.dumps({'token': new_chunk})}\n\n".encode()
  726. return StreamingResponse(stream_kobold(),
  727. headers={
  728. "Cache-Control": "no-cache",
  729. "Connection": "keep-alive",
  730. },
  731. media_type="text/event-stream")
  732. @extra_api.post("/generate/check")
  733. @extra_api.get("/generate/check")
  734. async def check_generation(request: Request):
  735. text = ""
  736. try:
  737. request_dict = await request.json()
  738. if "genkey" in request_dict and request_dict["genkey"] in gen_cache:
  739. text = gen_cache[request_dict["genkey"]]
  740. except json.JSONDecodeError:
  741. pass
  742. return JSONResponse({"results": [{"text": text}]})
  743. @extra_api.post("/abort")
  744. async def abort_generation(request: Request):
  745. try:
  746. request_dict = await request.json()
  747. if "genkey" in request_dict:
  748. await async_engine_client.abort(request_dict["genkey"])
  749. except json.JSONDecodeError:
  750. pass
  751. return JSONResponse({})
  752. @extra_api.post("/tokencount")
  753. async def count_tokens(request: TokenizeRequest):
  754. """Tokenize string and return token count"""
  755. generator = await openai_serving_tokenization.create_tokenize(request)
  756. return JSONResponse({"value": generator.model_dump()["tokens"]})
  757. @kai_api.get("/info/version")
  758. async def get_version():
  759. """Impersonate KAI"""
  760. return JSONResponse({"result": "1.2.4"})
  761. @kai_api.get("/model")
  762. async def get_model():
  763. return JSONResponse({"result": f"aphrodite/{served_model_names[0]}"})
  764. @kai_api.get("/config/soft_prompts_list")
  765. async def get_available_softprompts():
  766. """Stub for compatibility"""
  767. return JSONResponse({"values": []})
  768. @kai_api.get("/config/soft_prompt")
  769. async def get_current_softprompt():
  770. """Stub for compatibility"""
  771. return JSONResponse({"value": ""})
  772. @kai_api.put("/config/soft_prompt")
  773. async def set_current_softprompt():
  774. """Stub for compatibility"""
  775. return JSONResponse({})
  776. @kai_api.get("/config/max_length")
  777. async def get_max_length() -> JSONResponse:
  778. max_length = args.max_length
  779. return JSONResponse({"value": max_length})
  780. @kai_api.get("/config/max_context_length")
  781. @extra_api.get("/true_max_context_length")
  782. async def get_max_context_length() -> JSONResponse:
  783. max_context_length = engine_args.max_model_len
  784. return JSONResponse({"value": max_context_length})
  785. @extra_api.get("/preloadstory")
  786. async def get_preloaded_story() -> JSONResponse:
  787. """Stub for compatibility"""
  788. return JSONResponse({})
  789. @extra_api.get("/version")
  790. async def get_extra_version():
  791. """Impersonate KoboldCpp"""
  792. return JSONResponse({"result": "KoboldCpp", "version": "1.63"})
  793. @router.get("/")
  794. async def get_kobold_lite_ui():
  795. """Serves a cached copy of the Kobold Lite UI, loading it from disk
  796. on demand if needed. Can be disabled with SERVE_KOBOLD_LITE_UI=0."""
  797. if not SERVE_KOBOLD_LITE_UI:
  798. return JSONResponse(content={"error": "Kobold Lite UI is disabled"},
  799. status_code=404)
  800. global kobold_lite_ui
  801. if kobold_lite_ui == "":
  802. scriptpath = os.path.dirname(os.path.abspath(__file__))
  803. klitepath = os.path.join(scriptpath, "../kobold/klite.embd")
  804. klitepath = os.path.normpath(klitepath) # Normalize the path
  805. if os.path.exists(klitepath):
  806. with open(klitepath, "r", encoding="utf-8") as f:
  807. kobold_lite_ui = f.read()
  808. else:
  809. logger.error("Kobold Lite UI not found at " + klitepath)
  810. return HTMLResponse(content=kobold_lite_ui)
  811. # ============ KoboldAI API ============ #
  812. def build_app(args: Namespace) -> FastAPI:
  813. app = FastAPI(lifespan=lifespan)
  814. app.include_router(router)
  815. app.root_path = args.root_path
  816. app.state.args = args
  817. if args.launch_kobold_api:
  818. logger.warning("Kobold API is now enabled by default. "
  819. "This flag will be removed in the future.")
  820. app.include_router(kai_api, prefix="/api/v1")
  821. app.include_router(kai_api,
  822. prefix="/api/latest",
  823. include_in_schema=False)
  824. app.include_router(extra_api, prefix="/api/extra")
  825. mount_metrics(app)
  826. app.add_middleware(
  827. CORSMiddleware,
  828. allow_origins=args.allowed_origins,
  829. allow_credentials=args.allow_credentials,
  830. allow_methods=args.allowed_methods,
  831. allow_headers=args.allowed_headers,
  832. )
  833. @app.exception_handler(RequestValidationError)
  834. async def validation_exception_handler(_, exc):
  835. err = openai_serving_completion.create_error_response(message=str(exc))
  836. return JSONResponse(err.model_dump(),
  837. status_code=HTTPStatus.BAD_REQUEST)
  838. if token := envs.APHRODITE_API_KEY or args.api_keys:
  839. admin_key = os.environ.get("APHRODITE_ADMIN_KEY") or args.admin_key
  840. if admin_key is None:
  841. logger.warning("Admin key not provided. Admin operations will "
  842. "be disabled.")
  843. @app.middleware("http")
  844. async def authentication(request: Request, call_next):
  845. if not request.url.path.startswith(("/v1", "/api")):
  846. return await call_next(request)
  847. # Browsers may send OPTIONS requests to check CORS headers
  848. # before sending the actual request. We should allow these
  849. # requests to pass through without authentication.
  850. # See https://github.com/PygmalionAI/aphrodite-engine/issues/434
  851. if request.method == "OPTIONS":
  852. return await call_next(request)
  853. auth_header = request.headers.get("Authorization")
  854. api_key_header = request.headers.get("x-api-key")
  855. if request.url.path.startswith(
  856. ("/v1/lora", "/v1/soft_prompt", "/v1/model")):
  857. if admin_key is not None and (
  858. api_key_header == admin_key or
  859. auth_header == "Bearer " + admin_key
  860. ):
  861. return await call_next(request)
  862. return JSONResponse(content={"error": "Unauthorized"},
  863. status_code=401)
  864. if (auth_header == f"Bearer {token}" or api_key_header == token or
  865. (admin_key is not None and
  866. (api_key_header == admin_key or
  867. auth_header == f"Bearer {admin_key}"))):
  868. return await call_next(request)
  869. return JSONResponse(
  870. content={"error": "Unauthorized"}, status_code=401)
  871. for middleware in args.middleware:
  872. module_path, object_name = middleware.rsplit(".", 1)
  873. imported = getattr(importlib.import_module(module_path), object_name)
  874. if inspect.isclass(imported):
  875. app.add_middleware(imported)
  876. elif inspect.iscoroutinefunction(imported):
  877. app.middleware("http")(imported)
  878. else:
  879. raise ValueError(f"Invalid middleware {middleware}. "
  880. f"Must be a function or a class.")
  881. return app
  882. async def init_app(
  883. async_engine_client: AsyncEngineClient,
  884. args: Namespace,
  885. ) -> FastAPI:
  886. global api_server_args
  887. api_server_args = args
  888. app = build_app(args)
  889. logger.debug(f"args: {args}")
  890. global served_model_names
  891. if args.served_model_name is not None:
  892. served_model_names = args.served_model_name
  893. else:
  894. served_model_names = [args.model]
  895. if args.uvloop:
  896. uvloop.install()
  897. global tokenizer
  898. model_config = await async_engine_client.get_model_config()
  899. if args.disable_log_requests:
  900. request_logger = None
  901. else:
  902. request_logger = RequestLogger(max_log_len=args.max_log_len)
  903. global openai_serving_chat
  904. global openai_serving_completion
  905. global openai_serving_embedding
  906. global openai_serving_tokenization
  907. openai_serving_chat = OpenAIServingChat(
  908. async_engine_client,
  909. model_config,
  910. served_model_names,
  911. args.response_role,
  912. lora_modules=args.lora_modules,
  913. prompt_adapters=args.prompt_adapters,
  914. request_logger=request_logger,
  915. chat_template=args.chat_template,
  916. return_tokens_as_token_ids=args.return_tokens_as_token_ids,
  917. )
  918. openai_serving_completion = OpenAIServingCompletion(
  919. async_engine_client,
  920. model_config,
  921. served_model_names,
  922. lora_modules=args.lora_modules,
  923. prompt_adapters=args.prompt_adapters,
  924. request_logger=request_logger,
  925. return_tokens_as_token_ids=args.return_tokens_as_token_ids,
  926. )
  927. openai_serving_embedding = OpenAIServingEmbedding(
  928. async_engine_client,
  929. model_config,
  930. served_model_names,
  931. request_logger=request_logger,
  932. )
  933. openai_serving_tokenization = OpenAIServingTokenization(
  934. async_engine_client,
  935. model_config,
  936. served_model_names,
  937. lora_modules=args.lora_modules,
  938. request_logger=request_logger,
  939. chat_template=args.chat_template,
  940. )
  941. app.root_path = args.root_path
  942. tokenizer = get_tokenizer(
  943. tokenizer_name=engine_args.tokenizer,
  944. tokenizer_mode=engine_args.tokenizer_mode,
  945. trust_remote_code=engine_args.trust_remote_code,
  946. revision=engine_args.revision,
  947. )
  948. if args.launch_kobold_api:
  949. _set_badwords(tokenizer, model_config.hf_config)
  950. return app
  951. async def run_server(args, **uvicorn_kwargs) -> None:
  952. async with build_async_engine_client(args) as async_engine_client:
  953. # If None, creation of the client failed and we exit.
  954. if async_engine_client is None:
  955. return
  956. app = await init_app(async_engine_client, args)
  957. protocol = "https" if args.ssl_certfile else "http"
  958. root_path = args.root_path.rstrip("/") if args.root_path else ""
  959. host_name = args.host if args.host else "localhost"
  960. port_str = str(args.port)
  961. if SERVE_KOBOLD_LITE_UI:
  962. ui_url = f"{protocol}://{host_name}:{port_str}{root_path}/"
  963. logger.info(f"Kobold Lite UI: {ui_url}")
  964. logger.info(f"Documentation: {protocol}://{host_name}:{port_str}{root_path}/redoc") # noqa: E501
  965. logger.info(f"Completions API: {protocol}://{host_name}:{port_str}{root_path}/v1/completions") # noqa: E501
  966. logger.info(f"Chat API: {protocol}://{host_name}:{port_str}{root_path}/v1/chat/completions") # noqa: E501
  967. logger.info(f"Embeddings API: {protocol}://{host_name}:{port_str}{root_path}/v1/embeddings") # noqa: E501
  968. logger.info(f"Tokenization API: {protocol}://{host_name}:{port_str}{root_path}/v1/tokenize") # noqa: E501
  969. shutdown_task = await serve_http(
  970. app,
  971. engine=async_engine_client,
  972. host=args.host,
  973. port=args.port,
  974. log_level=args.uvicorn_log_level,
  975. timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
  976. ssl_keyfile=args.ssl_keyfile,
  977. ssl_certfile=args.ssl_certfile,
  978. ssl_ca_certs=args.ssl_ca_certs,
  979. ssl_cert_reqs=args.ssl_cert_reqs,
  980. **uvicorn_kwargs,
  981. )
  982. # NB: Await server shutdown only after the backend context is exited
  983. await shutdown_task
  984. if __name__ == "__main__":
  985. # NOTE:
  986. # This section should be in sync with aphrodite/endpoints/cli.py
  987. # for CLI entrypoints.
  988. parser = FlexibleArgumentParser(
  989. description="Aphrodite OpenAI-Compatible RESTful API Server")
  990. parser = make_arg_parser(parser)
  991. args = parser.parse_args()
  992. asyncio.run(run_server(args))