api_server.py 43 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233
  1. import asyncio
  2. import importlib
  3. import inspect
  4. import json
  5. import multiprocessing
  6. import os
  7. import pickle
  8. import re
  9. import signal
  10. import tempfile
  11. from argparse import Namespace
  12. from contextlib import asynccontextmanager
  13. from distutils.util import strtobool
  14. from functools import partial
  15. from http import HTTPStatus
  16. from typing import AsyncGenerator, AsyncIterator, List, Optional, Set, Tuple
  17. import yaml
  18. from fastapi import APIRouter, FastAPI, Form, Request, UploadFile
  19. from fastapi.exceptions import RequestValidationError
  20. from fastapi.middleware.cors import CORSMiddleware
  21. from fastapi.responses import (HTMLResponse, JSONResponse, Response,
  22. StreamingResponse)
  23. from loguru import logger
  24. from starlette.datastructures import State
  25. from starlette.routing import Mount
  26. import aphrodite.common.envs as envs
  27. from aphrodite.common.config import ModelConfig
  28. from aphrodite.common.outputs import RequestOutput
  29. from aphrodite.common.sampling_params import _SAMPLING_EPS, SamplingParams
  30. from aphrodite.common.utils import (FlexibleArgumentParser,
  31. get_open_zmq_ipc_path, in_windows,
  32. random_uuid)
  33. from aphrodite.endpoints.logger import RequestLogger
  34. from aphrodite.endpoints.openai.args import make_arg_parser
  35. from aphrodite.endpoints.openai.protocol import (ChatCompletionRequest,
  36. ChatCompletionResponse,
  37. CompletionRequest,
  38. DetokenizeRequest,
  39. DetokenizeResponse,
  40. EmbeddingRequest,
  41. ErrorResponse,
  42. KAIGenerationInputSchema,
  43. TokenizeRequest,
  44. TokenizeResponse)
  45. from aphrodite.endpoints.openai.serving_chat import OpenAIServingChat
  46. from aphrodite.endpoints.openai.serving_completions import (
  47. OpenAIServingCompletion)
  48. from aphrodite.endpoints.openai.serving_embedding import OpenAIServingEmbedding
  49. from aphrodite.endpoints.openai.serving_engine import (BaseModelPath,
  50. LoRAModulePath,
  51. PromptAdapterPath)
  52. from aphrodite.endpoints.openai.serving_tokenization import (
  53. OpenAIServingTokenization)
  54. from aphrodite.engine.args_tools import AsyncEngineArgs
  55. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  56. from aphrodite.engine.multiprocessing import (APHRODITE_RPC_SUCCESS_STR,
  57. RPCShutdownRequest)
  58. from aphrodite.engine.multiprocessing.client import MQAphroditeEngineClient
  59. from aphrodite.engine.multiprocessing.engine import run_mp_engine
  60. from aphrodite.engine.protocol import EngineClient
  61. from aphrodite.modeling.model_loader.weight_utils import get_model_config_yaml
  62. from aphrodite.server import serve_http
  63. from aphrodite.transformers_utils.tokenizer import get_tokenizer
  64. from aphrodite.version import __version__ as APHRODITE_VERSION
  65. if in_windows():
  66. import winloop as uvloop
  67. else:
  68. import uvloop
  69. TIMEOUT_KEEP_ALIVE = 5 # seconds
  70. SERVE_KOBOLD_LITE_UI = strtobool(os.getenv("SERVE_KOBOLD_LITE_UI", "1"))
  71. router = APIRouter()
  72. kai_api = APIRouter()
  73. extra_api = APIRouter()
  74. kobold_lite_ui = ""
  75. sampler_json = ""
  76. gen_cache: dict = {}
  77. prometheus_multiproc_dir: tempfile.TemporaryDirectory
  78. _running_tasks: Set[asyncio.Task] = set()
  79. @asynccontextmanager
  80. async def lifespan(app: FastAPI):
  81. try:
  82. if app.state.log_stats:
  83. engine_client: EngineClient = app.state.engine_client
  84. async def _force_log():
  85. while True:
  86. await asyncio.sleep(10.)
  87. await engine_client.do_log_stats()
  88. task = asyncio.create_task(_force_log())
  89. _running_tasks.add(task)
  90. task.add_done_callback(_running_tasks.remove)
  91. else:
  92. task = None
  93. try:
  94. yield
  95. finally:
  96. if task is not None:
  97. task.cancel()
  98. finally:
  99. # Ensure app state including engine ref is gc'd
  100. del app.state
  101. @asynccontextmanager
  102. async def build_engine_client(
  103. args: Namespace) -> AsyncIterator[EngineClient]:
  104. # Context manager to handle engine_client lifecycle
  105. # Ensures everything is shutdown and cleaned up on error/exit
  106. engine_args = AsyncEngineArgs.from_cli_args(args)
  107. async with build_engine_client_from_engine_args(
  108. engine_args, args.disable_frontend_multiprocessing) as engine:
  109. yield engine
  110. @asynccontextmanager
  111. async def build_engine_client_from_engine_args(
  112. engine_args: AsyncEngineArgs,
  113. disable_frontend_multiprocessing: bool = False,
  114. ) -> AsyncIterator[EngineClient]:
  115. """
  116. Create EngineClient, either:
  117. - in-process using the AsyncAphrodite Directly
  118. - multiprocess using AsyncAphrodite RPC
  119. Returns the Client or None if the creation failed.
  120. """
  121. # Fall back
  122. # TODO: fill out feature matrix.
  123. if (MQAphroditeEngineClient.is_unsupported_config(engine_args)
  124. or disable_frontend_multiprocessing):
  125. engine_config = engine_args.create_engine_config()
  126. uses_ray = getattr(AsyncAphrodite._get_executor_cls(engine_config),
  127. "uses_ray", False)
  128. build_engine = partial(AsyncAphrodite.from_engine_args,
  129. engine_args=engine_args,
  130. engine_config=engine_config)
  131. if uses_ray:
  132. # Must run in main thread with ray for its signal handlers to work
  133. engine_client = build_engine()
  134. else:
  135. engine_client = await asyncio.get_running_loop().run_in_executor(
  136. None, build_engine)
  137. yield engine_client
  138. return
  139. # Otherwise, use the multiprocessing AsyncAphrodite.
  140. else:
  141. if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
  142. # Make TemporaryDirectory for prometheus multiprocessing
  143. # Note: global TemporaryDirectory will be automatically
  144. # cleaned up upon exit.
  145. global prometheus_multiproc_dir
  146. prometheus_multiproc_dir = tempfile.TemporaryDirectory()
  147. os.environ[
  148. "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
  149. else:
  150. logger.warning(
  151. "Found PROMETHEUS_MULTIPROC_DIR was set by user. "
  152. "This directory must be wiped between Aphrodite runs or "
  153. "you will find inaccurate metrics. Unset the variable "
  154. "and Aphrodite will properly handle cleanup.")
  155. # Select random path for IPC.
  156. ipc_path = get_open_zmq_ipc_path()
  157. logger.info(
  158. f"Multiprocessing frontend to use {ipc_path} for IPC Path.")
  159. # Start RPCServer in separate process (holds the LLMEngine).
  160. # the current process might have CUDA context,
  161. # so we need to spawn a new process
  162. context = multiprocessing.get_context("spawn")
  163. engine_process = context.Process(target=run_mp_engine,
  164. args=(engine_args,
  165. ipc_path))
  166. engine_process.start()
  167. logger.info(f"Started engine process with PID {engine_process.pid}")
  168. # Build RPCClient, which conforms to EngineClient Protocol.
  169. # NOTE: Actually, this is not true yet. We still need to support
  170. # embedding models via RPC (see TODO above)
  171. engine_config = engine_args.create_engine_config()
  172. mp_engine_client = MQAphroditeEngineClient(ipc_path, engine_config)
  173. try:
  174. while True:
  175. try:
  176. await mp_engine_client.setup()
  177. break
  178. except TimeoutError:
  179. if not engine_process.is_alive():
  180. raise RuntimeError(
  181. "Engine process failed to start") from None
  182. yield mp_engine_client # type: ignore[misc]
  183. finally:
  184. # Ensure rpc server process was terminated
  185. engine_process.terminate()
  186. # Close all open connections to the backend
  187. mp_engine_client.close()
  188. # Wait for engine process to join
  189. engine_process.join(4)
  190. if engine_process.exitcode is None:
  191. # Kill if taking longer than 5 seconds to stop
  192. engine_process.kill()
  193. # Lazy import for prometheus multiprocessing.
  194. # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
  195. # before prometheus_client is imported.
  196. # See https://prometheus.github.io/client_python/multiprocess/
  197. from prometheus_client import multiprocess
  198. multiprocess.mark_process_dead(engine_process.pid)
  199. def mount_metrics(app: FastAPI):
  200. # Lazy import for prometheus multiprocessing.
  201. # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
  202. # before prometheus_client is imported.
  203. # See https://prometheus.github.io/client_python/multiprocess/
  204. from prometheus_client import (CollectorRegistry, make_asgi_app,
  205. multiprocess)
  206. prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
  207. if prometheus_multiproc_dir_path is not None:
  208. logger.info(f"Aphrodite to use {prometheus_multiproc_dir_path} "
  209. "as PROMETHEUS_MULTIPROC_DIR")
  210. registry = CollectorRegistry()
  211. multiprocess.MultiProcessCollector(registry)
  212. # Add prometheus asgi middleware to route /metrics requests
  213. metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
  214. else:
  215. # Add prometheus asgi middleware to route /metrics requests
  216. metrics_route = Mount("/metrics", make_asgi_app())
  217. # Workaround for 307 Redirect for /metrics
  218. metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')
  219. app.routes.append(metrics_route)
  220. async def _handle_model_switch(
  221. raw_request: Request,
  222. requested_model: str
  223. ) -> Optional[JSONResponse]:
  224. """Helper function to handle model switching if needed.
  225. Returns error response if something went wrong, None if successful."""
  226. if not raw_request.app.state.args.allow_inline_model_loading:
  227. return None
  228. if not raw_request.app.state.model_is_loaded:
  229. config = get_model_config_yaml(requested_model)
  230. request_data = {"model": requested_model}
  231. if config:
  232. config.pop("model", None)
  233. request_data.update(config)
  234. load_response = await load_model(
  235. raw_request,
  236. request=json.dumps(request_data)
  237. )
  238. if load_response.status_code != 200:
  239. return load_response
  240. return None
  241. current_model = raw_request.app.state.current_model
  242. if current_model == requested_model:
  243. return None
  244. unload_response = await unload_model(raw_request)
  245. if unload_response.status_code != 200:
  246. return unload_response
  247. config = get_model_config_yaml(requested_model)
  248. request_data = {"model": requested_model}
  249. if config:
  250. config.pop("model", None)
  251. request_data.update(config)
  252. load_response = await load_model(
  253. raw_request,
  254. request=json.dumps(request_data)
  255. )
  256. if load_response.status_code != 200:
  257. return load_response
  258. return None
  259. def chat(request: Request) -> OpenAIServingChat:
  260. return request.app.state.openai_serving_chat
  261. def completion(request: Request) -> OpenAIServingCompletion:
  262. return request.app.state.openai_serving_completion
  263. def tokenization(request: Request) -> OpenAIServingTokenization:
  264. return request.app.state.openai_serving_tokenization
  265. def embedding(request: Request) -> OpenAIServingEmbedding:
  266. return request.app.state.openai_serving_embedding
  267. def engine_client(request: Request) -> EngineClient:
  268. return request.app.state.engine_client
  269. @router.delete("/v1/model/unload")
  270. async def unload_model(raw_request: Request):
  271. """Unload the model and shut down the engine process."""
  272. if not raw_request.app.state.model_is_loaded:
  273. return JSONResponse(
  274. content={
  275. "status": "error",
  276. "message": "No model loaded."
  277. },
  278. status_code=500
  279. )
  280. client = raw_request.app.state.engine_client
  281. if isinstance(client, MQAphroditeEngineClient):
  282. try:
  283. shutdown_req = RPCShutdownRequest()
  284. await client.input_socket.send_multipart(
  285. (pickle.dumps(shutdown_req),), copy=False
  286. )
  287. response = await client.output_socket.recv_multipart()
  288. if pickle.loads(response[0]) != APHRODITE_RPC_SUCCESS_STR:
  289. raise RuntimeError("Engine shutdown failed")
  290. client.output_loop.cancel()
  291. if client.health_loop is not None:
  292. client.health_loop.cancel()
  293. client.close()
  294. raw_request.app.state.engine_client = None
  295. raw_request.app.state.openai_serving_chat = None
  296. raw_request.app.state.openai_serving_completion = None
  297. raw_request.app.state.openai_serving_embedding = None
  298. raw_request.app.state.openai_serving_tokenization = None
  299. raw_request.app.state.model_is_loaded = False
  300. return JSONResponse(content={"status": "success"})
  301. except Exception as e:
  302. return JSONResponse(
  303. content={
  304. "status": "error",
  305. "message": f"Failed to shutdown engine: {str(e)}"
  306. },
  307. status_code=500
  308. )
  309. else:
  310. return JSONResponse(
  311. content={
  312. "status": "error",
  313. "message": "Model unloading only supported with multiprocessing"
  314. " backend"
  315. },
  316. status_code=400
  317. )
  318. @router.post("/v1/model/load")
  319. async def load_model(
  320. raw_request: Request,
  321. config_file: Optional[UploadFile] = None,
  322. request: Optional[str] = Form(None)
  323. ):
  324. """Load a new model after unloading the previous one.
  325. Accept either a config file, a JSON request body, or both."""
  326. if raw_request.app.state.model_is_loaded:
  327. return JSONResponse(
  328. content={
  329. "status": "error",
  330. "message": "A model is already loaded. Please unload it first."
  331. },
  332. status_code=400
  333. )
  334. try:
  335. parser = FlexibleArgumentParser()
  336. parser = make_arg_parser(parser)
  337. new_args = parser.parse_args([])
  338. original_args = api_server_args
  339. essential_params = [
  340. 'host', 'port', 'api_keys', 'admin_key',
  341. 'disable_frontend_multiprocessing', 'root_path',
  342. 'ssl_keyfile', 'ssl_certfile'
  343. ]
  344. for param in essential_params:
  345. if hasattr(original_args, param):
  346. setattr(new_args, param, getattr(original_args, param))
  347. if config_file:
  348. yaml_content = await config_file.read()
  349. config_args = yaml.safe_load(yaml_content)
  350. if config_args:
  351. for key, value in config_args.items():
  352. if hasattr(new_args, key):
  353. setattr(new_args, key, value)
  354. json_args = None
  355. if request:
  356. try:
  357. json_args = json.loads(request)
  358. except json.JSONDecodeError:
  359. return JSONResponse(
  360. content={
  361. "status": "error",
  362. "message": "Invalid JSON in request form field."
  363. },
  364. status_code=400
  365. )
  366. else:
  367. try:
  368. json_args = await raw_request.json()
  369. except Exception:
  370. if not config_file:
  371. return JSONResponse(
  372. content={
  373. "status": "error",
  374. "message": "Must provide either config_file or "
  375. "valid JSON request body."
  376. },
  377. status_code=400
  378. )
  379. if json_args:
  380. for key, value in json_args.items():
  381. if hasattr(new_args, key):
  382. setattr(new_args, key, value)
  383. if not hasattr(new_args, 'model') or not new_args.model:
  384. return JSONResponse(
  385. content={
  386. "status": "error",
  387. "message": "No model specified in config or request body."
  388. },
  389. status_code=400
  390. )
  391. engine_args = AsyncEngineArgs.from_cli_args(new_args)
  392. if (MQAphroditeEngineClient.is_unsupported_config(engine_args)
  393. or new_args.disable_frontend_multiprocessing):
  394. return JSONResponse(
  395. content={
  396. "status": "error",
  397. "message": "Model loading only supported with "
  398. "multiprocessing backend."
  399. },
  400. status_code=400
  401. )
  402. ipc_path = get_open_zmq_ipc_path()
  403. context = multiprocessing.get_context("spawn")
  404. engine_process = context.Process(
  405. target=run_mp_engine,
  406. args=(engine_args, ipc_path)
  407. )
  408. engine_process.start()
  409. engine_config = engine_args.create_engine_config()
  410. engine_client = MQAphroditeEngineClient(ipc_path, engine_config)
  411. try:
  412. while True:
  413. try:
  414. await engine_client.setup()
  415. break
  416. except TimeoutError:
  417. if not engine_process.is_alive():
  418. return JSONResponse(
  419. content={
  420. "status": "error",
  421. "message": "Engine process died before "
  422. "responding to readiness probe."
  423. },
  424. status_code=500
  425. )
  426. model_config = await engine_client.get_model_config()
  427. init_app_state(
  428. engine_client, model_config, raw_request.app.state, new_args)
  429. raw_request.app.state.model_is_loaded = True
  430. raw_request.app.state.current_model = new_args.model
  431. return JSONResponse(content={"status": "success"})
  432. except Exception as e:
  433. engine_process.terminate()
  434. engine_client.close()
  435. raise e
  436. except Exception as e:
  437. return JSONResponse(
  438. content={
  439. "status": "error",
  440. "message": f"Failed to load model: {str(e)}"
  441. },
  442. status_code=500
  443. )
  444. @router.get("/health")
  445. async def health(raw_request: Request) -> Response:
  446. """Health check."""
  447. await engine_client(raw_request).check_health()
  448. return Response(status_code=200)
  449. @router.post("/v1/tokenize")
  450. async def tokenize(request: TokenizeRequest, raw_request: Request):
  451. if hasattr(request, "model"):
  452. error_response = await _handle_model_switch(raw_request, request.model)
  453. if error_response is not None:
  454. return error_response
  455. if not raw_request.app.state.model_is_loaded:
  456. return JSONResponse(
  457. content={
  458. "status": "error",
  459. "message": "No model loaded."
  460. },
  461. status_code=500
  462. )
  463. generator = await tokenization(raw_request).create_tokenize(request)
  464. if isinstance(generator, ErrorResponse):
  465. return JSONResponse(content=generator.model_dump(),
  466. status_code=generator.code)
  467. else:
  468. assert isinstance(generator, TokenizeResponse)
  469. return JSONResponse(content=generator.model_dump())
  470. @router.post("/v1/detokenize")
  471. async def detokenize(request: DetokenizeRequest, raw_request: Request):
  472. if hasattr(request, "model"):
  473. error_response = await _handle_model_switch(
  474. raw_request, request.model)
  475. if error_response is not None:
  476. return error_response
  477. if not raw_request.app.state.model_is_loaded:
  478. return JSONResponse(
  479. content={
  480. "status": "error",
  481. "message": "No model loaded."
  482. },
  483. status_code=500
  484. )
  485. generator = await tokenization(raw_request).create_detokenize(request)
  486. if isinstance(generator, ErrorResponse):
  487. return JSONResponse(content=generator.model_dump(),
  488. status_code=generator.code)
  489. else:
  490. assert isinstance(generator, DetokenizeResponse)
  491. return JSONResponse(content=generator.model_dump())
  492. @router.get("/v1/models")
  493. async def show_available_models(raw_request: Request):
  494. if not raw_request.app.state.model_is_loaded:
  495. return JSONResponse(
  496. content={
  497. "status": "error",
  498. "message": "No model loaded."
  499. },
  500. status_code=500
  501. )
  502. models = await completion(raw_request).show_available_models()
  503. return JSONResponse(content=models.model_dump())
  504. @router.get("/version")
  505. async def show_version():
  506. ver = {"version": APHRODITE_VERSION}
  507. return JSONResponse(content=ver)
  508. @router.get("/.well-known/serviceinfo")
  509. async def serviceinfo():
  510. """Return service information including version, API endpoints,
  511. and documentation URLs."""
  512. return JSONResponse(content={
  513. "version": 0.2,
  514. "software": {
  515. "name": "Aphrodite Engine",
  516. "version": APHRODITE_VERSION,
  517. "repository": "https://github.com/PygmalionAI/aphrodite-engine",
  518. "homepage": "https://aphrodite.pygmalion.chat",
  519. "logo": "https://pygmalion.chat/icons/favicon.ico",
  520. },
  521. "api": {
  522. "openai": {
  523. "name": "OpenAI API",
  524. "rel_url": "/v1",
  525. "documentation": "/redoc",
  526. "version": 1,
  527. },
  528. "koboldai": {
  529. "name": "KoboldAI API",
  530. "rel_url": "/api",
  531. "documentation": "/redoc",
  532. "version": 1,
  533. }
  534. }
  535. })
  536. @router.post("/v1/chat/completions")
  537. async def create_chat_completion(request: ChatCompletionRequest,
  538. raw_request: Request):
  539. if hasattr(request, "model"):
  540. error_response = await _handle_model_switch(raw_request, request.model)
  541. if error_response is not None:
  542. return error_response
  543. if not raw_request.app.state.model_is_loaded:
  544. return JSONResponse(
  545. content={
  546. "status": "error",
  547. "message": "No model loaded."
  548. },
  549. status_code=500
  550. )
  551. generator = await chat(raw_request).create_chat_completion(
  552. request, raw_request)
  553. if isinstance(generator, ErrorResponse):
  554. return JSONResponse(content=generator.model_dump(),
  555. status_code=generator.code)
  556. if request.stream:
  557. return StreamingResponse(content=generator,
  558. media_type="text/event-stream")
  559. else:
  560. assert isinstance(generator, ChatCompletionResponse)
  561. return JSONResponse(content=generator.model_dump())
  562. @router.post("/v1/completions")
  563. async def create_completion(request: CompletionRequest, raw_request: Request):
  564. if hasattr(request, "model"):
  565. error_response = await _handle_model_switch(raw_request, request.model)
  566. if error_response is not None:
  567. return error_response
  568. if not raw_request.app.state.model_is_loaded:
  569. return JSONResponse(
  570. content={
  571. "status": "error",
  572. "message": "No model loaded."
  573. },
  574. status_code=500
  575. )
  576. generator = await completion(raw_request).create_completion(
  577. request, raw_request)
  578. if isinstance(generator, ErrorResponse):
  579. return JSONResponse(content=generator.model_dump(),
  580. status_code=generator.code)
  581. if request.stream:
  582. return StreamingResponse(content=generator,
  583. media_type="text/event-stream")
  584. else:
  585. return JSONResponse(content=generator.model_dump())
  586. @router.post("/v1/embeddings")
  587. async def create_embedding(request: EmbeddingRequest, raw_request: Request):
  588. if hasattr(request, "model"):
  589. error_response = await _handle_model_switch(raw_request, request.model)
  590. if error_response is not None:
  591. return error_response
  592. if not raw_request.app.state.model_is_loaded:
  593. return JSONResponse(
  594. content={
  595. "status": "error",
  596. "message": "No model loaded."
  597. },
  598. status_code=500
  599. )
  600. generator = await embedding(raw_request).create_embedding(
  601. request, raw_request)
  602. if isinstance(generator, ErrorResponse):
  603. return JSONResponse(content=generator.model_dump(),
  604. status_code=generator.code)
  605. else:
  606. return JSONResponse(content=generator.model_dump())
  607. @router.post("/v1/lora/load")
  608. async def load_lora(lora: LoRAModulePath, raw_request: Request):
  609. if not raw_request.app.state.model_is_loaded:
  610. return JSONResponse(
  611. content={
  612. "status": "error",
  613. "message": "No model loaded."
  614. },
  615. status_code=500
  616. )
  617. completion(raw_request).add_lora(lora)
  618. if args.enable_lora is False:
  619. logger.error("LoRA is not enabled in the engine. "
  620. "Please start the server with the "
  621. "--enable-lora flag!")
  622. return JSONResponse(content={"status": "success"})
  623. @router.delete("/v1/lora/unload")
  624. async def unload_lora(lora_name: str, raw_request: Request):
  625. if not raw_request.app.state.model_is_loaded:
  626. return JSONResponse(
  627. content={
  628. "status": "error",
  629. "message": "No model loaded."
  630. },
  631. status_code=500
  632. )
  633. completion(raw_request).remove_lora(lora_name)
  634. return JSONResponse(content={"status": "success"})
  635. @router.post("/v1/soft_prompt/load")
  636. async def load_soft_prompt(soft_prompt: PromptAdapterPath,
  637. raw_request: Request):
  638. if not raw_request.app.state.model_is_loaded:
  639. return JSONResponse(
  640. content={
  641. "status": "error",
  642. "message": "No model loaded."
  643. },
  644. status_code=500
  645. )
  646. completion(raw_request).add_prompt_adapter(soft_prompt)
  647. if args.enable_prompt_adapter is False:
  648. logger.error("Prompt Adapter is not enabled in the engine. "
  649. "Please start the server with the "
  650. "--enable-prompt-adapter flag!")
  651. return JSONResponse(content={"status": "success"})
  652. @router.delete("/v1/soft_prompt/unload")
  653. async def unload_soft_prompt(soft_prompt_name: str, raw_request: Request):
  654. if not raw_request.app.state.model_is_loaded:
  655. return JSONResponse(
  656. content={
  657. "status": "error",
  658. "message": "No model loaded."
  659. },
  660. status_code=500
  661. )
  662. completion(raw_request).remove_prompt_adapter(soft_prompt_name)
  663. return JSONResponse(content={"status": "success"})
  664. # ============ KoboldAI API ============ #
  665. badwordsids: List[int] = []
  666. def _set_badwords(tokenizer, hf_config): # pylint: disable=redefined-outer-name
  667. # pylint: disable=global-variable-undefined
  668. global badwordsids
  669. if hf_config.bad_words_ids is not None:
  670. badwordsids = hf_config.bad_words_ids
  671. return
  672. badwordsids = [
  673. v for k, v in tokenizer.get_vocab().items()
  674. if any(c in str(k) for c in "[]")
  675. ]
  676. if tokenizer.pad_token_id in badwordsids:
  677. badwordsids.remove(tokenizer.pad_token_id)
  678. badwordsids.append(tokenizer.eos_token_id)
  679. def prepare_engine_payload(
  680. kai_payload: KAIGenerationInputSchema
  681. ) -> Tuple[SamplingParams, List[int]]:
  682. """Create SamplingParams and truncated input tokens for AsyncEngine"""
  683. if not kai_payload.genkey:
  684. kai_payload.genkey = f"kai-{random_uuid()}"
  685. kai_payload.top_k = kai_payload.top_k if kai_payload.top_k != 0.0 else -1
  686. kai_payload.tfs = max(_SAMPLING_EPS, kai_payload.tfs)
  687. if kai_payload.temperature < _SAMPLING_EPS:
  688. kai_payload.n = 1
  689. kai_payload.top_p = 1.0
  690. kai_payload.top_k = -1
  691. sampling_params = SamplingParams(
  692. n=kai_payload.n,
  693. best_of=kai_payload.n,
  694. repetition_penalty=kai_payload.rep_pen,
  695. temperature=kai_payload.temperature,
  696. smoothing_factor=kai_payload.smoothing_factor,
  697. smoothing_curve=kai_payload.smoothing_curve,
  698. tfs=kai_payload.tfs,
  699. top_p=kai_payload.top_p,
  700. top_k=kai_payload.top_k,
  701. top_a=kai_payload.top_a,
  702. min_p=kai_payload.min_p,
  703. typical_p=kai_payload.typical,
  704. eta_cutoff=kai_payload.eta_cutoff,
  705. epsilon_cutoff=kai_payload.eps_cutoff,
  706. stop=kai_payload.stop_sequence,
  707. include_stop_str_in_output=kai_payload.include_stop_str_in_output,
  708. custom_token_bans=badwordsids
  709. if kai_payload.use_default_badwordsids else [],
  710. max_tokens=kai_payload.max_length,
  711. seed=kai_payload.sampler_seed,
  712. xtc_probability=kai_payload.xtc_probability,
  713. xtc_threshold=kai_payload.xtc_threshold,
  714. )
  715. max_input_tokens = max(
  716. 1, kai_payload.max_context_length - kai_payload.max_length)
  717. input_tokens = tokenizer(kai_payload.prompt).input_ids[-max_input_tokens:]
  718. return sampling_params, input_tokens
  719. @kai_api.post("/generate")
  720. async def generate(kai_payload: KAIGenerationInputSchema,
  721. raw_request: Request) -> JSONResponse:
  722. sampling_params, input_tokens = prepare_engine_payload(kai_payload)
  723. result_generator = engine_client(raw_request).generate(
  724. {
  725. "prompt": kai_payload.prompt,
  726. "prompt_token_ids": input_tokens,
  727. },
  728. sampling_params,
  729. kai_payload.genkey,
  730. )
  731. final_res: RequestOutput = None
  732. previous_output = ""
  733. async for res in result_generator:
  734. final_res = res
  735. new_chunk = res.outputs[0].text[len(previous_output):]
  736. previous_output += new_chunk
  737. gen_cache[kai_payload.genkey] = previous_output
  738. assert final_res is not None
  739. del gen_cache[kai_payload.genkey]
  740. return JSONResponse(
  741. {"results": [{
  742. "text": output.text
  743. } for output in final_res.outputs]})
  744. @extra_api.post("/generate/stream")
  745. async def generate_stream(
  746. kai_payload: KAIGenerationInputSchema,
  747. raw_request: Request) -> StreamingResponse:
  748. sampling_params, input_tokens = prepare_engine_payload(kai_payload)
  749. results_generator = engine_client(raw_request).generate(
  750. {
  751. "prompt": kai_payload.prompt,
  752. "prompt_token_ids": input_tokens,
  753. },
  754. sampling_params,
  755. kai_payload.genkey,
  756. )
  757. async def stream_kobold() -> AsyncGenerator[bytes, None]:
  758. previous_output = ""
  759. async for res in results_generator:
  760. new_chunk = res.outputs[0].text[len(previous_output):]
  761. previous_output += new_chunk
  762. yield b"event: message\n"
  763. yield f"data: {json.dumps({'token': new_chunk})}\n\n".encode()
  764. return StreamingResponse(stream_kobold(),
  765. headers={
  766. "Cache-Control": "no-cache",
  767. "Connection": "keep-alive",
  768. },
  769. media_type="text/event-stream")
  770. @extra_api.post("/generate/check")
  771. @extra_api.get("/generate/check")
  772. async def check_generation(request: Request):
  773. text = ""
  774. try:
  775. request_dict = await request.json()
  776. if "genkey" in request_dict and request_dict["genkey"] in gen_cache:
  777. text = gen_cache[request_dict["genkey"]]
  778. except json.JSONDecodeError:
  779. pass
  780. return JSONResponse({"results": [{"text": text}]})
  781. @extra_api.post("/abort")
  782. async def abort_generation(raw_request: Request):
  783. try:
  784. request_dict = await raw_request.json()
  785. if "genkey" in request_dict:
  786. await engine_client(raw_request).abort(request_dict["genkey"])
  787. except json.JSONDecodeError:
  788. pass
  789. return JSONResponse({})
  790. @extra_api.post("/tokencount")
  791. async def count_tokens(request: TokenizeRequest, raw_request: Request):
  792. """Tokenize string and return token count"""
  793. generator = await tokenization(raw_request).create_tokenize(request)
  794. return JSONResponse({"value": generator.model_dump()["tokens"]})
  795. @kai_api.get("/info/version")
  796. async def get_version():
  797. """Impersonate KAI"""
  798. return JSONResponse({"result": "1.2.4"})
  799. @kai_api.get("/model")
  800. async def get_model():
  801. return JSONResponse({"result": f"aphrodite/{served_model_names[0]}"})
  802. @kai_api.get("/config/soft_prompts_list")
  803. async def get_available_softprompts():
  804. """Stub for compatibility"""
  805. return JSONResponse({"values": []})
  806. @kai_api.get("/config/soft_prompt")
  807. async def get_current_softprompt():
  808. """Stub for compatibility"""
  809. return JSONResponse({"value": ""})
  810. @kai_api.put("/config/soft_prompt")
  811. async def set_current_softprompt():
  812. """Stub for compatibility"""
  813. return JSONResponse({})
  814. @kai_api.get("/config/max_length")
  815. async def get_max_length() -> JSONResponse:
  816. max_length = args.max_length
  817. return JSONResponse({"value": max_length})
  818. @kai_api.get("/config/max_context_length")
  819. @extra_api.get("/true_max_context_length")
  820. async def get_max_context_length() -> JSONResponse:
  821. max_context_length = args.max_model_len
  822. return JSONResponse({"value": max_context_length})
  823. @extra_api.get("/preloadstory")
  824. async def get_preloaded_story() -> JSONResponse:
  825. """Stub for compatibility"""
  826. return JSONResponse({})
  827. @extra_api.get("/version")
  828. async def get_extra_version():
  829. """Impersonate KoboldCpp"""
  830. return JSONResponse({"result": "KoboldCpp", "version": "1.63"})
  831. @router.get("/")
  832. async def get_kobold_lite_ui():
  833. """Serves a cached copy of the Kobold Lite UI, loading it from disk
  834. on demand if needed. Can be disabled with SERVE_KOBOLD_LITE_UI=0."""
  835. if not SERVE_KOBOLD_LITE_UI:
  836. return JSONResponse(content={"error": "Kobold Lite UI is disabled"},
  837. status_code=404)
  838. global kobold_lite_ui
  839. if kobold_lite_ui == "":
  840. scriptpath = os.path.dirname(os.path.abspath(__file__))
  841. klitepath = os.path.join(scriptpath, "../kobold/klite.embd")
  842. klitepath = os.path.normpath(klitepath) # Normalize the path
  843. if os.path.exists(klitepath):
  844. with open(klitepath, "r", encoding="utf-8") as f:
  845. kobold_lite_ui = f.read()
  846. else:
  847. logger.error("Kobold Lite UI not found at " + klitepath)
  848. return HTMLResponse(content=kobold_lite_ui)
  849. # ============ KoboldAI API ============ #
  850. def build_app(args: Namespace) -> FastAPI:
  851. app = FastAPI(lifespan=lifespan)
  852. app.include_router(router)
  853. app.root_path = args.root_path
  854. app.state.args = args
  855. app.state.model_is_loaded = False
  856. if args.launch_kobold_api:
  857. logger.warning("Kobold API is now enabled by default. "
  858. "This flag will be removed in the future.")
  859. app.include_router(kai_api, prefix="/api/v1")
  860. app.include_router(kai_api,
  861. prefix="/api/latest",
  862. include_in_schema=False)
  863. app.include_router(extra_api, prefix="/api/extra")
  864. mount_metrics(app)
  865. app.add_middleware(
  866. CORSMiddleware,
  867. allow_origins=args.allowed_origins,
  868. allow_credentials=args.allow_credentials,
  869. allow_methods=args.allowed_methods,
  870. allow_headers=args.allowed_headers,
  871. )
  872. @app.exception_handler(RequestValidationError)
  873. async def validation_exception_handler(_, exc):
  874. chat = app.state.openai_serving_chat
  875. err = chat.create_error_response(message=str(exc))
  876. return JSONResponse(err.model_dump(),
  877. status_code=HTTPStatus.BAD_REQUEST)
  878. if token := envs.APHRODITE_API_KEY or args.api_keys:
  879. admin_key = os.environ.get("APHRODITE_ADMIN_KEY") or args.admin_key
  880. if admin_key is None:
  881. logger.warning("Admin key not provided. Admin operations will "
  882. "be disabled.")
  883. @app.middleware("http")
  884. async def authentication(request: Request, call_next):
  885. if not request.url.path.startswith(("/v1", "/api")):
  886. return await call_next(request)
  887. # Browsers may send OPTIONS requests to check CORS headers
  888. # before sending the actual request. We should allow these
  889. # requests to pass through without authentication.
  890. # See https://github.com/PygmalionAI/aphrodite-engine/issues/434
  891. if request.method == "OPTIONS":
  892. return await call_next(request)
  893. auth_header = request.headers.get("Authorization")
  894. api_key_header = request.headers.get("x-api-key")
  895. if request.url.path.startswith(
  896. ("/v1/lora", "/v1/soft_prompt", "/v1/model")):
  897. if admin_key is not None and (
  898. api_key_header == admin_key or
  899. auth_header == "Bearer " + admin_key
  900. ):
  901. return await call_next(request)
  902. return JSONResponse(content={"error": "Unauthorized"},
  903. status_code=401)
  904. if (auth_header == f"Bearer {token}" or api_key_header == token or
  905. (admin_key is not None and
  906. (api_key_header == admin_key or
  907. auth_header == f"Bearer {admin_key}"))):
  908. return await call_next(request)
  909. return JSONResponse(
  910. content={"error": "Unauthorized"}, status_code=401)
  911. for middleware in args.middleware:
  912. module_path, object_name = middleware.rsplit(".", 1)
  913. imported = getattr(importlib.import_module(module_path), object_name)
  914. if inspect.isclass(imported):
  915. app.add_middleware(imported)
  916. elif inspect.iscoroutinefunction(imported):
  917. app.middleware("http")(imported)
  918. else:
  919. raise ValueError(f"Invalid middleware {middleware}. "
  920. f"Must be a function or a class.")
  921. return app
  922. def init_app_state(
  923. engine_client: EngineClient,
  924. model_config: ModelConfig,
  925. state: State,
  926. args: Namespace,
  927. ) -> None:
  928. global api_server_args
  929. api_server_args = args
  930. logger.debug(f"args: {args}")
  931. global served_model_names
  932. if args.served_model_name is not None:
  933. served_model_names = args.served_model_name
  934. else:
  935. served_model_names = [args.model]
  936. if args.uvloop:
  937. uvloop.install()
  938. global tokenizer
  939. if args.disable_log_requests:
  940. request_logger = None
  941. else:
  942. request_logger = RequestLogger(max_log_len=args.max_log_len)
  943. base_model_paths = [
  944. BaseModelPath(name=name, model_path=args.model)
  945. for name in served_model_names
  946. ]
  947. state.engine_client = engine_client
  948. state.log_stats = not args.disable_log_stats
  949. state.current_model = args.model
  950. state.openai_serving_chat = OpenAIServingChat(
  951. engine_client,
  952. model_config,
  953. base_model_paths,
  954. args.response_role,
  955. lora_modules=args.lora_modules,
  956. prompt_adapters=args.prompt_adapters,
  957. request_logger=request_logger,
  958. chat_template=args.chat_template,
  959. return_tokens_as_token_ids=args.return_tokens_as_token_ids,
  960. enable_auto_tools=args.enable_auto_tool_choice,
  961. tool_parser=args.tool_call_parser
  962. )
  963. state.openai_serving_completion = OpenAIServingCompletion(
  964. engine_client,
  965. model_config,
  966. base_model_paths,
  967. lora_modules=args.lora_modules,
  968. prompt_adapters=args.prompt_adapters,
  969. request_logger=request_logger,
  970. return_tokens_as_token_ids=args.return_tokens_as_token_ids,
  971. )
  972. state.openai_serving_embedding = OpenAIServingEmbedding(
  973. engine_client,
  974. model_config,
  975. base_model_paths,
  976. request_logger=request_logger,
  977. )
  978. state.openai_serving_tokenization = OpenAIServingTokenization(
  979. engine_client,
  980. model_config,
  981. base_model_paths,
  982. lora_modules=args.lora_modules,
  983. request_logger=request_logger,
  984. chat_template=args.chat_template,
  985. )
  986. tokenizer = get_tokenizer(
  987. tokenizer_name=args.tokenizer if args.tokenizer else args.model,
  988. tokenizer_mode=args.tokenizer_mode,
  989. trust_remote_code=args.trust_remote_code,
  990. revision=args.revision,
  991. )
  992. if args.launch_kobold_api:
  993. _set_badwords(tokenizer, model_config.hf_config)
  994. async def run_server(args, **uvicorn_kwargs) -> None:
  995. def signal_handler(*_) -> None:
  996. # Interrupt server on sigterm while initializing
  997. raise KeyboardInterrupt("terminated")
  998. signal.signal(signal.SIGTERM, signal_handler)
  999. async with build_engine_client(args) as engine_client:
  1000. app = build_app(args)
  1001. model_config = await engine_client.get_model_config()
  1002. init_app_state(engine_client, model_config, app.state, args)
  1003. protocol = "https" if args.ssl_certfile else "http"
  1004. root_path = args.root_path.rstrip("/") if args.root_path else ""
  1005. host_name = args.host if args.host else "localhost"
  1006. port_str = str(args.port)
  1007. app.state.model_is_loaded = True
  1008. if SERVE_KOBOLD_LITE_UI:
  1009. ui_url = f"{protocol}://{host_name}:{port_str}{root_path}/"
  1010. logger.info(f"Kobold Lite UI: {ui_url}")
  1011. logger.info(f"Documentation: {protocol}://{host_name}:{port_str}{root_path}/redoc") # noqa: E501
  1012. logger.info(f"Completions API: {protocol}://{host_name}:{port_str}{root_path}/v1/completions") # noqa: E501
  1013. logger.info(f"Chat API: {protocol}://{host_name}:{port_str}{root_path}/v1/chat/completions") # noqa: E501
  1014. logger.info(f"Embeddings API: {protocol}://{host_name}:{port_str}{root_path}/v1/embeddings") # noqa: E501
  1015. logger.info(f"Tokenization API: {protocol}://{host_name}:{port_str}{root_path}/v1/tokenize") # noqa: E501
  1016. shutdown_task = await serve_http(
  1017. app,
  1018. host=args.host,
  1019. port=args.port,
  1020. log_level=args.uvicorn_log_level,
  1021. timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
  1022. ssl_keyfile=args.ssl_keyfile,
  1023. ssl_certfile=args.ssl_certfile,
  1024. ssl_ca_certs=args.ssl_ca_certs,
  1025. ssl_cert_reqs=args.ssl_cert_reqs,
  1026. **uvicorn_kwargs,
  1027. )
  1028. # NB: Await server shutdown only after the backend context is exited
  1029. await shutdown_task
  1030. if __name__ == "__main__":
  1031. # NOTE:
  1032. # This section should be in sync with aphrodite/endpoints/cli.py
  1033. # for CLI entrypoints.
  1034. parser = FlexibleArgumentParser(
  1035. description="Aphrodite OpenAI-Compatible RESTful API Server")
  1036. parser = make_arg_parser(parser)
  1037. args = parser.parse_args()
  1038. uvloop.run(run_server(args))