api_server.py 42 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183
  1. import asyncio
  2. import copy
  3. import importlib
  4. import inspect
  5. import json
  6. import multiprocessing
  7. import os
  8. import re
  9. import tempfile
  10. from argparse import Namespace
  11. from contextlib import asynccontextmanager
  12. from distutils.util import strtobool
  13. from http import HTTPStatus
  14. from typing import (Any, AsyncGenerator, AsyncIterator, Dict, List, Optional,
  15. Set, Tuple)
  16. import yaml
  17. from fastapi import APIRouter, FastAPI, Request, UploadFile
  18. from fastapi.exceptions import RequestValidationError
  19. from fastapi.middleware.cors import CORSMiddleware
  20. from fastapi.responses import (HTMLResponse, JSONResponse, Response,
  21. StreamingResponse)
  22. from loguru import logger
  23. from starlette.routing import Mount
  24. import aphrodite.common.envs as envs
  25. from aphrodite.common.config import ModelConfig
  26. from aphrodite.common.outputs import RequestOutput
  27. from aphrodite.common.sampling_params import _SAMPLING_EPS, SamplingParams
  28. from aphrodite.common.utils import (FlexibleArgumentParser,
  29. get_open_zmq_ipc_path, in_windows,
  30. random_uuid)
  31. from aphrodite.endpoints.logger import RequestLogger
  32. from aphrodite.endpoints.openai.args import make_arg_parser
  33. # yapf: disable
  34. from aphrodite.endpoints.openai.protocol import (ChatCompletionRequest,
  35. ChatCompletionResponse,
  36. CompletionRequest,
  37. DetokenizeRequest,
  38. DetokenizeResponse,
  39. EmbeddingRequest,
  40. ErrorResponse,
  41. KAIGenerationInputSchema,
  42. TokenizeRequest,
  43. TokenizeResponse)
  44. from aphrodite.endpoints.openai.rpc.client import AsyncEngineRPCClient
  45. from aphrodite.endpoints.openai.rpc.server import run_rpc_server
  46. # yapf: enable
  47. from aphrodite.endpoints.openai.serving_chat import OpenAIServingChat
  48. from aphrodite.endpoints.openai.serving_completions import (
  49. OpenAIServingCompletion)
  50. from aphrodite.endpoints.openai.serving_embedding import OpenAIServingEmbedding
  51. from aphrodite.endpoints.openai.serving_engine import (LoRAModulePath,
  52. PromptAdapterPath)
  53. from aphrodite.endpoints.openai.serving_tokenization import (
  54. OpenAIServingTokenization)
  55. from aphrodite.engine.args_tools import AsyncEngineArgs
  56. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  57. from aphrodite.engine.protocol import AsyncEngineClient
  58. from aphrodite.modeling.model_loader.weight_utils import get_model_config_yaml
  59. from aphrodite.server import serve_http
  60. from aphrodite.transformers_utils.tokenizer import get_tokenizer
  61. from aphrodite.version import __version__ as APHRODITE_VERSION
  62. if in_windows():
  63. import winloop as uvloop
  64. else:
  65. import uvloop
  66. TIMEOUT_KEEP_ALIVE = 5 # seconds
  67. SERVE_KOBOLD_LITE_UI = strtobool(os.getenv("SERVE_KOBOLD_LITE_UI", "1"))
  68. async_engine_client: AsyncEngineClient
  69. engine_args: AsyncEngineArgs
  70. openai_serving_chat: OpenAIServingChat
  71. openai_serving_completion: OpenAIServingCompletion
  72. openai_serving_embedding: OpenAIServingEmbedding
  73. openai_serving_tokenization: OpenAIServingTokenization
  74. router = APIRouter()
  75. kai_api = APIRouter()
  76. extra_api = APIRouter()
  77. kobold_lite_ui = ""
  78. sampler_json = ""
  79. gen_cache: dict = {}
  80. prometheus_multiproc_dir: tempfile.TemporaryDirectory
  81. model_is_loaded = True
  82. _running_tasks: Set[asyncio.Task] = set()
  83. def model_is_embedding(model_name: str, trust_remote_code: bool) -> bool:
  84. return ModelConfig(model=model_name,
  85. tokenizer=model_name,
  86. tokenizer_mode="auto",
  87. trust_remote_code=trust_remote_code,
  88. seed=0,
  89. dtype="auto").embedding_mode
  90. @asynccontextmanager
  91. async def lifespan(app: FastAPI):
  92. async def _force_log():
  93. while True:
  94. await asyncio.sleep(10)
  95. await async_engine_client.do_log_stats()
  96. if not engine_args.disable_log_stats:
  97. task = asyncio.create_task(_force_log())
  98. _running_tasks.add(task)
  99. task.add_done_callback(_running_tasks.remove)
  100. yield
  101. @asynccontextmanager
  102. async def build_async_engine_client(
  103. args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]:
  104. """
  105. Create AsyncEngineClient, either:
  106. - in-process using the AsyncAphrodite Directly
  107. - multiprocess using AsyncAphrodite RPC
  108. Returns the Client or None if the creation failed.
  109. """
  110. # Context manager to handle async_engine_client lifecycle
  111. # Ensures everything is shutdown and cleaned up on error/exit
  112. global engine_args
  113. engine_args = AsyncEngineArgs.from_cli_args(args)
  114. # Backend itself still global for the silly lil' health handler
  115. global async_engine_client
  116. # If manually triggered or embedding model, use AsyncAphrodite in process.
  117. # TODO: support embedding model via RPC.
  118. if (model_is_embedding(args.model, args.trust_remote_code)
  119. or args.disable_frontend_multiprocessing):
  120. async_engine_client = AsyncAphrodite.from_engine_args(engine_args)
  121. yield async_engine_client
  122. return
  123. # Otherwise, use the multiprocessing AsyncAphrodite.
  124. else:
  125. if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
  126. # Make TemporaryDirectory for prometheus multiprocessing
  127. # Note: global TemporaryDirectory will be automatically
  128. # cleaned up upon exit.
  129. global prometheus_multiproc_dir
  130. prometheus_multiproc_dir = tempfile.TemporaryDirectory()
  131. os.environ[
  132. "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
  133. else:
  134. logger.warning(
  135. "Found PROMETHEUS_MULTIPROC_DIR was set by user. "
  136. "This directory must be wiped between Aphrodite runs or "
  137. "you will find inaccurate metrics. Unset the variable "
  138. "and Aphrodite will properly handle cleanup.")
  139. # Select random path for IPC.
  140. rpc_path = get_open_zmq_ipc_path()
  141. logger.info(f"Multiprocessing frontend to use {rpc_path} for RPC Path."
  142. )
  143. # Build RPCClient, which conforms to AsyncEngineClient Protocol.
  144. # NOTE: Actually, this is not true yet. We still need to support
  145. # embedding models via RPC (see TODO above)
  146. rpc_client = AsyncEngineRPCClient(rpc_path)
  147. async_engine_client = rpc_client # type: ignore
  148. # Start RPCServer in separate process (holds the AsyncAphrodite).
  149. context = multiprocessing.get_context("spawn")
  150. # the current process might have CUDA context,
  151. # so we need to spawn a new process
  152. rpc_server_process = context.Process(
  153. target=run_rpc_server,
  154. args=(engine_args, rpc_path))
  155. rpc_server_process.start()
  156. logger.info(
  157. f"Started engine process with PID {rpc_server_process.pid}")
  158. try:
  159. while True:
  160. try:
  161. await async_engine_client.setup()
  162. break
  163. except TimeoutError:
  164. if not rpc_server_process.is_alive():
  165. logger.error(
  166. "RPCServer process died before responding "
  167. "to readiness probe")
  168. yield None
  169. return
  170. yield async_engine_client
  171. finally:
  172. # Ensure rpc server process was terminated
  173. rpc_server_process.terminate()
  174. # Close all open connections to the backend
  175. async_engine_client.close()
  176. # Wait for server process to join
  177. rpc_server_process.join()
  178. # Lazy import for prometheus multiprocessing.
  179. # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
  180. # before prometheus_client is imported.
  181. # See https://prometheus.github.io/client_python/multiprocess/
  182. from prometheus_client import multiprocess
  183. multiprocess.mark_process_dead(rpc_server_process.pid)
  184. async def _maybe_switch_model(
  185. request_model: str, app_state,
  186. raw_request: Request) -> Optional[ErrorResponse]:
  187. """Switch to requested model if different from currently loaded one."""
  188. global model_is_loaded, async_engine_client, engine_args, served_model_names
  189. if not model_is_loaded:
  190. return None
  191. models = await openai_serving_completion.show_available_models()
  192. for model in models.data:
  193. if request_model in (model.id, model.root):
  194. return None
  195. if not app_state.args.allow_inline_model_loading:
  196. return JSONResponse(
  197. content={
  198. "error": {
  199. "message": "Requested model is not currently loaded. "
  200. "Inline model loading is disabled. Enable it with "
  201. "--allow-inline-model-loading.",
  202. "type": "invalid_request_error",
  203. "code": "model_not_loaded"
  204. }
  205. },
  206. status_code=400
  207. ) # type: ignore
  208. api_key = envs.APHRODITE_API_KEY or app_state.args.api_keys
  209. admin_key = envs.APHRODITE_ADMIN_KEY or app_state.args.admin_key
  210. if api_key:
  211. api_key_header = raw_request.headers.get("x-api-key")
  212. auth_header = raw_request.headers.get("Authorization")
  213. if not admin_key:
  214. return JSONResponse(
  215. content={
  216. "error": {
  217. "message": "Admin key not configured. "
  218. "Inline model loading is disabled.",
  219. "type": "invalid_request_error",
  220. "code": "admin_key_required"
  221. }
  222. },
  223. status_code=401
  224. ) # type: ignore
  225. if not (api_key_header == admin_key or
  226. auth_header == f"Bearer {admin_key}"):
  227. return JSONResponse(
  228. content={
  229. "error": {
  230. "message": "Admin privileges required for inline "
  231. "model loading.",
  232. "type": "invalid_request_error",
  233. "code": "unauthorized"
  234. }
  235. },
  236. status_code=401
  237. ) # type: ignore
  238. logger.info(f"Switching from {served_model_names[0]} to {request_model}")
  239. try:
  240. args = app_state.args
  241. if not args.disable_frontend_multiprocessing:
  242. await async_engine_client.kill()
  243. else:
  244. await async_engine_client.shutdown_background_loop()
  245. model_is_loaded = False
  246. yaml_config = get_model_config_yaml(request_model, args.download_dir)
  247. if yaml_config:
  248. parser = FlexibleArgumentParser()
  249. parser = make_arg_parser(parser)
  250. engine_args = parser.parse_args([]) # empty args
  251. for key, value in yaml_config.items():
  252. if hasattr(engine_args, key):
  253. setattr(engine_args, key, value)
  254. engine_args.model = request_model
  255. engine_args = AsyncEngineArgs.from_cli_args(engine_args)
  256. else:
  257. # Fallback to minimal config
  258. engine_args = AsyncEngineArgs(model=request_model)
  259. if (model_is_embedding(engine_args.model, engine_args.trust_remote_code)
  260. or args.disable_frontend_multiprocessing):
  261. async_engine_client = AsyncAphrodite.from_engine_args(engine_args)
  262. await async_engine_client.setup()
  263. else:
  264. if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
  265. global prometheus_multiproc_dir
  266. prometheus_multiproc_dir = tempfile.TemporaryDirectory()
  267. os.environ[
  268. "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
  269. rpc_path = get_open_zmq_ipc_path()
  270. logger.info(
  271. f"Multiprocessing frontend to use {rpc_path} for RPC Path.")
  272. rpc_client = AsyncEngineRPCClient(rpc_path)
  273. async_engine_client = rpc_client
  274. context = multiprocessing.get_context("spawn")
  275. rpc_server_process = context.Process(
  276. target=run_rpc_server,
  277. args=(engine_args, rpc_path))
  278. rpc_server_process.start()
  279. logger.info(
  280. f"Started engine process with PID {rpc_server_process.pid}")
  281. while True:
  282. try:
  283. await async_engine_client.setup()
  284. break
  285. except TimeoutError as e:
  286. if not rpc_server_process.is_alive():
  287. raise RuntimeError(
  288. "RPC Server died before responding to "
  289. "readiness probe") from e
  290. new_args = copy.deepcopy(args)
  291. new_args.model = request_model
  292. app = await init_app(async_engine_client, new_args) # noqa: F841
  293. served_model_names = [request_model]
  294. model_is_loaded = True
  295. return None
  296. except Exception as e:
  297. error_msg = f"Error while switching models: {str(e)}"
  298. logger.error(error_msg)
  299. return JSONResponse(
  300. content={
  301. "error": {
  302. "message": error_msg,
  303. "type": "invalid_request_error",
  304. "code": "model_load_error"
  305. }
  306. },
  307. status_code=500
  308. ) # type: ignore
  309. def mount_metrics(app: FastAPI):
  310. # Lazy import for prometheus multiprocessing.
  311. # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
  312. # before prometheus_client is imported.
  313. # See https://prometheus.github.io/client_python/multiprocess/
  314. from prometheus_client import (CollectorRegistry, make_asgi_app,
  315. multiprocess)
  316. prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
  317. if prometheus_multiproc_dir_path is not None:
  318. logger.info(f"Aphrodite to use {prometheus_multiproc_dir_path} "
  319. "as PROMETHEUS_MULTIPROC_DIR")
  320. registry = CollectorRegistry()
  321. multiprocess.MultiProcessCollector(registry)
  322. # Add prometheus asgi middleware to route /metrics requests
  323. metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
  324. else:
  325. # Add prometheus asgi middleware to route /metrics requests
  326. metrics_route = Mount("/metrics", make_asgi_app())
  327. # Workaround for 307 Redirect for /metrics
  328. metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')
  329. app.routes.append(metrics_route)
  330. @router.delete("/v1/model/unload")
  331. async def unload_model(request: Request):
  332. """Unload the current model and shut down the server."""
  333. logger.info("Received request to unload model.")
  334. try:
  335. args = request.app.state.args
  336. if not args.disable_frontend_multiprocessing:
  337. await async_engine_client.kill()
  338. else:
  339. await async_engine_client.shutdown_background_loop()
  340. global model_is_loaded
  341. model_is_loaded = False
  342. return JSONResponse(
  343. content={
  344. "status": "success",
  345. "message": "Model unloaded successfully"
  346. }
  347. )
  348. except Exception as e:
  349. error_msg = f"Error while unloading model: {str(e)}"
  350. logger.error(error_msg)
  351. return JSONResponse(
  352. content={"status": "error", "message": error_msg},
  353. status_code=500
  354. )
  355. @router.post("/v1/model/load")
  356. async def load_model(config_file: UploadFile):
  357. """Load a model using a YAML configuration file."""
  358. global model_is_loaded, async_engine_client, engine_args
  359. if model_is_loaded:
  360. return JSONResponse(
  361. content={
  362. "error": {
  363. "message": "A model is already loaded. "
  364. "Please unload it first.",
  365. "type": "invalid_request_error",
  366. "code": "model_already_loaded"
  367. }
  368. },
  369. status_code=400
  370. )
  371. try:
  372. # basically the same logic as the one in aphrodite.endpoints.cli
  373. config_text = await config_file.read()
  374. config: Dict[Any, Any] = yaml.safe_load(config_text)
  375. args = []
  376. for key, value in config.items():
  377. key = key.replace('_', '-')
  378. if isinstance(value, bool):
  379. if value:
  380. args.append(f"--{key}")
  381. elif isinstance(value, (list, tuple)):
  382. if key in ['lora-modules', 'prompt-adapters']:
  383. for item in value:
  384. args.append(f"--{key}")
  385. args.append(f"{item['name']}={item['path']}")
  386. else:
  387. for item in value:
  388. args.append(f"--{key}")
  389. args.append(str(item))
  390. else:
  391. args.append(f"--{key}")
  392. args.append(str(value))
  393. parser = FlexibleArgumentParser()
  394. parser = make_arg_parser(parser)
  395. parsed_args = parser.parse_args(args)
  396. if (model_is_embedding(parsed_args.model, parsed_args.trust_remote_code)
  397. or parsed_args.disable_frontend_multiprocessing):
  398. async_engine_client = AsyncAphrodite.from_engine_args(engine_args)
  399. await async_engine_client.setup()
  400. else:
  401. if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
  402. global prometheus_multiproc_dir
  403. prometheus_multiproc_dir = tempfile.TemporaryDirectory()
  404. os.environ[
  405. "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
  406. rpc_path = get_open_zmq_ipc_path()
  407. logger.info(
  408. f"Multiprocessing frontend to use {rpc_path} for RPC Path.")
  409. rpc_client = AsyncEngineRPCClient(rpc_path)
  410. async_engine_client = rpc_client
  411. context = multiprocessing.get_context("spawn")
  412. rpc_server_process = context.Process(
  413. target=run_rpc_server,
  414. args=(engine_args, rpc_path))
  415. rpc_server_process.start()
  416. logger.info(
  417. f"Started engine process with PID {rpc_server_process.pid}")
  418. while True:
  419. try:
  420. await async_engine_client.setup()
  421. break
  422. except TimeoutError as e:
  423. if not rpc_server_process.is_alive():
  424. raise RuntimeError(
  425. "RPC Server died before responding to readiness "
  426. "probe") from e
  427. app = await init_app(async_engine_client, parsed_args) # noqa: F841
  428. model_is_loaded = True
  429. return JSONResponse(
  430. content={
  431. "status": "success",
  432. "message": "Model loaded successfully"
  433. }
  434. )
  435. except Exception as e:
  436. error_msg = f"Error while loading model: {str(e)}"
  437. logger.error(error_msg)
  438. return JSONResponse(
  439. content={
  440. "error": {
  441. "message": error_msg,
  442. "type": "invalid_request_error",
  443. "code": "model_load_error"
  444. }
  445. },
  446. status_code=500
  447. )
  448. @router.get("/health")
  449. async def health() -> Response:
  450. """Health check."""
  451. await async_engine_client.check_health()
  452. return Response(status_code=200)
  453. @router.post("/v1/tokenize")
  454. async def tokenize(request: TokenizeRequest):
  455. if not model_is_loaded:
  456. return JSONResponse(
  457. content={
  458. "status": "error",
  459. "message": "No model loaded."
  460. },
  461. status_code=500
  462. )
  463. generator = await openai_serving_tokenization.create_tokenize(request)
  464. if isinstance(generator, ErrorResponse):
  465. return JSONResponse(content=generator.model_dump(),
  466. status_code=generator.code)
  467. else:
  468. assert isinstance(generator, TokenizeResponse)
  469. return JSONResponse(content=generator.model_dump())
  470. @router.post("/v1/detokenize")
  471. async def detokenize(request: DetokenizeRequest):
  472. if not model_is_loaded:
  473. return JSONResponse(
  474. content={
  475. "status": "error",
  476. "message": "No model loaded."
  477. },
  478. status_code=500
  479. )
  480. generator = await openai_serving_tokenization.create_detokenize(request)
  481. if isinstance(generator, ErrorResponse):
  482. return JSONResponse(content=generator.model_dump(),
  483. status_code=generator.code)
  484. else:
  485. assert isinstance(generator, DetokenizeResponse)
  486. return JSONResponse(content=generator.model_dump())
  487. @router.get("/v1/models")
  488. async def show_available_models():
  489. models = await openai_serving_completion.show_available_models()
  490. return JSONResponse(content=models.model_dump())
  491. @router.get("/version")
  492. async def show_version():
  493. ver = {"version": APHRODITE_VERSION}
  494. return JSONResponse(content=ver)
  495. @router.get("/.well-known/serviceinfo")
  496. async def serviceinfo():
  497. """Return service information including version, API endpoints,
  498. and documentation URLs."""
  499. return JSONResponse(content={
  500. "version": 0.2,
  501. "software": {
  502. "name": "Aphrodite Engine",
  503. "version": APHRODITE_VERSION,
  504. "repository": "https://github.com/PygmalionAI/aphrodite-engine",
  505. "homepage": "https://aphrodite.pygmalion.chat",
  506. "logo": "https://pygmalion.chat/icons/favicon.ico",
  507. },
  508. "api": {
  509. "openai": {
  510. "name": "OpenAI API",
  511. "rel_url": "/v1",
  512. "documentation": "/redoc",
  513. "version": 1,
  514. },
  515. "koboldai": {
  516. "name": "KoboldAI API",
  517. "rel_url": "/api",
  518. "documentation": "/redoc",
  519. "version": 1,
  520. }
  521. }
  522. })
  523. @router.post("/v1/chat/completions")
  524. async def create_chat_completion(request: ChatCompletionRequest,
  525. raw_request: Request):
  526. error_check_ret = await _maybe_switch_model(
  527. request.model, raw_request.app.state, raw_request)
  528. if error_check_ret is not None:
  529. return error_check_ret
  530. generator = await openai_serving_chat.create_chat_completion(
  531. request, raw_request)
  532. if isinstance(generator, ErrorResponse):
  533. return JSONResponse(content=generator.model_dump(),
  534. status_code=generator.code)
  535. if request.stream:
  536. return StreamingResponse(content=generator,
  537. media_type="text/event-stream")
  538. else:
  539. assert isinstance(generator, ChatCompletionResponse)
  540. return JSONResponse(content=generator.model_dump())
  541. @router.post("/v1/completions")
  542. async def create_completion(request: CompletionRequest, raw_request: Request):
  543. error_check_ret = await _maybe_switch_model(
  544. request.model, raw_request.app.state, raw_request)
  545. if error_check_ret is not None:
  546. return error_check_ret
  547. generator = await openai_serving_completion.create_completion(
  548. request, raw_request)
  549. if isinstance(generator, ErrorResponse):
  550. return JSONResponse(content=generator.model_dump(),
  551. status_code=generator.code)
  552. if request.stream:
  553. return StreamingResponse(content=generator,
  554. media_type="text/event-stream")
  555. else:
  556. return JSONResponse(content=generator.model_dump())
  557. @router.post("/v1/embeddings")
  558. async def create_embedding(request: EmbeddingRequest, raw_request: Request):
  559. if not model_is_loaded:
  560. return JSONResponse(
  561. content={
  562. "status": "error",
  563. "message": "No model loaded."
  564. },
  565. status_code=500
  566. )
  567. generator = await openai_serving_embedding.create_embedding(
  568. request, raw_request)
  569. if isinstance(generator, ErrorResponse):
  570. return JSONResponse(content=generator.model_dump(),
  571. status_code=generator.code)
  572. else:
  573. return JSONResponse(content=generator.model_dump())
  574. @router.post("/v1/lora/load")
  575. async def load_lora(lora: LoRAModulePath):
  576. if not model_is_loaded:
  577. return JSONResponse(
  578. content={
  579. "status": "error",
  580. "message": "No model loaded."
  581. },
  582. status_code=500
  583. )
  584. openai_serving_completion.add_lora(lora)
  585. if engine_args.enable_lora is False:
  586. logger.error("LoRA is not enabled in the engine. "
  587. "Please start the server with the "
  588. "--enable-lora flag!")
  589. return JSONResponse(content={"status": "success"})
  590. @router.delete("/v1/lora/unload")
  591. async def unload_lora(lora_name: str):
  592. if not model_is_loaded:
  593. return JSONResponse(
  594. content={
  595. "status": "error",
  596. "message": "No model loaded."
  597. },
  598. status_code=500
  599. )
  600. openai_serving_completion.remove_lora(lora_name)
  601. return JSONResponse(content={"status": "success"})
  602. @router.post("/v1/soft_prompt/load")
  603. async def load_soft_prompt(soft_prompt: PromptAdapterPath):
  604. if not model_is_loaded:
  605. return JSONResponse(
  606. content={
  607. "status": "error",
  608. "message": "No model loaded."
  609. },
  610. status_code=500
  611. )
  612. openai_serving_completion.add_prompt_adapter(soft_prompt)
  613. if engine_args.enable_prompt_adapter is False:
  614. logger.error("Prompt Adapter is not enabled in the engine. "
  615. "Please start the server with the "
  616. "--enable-prompt-adapter flag!")
  617. return JSONResponse(content={"status": "success"})
  618. @router.delete("/v1/soft_prompt/unload")
  619. async def unload_soft_prompt(soft_prompt_name: str):
  620. if not model_is_loaded:
  621. return JSONResponse(
  622. content={
  623. "status": "error",
  624. "message": "No model loaded."
  625. },
  626. status_code=500
  627. )
  628. openai_serving_completion.remove_prompt_adapter(soft_prompt_name)
  629. return JSONResponse(content={"status": "success"})
  630. # ============ KoboldAI API ============ #
  631. badwordsids: List[int] = []
  632. def _set_badwords(tokenizer, hf_config): # pylint: disable=redefined-outer-name
  633. # pylint: disable=global-variable-undefined
  634. global badwordsids
  635. if hf_config.bad_words_ids is not None:
  636. badwordsids = hf_config.bad_words_ids
  637. return
  638. badwordsids = [
  639. v for k, v in tokenizer.get_vocab().items()
  640. if any(c in str(k) for c in "[]")
  641. ]
  642. if tokenizer.pad_token_id in badwordsids:
  643. badwordsids.remove(tokenizer.pad_token_id)
  644. badwordsids.append(tokenizer.eos_token_id)
  645. def prepare_engine_payload(
  646. kai_payload: KAIGenerationInputSchema
  647. ) -> Tuple[SamplingParams, List[int]]:
  648. """Create SamplingParams and truncated input tokens for AsyncEngine"""
  649. if not kai_payload.genkey:
  650. kai_payload.genkey = f"kai-{random_uuid()}"
  651. kai_payload.top_k = kai_payload.top_k if kai_payload.top_k != 0.0 else -1
  652. kai_payload.tfs = max(_SAMPLING_EPS, kai_payload.tfs)
  653. if kai_payload.temperature < _SAMPLING_EPS:
  654. kai_payload.n = 1
  655. kai_payload.top_p = 1.0
  656. kai_payload.top_k = -1
  657. sampling_params = SamplingParams(
  658. n=kai_payload.n,
  659. best_of=kai_payload.n,
  660. repetition_penalty=kai_payload.rep_pen,
  661. temperature=kai_payload.temperature,
  662. smoothing_factor=kai_payload.smoothing_factor,
  663. smoothing_curve=kai_payload.smoothing_curve,
  664. tfs=kai_payload.tfs,
  665. top_p=kai_payload.top_p,
  666. top_k=kai_payload.top_k,
  667. top_a=kai_payload.top_a,
  668. min_p=kai_payload.min_p,
  669. typical_p=kai_payload.typical,
  670. eta_cutoff=kai_payload.eta_cutoff,
  671. epsilon_cutoff=kai_payload.eps_cutoff,
  672. stop=kai_payload.stop_sequence,
  673. include_stop_str_in_output=kai_payload.include_stop_str_in_output,
  674. custom_token_bans=badwordsids
  675. if kai_payload.use_default_badwordsids else [],
  676. max_tokens=kai_payload.max_length,
  677. seed=kai_payload.sampler_seed,
  678. xtc_probability=kai_payload.xtc_probability,
  679. xtc_threshold=kai_payload.xtc_threshold,
  680. )
  681. max_input_tokens = max(
  682. 1, kai_payload.max_context_length - kai_payload.max_length)
  683. input_tokens = tokenizer(kai_payload.prompt).input_ids[-max_input_tokens:]
  684. return sampling_params, input_tokens
  685. @kai_api.post("/generate")
  686. async def generate(kai_payload: KAIGenerationInputSchema) -> JSONResponse:
  687. sampling_params, input_tokens = prepare_engine_payload(kai_payload)
  688. result_generator = async_engine_client.generate(
  689. {
  690. "prompt": kai_payload.prompt,
  691. "prompt_token_ids": input_tokens,
  692. },
  693. sampling_params,
  694. kai_payload.genkey,
  695. )
  696. final_res: RequestOutput = None
  697. previous_output = ""
  698. async for res in result_generator:
  699. final_res = res
  700. new_chunk = res.outputs[0].text[len(previous_output):]
  701. previous_output += new_chunk
  702. gen_cache[kai_payload.genkey] = previous_output
  703. assert final_res is not None
  704. del gen_cache[kai_payload.genkey]
  705. return JSONResponse(
  706. {"results": [{
  707. "text": output.text
  708. } for output in final_res.outputs]})
  709. @extra_api.post("/generate/stream")
  710. async def generate_stream(
  711. kai_payload: KAIGenerationInputSchema) -> StreamingResponse:
  712. sampling_params, input_tokens = prepare_engine_payload(kai_payload)
  713. results_generator = async_engine_client.generate(
  714. {
  715. "prompt": kai_payload.prompt,
  716. "prompt_token_ids": input_tokens,
  717. },
  718. sampling_params,
  719. kai_payload.genkey,
  720. )
  721. async def stream_kobold() -> AsyncGenerator[bytes, None]:
  722. previous_output = ""
  723. async for res in results_generator:
  724. new_chunk = res.outputs[0].text[len(previous_output):]
  725. previous_output += new_chunk
  726. yield b"event: message\n"
  727. yield f"data: {json.dumps({'token': new_chunk})}\n\n".encode()
  728. return StreamingResponse(stream_kobold(),
  729. headers={
  730. "Cache-Control": "no-cache",
  731. "Connection": "keep-alive",
  732. },
  733. media_type="text/event-stream")
  734. @extra_api.post("/generate/check")
  735. @extra_api.get("/generate/check")
  736. async def check_generation(request: Request):
  737. text = ""
  738. try:
  739. request_dict = await request.json()
  740. if "genkey" in request_dict and request_dict["genkey"] in gen_cache:
  741. text = gen_cache[request_dict["genkey"]]
  742. except json.JSONDecodeError:
  743. pass
  744. return JSONResponse({"results": [{"text": text}]})
  745. @extra_api.post("/abort")
  746. async def abort_generation(request: Request):
  747. try:
  748. request_dict = await request.json()
  749. if "genkey" in request_dict:
  750. await async_engine_client.abort(request_dict["genkey"])
  751. except json.JSONDecodeError:
  752. pass
  753. return JSONResponse({})
  754. @extra_api.post("/tokencount")
  755. async def count_tokens(request: TokenizeRequest):
  756. """Tokenize string and return token count"""
  757. generator = await openai_serving_tokenization.create_tokenize(request)
  758. return JSONResponse({"value": generator.model_dump()["tokens"]})
  759. @kai_api.get("/info/version")
  760. async def get_version():
  761. """Impersonate KAI"""
  762. return JSONResponse({"result": "1.2.4"})
  763. @kai_api.get("/model")
  764. async def get_model():
  765. return JSONResponse({"result": f"aphrodite/{served_model_names[0]}"})
  766. @kai_api.get("/config/soft_prompts_list")
  767. async def get_available_softprompts():
  768. """Stub for compatibility"""
  769. return JSONResponse({"values": []})
  770. @kai_api.get("/config/soft_prompt")
  771. async def get_current_softprompt():
  772. """Stub for compatibility"""
  773. return JSONResponse({"value": ""})
  774. @kai_api.put("/config/soft_prompt")
  775. async def set_current_softprompt():
  776. """Stub for compatibility"""
  777. return JSONResponse({})
  778. @kai_api.get("/config/max_length")
  779. async def get_max_length() -> JSONResponse:
  780. max_length = args.max_length
  781. return JSONResponse({"value": max_length})
  782. @kai_api.get("/config/max_context_length")
  783. @extra_api.get("/true_max_context_length")
  784. async def get_max_context_length() -> JSONResponse:
  785. max_context_length = engine_args.max_model_len
  786. return JSONResponse({"value": max_context_length})
  787. @extra_api.get("/preloadstory")
  788. async def get_preloaded_story() -> JSONResponse:
  789. """Stub for compatibility"""
  790. return JSONResponse({})
  791. @extra_api.get("/version")
  792. async def get_extra_version():
  793. """Impersonate KoboldCpp"""
  794. return JSONResponse({"result": "KoboldCpp", "version": "1.63"})
  795. @router.get("/")
  796. async def get_kobold_lite_ui():
  797. """Serves a cached copy of the Kobold Lite UI, loading it from disk
  798. on demand if needed. Can be disabled with SERVE_KOBOLD_LITE_UI=0."""
  799. if not SERVE_KOBOLD_LITE_UI:
  800. return JSONResponse(content={"error": "Kobold Lite UI is disabled"},
  801. status_code=404)
  802. global kobold_lite_ui
  803. if kobold_lite_ui == "":
  804. scriptpath = os.path.dirname(os.path.abspath(__file__))
  805. klitepath = os.path.join(scriptpath, "../kobold/klite.embd")
  806. klitepath = os.path.normpath(klitepath) # Normalize the path
  807. if os.path.exists(klitepath):
  808. with open(klitepath, "r", encoding="utf-8") as f:
  809. kobold_lite_ui = f.read()
  810. else:
  811. logger.error("Kobold Lite UI not found at " + klitepath)
  812. return HTMLResponse(content=kobold_lite_ui)
  813. # ============ KoboldAI API ============ #
  814. def build_app(args: Namespace) -> FastAPI:
  815. app = FastAPI(lifespan=lifespan)
  816. app.include_router(router)
  817. app.root_path = args.root_path
  818. app.state.args = args
  819. if args.launch_kobold_api:
  820. logger.warning("Kobold API is now enabled by default. "
  821. "This flag will be removed in the future.")
  822. app.include_router(kai_api, prefix="/api/v1")
  823. app.include_router(kai_api,
  824. prefix="/api/latest",
  825. include_in_schema=False)
  826. app.include_router(extra_api, prefix="/api/extra")
  827. mount_metrics(app)
  828. app.add_middleware(
  829. CORSMiddleware,
  830. allow_origins=args.allowed_origins,
  831. allow_credentials=args.allow_credentials,
  832. allow_methods=args.allowed_methods,
  833. allow_headers=args.allowed_headers,
  834. )
  835. @app.exception_handler(RequestValidationError)
  836. async def validation_exception_handler(_, exc):
  837. err = openai_serving_completion.create_error_response(message=str(exc))
  838. return JSONResponse(err.model_dump(),
  839. status_code=HTTPStatus.BAD_REQUEST)
  840. if token := envs.APHRODITE_API_KEY or args.api_keys:
  841. admin_key = os.environ.get("APHRODITE_ADMIN_KEY") or args.admin_key
  842. if admin_key is None:
  843. logger.warning("Admin key not provided. Admin operations will "
  844. "be disabled.")
  845. @app.middleware("http")
  846. async def authentication(request: Request, call_next):
  847. if not request.url.path.startswith(("/v1", "/api")):
  848. return await call_next(request)
  849. # Browsers may send OPTIONS requests to check CORS headers
  850. # before sending the actual request. We should allow these
  851. # requests to pass through without authentication.
  852. # See https://github.com/PygmalionAI/aphrodite-engine/issues/434
  853. if request.method == "OPTIONS":
  854. return await call_next(request)
  855. auth_header = request.headers.get("Authorization")
  856. api_key_header = request.headers.get("x-api-key")
  857. if request.url.path.startswith(
  858. ("/v1/lora", "/v1/soft_prompt", "/v1/model")):
  859. if admin_key is not None and (
  860. api_key_header == admin_key or
  861. auth_header == "Bearer " + admin_key
  862. ):
  863. return await call_next(request)
  864. return JSONResponse(content={"error": "Unauthorized"},
  865. status_code=401)
  866. if (auth_header == f"Bearer {token}" or api_key_header == token or
  867. (admin_key is not None and
  868. (api_key_header == admin_key or
  869. auth_header == f"Bearer {admin_key}"))):
  870. return await call_next(request)
  871. return JSONResponse(
  872. content={"error": "Unauthorized"}, status_code=401)
  873. for middleware in args.middleware:
  874. module_path, object_name = middleware.rsplit(".", 1)
  875. imported = getattr(importlib.import_module(module_path), object_name)
  876. if inspect.isclass(imported):
  877. app.add_middleware(imported)
  878. elif inspect.iscoroutinefunction(imported):
  879. app.middleware("http")(imported)
  880. else:
  881. raise ValueError(f"Invalid middleware {middleware}. "
  882. f"Must be a function or a class.")
  883. return app
  884. async def init_app(
  885. async_engine_client: AsyncEngineClient,
  886. args: Namespace,
  887. ) -> FastAPI:
  888. global api_server_args
  889. api_server_args = args
  890. app = build_app(args)
  891. logger.debug(f"args: {args}")
  892. global served_model_names
  893. if args.served_model_name is not None:
  894. served_model_names = args.served_model_name
  895. else:
  896. served_model_names = [args.model]
  897. if args.uvloop:
  898. uvloop.install()
  899. global tokenizer
  900. model_config = await async_engine_client.get_model_config()
  901. if args.disable_log_requests:
  902. request_logger = None
  903. else:
  904. request_logger = RequestLogger(max_log_len=args.max_log_len)
  905. global openai_serving_chat
  906. global openai_serving_completion
  907. global openai_serving_embedding
  908. global openai_serving_tokenization
  909. openai_serving_chat = OpenAIServingChat(
  910. async_engine_client,
  911. model_config,
  912. served_model_names,
  913. args.response_role,
  914. lora_modules=args.lora_modules,
  915. prompt_adapters=args.prompt_adapters,
  916. request_logger=request_logger,
  917. chat_template=args.chat_template,
  918. return_tokens_as_token_ids=args.return_tokens_as_token_ids,
  919. )
  920. openai_serving_completion = OpenAIServingCompletion(
  921. async_engine_client,
  922. model_config,
  923. served_model_names,
  924. lora_modules=args.lora_modules,
  925. prompt_adapters=args.prompt_adapters,
  926. request_logger=request_logger,
  927. return_tokens_as_token_ids=args.return_tokens_as_token_ids,
  928. )
  929. openai_serving_embedding = OpenAIServingEmbedding(
  930. async_engine_client,
  931. model_config,
  932. served_model_names,
  933. request_logger=request_logger,
  934. )
  935. openai_serving_tokenization = OpenAIServingTokenization(
  936. async_engine_client,
  937. model_config,
  938. served_model_names,
  939. lora_modules=args.lora_modules,
  940. request_logger=request_logger,
  941. chat_template=args.chat_template,
  942. )
  943. app.root_path = args.root_path
  944. tokenizer = get_tokenizer(
  945. tokenizer_name=engine_args.tokenizer,
  946. tokenizer_mode=engine_args.tokenizer_mode,
  947. trust_remote_code=engine_args.trust_remote_code,
  948. revision=engine_args.revision,
  949. )
  950. if args.launch_kobold_api:
  951. _set_badwords(tokenizer, model_config.hf_config)
  952. return app
  953. async def run_server(args, **uvicorn_kwargs) -> None:
  954. async with build_async_engine_client(args) as async_engine_client:
  955. # If None, creation of the client failed and we exit.
  956. if async_engine_client is None:
  957. return
  958. app = await init_app(async_engine_client, args)
  959. protocol = "https" if args.ssl_certfile else "http"
  960. root_path = args.root_path.rstrip("/") if args.root_path else ""
  961. host_name = args.host if args.host else "localhost"
  962. port_str = str(args.port)
  963. if SERVE_KOBOLD_LITE_UI:
  964. ui_url = f"{protocol}://{host_name}:{port_str}{root_path}/"
  965. logger.info(f"Kobold Lite UI: {ui_url}")
  966. logger.info(f"Documentation: {protocol}://{host_name}:{port_str}{root_path}/redoc") # noqa: E501
  967. logger.info(f"Completions API: {protocol}://{host_name}:{port_str}{root_path}/v1/completions") # noqa: E501
  968. logger.info(f"Chat API: {protocol}://{host_name}:{port_str}{root_path}/v1/chat/completions") # noqa: E501
  969. logger.info(f"Embeddings API: {protocol}://{host_name}:{port_str}{root_path}/v1/embeddings") # noqa: E501
  970. logger.info(f"Tokenization API: {protocol}://{host_name}:{port_str}{root_path}/v1/tokenize") # noqa: E501
  971. shutdown_task = await serve_http(
  972. app,
  973. engine=async_engine_client,
  974. host=args.host,
  975. port=args.port,
  976. log_level=args.uvicorn_log_level,
  977. timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
  978. ssl_keyfile=args.ssl_keyfile,
  979. ssl_certfile=args.ssl_certfile,
  980. ssl_ca_certs=args.ssl_ca_certs,
  981. ssl_cert_reqs=args.ssl_cert_reqs,
  982. **uvicorn_kwargs,
  983. )
  984. # NB: Await server shutdown only after the backend context is exited
  985. await shutdown_task
  986. if __name__ == "__main__":
  987. # NOTE:
  988. # This section should be in sync with aphrodite/endpoints/cli.py
  989. # for CLI entrypoints.
  990. parser = FlexibleArgumentParser(
  991. description="Aphrodite OpenAI-Compatible RESTful API Server")
  992. parser = make_arg_parser(parser)
  993. args = parser.parse_args()
  994. asyncio.run(run_server(args))