api_server.py 43 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232
  1. import asyncio
  2. import importlib
  3. import inspect
  4. import json
  5. import multiprocessing
  6. import os
  7. import pickle
  8. import re
  9. import signal
  10. import tempfile
  11. from argparse import Namespace
  12. from contextlib import asynccontextmanager
  13. from distutils.util import strtobool
  14. from functools import partial
  15. from http import HTTPStatus
  16. from typing import AsyncGenerator, AsyncIterator, List, Optional, Set, Tuple
  17. import yaml
  18. from fastapi import APIRouter, FastAPI, Form, Request, UploadFile
  19. from fastapi.exceptions import RequestValidationError
  20. from fastapi.middleware.cors import CORSMiddleware
  21. from fastapi.responses import (HTMLResponse, JSONResponse, Response,
  22. StreamingResponse)
  23. from loguru import logger
  24. from starlette.datastructures import State
  25. from starlette.routing import Mount
  26. import aphrodite.common.envs as envs
  27. from aphrodite.common.config import ModelConfig
  28. from aphrodite.common.outputs import RequestOutput
  29. from aphrodite.common.sampling_params import _SAMPLING_EPS, SamplingParams
  30. from aphrodite.common.utils import (FlexibleArgumentParser,
  31. get_open_zmq_ipc_path, in_windows,
  32. random_uuid)
  33. from aphrodite.endpoints.logger import RequestLogger
  34. from aphrodite.endpoints.openai.args import make_arg_parser
  35. from aphrodite.endpoints.openai.protocol import (ChatCompletionRequest,
  36. ChatCompletionResponse,
  37. CompletionRequest,
  38. DetokenizeRequest,
  39. DetokenizeResponse,
  40. EmbeddingRequest,
  41. ErrorResponse,
  42. KAIGenerationInputSchema,
  43. TokenizeRequest,
  44. TokenizeResponse)
  45. from aphrodite.endpoints.openai.serving_chat import OpenAIServingChat
  46. from aphrodite.endpoints.openai.serving_completions import (
  47. OpenAIServingCompletion)
  48. from aphrodite.endpoints.openai.serving_embedding import OpenAIServingEmbedding
  49. from aphrodite.endpoints.openai.serving_engine import (LoRAModulePath,
  50. PromptAdapterPath)
  51. from aphrodite.endpoints.openai.serving_tokenization import (
  52. OpenAIServingTokenization)
  53. from aphrodite.engine.args_tools import AsyncEngineArgs
  54. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  55. from aphrodite.engine.multiprocessing import (APHRODITE_RPC_SUCCESS_STR,
  56. RPCShutdownRequest)
  57. from aphrodite.engine.multiprocessing.client import MQAphroditeEngineClient
  58. from aphrodite.engine.multiprocessing.engine import run_mp_engine
  59. from aphrodite.engine.protocol import EngineClient
  60. from aphrodite.modeling.model_loader.weight_utils import get_model_config_yaml
  61. from aphrodite.server import serve_http
  62. from aphrodite.transformers_utils.tokenizer import get_tokenizer
  63. from aphrodite.version import __version__ as APHRODITE_VERSION
  64. if in_windows():
  65. import winloop as uvloop
  66. else:
  67. import uvloop
  68. TIMEOUT_KEEP_ALIVE = 5 # seconds
  69. SERVE_KOBOLD_LITE_UI = strtobool(os.getenv("SERVE_KOBOLD_LITE_UI", "1"))
  70. router = APIRouter()
  71. kai_api = APIRouter()
  72. extra_api = APIRouter()
  73. kobold_lite_ui = ""
  74. sampler_json = ""
  75. gen_cache: dict = {}
  76. prometheus_multiproc_dir: tempfile.TemporaryDirectory
  77. _running_tasks: Set[asyncio.Task] = set()
  78. @asynccontextmanager
  79. async def lifespan(app: FastAPI):
  80. try:
  81. if app.state.log_stats:
  82. engine_client: EngineClient = app.state.engine_client
  83. async def _force_log():
  84. while True:
  85. await asyncio.sleep(10.)
  86. await engine_client.do_log_stats()
  87. task = asyncio.create_task(_force_log())
  88. _running_tasks.add(task)
  89. task.add_done_callback(_running_tasks.remove)
  90. else:
  91. task = None
  92. try:
  93. yield
  94. finally:
  95. if task is not None:
  96. task.cancel()
  97. finally:
  98. # Ensure app state including engine ref is gc'd
  99. del app.state
  100. @asynccontextmanager
  101. async def build_engine_client(
  102. args: Namespace) -> AsyncIterator[Optional[EngineClient]]:
  103. # Context manager to handle engine_client lifecycle
  104. # Ensures everything is shutdown and cleaned up on error/exit
  105. engine_args = AsyncEngineArgs.from_cli_args(args)
  106. async with build_engine_client_from_engine_args(
  107. engine_args, args.disable_frontend_multiprocessing) as engine:
  108. yield engine
  109. @asynccontextmanager
  110. async def build_engine_client_from_engine_args(
  111. engine_args: AsyncEngineArgs,
  112. disable_frontend_multiprocessing: bool = False,
  113. ) -> AsyncIterator[Optional[EngineClient]]:
  114. """
  115. Create EngineClient, either:
  116. - in-process using the AsyncAphrodite Directly
  117. - multiprocess using AsyncAphrodite RPC
  118. Returns the Client or None if the creation failed.
  119. """
  120. # Fall back
  121. # TODO: fill out feature matrix.
  122. if (MQAphroditeEngineClient.is_unsupported_config(engine_args)
  123. or disable_frontend_multiprocessing):
  124. engine_config = engine_args.create_engine_config()
  125. uses_ray = getattr(AsyncAphrodite._get_executor_cls(engine_config),
  126. "uses_ray", False)
  127. build_engine = partial(AsyncAphrodite.from_engine_args,
  128. engine_args=engine_args,
  129. engine_config=engine_config)
  130. if uses_ray:
  131. # Must run in main thread with ray for its signal handlers to work
  132. engine_client = build_engine()
  133. else:
  134. engine_client = await asyncio.get_running_loop().run_in_executor(
  135. None, build_engine)
  136. yield engine_client
  137. return
  138. # Otherwise, use the multiprocessing AsyncAphrodite.
  139. else:
  140. if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
  141. # Make TemporaryDirectory for prometheus multiprocessing
  142. # Note: global TemporaryDirectory will be automatically
  143. # cleaned up upon exit.
  144. global prometheus_multiproc_dir
  145. prometheus_multiproc_dir = tempfile.TemporaryDirectory()
  146. os.environ[
  147. "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
  148. else:
  149. logger.warning(
  150. "Found PROMETHEUS_MULTIPROC_DIR was set by user. "
  151. "This directory must be wiped between Aphrodite runs or "
  152. "you will find inaccurate metrics. Unset the variable "
  153. "and Aphrodite will properly handle cleanup.")
  154. # Select random path for IPC.
  155. ipc_path = get_open_zmq_ipc_path()
  156. logger.info(
  157. f"Multiprocessing frontend to use {ipc_path} for IPC Path.")
  158. # Start RPCServer in separate process (holds the LLMEngine).
  159. # the current process might have CUDA context,
  160. # so we need to spawn a new process
  161. context = multiprocessing.get_context("spawn")
  162. engine_process = context.Process(target=run_mp_engine,
  163. args=(engine_args,
  164. ipc_path))
  165. engine_process.start()
  166. logger.info(f"Started engine process with PID {engine_process.pid}")
  167. # Build RPCClient, which conforms to EngineClient Protocol.
  168. # NOTE: Actually, this is not true yet. We still need to support
  169. # embedding models via RPC (see TODO above)
  170. engine_config = engine_args.create_engine_config()
  171. mp_engine_client = MQAphroditeEngineClient(ipc_path, engine_config)
  172. try:
  173. while True:
  174. try:
  175. await mp_engine_client.setup()
  176. break
  177. except TimeoutError:
  178. if not engine_process.is_alive():
  179. logger.error("Engine process died before responding "
  180. "to readiness probe")
  181. yield None
  182. return
  183. yield mp_engine_client # type: ignore[misc]
  184. finally:
  185. # Ensure rpc server process was terminated
  186. engine_process.terminate()
  187. # Close all open connections to the backend
  188. mp_engine_client.close()
  189. # Wait for engine process to join
  190. engine_process.join(4)
  191. if engine_process.exitcode is None:
  192. # Kill if taking longer than 5 seconds to stop
  193. engine_process.kill()
  194. # Lazy import for prometheus multiprocessing.
  195. # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
  196. # before prometheus_client is imported.
  197. # See https://prometheus.github.io/client_python/multiprocess/
  198. from prometheus_client import multiprocess
  199. multiprocess.mark_process_dead(engine_process.pid)
  200. def mount_metrics(app: FastAPI):
  201. # Lazy import for prometheus multiprocessing.
  202. # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
  203. # before prometheus_client is imported.
  204. # See https://prometheus.github.io/client_python/multiprocess/
  205. from prometheus_client import (CollectorRegistry, make_asgi_app,
  206. multiprocess)
  207. prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
  208. if prometheus_multiproc_dir_path is not None:
  209. logger.info(f"Aphrodite to use {prometheus_multiproc_dir_path} "
  210. "as PROMETHEUS_MULTIPROC_DIR")
  211. registry = CollectorRegistry()
  212. multiprocess.MultiProcessCollector(registry)
  213. # Add prometheus asgi middleware to route /metrics requests
  214. metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
  215. else:
  216. # Add prometheus asgi middleware to route /metrics requests
  217. metrics_route = Mount("/metrics", make_asgi_app())
  218. # Workaround for 307 Redirect for /metrics
  219. metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')
  220. app.routes.append(metrics_route)
  221. async def _handle_model_switch(
  222. raw_request: Request,
  223. requested_model: str
  224. ) -> Optional[JSONResponse]:
  225. """Helper function to handle model switching if needed.
  226. Returns error response if something went wrong, None if successful."""
  227. if not raw_request.app.state.args.allow_inline_model_loading:
  228. return None
  229. if not raw_request.app.state.model_is_loaded:
  230. config = get_model_config_yaml(requested_model)
  231. request_data = {"model": requested_model}
  232. if config:
  233. config.pop("model", None)
  234. request_data.update(config)
  235. load_response = await load_model(
  236. raw_request,
  237. request=json.dumps(request_data)
  238. )
  239. if load_response.status_code != 200:
  240. return load_response
  241. return None
  242. current_model = raw_request.app.state.current_model
  243. if current_model == requested_model:
  244. return None
  245. unload_response = await unload_model(raw_request)
  246. if unload_response.status_code != 200:
  247. return unload_response
  248. config = get_model_config_yaml(requested_model)
  249. request_data = {"model": requested_model}
  250. if config:
  251. config.pop("model", None)
  252. request_data.update(config)
  253. load_response = await load_model(
  254. raw_request,
  255. request=json.dumps(request_data)
  256. )
  257. if load_response.status_code != 200:
  258. return load_response
  259. return None
  260. def chat(request: Request) -> OpenAIServingChat:
  261. return request.app.state.openai_serving_chat
  262. def completion(request: Request) -> OpenAIServingCompletion:
  263. return request.app.state.openai_serving_completion
  264. def tokenization(request: Request) -> OpenAIServingTokenization:
  265. return request.app.state.openai_serving_tokenization
  266. def embedding(request: Request) -> OpenAIServingEmbedding:
  267. return request.app.state.openai_serving_embedding
  268. def engine_client(request: Request) -> EngineClient:
  269. return request.app.state.engine_client
  270. @router.delete("/v1/model/unload")
  271. async def unload_model(raw_request: Request):
  272. """Unload the model and shut down the engine process."""
  273. if not raw_request.app.state.model_is_loaded:
  274. return JSONResponse(
  275. content={
  276. "status": "error",
  277. "message": "No model loaded."
  278. },
  279. status_code=500
  280. )
  281. client = raw_request.app.state.engine_client
  282. if isinstance(client, MQAphroditeEngineClient):
  283. try:
  284. shutdown_req = RPCShutdownRequest()
  285. await client.input_socket.send_multipart(
  286. (pickle.dumps(shutdown_req),), copy=False
  287. )
  288. response = await client.output_socket.recv_multipart()
  289. if pickle.loads(response[0]) != APHRODITE_RPC_SUCCESS_STR:
  290. raise RuntimeError("Engine shutdown failed")
  291. client.output_loop.cancel()
  292. if client.health_loop is not None:
  293. client.health_loop.cancel()
  294. client.close()
  295. raw_request.app.state.engine_client = None
  296. raw_request.app.state.openai_serving_chat = None
  297. raw_request.app.state.openai_serving_completion = None
  298. raw_request.app.state.openai_serving_embedding = None
  299. raw_request.app.state.openai_serving_tokenization = None
  300. raw_request.app.state.model_is_loaded = False
  301. return JSONResponse(content={"status": "success"})
  302. except Exception as e:
  303. return JSONResponse(
  304. content={
  305. "status": "error",
  306. "message": f"Failed to shutdown engine: {str(e)}"
  307. },
  308. status_code=500
  309. )
  310. else:
  311. return JSONResponse(
  312. content={
  313. "status": "error",
  314. "message": "Model unloading only supported with multiprocessing"
  315. " backend"
  316. },
  317. status_code=400
  318. )
  319. @router.post("/v1/model/load")
  320. async def load_model(
  321. raw_request: Request,
  322. config_file: Optional[UploadFile] = None,
  323. request: Optional[str] = Form(None)
  324. ):
  325. """Load a new model after unloading the previous one.
  326. Accept either a config file, a JSON request body, or both."""
  327. if raw_request.app.state.model_is_loaded:
  328. return JSONResponse(
  329. content={
  330. "status": "error",
  331. "message": "A model is already loaded. Please unload it first."
  332. },
  333. status_code=400
  334. )
  335. try:
  336. parser = FlexibleArgumentParser()
  337. parser = make_arg_parser(parser)
  338. new_args = parser.parse_args([])
  339. original_args = api_server_args
  340. essential_params = [
  341. 'host', 'port', 'api_keys', 'admin_key',
  342. 'disable_frontend_multiprocessing', 'root_path',
  343. 'ssl_keyfile', 'ssl_certfile'
  344. ]
  345. for param in essential_params:
  346. if hasattr(original_args, param):
  347. setattr(new_args, param, getattr(original_args, param))
  348. if config_file:
  349. yaml_content = await config_file.read()
  350. config_args = yaml.safe_load(yaml_content)
  351. if config_args:
  352. for key, value in config_args.items():
  353. if hasattr(new_args, key):
  354. setattr(new_args, key, value)
  355. json_args = None
  356. if request:
  357. try:
  358. json_args = json.loads(request)
  359. except json.JSONDecodeError:
  360. return JSONResponse(
  361. content={
  362. "status": "error",
  363. "message": "Invalid JSON in request form field."
  364. },
  365. status_code=400
  366. )
  367. else:
  368. try:
  369. json_args = await raw_request.json()
  370. except Exception:
  371. if not config_file:
  372. return JSONResponse(
  373. content={
  374. "status": "error",
  375. "message": "Must provide either config_file or "
  376. "valid JSON request body."
  377. },
  378. status_code=400
  379. )
  380. if json_args:
  381. for key, value in json_args.items():
  382. if hasattr(new_args, key):
  383. setattr(new_args, key, value)
  384. if not hasattr(new_args, 'model') or not new_args.model:
  385. return JSONResponse(
  386. content={
  387. "status": "error",
  388. "message": "No model specified in config or request body."
  389. },
  390. status_code=400
  391. )
  392. engine_args = AsyncEngineArgs.from_cli_args(new_args)
  393. if (MQAphroditeEngineClient.is_unsupported_config(engine_args)
  394. or new_args.disable_frontend_multiprocessing):
  395. return JSONResponse(
  396. content={
  397. "status": "error",
  398. "message": "Model loading only supported with "
  399. "multiprocessing backend."
  400. },
  401. status_code=400
  402. )
  403. ipc_path = get_open_zmq_ipc_path()
  404. context = multiprocessing.get_context("spawn")
  405. engine_process = context.Process(
  406. target=run_mp_engine,
  407. args=(engine_args, ipc_path)
  408. )
  409. engine_process.start()
  410. engine_config = engine_args.create_engine_config()
  411. engine_client = MQAphroditeEngineClient(ipc_path, engine_config)
  412. try:
  413. while True:
  414. try:
  415. await engine_client.setup()
  416. break
  417. except TimeoutError:
  418. if not engine_process.is_alive():
  419. return JSONResponse(
  420. content={
  421. "status": "error",
  422. "message": "Engine process died before "
  423. "responding to readiness probe."
  424. },
  425. status_code=500
  426. )
  427. model_config = await engine_client.get_model_config()
  428. init_app_state(
  429. engine_client, model_config, raw_request.app.state, new_args)
  430. raw_request.app.state.model_is_loaded = True
  431. raw_request.app.state.current_model = new_args.model
  432. return JSONResponse(content={"status": "success"})
  433. except Exception as e:
  434. engine_process.terminate()
  435. engine_client.close()
  436. raise e
  437. except Exception as e:
  438. return JSONResponse(
  439. content={
  440. "status": "error",
  441. "message": f"Failed to load model: {str(e)}"
  442. },
  443. status_code=500
  444. )
  445. @router.get("/health")
  446. async def health(raw_request: Request) -> Response:
  447. """Health check."""
  448. await engine_client(raw_request).check_health()
  449. return Response(status_code=200)
  450. @router.post("/v1/tokenize")
  451. async def tokenize(request: TokenizeRequest, raw_request: Request):
  452. if hasattr(request, "model"):
  453. error_response = await _handle_model_switch(raw_request, request.model)
  454. if error_response is not None:
  455. return error_response
  456. if not raw_request.app.state.model_is_loaded:
  457. return JSONResponse(
  458. content={
  459. "status": "error",
  460. "message": "No model loaded."
  461. },
  462. status_code=500
  463. )
  464. generator = await tokenization(raw_request).create_tokenize(request)
  465. if isinstance(generator, ErrorResponse):
  466. return JSONResponse(content=generator.model_dump(),
  467. status_code=generator.code)
  468. else:
  469. assert isinstance(generator, TokenizeResponse)
  470. return JSONResponse(content=generator.model_dump())
  471. @router.post("/v1/detokenize")
  472. async def detokenize(request: DetokenizeRequest, raw_request: Request):
  473. if hasattr(request, "model"):
  474. error_response = await _handle_model_switch(
  475. raw_request, request.model)
  476. if error_response is not None:
  477. return error_response
  478. if not raw_request.app.state.model_is_loaded:
  479. return JSONResponse(
  480. content={
  481. "status": "error",
  482. "message": "No model loaded."
  483. },
  484. status_code=500
  485. )
  486. generator = await tokenization(raw_request).create_detokenize(request)
  487. if isinstance(generator, ErrorResponse):
  488. return JSONResponse(content=generator.model_dump(),
  489. status_code=generator.code)
  490. else:
  491. assert isinstance(generator, DetokenizeResponse)
  492. return JSONResponse(content=generator.model_dump())
  493. @router.get("/v1/models")
  494. async def show_available_models(raw_request: Request):
  495. if not raw_request.app.state.model_is_loaded:
  496. return JSONResponse(
  497. content={
  498. "status": "error",
  499. "message": "No model loaded."
  500. },
  501. status_code=500
  502. )
  503. models = await completion(raw_request).show_available_models()
  504. return JSONResponse(content=models.model_dump())
  505. @router.get("/version")
  506. async def show_version():
  507. ver = {"version": APHRODITE_VERSION}
  508. return JSONResponse(content=ver)
  509. @router.get("/.well-known/serviceinfo")
  510. async def serviceinfo():
  511. """Return service information including version, API endpoints,
  512. and documentation URLs."""
  513. return JSONResponse(content={
  514. "version": 0.2,
  515. "software": {
  516. "name": "Aphrodite Engine",
  517. "version": APHRODITE_VERSION,
  518. "repository": "https://github.com/PygmalionAI/aphrodite-engine",
  519. "homepage": "https://aphrodite.pygmalion.chat",
  520. "logo": "https://pygmalion.chat/icons/favicon.ico",
  521. },
  522. "api": {
  523. "openai": {
  524. "name": "OpenAI API",
  525. "rel_url": "/v1",
  526. "documentation": "/redoc",
  527. "version": 1,
  528. },
  529. "koboldai": {
  530. "name": "KoboldAI API",
  531. "rel_url": "/api",
  532. "documentation": "/redoc",
  533. "version": 1,
  534. }
  535. }
  536. })
  537. @router.post("/v1/chat/completions")
  538. async def create_chat_completion(request: ChatCompletionRequest,
  539. raw_request: Request):
  540. if hasattr(request, "model"):
  541. error_response = await _handle_model_switch(raw_request, request.model)
  542. if error_response is not None:
  543. return error_response
  544. if not raw_request.app.state.model_is_loaded:
  545. return JSONResponse(
  546. content={
  547. "status": "error",
  548. "message": "No model loaded."
  549. },
  550. status_code=500
  551. )
  552. generator = await chat(raw_request).create_chat_completion(
  553. request, raw_request)
  554. if isinstance(generator, ErrorResponse):
  555. return JSONResponse(content=generator.model_dump(),
  556. status_code=generator.code)
  557. if request.stream:
  558. return StreamingResponse(content=generator,
  559. media_type="text/event-stream")
  560. else:
  561. assert isinstance(generator, ChatCompletionResponse)
  562. return JSONResponse(content=generator.model_dump())
  563. @router.post("/v1/completions")
  564. async def create_completion(request: CompletionRequest, raw_request: Request):
  565. if hasattr(request, "model"):
  566. error_response = await _handle_model_switch(raw_request, request.model)
  567. if error_response is not None:
  568. return error_response
  569. if not raw_request.app.state.model_is_loaded:
  570. return JSONResponse(
  571. content={
  572. "status": "error",
  573. "message": "No model loaded."
  574. },
  575. status_code=500
  576. )
  577. generator = await completion(raw_request).create_completion(
  578. request, raw_request)
  579. if isinstance(generator, ErrorResponse):
  580. return JSONResponse(content=generator.model_dump(),
  581. status_code=generator.code)
  582. if request.stream:
  583. return StreamingResponse(content=generator,
  584. media_type="text/event-stream")
  585. else:
  586. return JSONResponse(content=generator.model_dump())
  587. @router.post("/v1/embeddings")
  588. async def create_embedding(request: EmbeddingRequest, raw_request: Request):
  589. if hasattr(request, "model"):
  590. error_response = await _handle_model_switch(raw_request, request.model)
  591. if error_response is not None:
  592. return error_response
  593. if not raw_request.app.state.model_is_loaded:
  594. return JSONResponse(
  595. content={
  596. "status": "error",
  597. "message": "No model loaded."
  598. },
  599. status_code=500
  600. )
  601. generator = await embedding(raw_request).create_embedding(
  602. request, raw_request)
  603. if isinstance(generator, ErrorResponse):
  604. return JSONResponse(content=generator.model_dump(),
  605. status_code=generator.code)
  606. else:
  607. return JSONResponse(content=generator.model_dump())
  608. @router.post("/v1/lora/load")
  609. async def load_lora(lora: LoRAModulePath, raw_request: Request):
  610. if not raw_request.app.state.model_is_loaded:
  611. return JSONResponse(
  612. content={
  613. "status": "error",
  614. "message": "No model loaded."
  615. },
  616. status_code=500
  617. )
  618. completion(raw_request).add_lora(lora)
  619. if args.enable_lora is False:
  620. logger.error("LoRA is not enabled in the engine. "
  621. "Please start the server with the "
  622. "--enable-lora flag!")
  623. return JSONResponse(content={"status": "success"})
  624. @router.delete("/v1/lora/unload")
  625. async def unload_lora(lora_name: str, raw_request: Request):
  626. if not raw_request.app.state.model_is_loaded:
  627. return JSONResponse(
  628. content={
  629. "status": "error",
  630. "message": "No model loaded."
  631. },
  632. status_code=500
  633. )
  634. completion(raw_request).remove_lora(lora_name)
  635. return JSONResponse(content={"status": "success"})
  636. @router.post("/v1/soft_prompt/load")
  637. async def load_soft_prompt(soft_prompt: PromptAdapterPath,
  638. raw_request: Request):
  639. if not raw_request.app.state.model_is_loaded:
  640. return JSONResponse(
  641. content={
  642. "status": "error",
  643. "message": "No model loaded."
  644. },
  645. status_code=500
  646. )
  647. completion(raw_request).add_prompt_adapter(soft_prompt)
  648. if args.enable_prompt_adapter is False:
  649. logger.error("Prompt Adapter is not enabled in the engine. "
  650. "Please start the server with the "
  651. "--enable-prompt-adapter flag!")
  652. return JSONResponse(content={"status": "success"})
  653. @router.delete("/v1/soft_prompt/unload")
  654. async def unload_soft_prompt(soft_prompt_name: str, raw_request: Request):
  655. if not raw_request.app.state.model_is_loaded:
  656. return JSONResponse(
  657. content={
  658. "status": "error",
  659. "message": "No model loaded."
  660. },
  661. status_code=500
  662. )
  663. completion(raw_request).remove_prompt_adapter(soft_prompt_name)
  664. return JSONResponse(content={"status": "success"})
  665. # ============ KoboldAI API ============ #
  666. badwordsids: List[int] = []
  667. def _set_badwords(tokenizer, hf_config): # pylint: disable=redefined-outer-name
  668. # pylint: disable=global-variable-undefined
  669. global badwordsids
  670. if hf_config.bad_words_ids is not None:
  671. badwordsids = hf_config.bad_words_ids
  672. return
  673. badwordsids = [
  674. v for k, v in tokenizer.get_vocab().items()
  675. if any(c in str(k) for c in "[]")
  676. ]
  677. if tokenizer.pad_token_id in badwordsids:
  678. badwordsids.remove(tokenizer.pad_token_id)
  679. badwordsids.append(tokenizer.eos_token_id)
  680. def prepare_engine_payload(
  681. kai_payload: KAIGenerationInputSchema
  682. ) -> Tuple[SamplingParams, List[int]]:
  683. """Create SamplingParams and truncated input tokens for AsyncEngine"""
  684. if not kai_payload.genkey:
  685. kai_payload.genkey = f"kai-{random_uuid()}"
  686. kai_payload.top_k = kai_payload.top_k if kai_payload.top_k != 0.0 else -1
  687. kai_payload.tfs = max(_SAMPLING_EPS, kai_payload.tfs)
  688. if kai_payload.temperature < _SAMPLING_EPS:
  689. kai_payload.n = 1
  690. kai_payload.top_p = 1.0
  691. kai_payload.top_k = -1
  692. sampling_params = SamplingParams(
  693. n=kai_payload.n,
  694. best_of=kai_payload.n,
  695. repetition_penalty=kai_payload.rep_pen,
  696. temperature=kai_payload.temperature,
  697. smoothing_factor=kai_payload.smoothing_factor,
  698. smoothing_curve=kai_payload.smoothing_curve,
  699. tfs=kai_payload.tfs,
  700. top_p=kai_payload.top_p,
  701. top_k=kai_payload.top_k,
  702. top_a=kai_payload.top_a,
  703. min_p=kai_payload.min_p,
  704. typical_p=kai_payload.typical,
  705. eta_cutoff=kai_payload.eta_cutoff,
  706. epsilon_cutoff=kai_payload.eps_cutoff,
  707. stop=kai_payload.stop_sequence,
  708. include_stop_str_in_output=kai_payload.include_stop_str_in_output,
  709. custom_token_bans=badwordsids
  710. if kai_payload.use_default_badwordsids else [],
  711. max_tokens=kai_payload.max_length,
  712. seed=kai_payload.sampler_seed,
  713. xtc_probability=kai_payload.xtc_probability,
  714. xtc_threshold=kai_payload.xtc_threshold,
  715. )
  716. max_input_tokens = max(
  717. 1, kai_payload.max_context_length - kai_payload.max_length)
  718. input_tokens = tokenizer(kai_payload.prompt).input_ids[-max_input_tokens:]
  719. return sampling_params, input_tokens
  720. @kai_api.post("/generate")
  721. async def generate(kai_payload: KAIGenerationInputSchema,
  722. raw_request: Request) -> JSONResponse:
  723. sampling_params, input_tokens = prepare_engine_payload(kai_payload)
  724. result_generator = engine_client(raw_request).generate(
  725. {
  726. "prompt": kai_payload.prompt,
  727. "prompt_token_ids": input_tokens,
  728. },
  729. sampling_params,
  730. kai_payload.genkey,
  731. )
  732. final_res: RequestOutput = None
  733. previous_output = ""
  734. async for res in result_generator:
  735. final_res = res
  736. new_chunk = res.outputs[0].text[len(previous_output):]
  737. previous_output += new_chunk
  738. gen_cache[kai_payload.genkey] = previous_output
  739. assert final_res is not None
  740. del gen_cache[kai_payload.genkey]
  741. return JSONResponse(
  742. {"results": [{
  743. "text": output.text
  744. } for output in final_res.outputs]})
  745. @extra_api.post("/generate/stream")
  746. async def generate_stream(
  747. kai_payload: KAIGenerationInputSchema,
  748. raw_request: Request) -> StreamingResponse:
  749. sampling_params, input_tokens = prepare_engine_payload(kai_payload)
  750. results_generator = engine_client(raw_request).generate(
  751. {
  752. "prompt": kai_payload.prompt,
  753. "prompt_token_ids": input_tokens,
  754. },
  755. sampling_params,
  756. kai_payload.genkey,
  757. )
  758. async def stream_kobold() -> AsyncGenerator[bytes, None]:
  759. previous_output = ""
  760. async for res in results_generator:
  761. new_chunk = res.outputs[0].text[len(previous_output):]
  762. previous_output += new_chunk
  763. yield b"event: message\n"
  764. yield f"data: {json.dumps({'token': new_chunk})}\n\n".encode()
  765. return StreamingResponse(stream_kobold(),
  766. headers={
  767. "Cache-Control": "no-cache",
  768. "Connection": "keep-alive",
  769. },
  770. media_type="text/event-stream")
  771. @extra_api.post("/generate/check")
  772. @extra_api.get("/generate/check")
  773. async def check_generation(request: Request):
  774. text = ""
  775. try:
  776. request_dict = await request.json()
  777. if "genkey" in request_dict and request_dict["genkey"] in gen_cache:
  778. text = gen_cache[request_dict["genkey"]]
  779. except json.JSONDecodeError:
  780. pass
  781. return JSONResponse({"results": [{"text": text}]})
  782. @extra_api.post("/abort")
  783. async def abort_generation(raw_request: Request):
  784. try:
  785. request_dict = await raw_request.json()
  786. if "genkey" in request_dict:
  787. await engine_client(raw_request).abort(request_dict["genkey"])
  788. except json.JSONDecodeError:
  789. pass
  790. return JSONResponse({})
  791. @extra_api.post("/tokencount")
  792. async def count_tokens(request: TokenizeRequest, raw_request: Request):
  793. """Tokenize string and return token count"""
  794. generator = await tokenization(raw_request).create_tokenize(request)
  795. return JSONResponse({"value": generator.model_dump()["tokens"]})
  796. @kai_api.get("/info/version")
  797. async def get_version():
  798. """Impersonate KAI"""
  799. return JSONResponse({"result": "1.2.4"})
  800. @kai_api.get("/model")
  801. async def get_model():
  802. return JSONResponse({"result": f"aphrodite/{served_model_names[0]}"})
  803. @kai_api.get("/config/soft_prompts_list")
  804. async def get_available_softprompts():
  805. """Stub for compatibility"""
  806. return JSONResponse({"values": []})
  807. @kai_api.get("/config/soft_prompt")
  808. async def get_current_softprompt():
  809. """Stub for compatibility"""
  810. return JSONResponse({"value": ""})
  811. @kai_api.put("/config/soft_prompt")
  812. async def set_current_softprompt():
  813. """Stub for compatibility"""
  814. return JSONResponse({})
  815. @kai_api.get("/config/max_length")
  816. async def get_max_length() -> JSONResponse:
  817. max_length = args.max_length
  818. return JSONResponse({"value": max_length})
  819. @kai_api.get("/config/max_context_length")
  820. @extra_api.get("/true_max_context_length")
  821. async def get_max_context_length() -> JSONResponse:
  822. max_context_length = args.max_model_len
  823. return JSONResponse({"value": max_context_length})
  824. @extra_api.get("/preloadstory")
  825. async def get_preloaded_story() -> JSONResponse:
  826. """Stub for compatibility"""
  827. return JSONResponse({})
  828. @extra_api.get("/version")
  829. async def get_extra_version():
  830. """Impersonate KoboldCpp"""
  831. return JSONResponse({"result": "KoboldCpp", "version": "1.63"})
  832. @router.get("/")
  833. async def get_kobold_lite_ui():
  834. """Serves a cached copy of the Kobold Lite UI, loading it from disk
  835. on demand if needed. Can be disabled with SERVE_KOBOLD_LITE_UI=0."""
  836. if not SERVE_KOBOLD_LITE_UI:
  837. return JSONResponse(content={"error": "Kobold Lite UI is disabled"},
  838. status_code=404)
  839. global kobold_lite_ui
  840. if kobold_lite_ui == "":
  841. scriptpath = os.path.dirname(os.path.abspath(__file__))
  842. klitepath = os.path.join(scriptpath, "../kobold/klite.embd")
  843. klitepath = os.path.normpath(klitepath) # Normalize the path
  844. if os.path.exists(klitepath):
  845. with open(klitepath, "r", encoding="utf-8") as f:
  846. kobold_lite_ui = f.read()
  847. else:
  848. logger.error("Kobold Lite UI not found at " + klitepath)
  849. return HTMLResponse(content=kobold_lite_ui)
  850. # ============ KoboldAI API ============ #
  851. def build_app(args: Namespace) -> FastAPI:
  852. app = FastAPI(lifespan=lifespan)
  853. app.include_router(router)
  854. app.root_path = args.root_path
  855. app.state.args = args
  856. app.state.model_is_loaded = False
  857. if args.launch_kobold_api:
  858. logger.warning("Kobold API is now enabled by default. "
  859. "This flag will be removed in the future.")
  860. app.include_router(kai_api, prefix="/api/v1")
  861. app.include_router(kai_api,
  862. prefix="/api/latest",
  863. include_in_schema=False)
  864. app.include_router(extra_api, prefix="/api/extra")
  865. mount_metrics(app)
  866. app.add_middleware(
  867. CORSMiddleware,
  868. allow_origins=args.allowed_origins,
  869. allow_credentials=args.allow_credentials,
  870. allow_methods=args.allowed_methods,
  871. allow_headers=args.allowed_headers,
  872. )
  873. @app.exception_handler(RequestValidationError)
  874. async def validation_exception_handler(_, exc):
  875. chat = app.state.openai_serving_chat
  876. err = chat.create_error_response(message=str(exc))
  877. return JSONResponse(err.model_dump(),
  878. status_code=HTTPStatus.BAD_REQUEST)
  879. if token := envs.APHRODITE_API_KEY or args.api_keys:
  880. admin_key = os.environ.get("APHRODITE_ADMIN_KEY") or args.admin_key
  881. if admin_key is None:
  882. logger.warning("Admin key not provided. Admin operations will "
  883. "be disabled.")
  884. @app.middleware("http")
  885. async def authentication(request: Request, call_next):
  886. if not request.url.path.startswith(("/v1", "/api")):
  887. return await call_next(request)
  888. # Browsers may send OPTIONS requests to check CORS headers
  889. # before sending the actual request. We should allow these
  890. # requests to pass through without authentication.
  891. # See https://github.com/PygmalionAI/aphrodite-engine/issues/434
  892. if request.method == "OPTIONS":
  893. return await call_next(request)
  894. auth_header = request.headers.get("Authorization")
  895. api_key_header = request.headers.get("x-api-key")
  896. if request.url.path.startswith(
  897. ("/v1/lora", "/v1/soft_prompt", "/v1/model")):
  898. if admin_key is not None and (
  899. api_key_header == admin_key or
  900. auth_header == "Bearer " + admin_key
  901. ):
  902. return await call_next(request)
  903. return JSONResponse(content={"error": "Unauthorized"},
  904. status_code=401)
  905. if (auth_header == f"Bearer {token}" or api_key_header == token or
  906. (admin_key is not None and
  907. (api_key_header == admin_key or
  908. auth_header == f"Bearer {admin_key}"))):
  909. return await call_next(request)
  910. return JSONResponse(
  911. content={"error": "Unauthorized"}, status_code=401)
  912. for middleware in args.middleware:
  913. module_path, object_name = middleware.rsplit(".", 1)
  914. imported = getattr(importlib.import_module(module_path), object_name)
  915. if inspect.isclass(imported):
  916. app.add_middleware(imported)
  917. elif inspect.iscoroutinefunction(imported):
  918. app.middleware("http")(imported)
  919. else:
  920. raise ValueError(f"Invalid middleware {middleware}. "
  921. f"Must be a function or a class.")
  922. return app
  923. def init_app_state(
  924. engine_client: EngineClient,
  925. model_config: ModelConfig,
  926. state: State,
  927. args: Namespace,
  928. ) -> None:
  929. global api_server_args
  930. api_server_args = args
  931. logger.debug(f"args: {args}")
  932. global served_model_names
  933. if args.served_model_name is not None:
  934. served_model_names = args.served_model_name
  935. else:
  936. served_model_names = [args.model]
  937. if args.uvloop:
  938. uvloop.install()
  939. global tokenizer
  940. if args.disable_log_requests:
  941. request_logger = None
  942. else:
  943. request_logger = RequestLogger(max_log_len=args.max_log_len)
  944. state.engine_client = engine_client
  945. state.log_stats = not args.disable_log_stats
  946. state.current_model = args.model
  947. state.openai_serving_chat = OpenAIServingChat(
  948. engine_client,
  949. model_config,
  950. served_model_names,
  951. args.response_role,
  952. lora_modules=args.lora_modules,
  953. prompt_adapters=args.prompt_adapters,
  954. request_logger=request_logger,
  955. chat_template=args.chat_template,
  956. return_tokens_as_token_ids=args.return_tokens_as_token_ids,
  957. enable_auto_tools=args.enable_auto_tool_choice,
  958. tool_parser=args.tool_call_parser
  959. )
  960. state.openai_serving_completion = OpenAIServingCompletion(
  961. engine_client,
  962. model_config,
  963. served_model_names,
  964. lora_modules=args.lora_modules,
  965. prompt_adapters=args.prompt_adapters,
  966. request_logger=request_logger,
  967. return_tokens_as_token_ids=args.return_tokens_as_token_ids,
  968. )
  969. state.openai_serving_embedding = OpenAIServingEmbedding(
  970. engine_client,
  971. model_config,
  972. served_model_names,
  973. request_logger=request_logger,
  974. )
  975. state.openai_serving_tokenization = OpenAIServingTokenization(
  976. engine_client,
  977. model_config,
  978. served_model_names,
  979. lora_modules=args.lora_modules,
  980. request_logger=request_logger,
  981. chat_template=args.chat_template,
  982. )
  983. tokenizer = get_tokenizer(
  984. tokenizer_name=args.tokenizer if args.tokenizer else args.model,
  985. tokenizer_mode=args.tokenizer_mode,
  986. trust_remote_code=args.trust_remote_code,
  987. revision=args.revision,
  988. )
  989. if args.launch_kobold_api:
  990. _set_badwords(tokenizer, model_config.hf_config)
  991. async def run_server(args, **uvicorn_kwargs) -> None:
  992. def signal_handler(*_) -> None:
  993. # Interrupt server on sigterm while initializing
  994. raise KeyboardInterrupt("terminated")
  995. signal.signal(signal.SIGTERM, signal_handler)
  996. async with build_engine_client(args) as engine_client:
  997. # If None, creation of the client failed and we exit.
  998. if engine_client is None:
  999. return
  1000. app = build_app(args)
  1001. model_config = await engine_client.get_model_config()
  1002. init_app_state(engine_client, model_config, app.state, args)
  1003. protocol = "https" if args.ssl_certfile else "http"
  1004. root_path = args.root_path.rstrip("/") if args.root_path else ""
  1005. host_name = args.host if args.host else "localhost"
  1006. port_str = str(args.port)
  1007. app.state.model_is_loaded = True
  1008. if SERVE_KOBOLD_LITE_UI:
  1009. ui_url = f"{protocol}://{host_name}:{port_str}{root_path}/"
  1010. logger.info(f"Kobold Lite UI: {ui_url}")
  1011. logger.info(f"Documentation: {protocol}://{host_name}:{port_str}{root_path}/redoc") # noqa: E501
  1012. logger.info(f"Completions API: {protocol}://{host_name}:{port_str}{root_path}/v1/completions") # noqa: E501
  1013. logger.info(f"Chat API: {protocol}://{host_name}:{port_str}{root_path}/v1/chat/completions") # noqa: E501
  1014. logger.info(f"Embeddings API: {protocol}://{host_name}:{port_str}{root_path}/v1/embeddings") # noqa: E501
  1015. logger.info(f"Tokenization API: {protocol}://{host_name}:{port_str}{root_path}/v1/tokenize") # noqa: E501
  1016. shutdown_task = await serve_http(
  1017. app,
  1018. host=args.host,
  1019. port=args.port,
  1020. log_level=args.uvicorn_log_level,
  1021. timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
  1022. ssl_keyfile=args.ssl_keyfile,
  1023. ssl_certfile=args.ssl_certfile,
  1024. ssl_ca_certs=args.ssl_ca_certs,
  1025. ssl_cert_reqs=args.ssl_cert_reqs,
  1026. **uvicorn_kwargs,
  1027. )
  1028. # NB: Await server shutdown only after the backend context is exited
  1029. await shutdown_task
  1030. if __name__ == "__main__":
  1031. # NOTE:
  1032. # This section should be in sync with aphrodite/endpoints/cli.py
  1033. # for CLI entrypoints.
  1034. parser = FlexibleArgumentParser(
  1035. description="Aphrodite OpenAI-Compatible RESTful API Server")
  1036. parser = make_arg_parser(parser)
  1037. args = parser.parse_args()
  1038. uvloop.run(run_server(args))