api_server.py 45 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261
  1. import asyncio
  2. import copy
  3. import importlib
  4. import inspect
  5. import json
  6. import multiprocessing
  7. import os
  8. import re
  9. import signal
  10. import tempfile
  11. from argparse import Namespace
  12. from contextlib import asynccontextmanager
  13. from distutils.util import strtobool
  14. from functools import partial
  15. from http import HTTPStatus
  16. from typing import (Any, AsyncGenerator, AsyncIterator, Dict, List, Optional,
  17. Set, Tuple)
  18. import yaml
  19. from fastapi import APIRouter, FastAPI, Request, UploadFile
  20. from fastapi.exceptions import RequestValidationError
  21. from fastapi.middleware.cors import CORSMiddleware
  22. from fastapi.responses import (HTMLResponse, JSONResponse, Response,
  23. StreamingResponse)
  24. from loguru import logger
  25. from starlette.datastructures import State
  26. from starlette.routing import Mount
  27. import aphrodite.common.envs as envs
  28. from aphrodite.common.config import ModelConfig
  29. from aphrodite.common.outputs import RequestOutput
  30. from aphrodite.common.sampling_params import _SAMPLING_EPS, SamplingParams
  31. from aphrodite.common.utils import (FlexibleArgumentParser,
  32. get_open_zmq_ipc_path, in_windows,
  33. random_uuid)
  34. from aphrodite.endpoints.logger import RequestLogger
  35. from aphrodite.endpoints.openai.args import make_arg_parser
  36. # yapf: disable
  37. from aphrodite.endpoints.openai.protocol import (ChatCompletionRequest,
  38. ChatCompletionResponse,
  39. CompletionRequest,
  40. DetokenizeRequest,
  41. DetokenizeResponse,
  42. EmbeddingRequest,
  43. ErrorResponse,
  44. KAIGenerationInputSchema,
  45. TokenizeRequest,
  46. TokenizeResponse)
  47. from aphrodite.endpoints.openai.rpc.client import AsyncEngineRPCClient
  48. from aphrodite.endpoints.openai.rpc.server import run_rpc_server
  49. # yapf: enable
  50. from aphrodite.endpoints.openai.serving_chat import OpenAIServingChat
  51. from aphrodite.endpoints.openai.serving_completions import (
  52. OpenAIServingCompletion)
  53. from aphrodite.endpoints.openai.serving_embedding import OpenAIServingEmbedding
  54. from aphrodite.endpoints.openai.serving_engine import (LoRAModulePath,
  55. PromptAdapterPath)
  56. from aphrodite.endpoints.openai.serving_tokenization import (
  57. OpenAIServingTokenization)
  58. from aphrodite.engine.args_tools import AsyncEngineArgs
  59. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  60. from aphrodite.engine.protocol import AsyncEngineClient
  61. from aphrodite.modeling.model_loader.weight_utils import get_model_config_yaml
  62. from aphrodite.server import serve_http
  63. from aphrodite.transformers_utils.tokenizer import get_tokenizer
  64. from aphrodite.version import __version__ as APHRODITE_VERSION
  65. if in_windows():
  66. import winloop as uvloop
  67. else:
  68. import uvloop
  69. TIMEOUT_KEEP_ALIVE = 5 # seconds
  70. SERVE_KOBOLD_LITE_UI = strtobool(os.getenv("SERVE_KOBOLD_LITE_UI", "1"))
  71. router = APIRouter()
  72. kai_api = APIRouter()
  73. extra_api = APIRouter()
  74. kobold_lite_ui = ""
  75. sampler_json = ""
  76. gen_cache: dict = {}
  77. prometheus_multiproc_dir: tempfile.TemporaryDirectory
  78. model_is_loaded = True
  79. _running_tasks: Set[asyncio.Task] = set()
  80. def model_is_embedding(model_name: str, trust_remote_code: bool,
  81. quantization: Optional[str]) -> bool:
  82. return ModelConfig(model=model_name,
  83. tokenizer=model_name,
  84. tokenizer_mode="auto",
  85. trust_remote_code=trust_remote_code,
  86. quantization=quantization,
  87. seed=0,
  88. dtype="auto").embedding_mode
  89. @asynccontextmanager
  90. async def lifespan(app: FastAPI):
  91. try:
  92. if app.state.log_stats:
  93. async_engine_client = app.state.engine_client
  94. async def _force_log():
  95. while True:
  96. await asyncio.sleep(10)
  97. await async_engine_client.do_log_stats()
  98. task = asyncio.create_task(_force_log())
  99. _running_tasks.add(task)
  100. task.add_done_callback(_running_tasks.remove)
  101. else:
  102. task = None
  103. try:
  104. yield
  105. finally:
  106. if task is not None:
  107. task.cancel()
  108. finally:
  109. # Ensure app state including engine ref is gc'd
  110. del app.state
  111. @asynccontextmanager
  112. async def build_async_engine_client(
  113. args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]:
  114. # Context manager to handle async_engine_client lifecycle
  115. # Ensures everything is shutdown and cleaned up on error/exit
  116. engine_args = AsyncEngineArgs.from_cli_args(args)
  117. async with build_async_engine_client_from_engine_args(
  118. engine_args, args.disable_frontend_multiprocessing) as engine:
  119. yield engine
  120. @asynccontextmanager
  121. async def build_async_engine_client_from_engine_args(
  122. engine_args: AsyncEngineArgs,
  123. disable_frontend_multiprocessing: bool = False,
  124. ) -> AsyncIterator[Optional[AsyncEngineClient]]:
  125. """
  126. Create AsyncEngineClient, either:
  127. - in-process using the AsyncAphrodite Directly
  128. - multiprocess using AsyncAphrodite RPC
  129. Returns the Client or None if the creation failed.
  130. """
  131. # If manually triggered or embedding model, use AsyncAphrodite in process.
  132. # TODO: support embedding model via RPC.
  133. if (model_is_embedding(engine_args.model, engine_args.trust_remote_code,
  134. engine_args.quantization)
  135. or disable_frontend_multiprocessing):
  136. engine_config = engine_args.create_engine_config()
  137. uses_ray = getattr(AsyncAphrodite._get_executor_cls(engine_config),
  138. "uses_ray", False)
  139. build_engine = partial(AsyncAphrodite.from_engine_args,
  140. engine_args=engine_args,
  141. engine_config=engine_config)
  142. if uses_ray:
  143. # Must run in main thread with ray for its signal handlers to work
  144. engine_client = build_engine()
  145. else:
  146. engine_client = await asyncio.get_running_loop().run_in_executor(
  147. None, build_engine)
  148. yield engine_client
  149. return
  150. # Otherwise, use the multiprocessing AsyncAphrodite.
  151. else:
  152. if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
  153. # Make TemporaryDirectory for prometheus multiprocessing
  154. # Note: global TemporaryDirectory will be automatically
  155. # cleaned up upon exit.
  156. global prometheus_multiproc_dir
  157. prometheus_multiproc_dir = tempfile.TemporaryDirectory()
  158. os.environ[
  159. "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
  160. else:
  161. logger.warning(
  162. "Found PROMETHEUS_MULTIPROC_DIR was set by user. "
  163. "This directory must be wiped between Aphrodite runs or "
  164. "you will find inaccurate metrics. Unset the variable "
  165. "and Aphrodite will properly handle cleanup.")
  166. # Select random path for IPC.
  167. rpc_path = get_open_zmq_ipc_path()
  168. logger.info(f"Multiprocessing frontend to use {rpc_path} for RPC Path."
  169. )
  170. # Build RPCClient, which conforms to AsyncEngineClient Protocol.
  171. # NOTE: Actually, this is not true yet. We still need to support
  172. # embedding models via RPC (see TODO above)
  173. rpc_client = AsyncEngineRPCClient(rpc_path)
  174. # Start RPCServer in separate process (holds the AsyncAphrodite).
  175. context = multiprocessing.get_context("spawn")
  176. # the current process might have CUDA context,
  177. # so we need to spawn a new process
  178. rpc_server_process = context.Process(
  179. target=run_rpc_server,
  180. args=(engine_args, rpc_path))
  181. rpc_server_process.start()
  182. logger.info(
  183. f"Started engine process with PID {rpc_server_process.pid}")
  184. try:
  185. while True:
  186. try:
  187. await rpc_client.setup()
  188. break
  189. except TimeoutError:
  190. if not rpc_server_process.is_alive():
  191. logger.error(
  192. "RPCServer process died before responding "
  193. "to readiness probe")
  194. yield None
  195. return
  196. yield rpc_client # type: ignore[misc]
  197. finally:
  198. # Ensure rpc server process was terminated
  199. rpc_server_process.terminate()
  200. # Close all open connections to the backend
  201. rpc_client.close()
  202. # Wait for server process to join
  203. rpc_server_process.join()
  204. # Lazy import for prometheus multiprocessing.
  205. # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
  206. # before prometheus_client is imported.
  207. # See https://prometheus.github.io/client_python/multiprocess/
  208. from prometheus_client import multiprocess
  209. multiprocess.mark_process_dead(rpc_server_process.pid)
  210. async def _maybe_switch_model(
  211. request_model: str, app_state,
  212. raw_request: Request) -> Optional[ErrorResponse]:
  213. """Switch to requested model if different from currently loaded one."""
  214. global model_is_loaded, async_engine_client, engine_args, served_model_names
  215. if not model_is_loaded:
  216. return None
  217. models = await completion(raw_request).show_available_models()
  218. for model in models.data:
  219. if request_model in (model.id, model.root):
  220. return None
  221. if not app_state.args.allow_inline_model_loading:
  222. return JSONResponse(
  223. content={
  224. "error": {
  225. "message": "Requested model is not currently loaded. "
  226. "Inline model loading is disabled. Enable it with "
  227. "--allow-inline-model-loading.",
  228. "type": "invalid_request_error",
  229. "code": "model_not_loaded"
  230. }
  231. },
  232. status_code=400
  233. ) # type: ignore
  234. # Authentication checks
  235. api_key = envs.APHRODITE_API_KEY or app_state.args.api_keys
  236. admin_key = envs.APHRODITE_ADMIN_KEY or app_state.args.admin_key
  237. if api_key:
  238. api_key_header = raw_request.headers.get("x-api-key")
  239. auth_header = raw_request.headers.get("Authorization")
  240. if not admin_key:
  241. return JSONResponse(
  242. content={
  243. "error": {
  244. "message": "Admin key not configured. "
  245. "Inline model loading is disabled.",
  246. "type": "invalid_request_error",
  247. "code": "admin_key_required"
  248. }
  249. },
  250. status_code=401
  251. ) # type: ignore
  252. if not (api_key_header == admin_key or
  253. auth_header == f"Bearer {admin_key}"):
  254. return JSONResponse(
  255. content={
  256. "error": {
  257. "message": "Admin privileges required for inline "
  258. "model loading.",
  259. "type": "invalid_request_error",
  260. "code": "unauthorized"
  261. }
  262. },
  263. status_code=401
  264. ) # type: ignore
  265. logger.info(f"Switching from {served_model_names[0]} to {request_model}")
  266. try:
  267. args = app_state.args
  268. current_client = engine_client(raw_request)
  269. # First shut down the current engine
  270. if not args.disable_frontend_multiprocessing:
  271. await current_client.kill()
  272. else:
  273. await current_client.shutdown_background_loop()
  274. model_is_loaded = False
  275. yaml_config = get_model_config_yaml(request_model, args.download_dir)
  276. if yaml_config:
  277. parser = FlexibleArgumentParser()
  278. parser = make_arg_parser(parser)
  279. engine_args = parser.parse_args([]) # empty args
  280. for key, value in yaml_config.items():
  281. if hasattr(engine_args, key):
  282. setattr(engine_args, key, value)
  283. engine_args.model = request_model
  284. engine_args = AsyncEngineArgs.from_cli_args(engine_args)
  285. else:
  286. # Fallback to minimal config
  287. engine_args = AsyncEngineArgs(model=request_model)
  288. # Create new engine client without context manager
  289. if (model_is_embedding(engine_args.model, engine_args.trust_remote_code,
  290. engine_args.quantization)
  291. or args.disable_frontend_multiprocessing):
  292. new_engine_client = AsyncAphrodite.from_engine_args(engine_args)
  293. await new_engine_client.setup()
  294. else:
  295. if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
  296. global prometheus_multiproc_dir
  297. prometheus_multiproc_dir = tempfile.TemporaryDirectory()
  298. os.environ[
  299. "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
  300. rpc_path = get_open_zmq_ipc_path()
  301. logger.info(
  302. f"Multiprocessing frontend to use {rpc_path} for RPC Path.")
  303. rpc_client = AsyncEngineRPCClient(rpc_path)
  304. context = multiprocessing.get_context("spawn")
  305. rpc_server_process = context.Process(
  306. target=run_rpc_server,
  307. args=(engine_args, rpc_path))
  308. rpc_server_process.start()
  309. logger.info(
  310. f"Started engine process with PID {rpc_server_process.pid}")
  311. try:
  312. while True:
  313. try:
  314. await rpc_client.setup()
  315. break
  316. except TimeoutError as e:
  317. if not rpc_server_process.is_alive():
  318. raise RuntimeError(
  319. "RPC Server died before responding to "
  320. "readiness probe") from e
  321. new_engine_client = rpc_client
  322. model_config = await new_engine_client.get_model_config()
  323. new_args = copy.deepcopy(args)
  324. new_args.model = request_model
  325. init_app_state(
  326. new_engine_client, model_config,
  327. raw_request.app.state, new_args)
  328. served_model_names = [request_model]
  329. model_is_loaded = True
  330. return None
  331. except Exception as e:
  332. # Clean up RPC resources on error
  333. rpc_server_process.terminate()
  334. rpc_client.close()
  335. rpc_server_process.join()
  336. raise e
  337. except Exception as e:
  338. error_msg = f"Error while switching models: {str(e)}"
  339. logger.error(error_msg)
  340. return JSONResponse(
  341. content={
  342. "error": {
  343. "message": error_msg,
  344. "type": "invalid_request_error",
  345. "code": "model_load_error"
  346. }
  347. },
  348. status_code=500
  349. ) # type: ignore
  350. def mount_metrics(app: FastAPI):
  351. # Lazy import for prometheus multiprocessing.
  352. # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
  353. # before prometheus_client is imported.
  354. # See https://prometheus.github.io/client_python/multiprocess/
  355. from prometheus_client import (CollectorRegistry, make_asgi_app,
  356. multiprocess)
  357. prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
  358. if prometheus_multiproc_dir_path is not None:
  359. logger.info(f"Aphrodite to use {prometheus_multiproc_dir_path} "
  360. "as PROMETHEUS_MULTIPROC_DIR")
  361. registry = CollectorRegistry()
  362. multiprocess.MultiProcessCollector(registry)
  363. # Add prometheus asgi middleware to route /metrics requests
  364. metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
  365. else:
  366. # Add prometheus asgi middleware to route /metrics requests
  367. metrics_route = Mount("/metrics", make_asgi_app())
  368. # Workaround for 307 Redirect for /metrics
  369. metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')
  370. app.routes.append(metrics_route)
  371. def chat(request: Request) -> OpenAIServingChat:
  372. return request.app.state.openai_serving_chat
  373. def completion(request: Request) -> OpenAIServingCompletion:
  374. return request.app.state.openai_serving_completion
  375. def tokenization(request: Request) -> OpenAIServingTokenization:
  376. return request.app.state.openai_serving_tokenization
  377. def embedding(request: Request) -> OpenAIServingEmbedding:
  378. return request.app.state.openai_serving_embedding
  379. def engine_client(request: Request) -> AsyncEngineClient:
  380. return request.app.state.engine_client
  381. @router.delete("/v1/model/unload")
  382. async def unload_model(raw_request: Request):
  383. """Unload the current model and shut down the server."""
  384. logger.info("Received request to unload model.")
  385. try:
  386. args = raw_request.app.state.args
  387. if not args.disable_frontend_multiprocessing:
  388. await engine_client(raw_request).kill()
  389. else:
  390. await engine_client(raw_request).shutdown_background_loop()
  391. global model_is_loaded
  392. model_is_loaded = False
  393. return JSONResponse(
  394. content={
  395. "status": "success",
  396. "message": "Model unloaded successfully"
  397. }
  398. )
  399. except Exception as e:
  400. error_msg = f"Error while unloading model: {str(e)}"
  401. logger.error(error_msg)
  402. return JSONResponse(
  403. content={"status": "error", "message": error_msg},
  404. status_code=500
  405. )
  406. @router.post("/v1/model/load")
  407. async def load_model(config_file: UploadFile, raw_request: Request):
  408. """Load a model using a YAML configuration file."""
  409. global model_is_loaded, async_engine_client, engine_args
  410. if model_is_loaded:
  411. return JSONResponse(
  412. content={
  413. "error": {
  414. "message": "A model is already loaded. "
  415. "Please unload it first.",
  416. "type": "invalid_request_error",
  417. "code": "model_already_loaded"
  418. }
  419. },
  420. status_code=400
  421. )
  422. try:
  423. config_text = await config_file.read()
  424. config: Dict[Any, Any] = yaml.safe_load(config_text)
  425. args = []
  426. for key, value in config.items():
  427. key = key.replace('_', '-')
  428. if isinstance(value, bool):
  429. if value:
  430. args.append(f"--{key}")
  431. elif isinstance(value, (list, tuple)):
  432. if key in ['lora-modules', 'prompt-adapters']:
  433. for item in value:
  434. args.append(f"--{key}")
  435. args.append(f"{item['name']}={item['path']}")
  436. else:
  437. for item in value:
  438. args.append(f"--{key}")
  439. args.append(str(item))
  440. else:
  441. args.append(f"--{key}")
  442. args.append(str(value))
  443. parser = FlexibleArgumentParser()
  444. parser = make_arg_parser(parser)
  445. parsed_args = parser.parse_args(args)
  446. engine_args = AsyncEngineArgs.from_cli_args(parsed_args)
  447. # Create new engine client without context manager
  448. if (model_is_embedding(engine_args.model,
  449. engine_args.trust_remote_code,
  450. engine_args.quantization)
  451. or parsed_args.disable_frontend_multiprocessing):
  452. new_engine_client = AsyncAphrodite.from_engine_args(engine_args)
  453. await new_engine_client.setup()
  454. else:
  455. if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
  456. global prometheus_multiproc_dir
  457. prometheus_multiproc_dir = tempfile.TemporaryDirectory()
  458. os.environ[
  459. "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
  460. rpc_path = get_open_zmq_ipc_path()
  461. logger.info(
  462. f"Multiprocessing frontend to use {rpc_path} for RPC Path.")
  463. rpc_client = AsyncEngineRPCClient(rpc_path)
  464. new_engine_client = rpc_client
  465. context = multiprocessing.get_context("spawn")
  466. rpc_server_process = context.Process(
  467. target=run_rpc_server,
  468. args=(engine_args, rpc_path))
  469. rpc_server_process.start()
  470. logger.info(
  471. f"Started engine process with PID {rpc_server_process.pid}")
  472. while True:
  473. try:
  474. await new_engine_client.setup()
  475. break
  476. except TimeoutError as e:
  477. if not rpc_server_process.is_alive():
  478. raise RuntimeError(
  479. "RPC Server died before responding to readiness "
  480. "probe") from e
  481. model_config = await engine_client(raw_request).get_model_config()
  482. init_app_state(engine_client(raw_request), model_config,
  483. raw_request.app.state, parsed_args)
  484. model_is_loaded = True
  485. return JSONResponse(
  486. content={
  487. "status": "success",
  488. "message": "Model loaded successfully"
  489. }
  490. )
  491. except Exception as e:
  492. error_msg = f"Error while loading model: {str(e)}"
  493. logger.error(error_msg)
  494. return JSONResponse(
  495. content={
  496. "error": {
  497. "message": error_msg,
  498. "type": "invalid_request_error",
  499. "code": "model_load_error"
  500. }
  501. },
  502. status_code=500
  503. )
  504. @router.get("/health")
  505. async def health(raw_request: Request) -> Response:
  506. """Health check."""
  507. await engine_client(raw_request).check_health()
  508. return Response(status_code=200)
  509. @router.post("/v1/tokenize")
  510. async def tokenize(request: TokenizeRequest, raw_request: Request):
  511. if not model_is_loaded:
  512. return JSONResponse(
  513. content={
  514. "status": "error",
  515. "message": "No model loaded."
  516. },
  517. status_code=500
  518. )
  519. generator = await tokenization(raw_request).create_tokenize(request)
  520. if isinstance(generator, ErrorResponse):
  521. return JSONResponse(content=generator.model_dump(),
  522. status_code=generator.code)
  523. else:
  524. assert isinstance(generator, TokenizeResponse)
  525. return JSONResponse(content=generator.model_dump())
  526. @router.post("/v1/detokenize")
  527. async def detokenize(request: DetokenizeRequest, raw_request: Request):
  528. if not model_is_loaded:
  529. return JSONResponse(
  530. content={
  531. "status": "error",
  532. "message": "No model loaded."
  533. },
  534. status_code=500
  535. )
  536. generator = await tokenization(raw_request).create_detokenize(request)
  537. if isinstance(generator, ErrorResponse):
  538. return JSONResponse(content=generator.model_dump(),
  539. status_code=generator.code)
  540. else:
  541. assert isinstance(generator, DetokenizeResponse)
  542. return JSONResponse(content=generator.model_dump())
  543. @router.get("/v1/models")
  544. async def show_available_models(raw_request: Request):
  545. models = await completion(raw_request).show_available_models()
  546. return JSONResponse(content=models.model_dump())
  547. @router.get("/version")
  548. async def show_version():
  549. ver = {"version": APHRODITE_VERSION}
  550. return JSONResponse(content=ver)
  551. @router.get("/.well-known/serviceinfo")
  552. async def serviceinfo():
  553. """Return service information including version, API endpoints,
  554. and documentation URLs."""
  555. return JSONResponse(content={
  556. "version": 0.2,
  557. "software": {
  558. "name": "Aphrodite Engine",
  559. "version": APHRODITE_VERSION,
  560. "repository": "https://github.com/PygmalionAI/aphrodite-engine",
  561. "homepage": "https://aphrodite.pygmalion.chat",
  562. "logo": "https://pygmalion.chat/icons/favicon.ico",
  563. },
  564. "api": {
  565. "openai": {
  566. "name": "OpenAI API",
  567. "rel_url": "/v1",
  568. "documentation": "/redoc",
  569. "version": 1,
  570. },
  571. "koboldai": {
  572. "name": "KoboldAI API",
  573. "rel_url": "/api",
  574. "documentation": "/redoc",
  575. "version": 1,
  576. }
  577. }
  578. })
  579. @router.post("/v1/chat/completions")
  580. async def create_chat_completion(request: ChatCompletionRequest,
  581. raw_request: Request):
  582. error_check_ret = await _maybe_switch_model(
  583. request.model, raw_request.app.state, raw_request)
  584. if error_check_ret is not None:
  585. return error_check_ret
  586. generator = await chat(raw_request).create_chat_completion(
  587. request, raw_request)
  588. if isinstance(generator, ErrorResponse):
  589. return JSONResponse(content=generator.model_dump(),
  590. status_code=generator.code)
  591. if request.stream:
  592. return StreamingResponse(content=generator,
  593. media_type="text/event-stream")
  594. else:
  595. assert isinstance(generator, ChatCompletionResponse)
  596. return JSONResponse(content=generator.model_dump())
  597. @router.post("/v1/completions")
  598. async def create_completion(request: CompletionRequest, raw_request: Request):
  599. error_check_ret = await _maybe_switch_model(
  600. request.model, raw_request.app.state, raw_request)
  601. if error_check_ret is not None:
  602. return error_check_ret
  603. generator = await completion(raw_request).create_completion(
  604. request, raw_request)
  605. if isinstance(generator, ErrorResponse):
  606. return JSONResponse(content=generator.model_dump(),
  607. status_code=generator.code)
  608. if request.stream:
  609. return StreamingResponse(content=generator,
  610. media_type="text/event-stream")
  611. else:
  612. return JSONResponse(content=generator.model_dump())
  613. @router.post("/v1/embeddings")
  614. async def create_embedding(request: EmbeddingRequest, raw_request: Request):
  615. if not model_is_loaded:
  616. return JSONResponse(
  617. content={
  618. "status": "error",
  619. "message": "No model loaded."
  620. },
  621. status_code=500
  622. )
  623. generator = await embedding(raw_request).create_embedding(
  624. request, raw_request)
  625. if isinstance(generator, ErrorResponse):
  626. return JSONResponse(content=generator.model_dump(),
  627. status_code=generator.code)
  628. else:
  629. return JSONResponse(content=generator.model_dump())
  630. @router.post("/v1/lora/load")
  631. async def load_lora(lora: LoRAModulePath, raw_request: Request):
  632. if not model_is_loaded:
  633. return JSONResponse(
  634. content={
  635. "status": "error",
  636. "message": "No model loaded."
  637. },
  638. status_code=500
  639. )
  640. completion(raw_request).add_lora(lora)
  641. if engine_args.enable_lora is False:
  642. logger.error("LoRA is not enabled in the engine. "
  643. "Please start the server with the "
  644. "--enable-lora flag!")
  645. return JSONResponse(content={"status": "success"})
  646. @router.delete("/v1/lora/unload")
  647. async def unload_lora(lora_name: str, raw_request: Request):
  648. if not model_is_loaded:
  649. return JSONResponse(
  650. content={
  651. "status": "error",
  652. "message": "No model loaded."
  653. },
  654. status_code=500
  655. )
  656. completion(raw_request).remove_lora(lora_name)
  657. return JSONResponse(content={"status": "success"})
  658. @router.post("/v1/soft_prompt/load")
  659. async def load_soft_prompt(soft_prompt: PromptAdapterPath,
  660. raw_request: Request):
  661. if not model_is_loaded:
  662. return JSONResponse(
  663. content={
  664. "status": "error",
  665. "message": "No model loaded."
  666. },
  667. status_code=500
  668. )
  669. completion(raw_request).add_prompt_adapter(soft_prompt)
  670. if engine_args.enable_prompt_adapter is False:
  671. logger.error("Prompt Adapter is not enabled in the engine. "
  672. "Please start the server with the "
  673. "--enable-prompt-adapter flag!")
  674. return JSONResponse(content={"status": "success"})
  675. @router.delete("/v1/soft_prompt/unload")
  676. async def unload_soft_prompt(soft_prompt_name: str, raw_request: Request):
  677. if not model_is_loaded:
  678. return JSONResponse(
  679. content={
  680. "status": "error",
  681. "message": "No model loaded."
  682. },
  683. status_code=500
  684. )
  685. completion(raw_request).remove_prompt_adapter(soft_prompt_name)
  686. return JSONResponse(content={"status": "success"})
  687. # ============ KoboldAI API ============ #
  688. badwordsids: List[int] = []
  689. def _set_badwords(tokenizer, hf_config): # pylint: disable=redefined-outer-name
  690. # pylint: disable=global-variable-undefined
  691. global badwordsids
  692. if hf_config.bad_words_ids is not None:
  693. badwordsids = hf_config.bad_words_ids
  694. return
  695. badwordsids = [
  696. v for k, v in tokenizer.get_vocab().items()
  697. if any(c in str(k) for c in "[]")
  698. ]
  699. if tokenizer.pad_token_id in badwordsids:
  700. badwordsids.remove(tokenizer.pad_token_id)
  701. badwordsids.append(tokenizer.eos_token_id)
  702. def prepare_engine_payload(
  703. kai_payload: KAIGenerationInputSchema
  704. ) -> Tuple[SamplingParams, List[int]]:
  705. """Create SamplingParams and truncated input tokens for AsyncEngine"""
  706. if not kai_payload.genkey:
  707. kai_payload.genkey = f"kai-{random_uuid()}"
  708. kai_payload.top_k = kai_payload.top_k if kai_payload.top_k != 0.0 else -1
  709. kai_payload.tfs = max(_SAMPLING_EPS, kai_payload.tfs)
  710. if kai_payload.temperature < _SAMPLING_EPS:
  711. kai_payload.n = 1
  712. kai_payload.top_p = 1.0
  713. kai_payload.top_k = -1
  714. sampling_params = SamplingParams(
  715. n=kai_payload.n,
  716. best_of=kai_payload.n,
  717. repetition_penalty=kai_payload.rep_pen,
  718. temperature=kai_payload.temperature,
  719. smoothing_factor=kai_payload.smoothing_factor,
  720. smoothing_curve=kai_payload.smoothing_curve,
  721. tfs=kai_payload.tfs,
  722. top_p=kai_payload.top_p,
  723. top_k=kai_payload.top_k,
  724. top_a=kai_payload.top_a,
  725. min_p=kai_payload.min_p,
  726. typical_p=kai_payload.typical,
  727. eta_cutoff=kai_payload.eta_cutoff,
  728. epsilon_cutoff=kai_payload.eps_cutoff,
  729. stop=kai_payload.stop_sequence,
  730. include_stop_str_in_output=kai_payload.include_stop_str_in_output,
  731. custom_token_bans=badwordsids
  732. if kai_payload.use_default_badwordsids else [],
  733. max_tokens=kai_payload.max_length,
  734. seed=kai_payload.sampler_seed,
  735. xtc_probability=kai_payload.xtc_probability,
  736. xtc_threshold=kai_payload.xtc_threshold,
  737. )
  738. max_input_tokens = max(
  739. 1, kai_payload.max_context_length - kai_payload.max_length)
  740. input_tokens = tokenizer(kai_payload.prompt).input_ids[-max_input_tokens:]
  741. return sampling_params, input_tokens
  742. @kai_api.post("/generate")
  743. async def generate(kai_payload: KAIGenerationInputSchema,
  744. raw_request: Request) -> JSONResponse:
  745. sampling_params, input_tokens = prepare_engine_payload(kai_payload)
  746. result_generator = engine_client(raw_request).generate(
  747. {
  748. "prompt": kai_payload.prompt,
  749. "prompt_token_ids": input_tokens,
  750. },
  751. sampling_params,
  752. kai_payload.genkey,
  753. )
  754. final_res: RequestOutput = None
  755. previous_output = ""
  756. async for res in result_generator:
  757. final_res = res
  758. new_chunk = res.outputs[0].text[len(previous_output):]
  759. previous_output += new_chunk
  760. gen_cache[kai_payload.genkey] = previous_output
  761. assert final_res is not None
  762. del gen_cache[kai_payload.genkey]
  763. return JSONResponse(
  764. {"results": [{
  765. "text": output.text
  766. } for output in final_res.outputs]})
  767. @extra_api.post("/generate/stream")
  768. async def generate_stream(
  769. kai_payload: KAIGenerationInputSchema,
  770. raw_request: Request) -> StreamingResponse:
  771. sampling_params, input_tokens = prepare_engine_payload(kai_payload)
  772. results_generator = engine_client(raw_request).generate(
  773. {
  774. "prompt": kai_payload.prompt,
  775. "prompt_token_ids": input_tokens,
  776. },
  777. sampling_params,
  778. kai_payload.genkey,
  779. )
  780. async def stream_kobold() -> AsyncGenerator[bytes, None]:
  781. previous_output = ""
  782. async for res in results_generator:
  783. new_chunk = res.outputs[0].text[len(previous_output):]
  784. previous_output += new_chunk
  785. yield b"event: message\n"
  786. yield f"data: {json.dumps({'token': new_chunk})}\n\n".encode()
  787. return StreamingResponse(stream_kobold(),
  788. headers={
  789. "Cache-Control": "no-cache",
  790. "Connection": "keep-alive",
  791. },
  792. media_type="text/event-stream")
  793. @extra_api.post("/generate/check")
  794. @extra_api.get("/generate/check")
  795. async def check_generation(request: Request):
  796. text = ""
  797. try:
  798. request_dict = await request.json()
  799. if "genkey" in request_dict and request_dict["genkey"] in gen_cache:
  800. text = gen_cache[request_dict["genkey"]]
  801. except json.JSONDecodeError:
  802. pass
  803. return JSONResponse({"results": [{"text": text}]})
  804. @extra_api.post("/abort")
  805. async def abort_generation(raw_request: Request):
  806. try:
  807. request_dict = await raw_request.json()
  808. if "genkey" in request_dict:
  809. await engine_client(raw_request).abort(request_dict["genkey"])
  810. except json.JSONDecodeError:
  811. pass
  812. return JSONResponse({})
  813. @extra_api.post("/tokencount")
  814. async def count_tokens(request: TokenizeRequest, raw_request: Request):
  815. """Tokenize string and return token count"""
  816. generator = await tokenization(raw_request).create_tokenize(request)
  817. return JSONResponse({"value": generator.model_dump()["tokens"]})
  818. @kai_api.get("/info/version")
  819. async def get_version():
  820. """Impersonate KAI"""
  821. return JSONResponse({"result": "1.2.4"})
  822. @kai_api.get("/model")
  823. async def get_model():
  824. return JSONResponse({"result": f"aphrodite/{served_model_names[0]}"})
  825. @kai_api.get("/config/soft_prompts_list")
  826. async def get_available_softprompts():
  827. """Stub for compatibility"""
  828. return JSONResponse({"values": []})
  829. @kai_api.get("/config/soft_prompt")
  830. async def get_current_softprompt():
  831. """Stub for compatibility"""
  832. return JSONResponse({"value": ""})
  833. @kai_api.put("/config/soft_prompt")
  834. async def set_current_softprompt():
  835. """Stub for compatibility"""
  836. return JSONResponse({})
  837. @kai_api.get("/config/max_length")
  838. async def get_max_length() -> JSONResponse:
  839. max_length = args.max_length
  840. return JSONResponse({"value": max_length})
  841. @kai_api.get("/config/max_context_length")
  842. @extra_api.get("/true_max_context_length")
  843. async def get_max_context_length() -> JSONResponse:
  844. max_context_length = engine_args.max_model_len
  845. return JSONResponse({"value": max_context_length})
  846. @extra_api.get("/preloadstory")
  847. async def get_preloaded_story() -> JSONResponse:
  848. """Stub for compatibility"""
  849. return JSONResponse({})
  850. @extra_api.get("/version")
  851. async def get_extra_version():
  852. """Impersonate KoboldCpp"""
  853. return JSONResponse({"result": "KoboldCpp", "version": "1.63"})
  854. @router.get("/")
  855. async def get_kobold_lite_ui():
  856. """Serves a cached copy of the Kobold Lite UI, loading it from disk
  857. on demand if needed. Can be disabled with SERVE_KOBOLD_LITE_UI=0."""
  858. if not SERVE_KOBOLD_LITE_UI:
  859. return JSONResponse(content={"error": "Kobold Lite UI is disabled"},
  860. status_code=404)
  861. global kobold_lite_ui
  862. if kobold_lite_ui == "":
  863. scriptpath = os.path.dirname(os.path.abspath(__file__))
  864. klitepath = os.path.join(scriptpath, "../kobold/klite.embd")
  865. klitepath = os.path.normpath(klitepath) # Normalize the path
  866. if os.path.exists(klitepath):
  867. with open(klitepath, "r", encoding="utf-8") as f:
  868. kobold_lite_ui = f.read()
  869. else:
  870. logger.error("Kobold Lite UI not found at " + klitepath)
  871. return HTMLResponse(content=kobold_lite_ui)
  872. # ============ KoboldAI API ============ #
  873. def build_app(args: Namespace) -> FastAPI:
  874. app = FastAPI(lifespan=lifespan)
  875. app.include_router(router)
  876. app.root_path = args.root_path
  877. app.state.args = args
  878. if args.launch_kobold_api:
  879. logger.warning("Kobold API is now enabled by default. "
  880. "This flag will be removed in the future.")
  881. app.include_router(kai_api, prefix="/api/v1")
  882. app.include_router(kai_api,
  883. prefix="/api/latest",
  884. include_in_schema=False)
  885. app.include_router(extra_api, prefix="/api/extra")
  886. mount_metrics(app)
  887. app.add_middleware(
  888. CORSMiddleware,
  889. allow_origins=args.allowed_origins,
  890. allow_credentials=args.allow_credentials,
  891. allow_methods=args.allowed_methods,
  892. allow_headers=args.allowed_headers,
  893. )
  894. @app.exception_handler(RequestValidationError)
  895. async def validation_exception_handler(_, exc):
  896. chat = app.state.openai_serving_chat
  897. err = chat.create_error_response(message=str(exc))
  898. return JSONResponse(err.model_dump(),
  899. status_code=HTTPStatus.BAD_REQUEST)
  900. if token := envs.APHRODITE_API_KEY or args.api_keys:
  901. admin_key = os.environ.get("APHRODITE_ADMIN_KEY") or args.admin_key
  902. if admin_key is None:
  903. logger.warning("Admin key not provided. Admin operations will "
  904. "be disabled.")
  905. @app.middleware("http")
  906. async def authentication(request: Request, call_next):
  907. if not request.url.path.startswith(("/v1", "/api")):
  908. return await call_next(request)
  909. # Browsers may send OPTIONS requests to check CORS headers
  910. # before sending the actual request. We should allow these
  911. # requests to pass through without authentication.
  912. # See https://github.com/PygmalionAI/aphrodite-engine/issues/434
  913. if request.method == "OPTIONS":
  914. return await call_next(request)
  915. auth_header = request.headers.get("Authorization")
  916. api_key_header = request.headers.get("x-api-key")
  917. if request.url.path.startswith(
  918. ("/v1/lora", "/v1/soft_prompt", "/v1/model")):
  919. if admin_key is not None and (
  920. api_key_header == admin_key or
  921. auth_header == "Bearer " + admin_key
  922. ):
  923. return await call_next(request)
  924. return JSONResponse(content={"error": "Unauthorized"},
  925. status_code=401)
  926. if (auth_header == f"Bearer {token}" or api_key_header == token or
  927. (admin_key is not None and
  928. (api_key_header == admin_key or
  929. auth_header == f"Bearer {admin_key}"))):
  930. return await call_next(request)
  931. return JSONResponse(
  932. content={"error": "Unauthorized"}, status_code=401)
  933. for middleware in args.middleware:
  934. module_path, object_name = middleware.rsplit(".", 1)
  935. imported = getattr(importlib.import_module(module_path), object_name)
  936. if inspect.isclass(imported):
  937. app.add_middleware(imported)
  938. elif inspect.iscoroutinefunction(imported):
  939. app.middleware("http")(imported)
  940. else:
  941. raise ValueError(f"Invalid middleware {middleware}. "
  942. f"Must be a function or a class.")
  943. return app
  944. def init_app_state(
  945. async_engine_client: AsyncEngineClient,
  946. model_config: ModelConfig,
  947. state: State,
  948. args: Namespace,
  949. ) -> None:
  950. global api_server_args
  951. api_server_args = args
  952. logger.debug(f"args: {args}")
  953. global served_model_names
  954. if args.served_model_name is not None:
  955. served_model_names = args.served_model_name
  956. else:
  957. served_model_names = [args.model]
  958. if args.uvloop:
  959. uvloop.install()
  960. global tokenizer
  961. if args.disable_log_requests:
  962. request_logger = None
  963. else:
  964. request_logger = RequestLogger(max_log_len=args.max_log_len)
  965. state.engine_client = async_engine_client
  966. state.log_stats = not args.disable_log_stats
  967. state.openai_serving_chat = OpenAIServingChat(
  968. async_engine_client,
  969. model_config,
  970. served_model_names,
  971. args.response_role,
  972. lora_modules=args.lora_modules,
  973. prompt_adapters=args.prompt_adapters,
  974. request_logger=request_logger,
  975. chat_template=args.chat_template,
  976. return_tokens_as_token_ids=args.return_tokens_as_token_ids,
  977. enable_auto_tools=args.enable_auto_tool_choice,
  978. tool_parser=args.tool_call_parser
  979. )
  980. state.openai_serving_completion = OpenAIServingCompletion(
  981. async_engine_client,
  982. model_config,
  983. served_model_names,
  984. lora_modules=args.lora_modules,
  985. prompt_adapters=args.prompt_adapters,
  986. request_logger=request_logger,
  987. return_tokens_as_token_ids=args.return_tokens_as_token_ids,
  988. )
  989. state.openai_serving_embedding = OpenAIServingEmbedding(
  990. async_engine_client,
  991. model_config,
  992. served_model_names,
  993. request_logger=request_logger,
  994. )
  995. state.openai_serving_tokenization = OpenAIServingTokenization(
  996. async_engine_client,
  997. model_config,
  998. served_model_names,
  999. lora_modules=args.lora_modules,
  1000. request_logger=request_logger,
  1001. chat_template=args.chat_template,
  1002. )
  1003. tokenizer = get_tokenizer(
  1004. tokenizer_name=args.tokenizer if args.tokenizer else args.model,
  1005. tokenizer_mode=args.tokenizer_mode,
  1006. trust_remote_code=args.trust_remote_code,
  1007. revision=args.revision,
  1008. )
  1009. if args.launch_kobold_api:
  1010. _set_badwords(tokenizer, model_config.hf_config)
  1011. async def run_server(args, **uvicorn_kwargs) -> None:
  1012. def signal_handler(*_) -> None:
  1013. # Interrupt server on sigterm while initializing
  1014. raise KeyboardInterrupt("terminated")
  1015. signal.signal(signal.SIGTERM, signal_handler)
  1016. async with build_async_engine_client(args) as async_engine_client:
  1017. # If None, creation of the client failed and we exit.
  1018. if async_engine_client is None:
  1019. return
  1020. app = build_app(args)
  1021. model_config = await async_engine_client.get_model_config()
  1022. init_app_state(async_engine_client, model_config, app.state, args)
  1023. protocol = "https" if args.ssl_certfile else "http"
  1024. root_path = args.root_path.rstrip("/") if args.root_path else ""
  1025. host_name = args.host if args.host else "localhost"
  1026. port_str = str(args.port)
  1027. if SERVE_KOBOLD_LITE_UI:
  1028. ui_url = f"{protocol}://{host_name}:{port_str}{root_path}/"
  1029. logger.info(f"Kobold Lite UI: {ui_url}")
  1030. logger.info(f"Documentation: {protocol}://{host_name}:{port_str}{root_path}/redoc") # noqa: E501
  1031. logger.info(f"Completions API: {protocol}://{host_name}:{port_str}{root_path}/v1/completions") # noqa: E501
  1032. logger.info(f"Chat API: {protocol}://{host_name}:{port_str}{root_path}/v1/chat/completions") # noqa: E501
  1033. logger.info(f"Embeddings API: {protocol}://{host_name}:{port_str}{root_path}/v1/embeddings") # noqa: E501
  1034. logger.info(f"Tokenization API: {protocol}://{host_name}:{port_str}{root_path}/v1/tokenize") # noqa: E501
  1035. shutdown_task = await serve_http(
  1036. app,
  1037. limit_concurrency=async_engine_client.limit_concurrency,
  1038. host=args.host,
  1039. port=args.port,
  1040. log_level=args.uvicorn_log_level,
  1041. timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
  1042. ssl_keyfile=args.ssl_keyfile,
  1043. ssl_certfile=args.ssl_certfile,
  1044. ssl_ca_certs=args.ssl_ca_certs,
  1045. ssl_cert_reqs=args.ssl_cert_reqs,
  1046. **uvicorn_kwargs,
  1047. )
  1048. # NB: Await server shutdown only after the backend context is exited
  1049. await shutdown_task
  1050. if __name__ == "__main__":
  1051. # NOTE:
  1052. # This section should be in sync with aphrodite/endpoints/cli.py
  1053. # for CLI entrypoints.
  1054. parser = FlexibleArgumentParser(
  1055. description="Aphrodite OpenAI-Compatible RESTful API Server")
  1056. parser = make_arg_parser(parser)
  1057. args = parser.parse_args()
  1058. uvloop.run(run_server(args))