api_server.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562
  1. import asyncio
  2. import importlib
  3. import inspect
  4. import json
  5. import os
  6. from contextlib import asynccontextmanager
  7. from http import HTTPStatus
  8. from typing import AsyncGenerator, List, Optional, Tuple
  9. import fastapi
  10. import uvicorn
  11. from fastapi import APIRouter, Header, Request
  12. from fastapi.exceptions import RequestValidationError
  13. from fastapi.middleware.cors import CORSMiddleware
  14. from fastapi.responses import (HTMLResponse, JSONResponse, Response,
  15. StreamingResponse)
  16. from loguru import logger
  17. from prometheus_client import make_asgi_app
  18. import aphrodite
  19. import aphrodite.endpoints.openai.embeddings as OAIembeddings
  20. from aphrodite.common.logger import UVICORN_LOG_CONFIG
  21. from aphrodite.common.outputs import RequestOutput
  22. from aphrodite.common.sampling_params import _SAMPLING_EPS, SamplingParams
  23. from aphrodite.common.utils import random_uuid
  24. from aphrodite.endpoints.openai.args import make_arg_parser
  25. from aphrodite.endpoints.openai.protocol import (
  26. ChatCompletionRequest, CompletionRequest, EmbeddingsRequest,
  27. EmbeddingsResponse, ErrorResponse, KAIGenerationInputSchema, Prompt)
  28. from aphrodite.endpoints.openai.serving_chat import OpenAIServingChat
  29. from aphrodite.endpoints.openai.serving_completions import \
  30. OpenAIServingCompletion
  31. from aphrodite.engine.args_tools import AsyncEngineArgs
  32. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  33. from aphrodite.transformers_utils.tokenizer import get_tokenizer
  34. from aphrodite.endpoints.openai.serving_engine import LoRA
  35. TIMEOUT_KEEP_ALIVE = 5 # seconds
  36. engine: Optional[AsyncAphrodite] = None
  37. engine_args: Optional[AsyncEngineArgs] = None
  38. openai_serving_chat: OpenAIServingChat = None
  39. openai_serving_completion: OpenAIServingCompletion = None
  40. router = APIRouter()
  41. kai_api = APIRouter()
  42. extra_api = APIRouter()
  43. kobold_lite_ui = ""
  44. sampler_json = ""
  45. gen_cache: dict = {}
  46. @asynccontextmanager
  47. async def lifespan(app: fastapi.FastAPI):
  48. async def _force_log():
  49. while True:
  50. await asyncio.sleep(10)
  51. await engine.do_log_stats()
  52. if not engine_args.disable_log_stats:
  53. asyncio.create_task(_force_log())
  54. yield
  55. # Add prometheus asgi middleware to route /metrics requests
  56. metrics_app = make_asgi_app()
  57. router.mount("/metrics", metrics_app)
  58. @router.get("/health")
  59. async def health() -> Response:
  60. """Health check."""
  61. await openai_serving_chat.engine.check_health()
  62. await openai_serving_completion.engine.check_health()
  63. return Response(status_code=200)
  64. @router.get("/v1/models")
  65. async def show_available_models(x_api_key: Optional[str] = Header(None)):
  66. models = await openai_serving_chat.show_available_models()
  67. return JSONResponse(content=models.model_dump())
  68. @router.post("/v1/tokenize")
  69. @router.post("/v1/token/encode")
  70. async def tokenize(request: Request,
  71. prompt: Prompt,
  72. x_api_key: Optional[str] = Header(None)):
  73. tokenized = await openai_serving_chat.tokenize(prompt)
  74. return JSONResponse(content=tokenized)
  75. @router.post("/v1/detokenize")
  76. @router.post("/v1/token/decode")
  77. async def detokenize(request: Request,
  78. token_ids: List[int],
  79. x_api_key: Optional[str] = Header(None)):
  80. detokenized = await openai_serving_chat.detokenize(token_ids)
  81. return JSONResponse(content=detokenized)
  82. @router.post("/v1/embeddings", response_model=EmbeddingsResponse)
  83. async def handle_embeddings(request: EmbeddingsRequest,
  84. x_api_key: Optional[str] = Header(None)):
  85. input = request.input
  86. if not input:
  87. raise JSONResponse(
  88. status_code=400,
  89. content={"error": "Missing required argument input"})
  90. model = request.model if request.model else None
  91. response = await OAIembeddings.embeddings(input, request.encoding_format,
  92. model)
  93. return JSONResponse(response)
  94. @router.get("/version", description="Fetch the Aphrodite Engine version.")
  95. async def show_version(x_api_key: Optional[str] = Header(None)):
  96. ver = {"version": aphrodite.__version__}
  97. return JSONResponse(content=ver)
  98. @router.get("/v1/samplers")
  99. async def show_samplers(x_api_key: Optional[str] = Header(None)):
  100. """Get the available samplers."""
  101. global sampler_json
  102. if not sampler_json:
  103. jsonpath = os.path.dirname(os.path.abspath(__file__))
  104. samplerpath = os.path.join(jsonpath, "./samplers.json")
  105. samplerpath = os.path.normpath(samplerpath) # Normalize the path
  106. if os.path.exists(samplerpath):
  107. with open(samplerpath, "r") as f:
  108. sampler_json = json.load(f)
  109. else:
  110. logger.error("Sampler JSON not found at " + samplerpath)
  111. return sampler_json
  112. @router.post("/v1/lora/load")
  113. async def load_lora(lora: LoRA, x_api_key: Optional[str] = Header(None)):
  114. openai_serving_chat.add_lora(lora)
  115. openai_serving_completion.add_lora(lora)
  116. if engine_args.enable_lora is False:
  117. logger.error("LoRA is not enabled in the engine. "
  118. "Please start the server with the "
  119. "--enable-lora flag.")
  120. return JSONResponse(content={"result": "success"})
  121. @router.delete("/v1/lora/unload")
  122. async def unload_lora(lora_name: str, x_api_key: Optional[str] = Header(None)):
  123. openai_serving_chat.remove_lora(lora_name)
  124. openai_serving_completion.remove_lora(lora_name)
  125. return JSONResponse(content={"result": "success"})
  126. @router.post("/v1/chat/completions")
  127. async def create_chat_completion(request: ChatCompletionRequest,
  128. raw_request: Request,
  129. x_api_key: Optional[str] = Header(None)):
  130. generator = await openai_serving_chat.create_chat_completion(
  131. request, raw_request)
  132. if isinstance(generator, ErrorResponse):
  133. return JSONResponse(content=generator.model_dump(),
  134. status_code=generator.code)
  135. if request.stream:
  136. return StreamingResponse(content=generator,
  137. media_type="text/event-stream")
  138. else:
  139. return JSONResponse(content=generator.model_dump())
  140. @router.post("/v1/completions")
  141. async def create_completion(request: CompletionRequest,
  142. raw_request: Request,
  143. x_api_key: Optional[str] = Header(None)):
  144. generator = await openai_serving_completion.create_completion(
  145. request, raw_request)
  146. if isinstance(generator, ErrorResponse):
  147. return JSONResponse(content=generator.model_dump(),
  148. status_code=generator.code)
  149. if request.stream:
  150. return StreamingResponse(content=generator,
  151. media_type="text/event-stream")
  152. else:
  153. return JSONResponse(content=generator.model_dump())
  154. # ============ KoboldAI API ============ #
  155. def _set_badwords(tokenizer, hf_config): # pylint: disable=redefined-outer-name
  156. # pylint: disable=global-variable-undefined
  157. global badwordsids
  158. if hf_config.bad_words_ids is not None:
  159. badwordsids = hf_config.bad_words_ids
  160. return
  161. badwordsids = [
  162. v for k, v in tokenizer.get_vocab().items()
  163. if any(c in str(k) for c in "[]")
  164. ]
  165. if tokenizer.pad_token_id in badwordsids:
  166. badwordsids.remove(tokenizer.pad_token_id)
  167. badwordsids.append(tokenizer.eos_token_id)
  168. def prepare_engine_payload(
  169. kai_payload: KAIGenerationInputSchema
  170. ) -> Tuple[SamplingParams, List[int]]:
  171. """Create SamplingParams and truncated input tokens for AsyncEngine"""
  172. if not kai_payload.genkey:
  173. kai_payload.genkey = f"kai-{random_uuid()}"
  174. # if kai_payload.max_context_length > engine_args.max_model_len:
  175. # raise ValueError(
  176. # f"max_context_length ({kai_payload.max_context_length}) "
  177. # "must be less than or equal to "
  178. # f"max_model_len ({engine_args.max_model_len})")
  179. kai_payload.top_k = kai_payload.top_k if kai_payload.top_k != 0.0 else -1
  180. kai_payload.tfs = max(_SAMPLING_EPS, kai_payload.tfs)
  181. if kai_payload.temperature < _SAMPLING_EPS:
  182. kai_payload.n = 1
  183. kai_payload.top_p = 1.0
  184. kai_payload.top_k = -1
  185. if kai_payload.dynatemp_range is not None:
  186. dynatemp_min = kai_payload.temperature - kai_payload.dynatemp_range
  187. dynatemp_max = kai_payload.temperature + kai_payload.dynatemp_range
  188. sampling_params = SamplingParams(
  189. n=kai_payload.n,
  190. best_of=kai_payload.n,
  191. repetition_penalty=kai_payload.rep_pen,
  192. temperature=kai_payload.temperature,
  193. dynatemp_min=dynatemp_min if kai_payload.dynatemp_range > 0 else 0.0,
  194. dynatemp_max=dynatemp_max if kai_payload.dynatemp_range > 0 else 0.0,
  195. dynatemp_exponent=kai_payload.dynatemp_exponent,
  196. smoothing_factor=kai_payload.smoothing_factor,
  197. smoothing_curve=kai_payload.smoothing_curve,
  198. tfs=kai_payload.tfs,
  199. top_p=kai_payload.top_p,
  200. top_k=kai_payload.top_k,
  201. top_a=kai_payload.top_a,
  202. min_p=kai_payload.min_p,
  203. typical_p=kai_payload.typical,
  204. eta_cutoff=kai_payload.eta_cutoff,
  205. epsilon_cutoff=kai_payload.eps_cutoff,
  206. mirostat_mode=kai_payload.mirostat,
  207. mirostat_tau=kai_payload.mirostat_tau,
  208. mirostat_eta=kai_payload.mirostat_eta,
  209. stop=kai_payload.stop_sequence,
  210. include_stop_str_in_output=kai_payload.include_stop_str_in_output,
  211. custom_token_bans=badwordsids
  212. if kai_payload.use_default_badwordsids else [],
  213. max_tokens=kai_payload.max_length,
  214. seed=kai_payload.sampler_seed,
  215. )
  216. max_input_tokens = max(
  217. 1, kai_payload.max_context_length - kai_payload.max_length)
  218. input_tokens = tokenizer(kai_payload.prompt).input_ids[-max_input_tokens:]
  219. return sampling_params, input_tokens
  220. @kai_api.post("/generate")
  221. async def generate(kai_payload: KAIGenerationInputSchema) -> JSONResponse:
  222. sampling_params, input_tokens = prepare_engine_payload(kai_payload)
  223. result_generator = engine.generate(None, sampling_params,
  224. kai_payload.genkey, input_tokens)
  225. final_res: RequestOutput = None
  226. previous_output = ""
  227. async for res in result_generator:
  228. final_res = res
  229. new_chunk = res.outputs[0].text[len(previous_output):]
  230. previous_output += new_chunk
  231. gen_cache[kai_payload.genkey] = previous_output
  232. assert final_res is not None
  233. del gen_cache[kai_payload.genkey]
  234. return JSONResponse(
  235. {"results": [{
  236. "text": output.text
  237. } for output in final_res.outputs]})
  238. @extra_api.post("/generate/stream")
  239. async def generate_stream(
  240. kai_payload: KAIGenerationInputSchema) -> StreamingResponse:
  241. sampling_params, input_tokens = prepare_engine_payload(kai_payload)
  242. results_generator = engine.generate(None, sampling_params,
  243. kai_payload.genkey, input_tokens)
  244. async def stream_kobold() -> AsyncGenerator[bytes, None]:
  245. previous_output = ""
  246. async for res in results_generator:
  247. new_chunk = res.outputs[0].text[len(previous_output):]
  248. previous_output += new_chunk
  249. yield b"event: message\n"
  250. yield f"data: {json.dumps({'token': new_chunk})}\n\n".encode()
  251. return StreamingResponse(stream_kobold(),
  252. headers={
  253. "Cache-Control": "no-cache",
  254. "Connection": "keep-alive",
  255. },
  256. media_type="text/event-stream")
  257. @extra_api.post("/generate/check")
  258. @extra_api.get("/generate/check")
  259. async def check_generation(request: Request):
  260. text = ""
  261. try:
  262. request_dict = await request.json()
  263. if "genkey" in request_dict and request_dict["genkey"] in gen_cache:
  264. text = gen_cache[request_dict["genkey"]]
  265. except json.JSONDecodeError:
  266. pass
  267. return JSONResponse({"results": [{"text": text}]})
  268. @extra_api.post("/abort")
  269. async def abort_generation(request: Request):
  270. try:
  271. request_dict = await request.json()
  272. if "genkey" in request_dict:
  273. await engine.abort(request_dict["genkey"])
  274. except json.JSONDecodeError:
  275. pass
  276. return JSONResponse({})
  277. @extra_api.post("/tokencount")
  278. async def count_tokens(request: Request):
  279. """Tokenize string and return token count"""
  280. request_dict = await request.json()
  281. tokenizer_result = await openai_serving_chat.tokenize(
  282. Prompt(**request_dict))
  283. return JSONResponse({"value": tokenizer_result["value"]})
  284. @kai_api.get("/info/version")
  285. async def get_version():
  286. """Impersonate KAI"""
  287. return JSONResponse({"result": "1.2.4"})
  288. @kai_api.get("/model")
  289. async def get_model():
  290. return JSONResponse({"result": f"aphrodite/{served_model_names[0]}"})
  291. @kai_api.get("/config/soft_prompts_list")
  292. async def get_available_softprompts():
  293. """Stub for compatibility"""
  294. return JSONResponse({"values": []})
  295. @kai_api.get("/config/soft_prompt")
  296. async def get_current_softprompt():
  297. """Stub for compatibility"""
  298. return JSONResponse({"value": ""})
  299. @kai_api.put("/config/soft_prompt")
  300. async def set_current_softprompt():
  301. """Stub for compatibility"""
  302. return JSONResponse({})
  303. @kai_api.get("/config/max_length")
  304. async def get_max_length() -> JSONResponse:
  305. max_length = args.max_length
  306. return JSONResponse({"value": max_length})
  307. @kai_api.get("/config/max_context_length")
  308. @extra_api.get("/true_max_context_length")
  309. async def get_max_context_length() -> JSONResponse:
  310. max_context_length = engine_args.max_model_len
  311. return JSONResponse({"value": max_context_length})
  312. @extra_api.get("/preloadstory")
  313. async def get_preloaded_story() -> JSONResponse:
  314. """Stub for compatibility"""
  315. return JSONResponse({})
  316. @extra_api.get("/version")
  317. async def get_extra_version():
  318. """Impersonate KoboldCpp"""
  319. return JSONResponse({"result": "KoboldCpp", "version": "1.63"})
  320. @router.get("/")
  321. async def get_kobold_lite_ui():
  322. """Serves a cached copy of the Kobold Lite UI, loading it from disk
  323. on demand if needed."""
  324. global kobold_lite_ui
  325. if kobold_lite_ui == "":
  326. scriptpath = os.path.dirname(os.path.abspath(__file__))
  327. klitepath = os.path.join(scriptpath, "../kobold/klite.embd")
  328. klitepath = os.path.normpath(klitepath) # Normalize the path
  329. if os.path.exists(klitepath):
  330. with open(klitepath, "r") as f:
  331. kobold_lite_ui = f.read()
  332. else:
  333. logger.error("Kobold Lite UI not found at " + klitepath)
  334. return HTMLResponse(content=kobold_lite_ui)
  335. # ============ KoboldAI API ============ #
  336. def build_app(args):
  337. app = fastapi.FastAPI(lifespan=lifespan)
  338. app.include_router(router)
  339. app.root_path = args.root_path
  340. if args.launch_kobold_api:
  341. logger.warning("Launching Kobold API server in addition to OpenAI. "
  342. "Keep in mind that the Kobold API routes are NOT "
  343. "protected via the API key.")
  344. app.include_router(kai_api, prefix="/api/v1")
  345. app.include_router(kai_api,
  346. prefix="/api/latest",
  347. include_in_schema=False)
  348. app.include_router(extra_api, prefix="/api/extra")
  349. app.add_middleware(
  350. CORSMiddleware,
  351. allow_origins=args.allowed_origins,
  352. allow_credentials=args.allow_credentials,
  353. allow_methods=args.allowed_methods,
  354. allow_headers=args.allowed_headers,
  355. )
  356. @app.exception_handler(RequestValidationError)
  357. async def validation_exception_handler(_, exc):
  358. err = openai_serving_completion.create_error_response(message=str(exc))
  359. return JSONResponse(err.model_dump(),
  360. status_code=HTTPStatus.BAD_REQUEST)
  361. if token := os.environ.get("APHRODITE_API_KEY") or args.api_keys:
  362. admin_key = os.environ.get("APHRODITE_ADMIN_KEY") or args.admin_key
  363. if admin_key is None:
  364. logger.warning("Admin key not provided. Admin operations will "
  365. "be disabled.")
  366. @app.middleware("http")
  367. async def authentication(request: Request, call_next):
  368. excluded_paths = ["/api"]
  369. if any(
  370. request.url.path.startswith(path)
  371. for path in excluded_paths):
  372. return await call_next(request)
  373. if not request.url.path.startswith("/v1"):
  374. return await call_next(request)
  375. # Browsers may send OPTIONS requests to check CORS headers
  376. # before sending the actual request. We should allow these
  377. # requests to pass through without authentication.
  378. # See https://github.com/PygmalionAI/aphrodite-engine/issues/434
  379. if request.method == "OPTIONS":
  380. return await call_next(request)
  381. auth_header = request.headers.get("Authorization")
  382. api_key_header = request.headers.get("x-api-key")
  383. if request.url.path.startswith("/v1/lora"):
  384. if admin_key is not None and api_key_header == admin_key:
  385. return await call_next(request)
  386. return JSONResponse(content={"error": "Unauthorized"},
  387. status_code=401)
  388. if auth_header != "Bearer " + token and api_key_header != token:
  389. return JSONResponse(content={"error": "Unauthorized"},
  390. status_code=401)
  391. return await call_next(request)
  392. for middleware in args.middleware:
  393. module_path, object_name = middleware.rsplit(".", 1)
  394. imported = getattr(importlib.import_module(module_path), object_name)
  395. if inspect.isclass(imported):
  396. app.add_middleware(imported)
  397. elif inspect.iscoroutinefunction(imported):
  398. app.middleware("http")(imported)
  399. else:
  400. raise ValueError(f"Invalid middleware {middleware}. "
  401. f"Must be a function or a class.")
  402. return app
  403. def run_server(args):
  404. app = build_app(args)
  405. logger.debug(f"args: {args}")
  406. global engine, engine_args, openai_serving_chat, openai_serving_completion,\
  407. tokenizer, served_model_names
  408. if args.served_model_name is not None:
  409. served_model_names = args.served_model_name
  410. else:
  411. served_model_names = [args.model]
  412. engine_args = AsyncEngineArgs.from_cli_args(args)
  413. engine = AsyncAphrodite.from_engine_args(engine_args)
  414. tokenizer = get_tokenizer(
  415. engine_args.tokenizer,
  416. tokenizer_mode=engine_args.tokenizer_mode,
  417. trust_remote_code=engine_args.trust_remote_code,
  418. revision=engine_args.revision,
  419. )
  420. chat_template = args.chat_template
  421. if chat_template is None and tokenizer.chat_template is not None:
  422. chat_template = tokenizer.chat_template
  423. openai_serving_chat = OpenAIServingChat(engine, served_model_names,
  424. args.response_role,
  425. args.lora_modules,
  426. args.chat_template)
  427. openai_serving_completion = OpenAIServingCompletion(
  428. engine, served_model_names, args.lora_modules)
  429. engine_model_config = asyncio.run(engine.get_model_config())
  430. if args.launch_kobold_api:
  431. _set_badwords(tokenizer, engine_model_config.hf_config)
  432. try:
  433. uvicorn.run(app,
  434. host=args.host,
  435. port=args.port,
  436. log_level="info",
  437. timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
  438. ssl_keyfile=args.ssl_keyfile,
  439. ssl_certfile=args.ssl_certfile,
  440. log_config=UVICORN_LOG_CONFIG)
  441. except KeyboardInterrupt:
  442. logger.info("API server stopped by user. Exiting.")
  443. except asyncio.exceptions.CancelledError:
  444. logger.info("API server stopped due to a cancelled request. Exiting.")
  445. if __name__ == "__main__":
  446. # NOTE:
  447. # This section should be in sync with aphrodite/endpoints/cli.py
  448. # for CLI entrypoints.
  449. parser = make_arg_parser()
  450. args = parser.parse_args()
  451. run_server(args)