api_server.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591
  1. import argparse
  2. import asyncio
  3. import json
  4. from contextlib import asynccontextmanager
  5. import os
  6. import importlib
  7. import inspect
  8. from typing import List, Tuple, AsyncGenerator, Optional
  9. from prometheus_client import make_asgi_app
  10. import fastapi
  11. import uvicorn
  12. from http import HTTPStatus
  13. from fastapi import Request, APIRouter, Header
  14. from fastapi.exceptions import RequestValidationError
  15. from fastapi.middleware.cors import CORSMiddleware
  16. from fastapi.responses import JSONResponse, StreamingResponse, Response, HTMLResponse
  17. from loguru import logger
  18. import aphrodite
  19. from aphrodite.engine.args_tools import AsyncEngineArgs
  20. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  21. from aphrodite.endpoints.openai.protocol import (CompletionRequest,
  22. ChatCompletionRequest,
  23. ErrorResponse, Prompt)
  24. from aphrodite.common.logger import UVICORN_LOG_CONFIG
  25. from aphrodite.common.outputs import RequestOutput
  26. from aphrodite.common.sampling_params import SamplingParams, _SAMPLING_EPS
  27. from aphrodite.common.utils import random_uuid
  28. from aphrodite.endpoints.openai.serving_chat import OpenAIServingChat
  29. from aphrodite.endpoints.openai.serving_completions import OpenAIServingCompletion
  30. from aphrodite.endpoints.openai.protocol import KAIGenerationInputSchema
  31. from aphrodite.endpoints.openai.serving_engine import LoRA
  32. from aphrodite.transformers_utils.tokenizer import get_tokenizer
  33. TIMEOUT_KEEP_ALIVE = 5 # seconds
  34. openai_serving_chat: OpenAIServingChat = None
  35. openai_serving_completion: OpenAIServingCompletion = None
  36. kai_api = APIRouter()
  37. extra_api = APIRouter()
  38. kobold_lite_ui = ""
  39. sampler_json = ""
  40. gen_cache: dict = {}
  41. @asynccontextmanager
  42. async def lifespan(app: fastapi.FastAPI):
  43. async def _force_log():
  44. while True:
  45. await asyncio.sleep(10)
  46. await engine.do_log_stats()
  47. if not engine_args.disable_log_stats:
  48. asyncio.create_task(_force_log())
  49. yield
  50. app = fastapi.FastAPI(title="Aphrodite Engine",
  51. summary="Serving language models at scale",
  52. description=("A RESTful API server compatible with "
  53. "OpenAI and KoboldAI clients. "),
  54. lifespan=lifespan)
  55. class LoRAParserAction(argparse.Action):
  56. def __call__(self, parser, namespace, values, option_string=None):
  57. lora_list = []
  58. for item in values:
  59. name, path = item.split('=')
  60. lora_list.append(LoRA(name, path))
  61. setattr(namespace, self.dest, lora_list)
  62. def parse_args():
  63. parser = argparse.ArgumentParser(
  64. description="Aphrodite OpenAI-Compatible RESTful API server.")
  65. parser.add_argument("--host", type=str, default=None, help="host name")
  66. parser.add_argument("--port", type=int, default=2242, help="port number")
  67. parser.add_argument("--allow-credentials",
  68. action="store_true",
  69. help="allow credentials")
  70. parser.add_argument("--allowed-origins",
  71. type=json.loads,
  72. default=["*"],
  73. help="allowed origins")
  74. parser.add_argument("--allowed-methods",
  75. type=json.loads,
  76. default=["*"],
  77. help="allowed methods")
  78. parser.add_argument("--allowed-headers",
  79. type=json.loads,
  80. default=["*"],
  81. help="allowed headers")
  82. parser.add_argument(
  83. "--api-keys",
  84. type=str,
  85. default=None,
  86. help=
  87. "If provided, the server will require this key to be presented in the header."
  88. )
  89. parser.add_argument(
  90. "--launch-kobold-api",
  91. action="store_true",
  92. help=
  93. "Launch the Kobold API server in addition to the OpenAI API server.")
  94. parser.add_argument("--max-length",
  95. type=int,
  96. default=256,
  97. help="The maximum length of the generated response. "
  98. "For use with Kobold Horde.")
  99. parser.add_argument("--served-model-name",
  100. type=str,
  101. default=None,
  102. help="The model name used in the API. If not "
  103. "specified, the model name will be the same as "
  104. "the huggingface name.")
  105. parser.add_argument(
  106. "--lora-modules",
  107. type=str,
  108. default=None,
  109. nargs='+',
  110. action=LoRAParserAction,
  111. help=
  112. "LoRA module configurations in the format name=path. Multiple modules can be specified."
  113. )
  114. parser.add_argument("--chat-template",
  115. type=str,
  116. default=None,
  117. help="The file path to the chat template, "
  118. "or the template in single-line form "
  119. "for the specified model")
  120. parser.add_argument("--response-role",
  121. type=str,
  122. default="assistant",
  123. help="The role name to return if "
  124. "`request.add_generation_prompt=true`.")
  125. parser.add_argument("--ssl-keyfile",
  126. type=str,
  127. default=None,
  128. help="The file path to the SSL key file")
  129. parser.add_argument("--ssl-certfile",
  130. type=str,
  131. default=None,
  132. help="The file path to the SSL cert file")
  133. parser.add_argument(
  134. "--root-path",
  135. type=str,
  136. default=None,
  137. help="FastAPI root_path when app is behind a path based routing proxy")
  138. parser.add_argument(
  139. "--middleware",
  140. type=str,
  141. action="append",
  142. default=[],
  143. help="Additional ASGI middleware to apply to the app. "
  144. "We accept multiple --middleware arguments. "
  145. "The value should be an import path. "
  146. "If a function is provided, Aphrodite will add it to the server using @app.middleware('http'). "
  147. "If a class is provided, Aphrodite will add it to the server using app.add_middleware(). "
  148. )
  149. parser = AsyncEngineArgs.add_cli_args(parser)
  150. return parser.parse_args()
  151. # Add prometheus asgi middleware to route /metrics requests
  152. metrics_app = make_asgi_app()
  153. app.mount("/metrics", metrics_app)
  154. @app.exception_handler(RequestValidationError)
  155. async def validation_exception_handler(_, exc):
  156. err = openai_serving_chat.create_error_response(message=str(exc))
  157. return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
  158. @app.get("/health")
  159. async def health() -> Response:
  160. """Health check."""
  161. return Response(status_code=200)
  162. @app.get("/v1/models")
  163. async def show_available_models(x_api_key: Optional[str] = Header(None)):
  164. models = await openai_serving_chat.show_available_models()
  165. return JSONResponse(content=models.model_dump())
  166. @app.post("/v1/tokenize")
  167. @app.post("/v1/token/encode")
  168. async def tokenize(request: Request,
  169. prompt: Prompt,
  170. x_api_key: Optional[str] = Header(None)):
  171. tokenized = await openai_serving_chat.tokenize(prompt)
  172. return JSONResponse(content=tokenized)
  173. @app.post("/v1/detokenize")
  174. @app.post("/v1/token/decode")
  175. async def detokenize(request: Request,
  176. token_ids: List[int],
  177. x_api_key: Optional[str] = Header(None)):
  178. detokenized = await openai_serving_chat.detokenize(token_ids)
  179. return JSONResponse(content=detokenized)
  180. @app.get("/version", description="Fetch the Aphrodite Engine version.")
  181. async def show_version(x_api_key: Optional[str] = Header(None)):
  182. ver = {"version": aphrodite.__version__}
  183. return JSONResponse(content=ver)
  184. @app.get("/v1/samplers")
  185. async def show_samplers(x_api_key: Optional[str] = Header(None)):
  186. """Get the available samplers."""
  187. global sampler_json
  188. if not sampler_json:
  189. jsonpath = os.path.dirname(os.path.abspath(__file__))
  190. samplerpath = os.path.join(jsonpath, "./samplers.json")
  191. samplerpath = os.path.normpath(samplerpath) # Normalize the path
  192. if os.path.exists(samplerpath):
  193. with open(samplerpath, "r") as f:
  194. sampler_json = json.load(f)
  195. else:
  196. logger.error("Sampler JSON not found at " + samplerpath)
  197. return sampler_json
  198. @app.post("/v1/chat/completions")
  199. async def create_chat_completion(request: ChatCompletionRequest,
  200. raw_request: Request,
  201. x_api_key: Optional[str] = Header(None)):
  202. generator = await openai_serving_chat.create_chat_completion(
  203. request, raw_request)
  204. if isinstance(generator, ErrorResponse):
  205. return JSONResponse(content=generator.model_dump(),
  206. status_code=generator.code)
  207. if request.stream:
  208. return StreamingResponse(content=generator,
  209. media_type="text/event-stream")
  210. else:
  211. return JSONResponse(content=generator.model_dump())
  212. @app.post("/v1/completions")
  213. async def create_completion(request: CompletionRequest,
  214. raw_request: Request,
  215. x_api_key: Optional[str] = Header(None)):
  216. generator = await openai_serving_completion.create_completion(
  217. request, raw_request)
  218. if isinstance(generator, ErrorResponse):
  219. return JSONResponse(content=generator.model_dump(),
  220. status_code=generator.code)
  221. if request.stream:
  222. return StreamingResponse(content=generator,
  223. media_type="text/event-stream")
  224. else:
  225. return JSONResponse(content=generator.model_dump())
  226. # ============ KoboldAI API ============ #
  227. def _set_badwords(tokenizer, hf_config): # pylint: disable=redefined-outer-name
  228. # pylint: disable=global-variable-undefined
  229. global badwordsids
  230. if hf_config.bad_words_ids is not None:
  231. badwordsids = hf_config.bad_words_ids
  232. return
  233. badwordsids = [
  234. v for k, v in tokenizer.get_vocab().items()
  235. if any(c in str(k) for c in "[]")
  236. ]
  237. if tokenizer.pad_token_id in badwordsids:
  238. badwordsids.remove(tokenizer.pad_token_id)
  239. badwordsids.append(tokenizer.eos_token_id)
  240. def prepare_engine_payload(
  241. kai_payload: KAIGenerationInputSchema
  242. ) -> Tuple[SamplingParams, List[int]]:
  243. """Create SamplingParams and truncated input tokens for AsyncEngine"""
  244. if not kai_payload.genkey:
  245. kai_payload.genkey = f"kai-{random_uuid()}"
  246. # if kai_payload.max_context_length > engine_args.max_model_len:
  247. # raise ValueError(
  248. # f"max_context_length ({kai_payload.max_context_length}) "
  249. # "must be less than or equal to "
  250. # f"max_model_len ({engine_args.max_model_len})")
  251. kai_payload.top_k = kai_payload.top_k if kai_payload.top_k != 0.0 else -1
  252. kai_payload.tfs = max(_SAMPLING_EPS, kai_payload.tfs)
  253. if kai_payload.temperature < _SAMPLING_EPS:
  254. kai_payload.n = 1
  255. kai_payload.top_p = 1.0
  256. kai_payload.top_k = -1
  257. if kai_payload.dynatemp_range is not None:
  258. dynatemp_min = kai_payload.temperature - kai_payload.dynatemp_range
  259. dynatemp_max = kai_payload.temperature + kai_payload.dynatemp_range
  260. sampling_params = SamplingParams(
  261. n=kai_payload.n,
  262. best_of=kai_payload.n,
  263. repetition_penalty=kai_payload.rep_pen,
  264. temperature=kai_payload.temperature,
  265. dynatemp_min=dynatemp_min if kai_payload.dynatemp_range > 0 else 0.0,
  266. dynatemp_max=dynatemp_max if kai_payload.dynatemp_range > 0 else 0.0,
  267. dynatemp_exponent=kai_payload.dynatemp_exponent,
  268. smoothing_factor=kai_payload.smoothing_factor,
  269. smoothing_curve=kai_payload.smoothing_curve,
  270. tfs=kai_payload.tfs,
  271. top_p=kai_payload.top_p,
  272. top_k=kai_payload.top_k,
  273. top_a=kai_payload.top_a,
  274. min_p=kai_payload.min_p,
  275. typical_p=kai_payload.typical,
  276. eta_cutoff=kai_payload.eta_cutoff,
  277. epsilon_cutoff=kai_payload.eps_cutoff,
  278. mirostat_mode=kai_payload.mirostat,
  279. mirostat_tau=kai_payload.mirostat_tau,
  280. mirostat_eta=kai_payload.mirostat_eta,
  281. stop=kai_payload.stop_sequence,
  282. include_stop_str_in_output=kai_payload.include_stop_str_in_output,
  283. custom_token_bans=badwordsids
  284. if kai_payload.use_default_badwordsids else [],
  285. max_tokens=kai_payload.max_length,
  286. seed=kai_payload.sampler_seed,
  287. )
  288. max_input_tokens = max(
  289. 1, kai_payload.max_context_length - kai_payload.max_length)
  290. input_tokens = tokenizer(kai_payload.prompt).input_ids[-max_input_tokens:]
  291. return sampling_params, input_tokens
  292. @kai_api.post("/generate")
  293. async def generate(kai_payload: KAIGenerationInputSchema) -> JSONResponse:
  294. sampling_params, input_tokens = prepare_engine_payload(kai_payload)
  295. result_generator = engine.generate(None, sampling_params,
  296. kai_payload.genkey, input_tokens)
  297. final_res: RequestOutput = None
  298. previous_output = ""
  299. async for res in result_generator:
  300. final_res = res
  301. new_chunk = res.outputs[0].text[len(previous_output):]
  302. previous_output += new_chunk
  303. gen_cache[kai_payload.genkey] = previous_output
  304. assert final_res is not None
  305. del gen_cache[kai_payload.genkey]
  306. return JSONResponse(
  307. {"results": [{
  308. "text": output.text
  309. } for output in final_res.outputs]})
  310. @extra_api.post("/generate/stream")
  311. async def generate_stream(
  312. kai_payload: KAIGenerationInputSchema) -> StreamingResponse:
  313. sampling_params, input_tokens = prepare_engine_payload(kai_payload)
  314. results_generator = engine.generate(None, sampling_params,
  315. kai_payload.genkey, input_tokens)
  316. async def stream_kobold() -> AsyncGenerator[bytes, None]:
  317. previous_output = ""
  318. async for res in results_generator:
  319. new_chunk = res.outputs[0].text[len(previous_output):]
  320. previous_output += new_chunk
  321. yield b"event: message\n"
  322. yield f"data: {json.dumps({'token': new_chunk})}\n\n".encode()
  323. return StreamingResponse(stream_kobold(),
  324. headers={
  325. "Cache-Control": "no-cache",
  326. "Connection": "keep-alive",
  327. },
  328. media_type="text/event-stream")
  329. @extra_api.post("/generate/check")
  330. @extra_api.get("/generate/check")
  331. async def check_generation(request: Request):
  332. text = ""
  333. try:
  334. request_dict = await request.json()
  335. if "genkey" in request_dict and request_dict["genkey"] in gen_cache:
  336. text = gen_cache[request_dict["genkey"]]
  337. except json.JSONDecodeError:
  338. pass
  339. return JSONResponse({"results": [{"text": text}]})
  340. @extra_api.post("/abort")
  341. async def abort_generation(request: Request):
  342. try:
  343. request_dict = await request.json()
  344. if "genkey" in request_dict:
  345. await engine.abort(request_dict["genkey"])
  346. except json.JSONDecodeError:
  347. pass
  348. return JSONResponse({})
  349. @extra_api.post("/tokencount")
  350. async def count_tokens(request: Request):
  351. """Tokenize string and return token count"""
  352. request_dict = await request.json()
  353. tokenizer_result = await openai_serving_chat.tokenize(
  354. request_dict["prompt"])
  355. return JSONResponse({"value": len(tokenizer_result)})
  356. @kai_api.get("/info/version")
  357. async def get_version():
  358. """Impersonate KAI"""
  359. return JSONResponse({"result": "1.2.4"})
  360. @kai_api.get("/model")
  361. async def get_model():
  362. return JSONResponse({"result": f"aphrodite/{served_model}"})
  363. @kai_api.get("/config/soft_prompts_list")
  364. async def get_available_softprompts():
  365. """Stub for compatibility"""
  366. return JSONResponse({"values": []})
  367. @kai_api.get("/config/soft_prompt")
  368. async def get_current_softprompt():
  369. """Stub for compatibility"""
  370. return JSONResponse({"value": ""})
  371. @kai_api.put("/config/soft_prompt")
  372. async def set_current_softprompt():
  373. """Stub for compatibility"""
  374. return JSONResponse({})
  375. @kai_api.get("/config/max_length")
  376. async def get_max_length() -> JSONResponse:
  377. max_length = args.max_length
  378. return JSONResponse({"value": max_length})
  379. @kai_api.get("/config/max_context_length")
  380. @extra_api.get("/true_max_context_length")
  381. async def get_max_context_length() -> JSONResponse:
  382. max_context_length = engine_args.max_model_len
  383. return JSONResponse({"value": max_context_length})
  384. @extra_api.get("/preloadstory")
  385. async def get_preloaded_story() -> JSONResponse:
  386. """Stub for compatibility"""
  387. return JSONResponse({})
  388. @extra_api.get("/version")
  389. async def get_extra_version():
  390. """Impersonate KoboldCpp"""
  391. return JSONResponse({"result": "KoboldCpp", "version": "1.60.1"})
  392. @app.get("/")
  393. async def get_kobold_lite_ui():
  394. """Serves a cached copy of the Kobold Lite UI, loading it from disk
  395. on demand if needed."""
  396. global kobold_lite_ui
  397. if kobold_lite_ui == "":
  398. scriptpath = os.path.dirname(os.path.abspath(__file__))
  399. klitepath = os.path.join(scriptpath, "../kobold/klite.embd")
  400. klitepath = os.path.normpath(klitepath) # Normalize the path
  401. if os.path.exists(klitepath):
  402. with open(klitepath, "r") as f:
  403. kobold_lite_ui = f.read()
  404. else:
  405. logger.error("Kobold Lite UI not found at " + klitepath)
  406. return HTMLResponse(content=kobold_lite_ui)
  407. # ============ KoboldAI API ============ #
  408. if __name__ == "__main__":
  409. args = parse_args()
  410. if args.launch_kobold_api:
  411. logger.warning("Launching Kobold API server in addition to OpenAI. "
  412. "Keep in mind that the Kobold API routes are NOT "
  413. "protected via the API key.")
  414. app.include_router(kai_api, prefix="/api/v1")
  415. app.include_router(kai_api,
  416. prefix="/api/latest",
  417. include_in_schema=False)
  418. app.include_router(extra_api, prefix="/api/extra")
  419. app.add_middleware(
  420. CORSMiddleware,
  421. allow_origins=args.allowed_origins,
  422. allow_credentials=args.allow_credentials,
  423. allow_methods=args.allowed_methods,
  424. allow_headers=args.allowed_headers,
  425. )
  426. if token := os.environ.get("APHRODITE_API_KEY") or args.api_keys:
  427. @app.middleware("http")
  428. async def authentication(request: Request, call_next):
  429. excluded_paths = ["/api"]
  430. if any(
  431. request.url.path.startswith(path)
  432. for path in excluded_paths):
  433. return await call_next(request)
  434. if not request.url.path.startswith("/v1"):
  435. return await call_next(request)
  436. auth_header = request.headers.get("Authorization")
  437. api_key_header = request.headers.get("x-api-key")
  438. if auth_header != "Bearer " + token and api_key_header != token:
  439. return JSONResponse(content={"error": "Unauthorized"},
  440. status_code=401)
  441. return await call_next(request)
  442. for middleware in args.middleware:
  443. module_path, object_name = middleware.rsplit(".", 1)
  444. imported = getattr(importlib.import_module(module_path), object_name)
  445. if inspect.isclass(imported):
  446. app.add_middleware(imported)
  447. elif inspect.iscoroutinefunction(imported):
  448. app.middleware("http")(imported)
  449. else:
  450. raise ValueError(
  451. f"Invalid middleware {middleware}. Must be a function or a class."
  452. )
  453. logger.debug(f"args: {args}")
  454. if args.served_model_name is not None:
  455. served_model = args.served_model_name
  456. else:
  457. served_model = args.model
  458. engine_args = AsyncEngineArgs.from_cli_args(args)
  459. engine = AsyncAphrodite.from_engine_args(engine_args)
  460. tokenizer = get_tokenizer(
  461. engine_args.tokenizer,
  462. tokenizer_mode=engine_args.tokenizer_mode,
  463. trust_remote_code=engine_args.trust_remote_code,
  464. )
  465. chat_template = args.chat_template
  466. if chat_template is None and tokenizer.chat_template is not None:
  467. chat_template = tokenizer.chat_template
  468. openai_serving_chat = OpenAIServingChat(engine, served_model,
  469. args.response_role,
  470. args.lora_modules,
  471. args.chat_template)
  472. openai_serving_completion = OpenAIServingCompletion(
  473. engine, served_model, args.lora_modules)
  474. engine_model_config = asyncio.run(engine.get_model_config())
  475. if args.launch_kobold_api:
  476. _set_badwords(tokenizer, engine_model_config.hf_config)
  477. app.root_path = args.root_path
  478. uvicorn.run(app,
  479. host=args.host,
  480. port=args.port,
  481. log_level="info",
  482. timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
  483. ssl_keyfile=args.ssl_keyfile,
  484. ssl_certfile=args.ssl_certfile,
  485. log_config=UVICORN_LOG_CONFIG)