api_server.py 36 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036
  1. import asyncio
  2. import importlib
  3. import inspect
  4. import json
  5. import multiprocessing
  6. import os
  7. import re
  8. import tempfile
  9. from argparse import Namespace
  10. from contextlib import asynccontextmanager
  11. from distutils.util import strtobool
  12. from http import HTTPStatus
  13. from typing import (Any, AsyncGenerator, AsyncIterator, Dict, List, Optional,
  14. Set, Tuple)
  15. import yaml
  16. from fastapi import APIRouter, FastAPI, Request, UploadFile
  17. from fastapi.exceptions import RequestValidationError
  18. from fastapi.middleware.cors import CORSMiddleware
  19. from fastapi.responses import (HTMLResponse, JSONResponse, Response,
  20. StreamingResponse)
  21. from loguru import logger
  22. from starlette.routing import Mount
  23. import aphrodite.common.envs as envs
  24. from aphrodite.common.config import ModelConfig
  25. from aphrodite.common.outputs import RequestOutput
  26. from aphrodite.common.sampling_params import _SAMPLING_EPS, SamplingParams
  27. from aphrodite.common.utils import (FlexibleArgumentParser,
  28. get_open_zmq_ipc_path, in_windows,
  29. random_uuid)
  30. from aphrodite.endpoints.logger import RequestLogger
  31. from aphrodite.endpoints.openai.args import make_arg_parser
  32. # yapf: disable
  33. from aphrodite.endpoints.openai.protocol import (ChatCompletionRequest,
  34. ChatCompletionResponse,
  35. CompletionRequest,
  36. DetokenizeRequest,
  37. DetokenizeResponse,
  38. EmbeddingRequest,
  39. ErrorResponse,
  40. KAIGenerationInputSchema,
  41. TokenizeRequest,
  42. TokenizeResponse)
  43. from aphrodite.endpoints.openai.rpc.client import AsyncEngineRPCClient
  44. from aphrodite.endpoints.openai.rpc.server import run_rpc_server
  45. # yapf: enable
  46. from aphrodite.endpoints.openai.serving_chat import OpenAIServingChat
  47. from aphrodite.endpoints.openai.serving_completions import (
  48. OpenAIServingCompletion)
  49. from aphrodite.endpoints.openai.serving_embedding import OpenAIServingEmbedding
  50. from aphrodite.endpoints.openai.serving_engine import (LoRAModulePath,
  51. PromptAdapterPath)
  52. from aphrodite.endpoints.openai.serving_tokenization import (
  53. OpenAIServingTokenization)
  54. from aphrodite.engine.args_tools import AsyncEngineArgs
  55. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  56. from aphrodite.engine.protocol import AsyncEngineClient
  57. from aphrodite.server import serve_http
  58. from aphrodite.transformers_utils.tokenizer import get_tokenizer
  59. from aphrodite.version import __version__ as APHRODITE_VERSION
  60. if in_windows():
  61. import winloop as uvloop
  62. else:
  63. import uvloop
  64. TIMEOUT_KEEP_ALIVE = 5 # seconds
  65. SERVE_KOBOLD_LITE_UI = strtobool(os.getenv("SERVE_KOBOLD_LITE_UI", "1"))
  66. async_engine_client: AsyncEngineClient
  67. engine_args: AsyncEngineArgs
  68. openai_serving_chat: OpenAIServingChat
  69. openai_serving_completion: OpenAIServingCompletion
  70. openai_serving_embedding: OpenAIServingEmbedding
  71. openai_serving_tokenization: OpenAIServingTokenization
  72. router = APIRouter()
  73. kai_api = APIRouter()
  74. extra_api = APIRouter()
  75. kobold_lite_ui = ""
  76. sampler_json = ""
  77. gen_cache: dict = {}
  78. prometheus_multiproc_dir: tempfile.TemporaryDirectory
  79. model_is_loaded = True
  80. _running_tasks: Set[asyncio.Task] = set()
  81. def model_is_embedding(model_name: str, trust_remote_code: bool) -> bool:
  82. return ModelConfig(model=model_name,
  83. tokenizer=model_name,
  84. tokenizer_mode="auto",
  85. trust_remote_code=trust_remote_code,
  86. seed=0,
  87. dtype="auto").embedding_mode
  88. @asynccontextmanager
  89. async def lifespan(app: FastAPI):
  90. async def _force_log():
  91. while True:
  92. await asyncio.sleep(10)
  93. await async_engine_client.do_log_stats()
  94. if not engine_args.disable_log_stats:
  95. task = asyncio.create_task(_force_log())
  96. _running_tasks.add(task)
  97. task.add_done_callback(_running_tasks.remove)
  98. yield
  99. @asynccontextmanager
  100. async def build_async_engine_client(
  101. args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]:
  102. """
  103. Create AsyncEngineClient, either:
  104. - in-process using the AsyncAphrodite Directly
  105. - multiprocess using AsyncAphrodite RPC
  106. Returns the Client or None if the creation failed.
  107. """
  108. # Context manager to handle async_engine_client lifecycle
  109. # Ensures everything is shutdown and cleaned up on error/exit
  110. global engine_args
  111. engine_args = AsyncEngineArgs.from_cli_args(args)
  112. # Backend itself still global for the silly lil' health handler
  113. global async_engine_client
  114. # If manually triggered or embedding model, use AsyncAphrodite in process.
  115. # TODO: support embedding model via RPC.
  116. if (model_is_embedding(args.model, args.trust_remote_code)
  117. or args.disable_frontend_multiprocessing):
  118. async_engine_client = AsyncAphrodite.from_engine_args(engine_args)
  119. yield async_engine_client
  120. return
  121. # Otherwise, use the multiprocessing AsyncAphrodite.
  122. else:
  123. if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
  124. # Make TemporaryDirectory for prometheus multiprocessing
  125. # Note: global TemporaryDirectory will be automatically
  126. # cleaned up upon exit.
  127. global prometheus_multiproc_dir
  128. prometheus_multiproc_dir = tempfile.TemporaryDirectory()
  129. os.environ[
  130. "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
  131. else:
  132. logger.warning(
  133. "Found PROMETHEUS_MULTIPROC_DIR was set by user. "
  134. "This directory must be wiped between Aphrodite runs or "
  135. "you will find inaccurate metrics. Unset the variable "
  136. "and Aphrodite will properly handle cleanup.")
  137. # Select random path for IPC.
  138. rpc_path = get_open_zmq_ipc_path()
  139. logger.info(f"Multiprocessing frontend to use {rpc_path} for RPC Path."
  140. )
  141. # Build RPCClient, which conforms to AsyncEngineClient Protocol.
  142. # NOTE: Actually, this is not true yet. We still need to support
  143. # embedding models via RPC (see TODO above)
  144. rpc_client = AsyncEngineRPCClient(rpc_path)
  145. async_engine_client = rpc_client # type: ignore
  146. # Start RPCServer in separate process (holds the AsyncAphrodite).
  147. context = multiprocessing.get_context("spawn")
  148. # the current process might have CUDA context,
  149. # so we need to spawn a new process
  150. rpc_server_process = context.Process(
  151. target=run_rpc_server,
  152. args=(engine_args, rpc_path))
  153. rpc_server_process.start()
  154. logger.info(
  155. f"Started engine process with PID {rpc_server_process.pid}")
  156. try:
  157. while True:
  158. try:
  159. await async_engine_client.setup()
  160. break
  161. except TimeoutError:
  162. if not rpc_server_process.is_alive():
  163. logger.error(
  164. "RPCServer process died before responding "
  165. "to readiness probe")
  166. yield None
  167. return
  168. yield async_engine_client
  169. finally:
  170. # Ensure rpc server process was terminated
  171. rpc_server_process.terminate()
  172. # Close all open connections to the backend
  173. async_engine_client.close()
  174. # Wait for server process to join
  175. rpc_server_process.join()
  176. # Lazy import for prometheus multiprocessing.
  177. # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
  178. # before prometheus_client is imported.
  179. # See https://prometheus.github.io/client_python/multiprocess/
  180. from prometheus_client import multiprocess
  181. multiprocess.mark_process_dead(rpc_server_process.pid)
  182. def mount_metrics(app: FastAPI):
  183. # Lazy import for prometheus multiprocessing.
  184. # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
  185. # before prometheus_client is imported.
  186. # See https://prometheus.github.io/client_python/multiprocess/
  187. from prometheus_client import (CollectorRegistry, make_asgi_app,
  188. multiprocess)
  189. prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
  190. if prometheus_multiproc_dir_path is not None:
  191. logger.info(f"Aphrodite to use {prometheus_multiproc_dir_path} "
  192. "as PROMETHEUS_MULTIPROC_DIR")
  193. registry = CollectorRegistry()
  194. multiprocess.MultiProcessCollector(registry)
  195. # Add prometheus asgi middleware to route /metrics requests
  196. metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
  197. else:
  198. # Add prometheus asgi middleware to route /metrics requests
  199. metrics_route = Mount("/metrics", make_asgi_app())
  200. # Workaround for 307 Redirect for /metrics
  201. metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')
  202. app.routes.append(metrics_route)
  203. @router.delete("/v1/model/unload")
  204. async def unload_model(request: Request):
  205. """Unload the current model and shut down the server."""
  206. logger.info("Received request to unload model.")
  207. try:
  208. args = request.app.state.args
  209. if not args.disable_frontend_multiprocessing:
  210. await async_engine_client.kill()
  211. else:
  212. await async_engine_client.shutdown_background_loop()
  213. global model_is_loaded
  214. model_is_loaded = False
  215. return JSONResponse(
  216. content={
  217. "status": "success",
  218. "message": "Model unloaded successfully"
  219. }
  220. )
  221. except Exception as e:
  222. error_msg = f"Error while unloading model: {str(e)}"
  223. logger.error(error_msg)
  224. return JSONResponse(
  225. content={"status": "error", "message": error_msg},
  226. status_code=500
  227. )
  228. @router.post("/v1/model/load")
  229. async def load_model(config_file: UploadFile):
  230. """Load a model using a YAML configuration file."""
  231. global model_is_loaded, async_engine_client, engine_args
  232. if model_is_loaded:
  233. return JSONResponse(
  234. content={
  235. "error": {
  236. "message": "A model is already loaded. "
  237. "Please unload it first.",
  238. "type": "invalid_request_error",
  239. "code": "model_already_loaded"
  240. }
  241. },
  242. status_code=400
  243. )
  244. try:
  245. # basically the same logic as the one in aphrodite.endpoints.cli
  246. config_text = await config_file.read()
  247. config: Dict[Any, Any] = yaml.safe_load(config_text)
  248. args = []
  249. for key, value in config.items():
  250. key = key.replace('_', '-')
  251. if isinstance(value, bool):
  252. if value:
  253. args.append(f"--{key}")
  254. elif isinstance(value, (list, tuple)):
  255. if key in ['lora-modules', 'prompt-adapters']:
  256. for item in value:
  257. args.append(f"--{key}")
  258. args.append(f"{item['name']}={item['path']}")
  259. else:
  260. for item in value:
  261. args.append(f"--{key}")
  262. args.append(str(item))
  263. else:
  264. args.append(f"--{key}")
  265. args.append(str(value))
  266. parser = FlexibleArgumentParser()
  267. parser = make_arg_parser(parser)
  268. parsed_args = parser.parse_args(args)
  269. if (model_is_embedding(parsed_args.model, parsed_args.trust_remote_code)
  270. or parsed_args.disable_frontend_multiprocessing):
  271. async_engine_client = AsyncAphrodite.from_engine_args(engine_args)
  272. await async_engine_client.setup()
  273. else:
  274. if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
  275. global prometheus_multiproc_dir
  276. prometheus_multiproc_dir = tempfile.TemporaryDirectory()
  277. os.environ[
  278. "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
  279. rpc_path = get_open_zmq_ipc_path()
  280. logger.info(
  281. f"Multiprocessing frontend to use {rpc_path} for RPC Path.")
  282. rpc_client = AsyncEngineRPCClient(rpc_path)
  283. async_engine_client = rpc_client
  284. context = multiprocessing.get_context("spawn")
  285. rpc_server_process = context.Process(
  286. target=run_rpc_server,
  287. args=(engine_args, rpc_path))
  288. rpc_server_process.start()
  289. logger.info(
  290. f"Started engine process with PID {rpc_server_process.pid}")
  291. while True:
  292. try:
  293. await async_engine_client.setup()
  294. break
  295. except TimeoutError as e:
  296. if not rpc_server_process.is_alive():
  297. raise RuntimeError(
  298. "RPC Server died before responding to readiness "
  299. "probe") from e
  300. app = await init_app(async_engine_client, parsed_args) # noqa: F841
  301. model_is_loaded = True
  302. return JSONResponse(
  303. content={
  304. "status": "success",
  305. "message": "Model loaded successfully"
  306. }
  307. )
  308. except Exception as e:
  309. error_msg = f"Error while loading model: {str(e)}"
  310. logger.error(error_msg)
  311. return JSONResponse(
  312. content={
  313. "error": {
  314. "message": error_msg,
  315. "type": "invalid_request_error",
  316. "code": "model_load_error"
  317. }
  318. },
  319. status_code=500
  320. )
  321. @router.get("/health")
  322. async def health() -> Response:
  323. """Health check."""
  324. await async_engine_client.check_health()
  325. return Response(status_code=200)
  326. @router.post("/v1/tokenize")
  327. async def tokenize(request: TokenizeRequest):
  328. if not model_is_loaded:
  329. return JSONResponse(
  330. content={
  331. "status": "error",
  332. "message": "No model loaded."
  333. },
  334. status_code=500
  335. )
  336. generator = await openai_serving_tokenization.create_tokenize(request)
  337. if isinstance(generator, ErrorResponse):
  338. return JSONResponse(content=generator.model_dump(),
  339. status_code=generator.code)
  340. else:
  341. assert isinstance(generator, TokenizeResponse)
  342. return JSONResponse(content=generator.model_dump())
  343. @router.post("/v1/detokenize")
  344. async def detokenize(request: DetokenizeRequest):
  345. if not model_is_loaded:
  346. return JSONResponse(
  347. content={
  348. "status": "error",
  349. "message": "No model loaded."
  350. },
  351. status_code=500
  352. )
  353. generator = await openai_serving_tokenization.create_detokenize(request)
  354. if isinstance(generator, ErrorResponse):
  355. return JSONResponse(content=generator.model_dump(),
  356. status_code=generator.code)
  357. else:
  358. assert isinstance(generator, DetokenizeResponse)
  359. return JSONResponse(content=generator.model_dump())
  360. @router.get("/v1/models")
  361. async def show_available_models():
  362. models = await openai_serving_completion.show_available_models()
  363. return JSONResponse(content=models.model_dump())
  364. @router.get("/version")
  365. async def show_version():
  366. ver = {"version": APHRODITE_VERSION}
  367. return JSONResponse(content=ver)
  368. @router.get("/.well-known/serviceinfo")
  369. async def serviceinfo():
  370. """Return service information including version, API endpoints,
  371. and documentation URLs."""
  372. return JSONResponse(content={
  373. "version": 0.2,
  374. "software": {
  375. "name": "Aphrodite Engine",
  376. "version": APHRODITE_VERSION,
  377. "repository": "https://github.com/PygmalionAI/aphrodite-engine",
  378. "homepage": "https://aphrodite.pygmalion.chat",
  379. "logo": "https://pygmalion.chat/icons/favicon.ico",
  380. },
  381. "api": {
  382. "openai": {
  383. "name": "OpenAI API",
  384. "rel_url": "/v1",
  385. "documentation": "/redoc",
  386. "version": 1,
  387. },
  388. "koboldai": {
  389. "name": "KoboldAI API",
  390. "rel_url": "/api",
  391. "documentation": "/redoc",
  392. "version": 1,
  393. }
  394. }
  395. })
  396. @router.post("/v1/chat/completions")
  397. async def create_chat_completion(request: ChatCompletionRequest,
  398. raw_request: Request):
  399. if not model_is_loaded:
  400. return JSONResponse(
  401. content={
  402. "status": "error",
  403. "message": "No model loaded."
  404. },
  405. status_code=500
  406. )
  407. generator = await openai_serving_chat.create_chat_completion(
  408. request, raw_request)
  409. if isinstance(generator, ErrorResponse):
  410. return JSONResponse(content=generator.model_dump(),
  411. status_code=generator.code)
  412. if request.stream:
  413. return StreamingResponse(content=generator,
  414. media_type="text/event-stream")
  415. else:
  416. assert isinstance(generator, ChatCompletionResponse)
  417. return JSONResponse(content=generator.model_dump())
  418. @router.post("/v1/completions")
  419. async def create_completion(request: CompletionRequest, raw_request: Request):
  420. if not model_is_loaded:
  421. return JSONResponse(
  422. content={
  423. "status": "error",
  424. "message": "No model loaded."
  425. },
  426. status_code=500
  427. )
  428. generator = await openai_serving_completion.create_completion(
  429. request, raw_request)
  430. if isinstance(generator, ErrorResponse):
  431. return JSONResponse(content=generator.model_dump(),
  432. status_code=generator.code)
  433. if request.stream:
  434. return StreamingResponse(content=generator,
  435. media_type="text/event-stream")
  436. else:
  437. return JSONResponse(content=generator.model_dump())
  438. @router.post("/v1/embeddings")
  439. async def create_embedding(request: EmbeddingRequest, raw_request: Request):
  440. if not model_is_loaded:
  441. return JSONResponse(
  442. content={
  443. "status": "error",
  444. "message": "No model loaded."
  445. },
  446. status_code=500
  447. )
  448. generator = await openai_serving_embedding.create_embedding(
  449. request, raw_request)
  450. if isinstance(generator, ErrorResponse):
  451. return JSONResponse(content=generator.model_dump(),
  452. status_code=generator.code)
  453. else:
  454. return JSONResponse(content=generator.model_dump())
  455. @router.post("/v1/lora/load")
  456. async def load_lora(lora: LoRAModulePath):
  457. if not model_is_loaded:
  458. return JSONResponse(
  459. content={
  460. "status": "error",
  461. "message": "No model loaded."
  462. },
  463. status_code=500
  464. )
  465. openai_serving_completion.add_lora(lora)
  466. if engine_args.enable_lora is False:
  467. logger.error("LoRA is not enabled in the engine. "
  468. "Please start the server with the "
  469. "--enable-lora flag!")
  470. return JSONResponse(content={"status": "success"})
  471. @router.delete("/v1/lora/unload")
  472. async def unload_lora(lora_name: str):
  473. if not model_is_loaded:
  474. return JSONResponse(
  475. content={
  476. "status": "error",
  477. "message": "No model loaded."
  478. },
  479. status_code=500
  480. )
  481. openai_serving_completion.remove_lora(lora_name)
  482. return JSONResponse(content={"status": "success"})
  483. @router.post("/v1/soft_prompt/load")
  484. async def load_soft_prompt(soft_prompt: PromptAdapterPath):
  485. if not model_is_loaded:
  486. return JSONResponse(
  487. content={
  488. "status": "error",
  489. "message": "No model loaded."
  490. },
  491. status_code=500
  492. )
  493. openai_serving_completion.add_prompt_adapter(soft_prompt)
  494. if engine_args.enable_prompt_adapter is False:
  495. logger.error("Prompt Adapter is not enabled in the engine. "
  496. "Please start the server with the "
  497. "--enable-prompt-adapter flag!")
  498. return JSONResponse(content={"status": "success"})
  499. @router.delete("/v1/soft_prompt/unload")
  500. async def unload_soft_prompt(soft_prompt_name: str):
  501. if not model_is_loaded:
  502. return JSONResponse(
  503. content={
  504. "status": "error",
  505. "message": "No model loaded."
  506. },
  507. status_code=500
  508. )
  509. openai_serving_completion.remove_prompt_adapter(soft_prompt_name)
  510. return JSONResponse(content={"status": "success"})
  511. # ============ KoboldAI API ============ #
  512. badwordsids: List[int] = []
  513. def _set_badwords(tokenizer, hf_config): # pylint: disable=redefined-outer-name
  514. # pylint: disable=global-variable-undefined
  515. global badwordsids
  516. if hf_config.bad_words_ids is not None:
  517. badwordsids = hf_config.bad_words_ids
  518. return
  519. badwordsids = [
  520. v for k, v in tokenizer.get_vocab().items()
  521. if any(c in str(k) for c in "[]")
  522. ]
  523. if tokenizer.pad_token_id in badwordsids:
  524. badwordsids.remove(tokenizer.pad_token_id)
  525. badwordsids.append(tokenizer.eos_token_id)
  526. def prepare_engine_payload(
  527. kai_payload: KAIGenerationInputSchema
  528. ) -> Tuple[SamplingParams, List[int]]:
  529. """Create SamplingParams and truncated input tokens for AsyncEngine"""
  530. if not kai_payload.genkey:
  531. kai_payload.genkey = f"kai-{random_uuid()}"
  532. kai_payload.top_k = kai_payload.top_k if kai_payload.top_k != 0.0 else -1
  533. kai_payload.tfs = max(_SAMPLING_EPS, kai_payload.tfs)
  534. if kai_payload.temperature < _SAMPLING_EPS:
  535. kai_payload.n = 1
  536. kai_payload.top_p = 1.0
  537. kai_payload.top_k = -1
  538. sampling_params = SamplingParams(
  539. n=kai_payload.n,
  540. best_of=kai_payload.n,
  541. repetition_penalty=kai_payload.rep_pen,
  542. temperature=kai_payload.temperature,
  543. smoothing_factor=kai_payload.smoothing_factor,
  544. smoothing_curve=kai_payload.smoothing_curve,
  545. tfs=kai_payload.tfs,
  546. top_p=kai_payload.top_p,
  547. top_k=kai_payload.top_k,
  548. top_a=kai_payload.top_a,
  549. min_p=kai_payload.min_p,
  550. typical_p=kai_payload.typical,
  551. eta_cutoff=kai_payload.eta_cutoff,
  552. epsilon_cutoff=kai_payload.eps_cutoff,
  553. stop=kai_payload.stop_sequence,
  554. include_stop_str_in_output=kai_payload.include_stop_str_in_output,
  555. custom_token_bans=badwordsids
  556. if kai_payload.use_default_badwordsids else [],
  557. max_tokens=kai_payload.max_length,
  558. seed=kai_payload.sampler_seed,
  559. xtc_probability=kai_payload.xtc_probability,
  560. xtc_threshold=kai_payload.xtc_threshold,
  561. )
  562. max_input_tokens = max(
  563. 1, kai_payload.max_context_length - kai_payload.max_length)
  564. input_tokens = tokenizer(kai_payload.prompt).input_ids[-max_input_tokens:]
  565. return sampling_params, input_tokens
  566. @kai_api.post("/generate")
  567. async def generate(kai_payload: KAIGenerationInputSchema) -> JSONResponse:
  568. sampling_params, input_tokens = prepare_engine_payload(kai_payload)
  569. result_generator = async_engine_client.generate(
  570. {
  571. "prompt": kai_payload.prompt,
  572. "prompt_token_ids": input_tokens,
  573. },
  574. sampling_params,
  575. kai_payload.genkey,
  576. )
  577. final_res: RequestOutput = None
  578. previous_output = ""
  579. async for res in result_generator:
  580. final_res = res
  581. new_chunk = res.outputs[0].text[len(previous_output):]
  582. previous_output += new_chunk
  583. gen_cache[kai_payload.genkey] = previous_output
  584. assert final_res is not None
  585. del gen_cache[kai_payload.genkey]
  586. return JSONResponse(
  587. {"results": [{
  588. "text": output.text
  589. } for output in final_res.outputs]})
  590. @extra_api.post("/generate/stream")
  591. async def generate_stream(
  592. kai_payload: KAIGenerationInputSchema) -> StreamingResponse:
  593. sampling_params, input_tokens = prepare_engine_payload(kai_payload)
  594. results_generator = async_engine_client.generate(
  595. {
  596. "prompt": kai_payload.prompt,
  597. "prompt_token_ids": input_tokens,
  598. },
  599. sampling_params,
  600. kai_payload.genkey,
  601. )
  602. async def stream_kobold() -> AsyncGenerator[bytes, None]:
  603. previous_output = ""
  604. async for res in results_generator:
  605. new_chunk = res.outputs[0].text[len(previous_output):]
  606. previous_output += new_chunk
  607. yield b"event: message\n"
  608. yield f"data: {json.dumps({'token': new_chunk})}\n\n".encode()
  609. return StreamingResponse(stream_kobold(),
  610. headers={
  611. "Cache-Control": "no-cache",
  612. "Connection": "keep-alive",
  613. },
  614. media_type="text/event-stream")
  615. @extra_api.post("/generate/check")
  616. @extra_api.get("/generate/check")
  617. async def check_generation(request: Request):
  618. text = ""
  619. try:
  620. request_dict = await request.json()
  621. if "genkey" in request_dict and request_dict["genkey"] in gen_cache:
  622. text = gen_cache[request_dict["genkey"]]
  623. except json.JSONDecodeError:
  624. pass
  625. return JSONResponse({"results": [{"text": text}]})
  626. @extra_api.post("/abort")
  627. async def abort_generation(request: Request):
  628. try:
  629. request_dict = await request.json()
  630. if "genkey" in request_dict:
  631. await async_engine_client.abort(request_dict["genkey"])
  632. except json.JSONDecodeError:
  633. pass
  634. return JSONResponse({})
  635. @extra_api.post("/tokencount")
  636. async def count_tokens(request: TokenizeRequest):
  637. """Tokenize string and return token count"""
  638. generator = await openai_serving_tokenization.create_tokenize(request)
  639. return JSONResponse({"value": generator.model_dump()["tokens"]})
  640. @kai_api.get("/info/version")
  641. async def get_version():
  642. """Impersonate KAI"""
  643. return JSONResponse({"result": "1.2.4"})
  644. @kai_api.get("/model")
  645. async def get_model():
  646. return JSONResponse({"result": f"aphrodite/{served_model_names[0]}"})
  647. @kai_api.get("/config/soft_prompts_list")
  648. async def get_available_softprompts():
  649. """Stub for compatibility"""
  650. return JSONResponse({"values": []})
  651. @kai_api.get("/config/soft_prompt")
  652. async def get_current_softprompt():
  653. """Stub for compatibility"""
  654. return JSONResponse({"value": ""})
  655. @kai_api.put("/config/soft_prompt")
  656. async def set_current_softprompt():
  657. """Stub for compatibility"""
  658. return JSONResponse({})
  659. @kai_api.get("/config/max_length")
  660. async def get_max_length() -> JSONResponse:
  661. max_length = args.max_length
  662. return JSONResponse({"value": max_length})
  663. @kai_api.get("/config/max_context_length")
  664. @extra_api.get("/true_max_context_length")
  665. async def get_max_context_length() -> JSONResponse:
  666. max_context_length = engine_args.max_model_len
  667. return JSONResponse({"value": max_context_length})
  668. @extra_api.get("/preloadstory")
  669. async def get_preloaded_story() -> JSONResponse:
  670. """Stub for compatibility"""
  671. return JSONResponse({})
  672. @extra_api.get("/version")
  673. async def get_extra_version():
  674. """Impersonate KoboldCpp"""
  675. return JSONResponse({"result": "KoboldCpp", "version": "1.63"})
  676. @router.get("/")
  677. async def get_kobold_lite_ui():
  678. """Serves a cached copy of the Kobold Lite UI, loading it from disk
  679. on demand if needed. Can be disabled with SERVE_KOBOLD_LITE_UI=0."""
  680. if not SERVE_KOBOLD_LITE_UI:
  681. return JSONResponse(content={"error": "Kobold Lite UI is disabled"},
  682. status_code=404)
  683. global kobold_lite_ui
  684. if kobold_lite_ui == "":
  685. scriptpath = os.path.dirname(os.path.abspath(__file__))
  686. klitepath = os.path.join(scriptpath, "../kobold/klite.embd")
  687. klitepath = os.path.normpath(klitepath) # Normalize the path
  688. if os.path.exists(klitepath):
  689. with open(klitepath, "r", encoding="utf-8") as f:
  690. kobold_lite_ui = f.read()
  691. else:
  692. logger.error("Kobold Lite UI not found at " + klitepath)
  693. return HTMLResponse(content=kobold_lite_ui)
  694. # ============ KoboldAI API ============ #
  695. def build_app(args: Namespace) -> FastAPI:
  696. app = FastAPI(lifespan=lifespan)
  697. app.include_router(router)
  698. app.root_path = args.root_path
  699. app.state.args = args
  700. if args.launch_kobold_api:
  701. logger.warning("Kobold API is now enabled by default. "
  702. "This flag will be removed in the future.")
  703. app.include_router(kai_api, prefix="/api/v1")
  704. app.include_router(kai_api,
  705. prefix="/api/latest",
  706. include_in_schema=False)
  707. app.include_router(extra_api, prefix="/api/extra")
  708. mount_metrics(app)
  709. app.add_middleware(
  710. CORSMiddleware,
  711. allow_origins=args.allowed_origins,
  712. allow_credentials=args.allow_credentials,
  713. allow_methods=args.allowed_methods,
  714. allow_headers=args.allowed_headers,
  715. )
  716. @app.exception_handler(RequestValidationError)
  717. async def validation_exception_handler(_, exc):
  718. err = openai_serving_completion.create_error_response(message=str(exc))
  719. return JSONResponse(err.model_dump(),
  720. status_code=HTTPStatus.BAD_REQUEST)
  721. if token := envs.APHRODITE_API_KEY or args.api_keys:
  722. admin_key = os.environ.get("APHRODITE_ADMIN_KEY") or args.admin_key
  723. if admin_key is None:
  724. logger.warning("Admin key not provided. Admin operations will "
  725. "be disabled.")
  726. @app.middleware("http")
  727. async def authentication(request: Request, call_next):
  728. if not request.url.path.startswith(("/v1", "/api")):
  729. return await call_next(request)
  730. # Browsers may send OPTIONS requests to check CORS headers
  731. # before sending the actual request. We should allow these
  732. # requests to pass through without authentication.
  733. # See https://github.com/PygmalionAI/aphrodite-engine/issues/434
  734. if request.method == "OPTIONS":
  735. return await call_next(request)
  736. auth_header = request.headers.get("Authorization")
  737. api_key_header = request.headers.get("x-api-key")
  738. if request.url.path.startswith(
  739. ("/v1/lora", "/v1/soft_prompt", "/v1/model")):
  740. if admin_key is not None and (
  741. api_key_header == admin_key or
  742. auth_header == "Bearer " + admin_key
  743. ):
  744. return await call_next(request)
  745. return JSONResponse(content={"error": "Unauthorized"},
  746. status_code=401)
  747. if auth_header != "Bearer " + token and api_key_header != token:
  748. return JSONResponse(content={"error": "Unauthorized"},
  749. status_code=401)
  750. return await call_next(request)
  751. for middleware in args.middleware:
  752. module_path, object_name = middleware.rsplit(".", 1)
  753. imported = getattr(importlib.import_module(module_path), object_name)
  754. if inspect.isclass(imported):
  755. app.add_middleware(imported)
  756. elif inspect.iscoroutinefunction(imported):
  757. app.middleware("http")(imported)
  758. else:
  759. raise ValueError(f"Invalid middleware {middleware}. "
  760. f"Must be a function or a class.")
  761. return app
  762. async def init_app(
  763. async_engine_client: AsyncEngineClient,
  764. args: Namespace,
  765. ) -> FastAPI:
  766. global api_server_args
  767. api_server_args = args
  768. app = build_app(args)
  769. logger.debug(f"args: {args}")
  770. global served_model_names
  771. if args.served_model_name is not None:
  772. served_model_names = args.served_model_name
  773. else:
  774. served_model_names = [args.model]
  775. if args.uvloop:
  776. uvloop.install()
  777. global tokenizer
  778. model_config = await async_engine_client.get_model_config()
  779. if args.disable_log_requests:
  780. request_logger = None
  781. else:
  782. request_logger = RequestLogger(max_log_len=args.max_log_len)
  783. global openai_serving_chat
  784. global openai_serving_completion
  785. global openai_serving_embedding
  786. global openai_serving_tokenization
  787. openai_serving_chat = OpenAIServingChat(
  788. async_engine_client,
  789. model_config,
  790. served_model_names,
  791. args.response_role,
  792. lora_modules=args.lora_modules,
  793. prompt_adapters=args.prompt_adapters,
  794. request_logger=request_logger,
  795. chat_template=args.chat_template,
  796. return_tokens_as_token_ids=args.return_tokens_as_token_ids,
  797. )
  798. openai_serving_completion = OpenAIServingCompletion(
  799. async_engine_client,
  800. model_config,
  801. served_model_names,
  802. lora_modules=args.lora_modules,
  803. prompt_adapters=args.prompt_adapters,
  804. request_logger=request_logger,
  805. return_tokens_as_token_ids=args.return_tokens_as_token_ids,
  806. )
  807. openai_serving_embedding = OpenAIServingEmbedding(
  808. async_engine_client,
  809. model_config,
  810. served_model_names,
  811. request_logger=request_logger,
  812. )
  813. openai_serving_tokenization = OpenAIServingTokenization(
  814. async_engine_client,
  815. model_config,
  816. served_model_names,
  817. lora_modules=args.lora_modules,
  818. request_logger=request_logger,
  819. chat_template=args.chat_template,
  820. )
  821. app.root_path = args.root_path
  822. tokenizer = get_tokenizer(
  823. tokenizer_name=engine_args.tokenizer,
  824. tokenizer_mode=engine_args.tokenizer_mode,
  825. trust_remote_code=engine_args.trust_remote_code,
  826. revision=engine_args.revision,
  827. )
  828. if args.launch_kobold_api:
  829. _set_badwords(tokenizer, model_config.hf_config)
  830. return app
  831. async def run_server(args, **uvicorn_kwargs) -> None:
  832. async with build_async_engine_client(args) as async_engine_client:
  833. # If None, creation of the client failed and we exit.
  834. if async_engine_client is None:
  835. return
  836. app = await init_app(async_engine_client, args)
  837. protocol = "https" if args.ssl_certfile else "http"
  838. root_path = args.root_path.rstrip("/") if args.root_path else ""
  839. host_name = args.host if args.host else "localhost"
  840. port_str = str(args.port)
  841. if SERVE_KOBOLD_LITE_UI:
  842. ui_url = f"{protocol}://{host_name}:{port_str}{root_path}/"
  843. logger.info(f"Kobold Lite UI: {ui_url}")
  844. logger.info(f"Documentation: {protocol}://{host_name}:{port_str}{root_path}/redoc") # noqa: E501
  845. logger.info(f"Completions API: {protocol}://{host_name}:{port_str}{root_path}/v1/completions") # noqa: E501
  846. logger.info(f"Chat API: {protocol}://{host_name}:{port_str}{root_path}/v1/chat/completions") # noqa: E501
  847. logger.info(f"Embeddings API: {protocol}://{host_name}:{port_str}{root_path}/v1/embeddings") # noqa: E501
  848. logger.info(f"Tokenization API: {protocol}://{host_name}:{port_str}{root_path}/v1/tokenize") # noqa: E501
  849. shutdown_task = await serve_http(
  850. app,
  851. engine=async_engine_client,
  852. host=args.host,
  853. port=args.port,
  854. log_level=args.uvicorn_log_level,
  855. timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
  856. ssl_keyfile=args.ssl_keyfile,
  857. ssl_certfile=args.ssl_certfile,
  858. ssl_ca_certs=args.ssl_ca_certs,
  859. ssl_cert_reqs=args.ssl_cert_reqs,
  860. **uvicorn_kwargs,
  861. )
  862. # NB: Await server shutdown only after the backend context is exited
  863. await shutdown_task
  864. if __name__ == "__main__":
  865. # NOTE:
  866. # This section should be in sync with aphrodite/endpoints/cli.py
  867. # for CLI entrypoints.
  868. parser = FlexibleArgumentParser(
  869. description="Aphrodite OpenAI-Compatible RESTful API Server")
  870. parser = make_arg_parser(parser)
  871. args = parser.parse_args()
  872. asyncio.run(run_server(args))