api_server.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629
  1. import argparse
  2. import asyncio
  3. import json
  4. from contextlib import asynccontextmanager
  5. import os
  6. import importlib
  7. import inspect
  8. from typing import List, Tuple, AsyncGenerator, Optional
  9. from prometheus_client import make_asgi_app
  10. import fastapi
  11. import uvicorn
  12. from http import HTTPStatus
  13. from fastapi import Request, APIRouter, Header
  14. from fastapi.exceptions import RequestValidationError
  15. from fastapi.middleware.cors import CORSMiddleware
  16. from fastapi.responses import (JSONResponse, StreamingResponse, Response,
  17. HTMLResponse)
  18. from loguru import logger
  19. import aphrodite
  20. from aphrodite.engine.args_tools import AsyncEngineArgs
  21. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  22. from aphrodite.endpoints.openai.protocol import (CompletionRequest,
  23. ChatCompletionRequest,
  24. ErrorResponse, Prompt)
  25. from aphrodite.common.logger import UVICORN_LOG_CONFIG
  26. from aphrodite.common.outputs import RequestOutput
  27. from aphrodite.common.sampling_params import SamplingParams, _SAMPLING_EPS
  28. from aphrodite.common.utils import random_uuid
  29. from aphrodite.endpoints.openai.serving_chat import OpenAIServingChat
  30. from aphrodite.endpoints.openai.serving_completions import (
  31. OpenAIServingCompletion)
  32. from aphrodite.endpoints.openai.protocol import KAIGenerationInputSchema
  33. from aphrodite.endpoints.openai.serving_engine import LoRA
  34. from aphrodite.transformers_utils.tokenizer import get_tokenizer
  35. TIMEOUT_KEEP_ALIVE = 5 # seconds
  36. openai_serving_chat: OpenAIServingChat = None
  37. openai_serving_completion: OpenAIServingCompletion = None
  38. kai_api = APIRouter()
  39. extra_api = APIRouter()
  40. kobold_lite_ui = ""
  41. sampler_json = ""
  42. gen_cache: dict = {}
  43. @asynccontextmanager
  44. async def lifespan(app: fastapi.FastAPI):
  45. async def _force_log():
  46. while True:
  47. await asyncio.sleep(10)
  48. await engine.do_log_stats()
  49. if not engine_args.disable_log_stats:
  50. asyncio.create_task(_force_log())
  51. yield
  52. app = fastapi.FastAPI(title="Aphrodite Engine",
  53. summary="Serving language models at scale",
  54. description=("A RESTful API server compatible with "
  55. "OpenAI and KoboldAI clients. "),
  56. lifespan=lifespan)
  57. class LoRAParserAction(argparse.Action):
  58. def __call__(self, parser, namespace, values, option_string=None):
  59. lora_list = []
  60. for item in values:
  61. name, path = item.split('=')
  62. lora_list.append(LoRA(name, path))
  63. setattr(namespace, self.dest, lora_list)
  64. def parse_args():
  65. parser = argparse.ArgumentParser(
  66. description="Aphrodite OpenAI-Compatible RESTful API server.")
  67. parser.add_argument("--host", type=str, default=None, help="host name")
  68. parser.add_argument("--port", type=int, default=2242, help="port number")
  69. parser.add_argument("--allow-credentials",
  70. action="store_true",
  71. help="allow credentials")
  72. parser.add_argument("--allowed-origins",
  73. type=json.loads,
  74. default=["*"],
  75. help="allowed origins")
  76. parser.add_argument("--allowed-methods",
  77. type=json.loads,
  78. default=["*"],
  79. help="allowed methods")
  80. parser.add_argument("--allowed-headers",
  81. type=json.loads,
  82. default=["*"],
  83. help="allowed headers")
  84. parser.add_argument(
  85. "--api-keys",
  86. type=str,
  87. default=None,
  88. help=
  89. "If provided, the server will require this key to be presented in the "
  90. "header.")
  91. parser.add_argument(
  92. "--admin-key",
  93. type=str,
  94. default=None,
  95. help=
  96. "If provided, the server will require this key to be presented in the "
  97. "header for admin operations.")
  98. parser.add_argument(
  99. "--launch-kobold-api",
  100. action="store_true",
  101. help=
  102. "Launch the Kobold API server in addition to the OpenAI API server.")
  103. parser.add_argument("--max-length",
  104. type=int,
  105. default=256,
  106. help="The maximum length of the generated response. "
  107. "For use with Kobold Horde.")
  108. parser.add_argument("--served-model-name",
  109. type=str,
  110. default=None,
  111. help="The model name used in the API. If not "
  112. "specified, the model name will be the same as "
  113. "the huggingface name.")
  114. parser.add_argument(
  115. "--lora-modules",
  116. type=str,
  117. default=None,
  118. nargs='+',
  119. action=LoRAParserAction,
  120. help=
  121. "LoRA module configurations in the format name=path. Multiple modules "
  122. "can be specified.")
  123. parser.add_argument("--chat-template",
  124. type=str,
  125. default=None,
  126. help="The file path to the chat template, "
  127. "or the template in single-line form "
  128. "for the specified model")
  129. parser.add_argument("--response-role",
  130. type=str,
  131. default="assistant",
  132. help="The role name to return if "
  133. "`request.add_generation_prompt=true`.")
  134. parser.add_argument("--ssl-keyfile",
  135. type=str,
  136. default=None,
  137. help="The file path to the SSL key file")
  138. parser.add_argument("--ssl-certfile",
  139. type=str,
  140. default=None,
  141. help="The file path to the SSL cert file")
  142. parser.add_argument(
  143. "--root-path",
  144. type=str,
  145. default=None,
  146. help="FastAPI root_path when app is behind a path based routing proxy")
  147. parser.add_argument(
  148. "--middleware",
  149. type=str,
  150. action="append",
  151. default=[],
  152. help="Additional ASGI middleware to apply to the app. "
  153. "We accept multiple --middleware arguments. "
  154. "The value should be an import path. "
  155. "If a function is provided, Aphrodite will add it to the server using "
  156. "@app.middleware('http'). "
  157. "If a class is provided, Aphrodite will add it to the server using "
  158. "app.add_middleware(). ")
  159. parser = AsyncEngineArgs.add_cli_args(parser)
  160. return parser.parse_args()
  161. # Add prometheus asgi middleware to route /metrics requests
  162. metrics_app = make_asgi_app()
  163. app.mount("/metrics", metrics_app)
  164. @app.exception_handler(RequestValidationError)
  165. async def validation_exception_handler(_, exc):
  166. err = openai_serving_chat.create_error_response(message=str(exc))
  167. return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
  168. @app.get("/health")
  169. async def health() -> Response:
  170. """Health check."""
  171. return Response(status_code=200)
  172. @app.get("/v1/models")
  173. async def show_available_models(x_api_key: Optional[str] = Header(None)):
  174. models = await openai_serving_chat.show_available_models()
  175. return JSONResponse(content=models.model_dump())
  176. @app.post("/v1/tokenize")
  177. @app.post("/v1/token/encode")
  178. async def tokenize(request: Request,
  179. prompt: Prompt,
  180. x_api_key: Optional[str] = Header(None)):
  181. tokenized = await openai_serving_chat.tokenize(prompt)
  182. return JSONResponse(content=tokenized)
  183. @app.post("/v1/detokenize")
  184. @app.post("/v1/token/decode")
  185. async def detokenize(request: Request,
  186. token_ids: List[int],
  187. x_api_key: Optional[str] = Header(None)):
  188. detokenized = await openai_serving_chat.detokenize(token_ids)
  189. return JSONResponse(content=detokenized)
  190. @app.get("/version", description="Fetch the Aphrodite Engine version.")
  191. async def show_version(x_api_key: Optional[str] = Header(None)):
  192. ver = {"version": aphrodite.__version__}
  193. return JSONResponse(content=ver)
  194. @app.get("/v1/samplers")
  195. async def show_samplers(x_api_key: Optional[str] = Header(None)):
  196. """Get the available samplers."""
  197. global sampler_json
  198. if not sampler_json:
  199. jsonpath = os.path.dirname(os.path.abspath(__file__))
  200. samplerpath = os.path.join(jsonpath, "./samplers.json")
  201. samplerpath = os.path.normpath(samplerpath) # Normalize the path
  202. if os.path.exists(samplerpath):
  203. with open(samplerpath, "r") as f:
  204. sampler_json = json.load(f)
  205. else:
  206. logger.error("Sampler JSON not found at " + samplerpath)
  207. return sampler_json
  208. @app.post("/v1/lora/load")
  209. async def load_lora(lora: LoRA, x_api_key: Optional[str] = Header(None)):
  210. openai_serving_chat.add_lora(lora)
  211. openai_serving_completion.add_lora(lora)
  212. if engine_args.enable_lora is False:
  213. logger.error("LoRA is not enabled in the engine. "
  214. "Please start the server with the "
  215. "--enable-lora flag.")
  216. return JSONResponse(content={"result": "success"})
  217. @app.delete("/v1/lora/unload")
  218. async def unload_lora(lora_name: str, x_api_key: Optional[str] = Header(None)):
  219. openai_serving_chat.remove_lora(lora_name)
  220. openai_serving_completion.remove_lora(lora_name)
  221. return JSONResponse(content={"result": "success"})
  222. @app.post("/v1/chat/completions")
  223. async def create_chat_completion(request: ChatCompletionRequest,
  224. raw_request: Request,
  225. x_api_key: Optional[str] = Header(None)):
  226. generator = await openai_serving_chat.create_chat_completion(
  227. request, raw_request)
  228. if isinstance(generator, ErrorResponse):
  229. return JSONResponse(content=generator.model_dump(),
  230. status_code=generator.code)
  231. if request.stream:
  232. return StreamingResponse(content=generator,
  233. media_type="text/event-stream")
  234. else:
  235. return JSONResponse(content=generator.model_dump())
  236. @app.post("/v1/completions")
  237. async def create_completion(request: CompletionRequest,
  238. raw_request: Request,
  239. x_api_key: Optional[str] = Header(None)):
  240. generator = await openai_serving_completion.create_completion(
  241. request, raw_request)
  242. if isinstance(generator, ErrorResponse):
  243. return JSONResponse(content=generator.model_dump(),
  244. status_code=generator.code)
  245. if request.stream:
  246. return StreamingResponse(content=generator,
  247. media_type="text/event-stream")
  248. else:
  249. return JSONResponse(content=generator.model_dump())
  250. # ============ KoboldAI API ============ #
  251. def _set_badwords(tokenizer, hf_config): # pylint: disable=redefined-outer-name
  252. # pylint: disable=global-variable-undefined
  253. global badwordsids
  254. if hf_config.bad_words_ids is not None:
  255. badwordsids = hf_config.bad_words_ids
  256. return
  257. badwordsids = [
  258. v for k, v in tokenizer.get_vocab().items()
  259. if any(c in str(k) for c in "[]")
  260. ]
  261. if tokenizer.pad_token_id in badwordsids:
  262. badwordsids.remove(tokenizer.pad_token_id)
  263. badwordsids.append(tokenizer.eos_token_id)
  264. def prepare_engine_payload(
  265. kai_payload: KAIGenerationInputSchema
  266. ) -> Tuple[SamplingParams, List[int]]:
  267. """Create SamplingParams and truncated input tokens for AsyncEngine"""
  268. if not kai_payload.genkey:
  269. kai_payload.genkey = f"kai-{random_uuid()}"
  270. # if kai_payload.max_context_length > engine_args.max_model_len:
  271. # raise ValueError(
  272. # f"max_context_length ({kai_payload.max_context_length}) "
  273. # "must be less than or equal to "
  274. # f"max_model_len ({engine_args.max_model_len})")
  275. kai_payload.top_k = kai_payload.top_k if kai_payload.top_k != 0.0 else -1
  276. kai_payload.tfs = max(_SAMPLING_EPS, kai_payload.tfs)
  277. if kai_payload.temperature < _SAMPLING_EPS:
  278. kai_payload.n = 1
  279. kai_payload.top_p = 1.0
  280. kai_payload.top_k = -1
  281. if kai_payload.dynatemp_range is not None:
  282. dynatemp_min = kai_payload.temperature - kai_payload.dynatemp_range
  283. dynatemp_max = kai_payload.temperature + kai_payload.dynatemp_range
  284. sampling_params = SamplingParams(
  285. n=kai_payload.n,
  286. best_of=kai_payload.n,
  287. repetition_penalty=kai_payload.rep_pen,
  288. temperature=kai_payload.temperature,
  289. dynatemp_min=dynatemp_min if kai_payload.dynatemp_range > 0 else 0.0,
  290. dynatemp_max=dynatemp_max if kai_payload.dynatemp_range > 0 else 0.0,
  291. dynatemp_exponent=kai_payload.dynatemp_exponent,
  292. smoothing_factor=kai_payload.smoothing_factor,
  293. smoothing_curve=kai_payload.smoothing_curve,
  294. tfs=kai_payload.tfs,
  295. top_p=kai_payload.top_p,
  296. top_k=kai_payload.top_k,
  297. top_a=kai_payload.top_a,
  298. min_p=kai_payload.min_p,
  299. typical_p=kai_payload.typical,
  300. eta_cutoff=kai_payload.eta_cutoff,
  301. epsilon_cutoff=kai_payload.eps_cutoff,
  302. mirostat_mode=kai_payload.mirostat,
  303. mirostat_tau=kai_payload.mirostat_tau,
  304. mirostat_eta=kai_payload.mirostat_eta,
  305. stop=kai_payload.stop_sequence,
  306. include_stop_str_in_output=kai_payload.include_stop_str_in_output,
  307. custom_token_bans=badwordsids
  308. if kai_payload.use_default_badwordsids else [],
  309. max_tokens=kai_payload.max_length,
  310. seed=kai_payload.sampler_seed,
  311. )
  312. max_input_tokens = max(
  313. 1, kai_payload.max_context_length - kai_payload.max_length)
  314. input_tokens = tokenizer(kai_payload.prompt).input_ids[-max_input_tokens:]
  315. return sampling_params, input_tokens
  316. @kai_api.post("/generate")
  317. async def generate(kai_payload: KAIGenerationInputSchema) -> JSONResponse:
  318. sampling_params, input_tokens = prepare_engine_payload(kai_payload)
  319. result_generator = engine.generate(None, sampling_params,
  320. kai_payload.genkey, input_tokens)
  321. final_res: RequestOutput = None
  322. previous_output = ""
  323. async for res in result_generator:
  324. final_res = res
  325. new_chunk = res.outputs[0].text[len(previous_output):]
  326. previous_output += new_chunk
  327. gen_cache[kai_payload.genkey] = previous_output
  328. assert final_res is not None
  329. del gen_cache[kai_payload.genkey]
  330. return JSONResponse(
  331. {"results": [{
  332. "text": output.text
  333. } for output in final_res.outputs]})
  334. @extra_api.post("/generate/stream")
  335. async def generate_stream(
  336. kai_payload: KAIGenerationInputSchema) -> StreamingResponse:
  337. sampling_params, input_tokens = prepare_engine_payload(kai_payload)
  338. results_generator = engine.generate(None, sampling_params,
  339. kai_payload.genkey, input_tokens)
  340. async def stream_kobold() -> AsyncGenerator[bytes, None]:
  341. previous_output = ""
  342. async for res in results_generator:
  343. new_chunk = res.outputs[0].text[len(previous_output):]
  344. previous_output += new_chunk
  345. yield b"event: message\n"
  346. yield f"data: {json.dumps({'token': new_chunk})}\n\n".encode()
  347. return StreamingResponse(stream_kobold(),
  348. headers={
  349. "Cache-Control": "no-cache",
  350. "Connection": "keep-alive",
  351. },
  352. media_type="text/event-stream")
  353. @extra_api.post("/generate/check")
  354. @extra_api.get("/generate/check")
  355. async def check_generation(request: Request):
  356. text = ""
  357. try:
  358. request_dict = await request.json()
  359. if "genkey" in request_dict and request_dict["genkey"] in gen_cache:
  360. text = gen_cache[request_dict["genkey"]]
  361. except json.JSONDecodeError:
  362. pass
  363. return JSONResponse({"results": [{"text": text}]})
  364. @extra_api.post("/abort")
  365. async def abort_generation(request: Request):
  366. try:
  367. request_dict = await request.json()
  368. if "genkey" in request_dict:
  369. await engine.abort(request_dict["genkey"])
  370. except json.JSONDecodeError:
  371. pass
  372. return JSONResponse({})
  373. @extra_api.post("/tokencount")
  374. async def count_tokens(request: Request):
  375. """Tokenize string and return token count"""
  376. request_dict = await request.json()
  377. tokenizer_result = await openai_serving_chat.tokenize(
  378. request_dict["prompt"])
  379. return JSONResponse({"value": len(tokenizer_result)})
  380. @kai_api.get("/info/version")
  381. async def get_version():
  382. """Impersonate KAI"""
  383. return JSONResponse({"result": "1.2.4"})
  384. @kai_api.get("/model")
  385. async def get_model():
  386. return JSONResponse({"result": f"aphrodite/{served_model}"})
  387. @kai_api.get("/config/soft_prompts_list")
  388. async def get_available_softprompts():
  389. """Stub for compatibility"""
  390. return JSONResponse({"values": []})
  391. @kai_api.get("/config/soft_prompt")
  392. async def get_current_softprompt():
  393. """Stub for compatibility"""
  394. return JSONResponse({"value": ""})
  395. @kai_api.put("/config/soft_prompt")
  396. async def set_current_softprompt():
  397. """Stub for compatibility"""
  398. return JSONResponse({})
  399. @kai_api.get("/config/max_length")
  400. async def get_max_length() -> JSONResponse:
  401. max_length = args.max_length
  402. return JSONResponse({"value": max_length})
  403. @kai_api.get("/config/max_context_length")
  404. @extra_api.get("/true_max_context_length")
  405. async def get_max_context_length() -> JSONResponse:
  406. max_context_length = engine_args.max_model_len
  407. return JSONResponse({"value": max_context_length})
  408. @extra_api.get("/preloadstory")
  409. async def get_preloaded_story() -> JSONResponse:
  410. """Stub for compatibility"""
  411. return JSONResponse({})
  412. @extra_api.get("/version")
  413. async def get_extra_version():
  414. """Impersonate KoboldCpp"""
  415. return JSONResponse({"result": "KoboldCpp", "version": "1.60.1"})
  416. @app.get("/")
  417. async def get_kobold_lite_ui():
  418. """Serves a cached copy of the Kobold Lite UI, loading it from disk
  419. on demand if needed."""
  420. global kobold_lite_ui
  421. if kobold_lite_ui == "":
  422. scriptpath = os.path.dirname(os.path.abspath(__file__))
  423. klitepath = os.path.join(scriptpath, "../kobold/klite.embd")
  424. klitepath = os.path.normpath(klitepath) # Normalize the path
  425. if os.path.exists(klitepath):
  426. with open(klitepath, "r") as f:
  427. kobold_lite_ui = f.read()
  428. else:
  429. logger.error("Kobold Lite UI not found at " + klitepath)
  430. return HTMLResponse(content=kobold_lite_ui)
  431. # ============ KoboldAI API ============ #
  432. if __name__ == "__main__":
  433. args = parse_args()
  434. if args.launch_kobold_api:
  435. logger.warning("Launching Kobold API server in addition to OpenAI. "
  436. "Keep in mind that the Kobold API routes are NOT "
  437. "protected via the API key.")
  438. app.include_router(kai_api, prefix="/api/v1")
  439. app.include_router(kai_api,
  440. prefix="/api/latest",
  441. include_in_schema=False)
  442. app.include_router(extra_api, prefix="/api/extra")
  443. app.add_middleware(
  444. CORSMiddleware,
  445. allow_origins=args.allowed_origins,
  446. allow_credentials=args.allow_credentials,
  447. allow_methods=args.allowed_methods,
  448. allow_headers=args.allowed_headers,
  449. )
  450. if token := os.environ.get("APHRODITE_API_KEY") or args.api_keys:
  451. admin_key = os.environ.get("APHRODITE_ADMIN_KEY") or args.admin_key
  452. if admin_key is None:
  453. logger.warning("Admin key not provided. Admin operations will "
  454. "be disabled.")
  455. @app.middleware("http")
  456. async def authentication(request: Request, call_next):
  457. excluded_paths = ["/api"]
  458. if any(
  459. request.url.path.startswith(path)
  460. for path in excluded_paths):
  461. return await call_next(request)
  462. if not request.url.path.startswith("/v1"):
  463. return await call_next(request)
  464. auth_header = request.headers.get("Authorization")
  465. api_key_header = request.headers.get("x-api-key")
  466. if request.url.path.startswith("/v1/lora"):
  467. if admin_key is not None and api_key_header == admin_key:
  468. return await call_next(request)
  469. return JSONResponse(content={"error": "Unauthorized"},
  470. status_code=401)
  471. if auth_header != "Bearer " + token and api_key_header != token:
  472. return JSONResponse(content={"error": "Unauthorized"},
  473. status_code=401)
  474. return await call_next(request)
  475. for middleware in args.middleware:
  476. module_path, object_name = middleware.rsplit(".", 1)
  477. imported = getattr(importlib.import_module(module_path), object_name)
  478. if inspect.isclass(imported):
  479. app.add_middleware(imported)
  480. elif inspect.iscoroutinefunction(imported):
  481. app.middleware("http")(imported)
  482. else:
  483. raise ValueError(f"Invalid middleware {middleware}. Must be a "
  484. "function or a class.")
  485. logger.debug(f"args: {args}")
  486. if args.served_model_name is not None:
  487. served_model = args.served_model_name
  488. else:
  489. served_model = args.model
  490. engine_args = AsyncEngineArgs.from_cli_args(args)
  491. engine = AsyncAphrodite.from_engine_args(engine_args)
  492. tokenizer = get_tokenizer(
  493. engine_args.tokenizer,
  494. tokenizer_mode=engine_args.tokenizer_mode,
  495. trust_remote_code=engine_args.trust_remote_code,
  496. )
  497. chat_template = args.chat_template
  498. if chat_template is None and tokenizer.chat_template is not None:
  499. chat_template = tokenizer.chat_template
  500. openai_serving_chat = OpenAIServingChat(engine, served_model,
  501. args.response_role,
  502. args.lora_modules,
  503. args.chat_template)
  504. openai_serving_completion = OpenAIServingCompletion(
  505. engine, served_model, args.lora_modules)
  506. engine_model_config = asyncio.run(engine.get_model_config())
  507. if args.launch_kobold_api:
  508. _set_badwords(tokenizer, engine_model_config.hf_config)
  509. app.root_path = args.root_path
  510. uvicorn.run(app,
  511. host=args.host,
  512. port=args.port,
  513. log_level="info",
  514. timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
  515. ssl_keyfile=args.ssl_keyfile,
  516. ssl_certfile=args.ssl_certfile,
  517. log_config=UVICORN_LOG_CONFIG)