api_server.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589
  1. import asyncio
  2. import importlib
  3. import inspect
  4. import json
  5. import os
  6. import re
  7. import socket
  8. from contextlib import asynccontextmanager
  9. from http import HTTPStatus
  10. from typing import AsyncGenerator, List, Optional, Set, Tuple
  11. import fastapi
  12. import uvicorn
  13. import uvloop
  14. from fastapi import APIRouter, Header, Request
  15. from fastapi.exceptions import RequestValidationError
  16. from fastapi.middleware.cors import CORSMiddleware
  17. from fastapi.responses import (HTMLResponse, JSONResponse, Response,
  18. StreamingResponse)
  19. from loguru import logger
  20. from prometheus_client import make_asgi_app
  21. from starlette.routing import Mount
  22. import aphrodite
  23. from aphrodite.common.logger import UVICORN_LOG_CONFIG
  24. from aphrodite.common.outputs import RequestOutput
  25. from aphrodite.common.sampling_params import _SAMPLING_EPS, SamplingParams
  26. from aphrodite.common.utils import random_uuid
  27. from aphrodite.endpoints.openai.args import make_arg_parser
  28. from aphrodite.endpoints.openai.protocol import (
  29. ChatCompletionRequest, CompletionRequest, EmbeddingRequest, ErrorResponse,
  30. KAIGenerationInputSchema, Prompt)
  31. from aphrodite.endpoints.openai.serving_chat import OpenAIServingChat
  32. from aphrodite.endpoints.openai.serving_completions import \
  33. OpenAIServingCompletion
  34. from aphrodite.endpoints.openai.serving_embedding import OpenAIServingEmbedding
  35. from aphrodite.endpoints.openai.serving_engine import LoRA
  36. from aphrodite.engine.args_tools import AsyncEngineArgs
  37. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  38. from aphrodite.transformers_utils.tokenizer import get_tokenizer
  39. TIMEOUT_KEEP_ALIVE = 5 # seconds
  40. engine: Optional[AsyncAphrodite] = None
  41. engine_args: Optional[AsyncEngineArgs] = None
  42. openai_serving_chat: OpenAIServingChat = None
  43. openai_serving_completion: OpenAIServingCompletion = None
  44. openai_serving_embedding: OpenAIServingEmbedding = None
  45. router = APIRouter()
  46. kai_api = APIRouter()
  47. extra_api = APIRouter()
  48. kobold_lite_ui = ""
  49. sampler_json = ""
  50. gen_cache: dict = {}
  51. _running_tasks: Set[asyncio.Task] = set()
  52. @asynccontextmanager
  53. async def lifespan(app: fastapi.FastAPI):
  54. async def _force_log():
  55. while True:
  56. await asyncio.sleep(10)
  57. await engine.do_log_stats()
  58. if not engine_args.disable_log_stats:
  59. task = asyncio.create_task(_force_log())
  60. _running_tasks.add(task)
  61. task.add_done_callback(_running_tasks.remove)
  62. yield
  63. def find_available_port(starting_port):
  64. port = starting_port
  65. while True:
  66. try:
  67. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
  68. s.bind(("", port))
  69. return port
  70. except OSError:
  71. logger.warning(f"Port {port} is already in use. Trying next port.")
  72. port += 1
  73. if port > 65535:
  74. port = 2242
  75. @router.get("/health")
  76. async def health() -> Response:
  77. """Health check."""
  78. await openai_serving_chat.engine.check_health()
  79. await openai_serving_completion.engine.check_health()
  80. return Response(status_code=200)
  81. @router.get("/v1/models")
  82. async def show_available_models(x_api_key: Optional[str] = Header(None)):
  83. models = await openai_serving_chat.show_available_models()
  84. return JSONResponse(content=models.model_dump())
  85. @router.post("/v1/tokenize")
  86. @router.post("/v1/token/encode")
  87. async def tokenize(request: Request,
  88. prompt: Prompt,
  89. x_api_key: Optional[str] = Header(None)):
  90. tokenized = await openai_serving_chat.tokenize(prompt)
  91. return JSONResponse(content=tokenized)
  92. @router.post("/v1/detokenize")
  93. @router.post("/v1/token/decode")
  94. async def detokenize(request: Request,
  95. token_ids: List[int],
  96. x_api_key: Optional[str] = Header(None)):
  97. detokenized = await openai_serving_chat.detokenize(token_ids)
  98. return JSONResponse(content=detokenized)
  99. @router.post("/v1/embeddings")
  100. async def create_embeddings(request: EmbeddingRequest,
  101. raw_request: Request,
  102. x_api_key: Optional[str] = Header(None)):
  103. generator = await openai_serving_embedding.create_embedding(
  104. request, raw_request)
  105. if isinstance(generator, ErrorResponse):
  106. return JSONResponse(content=generator.model_dump(),
  107. status_code=generator.code)
  108. else:
  109. return JSONResponse(content=generator.model_dump())
  110. @router.get("/version", description="Fetch the Aphrodite Engine version.")
  111. async def show_version(x_api_key: Optional[str] = Header(None)):
  112. ver = {"version": aphrodite.__version__}
  113. return JSONResponse(content=ver)
  114. @router.get("/v1/samplers")
  115. async def show_samplers(x_api_key: Optional[str] = Header(None)):
  116. """Get the available samplers."""
  117. global sampler_json
  118. if not sampler_json:
  119. jsonpath = os.path.dirname(os.path.abspath(__file__))
  120. samplerpath = os.path.join(jsonpath, "./samplers.json")
  121. samplerpath = os.path.normpath(samplerpath) # Normalize the path
  122. if os.path.exists(samplerpath):
  123. with open(samplerpath, "r") as f:
  124. sampler_json = json.load(f)
  125. else:
  126. logger.error("Sampler JSON not found at " + samplerpath)
  127. return sampler_json
  128. @router.post("/v1/lora/load")
  129. async def load_lora(lora: LoRA, x_api_key: Optional[str] = Header(None)):
  130. openai_serving_chat.add_lora(lora)
  131. openai_serving_completion.add_lora(lora)
  132. if engine_args.enable_lora is False:
  133. logger.error("LoRA is not enabled in the engine. "
  134. "Please start the server with the "
  135. "--enable-lora flag.")
  136. return JSONResponse(content={"result": "success"})
  137. @router.delete("/v1/lora/unload")
  138. async def unload_lora(lora_name: str, x_api_key: Optional[str] = Header(None)):
  139. openai_serving_chat.remove_lora(lora_name)
  140. openai_serving_completion.remove_lora(lora_name)
  141. return JSONResponse(content={"result": "success"})
  142. @router.post("/v1/chat/completions")
  143. async def create_chat_completion(request: ChatCompletionRequest,
  144. raw_request: Request,
  145. x_api_key: Optional[str] = Header(None)):
  146. generator = await openai_serving_chat.create_chat_completion(
  147. request, raw_request)
  148. if isinstance(generator, ErrorResponse):
  149. return JSONResponse(content=generator.model_dump(),
  150. status_code=generator.code)
  151. if request.stream:
  152. return StreamingResponse(content=generator,
  153. media_type="text/event-stream")
  154. else:
  155. return JSONResponse(content=generator.model_dump())
  156. @router.post("/v1/completions")
  157. async def create_completion(request: CompletionRequest,
  158. raw_request: Request,
  159. x_api_key: Optional[str] = Header(None)):
  160. generator = await openai_serving_completion.create_completion(
  161. request, raw_request)
  162. if isinstance(generator, ErrorResponse):
  163. return JSONResponse(content=generator.model_dump(),
  164. status_code=generator.code)
  165. if request.stream:
  166. return StreamingResponse(content=generator,
  167. media_type="text/event-stream")
  168. else:
  169. return JSONResponse(content=generator.model_dump())
  170. # ============ KoboldAI API ============ #
  171. def _set_badwords(tokenizer, hf_config): # pylint: disable=redefined-outer-name
  172. # pylint: disable=global-variable-undefined
  173. global badwordsids
  174. if hf_config.bad_words_ids is not None:
  175. badwordsids = hf_config.bad_words_ids
  176. return
  177. badwordsids = [
  178. v for k, v in tokenizer.get_vocab().items()
  179. if any(c in str(k) for c in "[]")
  180. ]
  181. if tokenizer.pad_token_id in badwordsids:
  182. badwordsids.remove(tokenizer.pad_token_id)
  183. badwordsids.append(tokenizer.eos_token_id)
  184. def prepare_engine_payload(
  185. kai_payload: KAIGenerationInputSchema
  186. ) -> Tuple[SamplingParams, List[int]]:
  187. """Create SamplingParams and truncated input tokens for AsyncEngine"""
  188. if not kai_payload.genkey:
  189. kai_payload.genkey = f"kai-{random_uuid()}"
  190. # if kai_payload.max_context_length > engine_args.max_model_len:
  191. # raise ValueError(
  192. # f"max_context_length ({kai_payload.max_context_length}) "
  193. # "must be less than or equal to "
  194. # f"max_model_len ({engine_args.max_model_len})")
  195. kai_payload.top_k = kai_payload.top_k if kai_payload.top_k != 0.0 else -1
  196. kai_payload.tfs = max(_SAMPLING_EPS, kai_payload.tfs)
  197. if kai_payload.temperature < _SAMPLING_EPS:
  198. kai_payload.n = 1
  199. kai_payload.top_p = 1.0
  200. kai_payload.top_k = -1
  201. if kai_payload.dynatemp_range is not None:
  202. dynatemp_min = kai_payload.temperature - kai_payload.dynatemp_range
  203. dynatemp_max = kai_payload.temperature + kai_payload.dynatemp_range
  204. sampling_params = SamplingParams(
  205. n=kai_payload.n,
  206. best_of=kai_payload.n,
  207. repetition_penalty=kai_payload.rep_pen,
  208. temperature=kai_payload.temperature,
  209. dynatemp_min=dynatemp_min if kai_payload.dynatemp_range > 0 else 0.0,
  210. dynatemp_max=dynatemp_max if kai_payload.dynatemp_range > 0 else 0.0,
  211. dynatemp_exponent=kai_payload.dynatemp_exponent,
  212. smoothing_factor=kai_payload.smoothing_factor,
  213. smoothing_curve=kai_payload.smoothing_curve,
  214. tfs=kai_payload.tfs,
  215. top_p=kai_payload.top_p,
  216. top_k=kai_payload.top_k,
  217. top_a=kai_payload.top_a,
  218. min_p=kai_payload.min_p,
  219. typical_p=kai_payload.typical,
  220. eta_cutoff=kai_payload.eta_cutoff,
  221. epsilon_cutoff=kai_payload.eps_cutoff,
  222. mirostat_mode=kai_payload.mirostat,
  223. mirostat_tau=kai_payload.mirostat_tau,
  224. mirostat_eta=kai_payload.mirostat_eta,
  225. stop=kai_payload.stop_sequence,
  226. include_stop_str_in_output=kai_payload.include_stop_str_in_output,
  227. custom_token_bans=badwordsids
  228. if kai_payload.use_default_badwordsids else [],
  229. max_tokens=kai_payload.max_length,
  230. seed=kai_payload.sampler_seed,
  231. )
  232. max_input_tokens = max(
  233. 1, kai_payload.max_context_length - kai_payload.max_length)
  234. input_tokens = tokenizer(kai_payload.prompt).input_ids[-max_input_tokens:]
  235. return sampling_params, input_tokens
  236. @kai_api.post("/generate")
  237. async def generate(kai_payload: KAIGenerationInputSchema) -> JSONResponse:
  238. sampling_params, input_tokens = prepare_engine_payload(kai_payload)
  239. result_generator = engine.generate(None, sampling_params,
  240. kai_payload.genkey, input_tokens)
  241. final_res: RequestOutput = None
  242. previous_output = ""
  243. async for res in result_generator:
  244. final_res = res
  245. new_chunk = res.outputs[0].text[len(previous_output):]
  246. previous_output += new_chunk
  247. gen_cache[kai_payload.genkey] = previous_output
  248. assert final_res is not None
  249. del gen_cache[kai_payload.genkey]
  250. return JSONResponse(
  251. {"results": [{
  252. "text": output.text
  253. } for output in final_res.outputs]})
  254. @extra_api.post("/generate/stream")
  255. async def generate_stream(
  256. kai_payload: KAIGenerationInputSchema) -> StreamingResponse:
  257. sampling_params, input_tokens = prepare_engine_payload(kai_payload)
  258. results_generator = engine.generate(None, sampling_params,
  259. kai_payload.genkey, input_tokens)
  260. async def stream_kobold() -> AsyncGenerator[bytes, None]:
  261. previous_output = ""
  262. async for res in results_generator:
  263. new_chunk = res.outputs[0].text[len(previous_output):]
  264. previous_output += new_chunk
  265. yield b"event: message\n"
  266. yield f"data: {json.dumps({'token': new_chunk})}\n\n".encode()
  267. return StreamingResponse(stream_kobold(),
  268. headers={
  269. "Cache-Control": "no-cache",
  270. "Connection": "keep-alive",
  271. },
  272. media_type="text/event-stream")
  273. @extra_api.post("/generate/check")
  274. @extra_api.get("/generate/check")
  275. async def check_generation(request: Request):
  276. text = ""
  277. try:
  278. request_dict = await request.json()
  279. if "genkey" in request_dict and request_dict["genkey"] in gen_cache:
  280. text = gen_cache[request_dict["genkey"]]
  281. except json.JSONDecodeError:
  282. pass
  283. return JSONResponse({"results": [{"text": text}]})
  284. @extra_api.post("/abort")
  285. async def abort_generation(request: Request):
  286. try:
  287. request_dict = await request.json()
  288. if "genkey" in request_dict:
  289. await engine.abort(request_dict["genkey"])
  290. except json.JSONDecodeError:
  291. pass
  292. return JSONResponse({})
  293. @extra_api.post("/tokencount")
  294. async def count_tokens(request: Request):
  295. """Tokenize string and return token count"""
  296. request_dict = await request.json()
  297. tokenizer_result = await openai_serving_chat.tokenize(
  298. Prompt(**request_dict))
  299. return JSONResponse({"value": tokenizer_result["value"]})
  300. @kai_api.get("/info/version")
  301. async def get_version():
  302. """Impersonate KAI"""
  303. return JSONResponse({"result": "1.2.4"})
  304. @kai_api.get("/model")
  305. async def get_model():
  306. return JSONResponse({"result": f"aphrodite/{served_model_names[0]}"})
  307. @kai_api.get("/config/soft_prompts_list")
  308. async def get_available_softprompts():
  309. """Stub for compatibility"""
  310. return JSONResponse({"values": []})
  311. @kai_api.get("/config/soft_prompt")
  312. async def get_current_softprompt():
  313. """Stub for compatibility"""
  314. return JSONResponse({"value": ""})
  315. @kai_api.put("/config/soft_prompt")
  316. async def set_current_softprompt():
  317. """Stub for compatibility"""
  318. return JSONResponse({})
  319. @kai_api.get("/config/max_length")
  320. async def get_max_length() -> JSONResponse:
  321. max_length = args.max_length
  322. return JSONResponse({"value": max_length})
  323. @kai_api.get("/config/max_context_length")
  324. @extra_api.get("/true_max_context_length")
  325. async def get_max_context_length() -> JSONResponse:
  326. max_context_length = engine_args.max_model_len
  327. return JSONResponse({"value": max_context_length})
  328. @extra_api.get("/preloadstory")
  329. async def get_preloaded_story() -> JSONResponse:
  330. """Stub for compatibility"""
  331. return JSONResponse({})
  332. @extra_api.get("/version")
  333. async def get_extra_version():
  334. """Impersonate KoboldCpp"""
  335. return JSONResponse({"result": "KoboldCpp", "version": "1.63"})
  336. @router.get("/")
  337. async def get_kobold_lite_ui():
  338. """Serves a cached copy of the Kobold Lite UI, loading it from disk
  339. on demand if needed."""
  340. global kobold_lite_ui
  341. if kobold_lite_ui == "":
  342. scriptpath = os.path.dirname(os.path.abspath(__file__))
  343. klitepath = os.path.join(scriptpath, "../kobold/klite.embd")
  344. klitepath = os.path.normpath(klitepath) # Normalize the path
  345. if os.path.exists(klitepath):
  346. with open(klitepath, "r") as f:
  347. kobold_lite_ui = f.read()
  348. else:
  349. logger.error("Kobold Lite UI not found at " + klitepath)
  350. return HTMLResponse(content=kobold_lite_ui)
  351. # ============ KoboldAI API ============ #
  352. def build_app(args):
  353. app = fastapi.FastAPI(lifespan=lifespan)
  354. app.include_router(router)
  355. # Add prometheus asgi middleware to route /metrics requests
  356. route = Mount("/metrics", make_asgi_app())
  357. route.path_regex = re.compile('^/metrics(?P<path>.*)$')
  358. app.routes.append(route)
  359. app.root_path = args.root_path
  360. if args.launch_kobold_api:
  361. logger.warning("Launching Kobold API server in addition to OpenAI. "
  362. "Keep in mind that the Kobold API routes are NOT "
  363. "protected via the API key.")
  364. app.include_router(kai_api, prefix="/api/v1")
  365. app.include_router(kai_api,
  366. prefix="/api/latest",
  367. include_in_schema=False)
  368. app.include_router(extra_api, prefix="/api/extra")
  369. app.add_middleware(
  370. CORSMiddleware,
  371. allow_origins=args.allowed_origins,
  372. allow_credentials=args.allow_credentials,
  373. allow_methods=args.allowed_methods,
  374. allow_headers=args.allowed_headers,
  375. )
  376. @app.exception_handler(RequestValidationError)
  377. async def validation_exception_handler(_, exc):
  378. err = openai_serving_completion.create_error_response(message=str(exc))
  379. return JSONResponse(err.model_dump(),
  380. status_code=HTTPStatus.BAD_REQUEST)
  381. if token := os.environ.get("APHRODITE_API_KEY") or args.api_keys:
  382. admin_key = os.environ.get("APHRODITE_ADMIN_KEY") or args.admin_key
  383. if admin_key is None:
  384. logger.warning("Admin key not provided. Admin operations will "
  385. "be disabled.")
  386. @app.middleware("http")
  387. async def authentication(request: Request, call_next):
  388. excluded_paths = ["/api"]
  389. if any(
  390. request.url.path.startswith(path)
  391. for path in excluded_paths):
  392. return await call_next(request)
  393. if not request.url.path.startswith("/v1"):
  394. return await call_next(request)
  395. # Browsers may send OPTIONS requests to check CORS headers
  396. # before sending the actual request. We should allow these
  397. # requests to pass through without authentication.
  398. # See https://github.com/PygmalionAI/aphrodite-engine/issues/434
  399. if request.method == "OPTIONS":
  400. return await call_next(request)
  401. auth_header = request.headers.get("Authorization")
  402. api_key_header = request.headers.get("x-api-key")
  403. if request.url.path.startswith("/v1/lora"):
  404. if admin_key is not None and api_key_header == admin_key:
  405. return await call_next(request)
  406. return JSONResponse(content={"error": "Unauthorized"},
  407. status_code=401)
  408. if auth_header != "Bearer " + token and api_key_header != token:
  409. return JSONResponse(content={"error": "Unauthorized"},
  410. status_code=401)
  411. return await call_next(request)
  412. for middleware in args.middleware:
  413. module_path, object_name = middleware.rsplit(".", 1)
  414. imported = getattr(importlib.import_module(module_path), object_name)
  415. if inspect.isclass(imported):
  416. app.add_middleware(imported)
  417. elif inspect.iscoroutinefunction(imported):
  418. app.middleware("http")(imported)
  419. else:
  420. raise ValueError(f"Invalid middleware {middleware}. "
  421. f"Must be a function or a class.")
  422. return app
  423. def run_server(args):
  424. app = build_app(args)
  425. logger.debug(f"args: {args}")
  426. global engine, engine_args, openai_serving_chat, openai_serving_completion,\
  427. tokenizer, served_model_names, openai_serving_embedding
  428. if args.served_model_name is not None:
  429. served_model_names = args.served_model_name
  430. else:
  431. served_model_names = [args.model]
  432. if args.uvloop:
  433. uvloop.install()
  434. engine_args = AsyncEngineArgs.from_cli_args(args)
  435. engine = AsyncAphrodite.from_engine_args(engine_args)
  436. tokenizer = get_tokenizer(
  437. engine_args.tokenizer,
  438. tokenizer_mode=engine_args.tokenizer_mode,
  439. trust_remote_code=engine_args.trust_remote_code,
  440. revision=engine_args.revision,
  441. )
  442. chat_template = args.chat_template
  443. if chat_template is None and tokenizer.chat_template is not None:
  444. chat_template = tokenizer.chat_template
  445. openai_serving_chat = OpenAIServingChat(engine, served_model_names,
  446. args.response_role,
  447. args.lora_modules,
  448. args.chat_template)
  449. openai_serving_completion = OpenAIServingCompletion(
  450. engine, served_model_names, args.lora_modules)
  451. openai_serving_embedding = OpenAIServingEmbedding(engine,
  452. served_model_names)
  453. engine_model_config = asyncio.run(engine.get_model_config())
  454. if args.launch_kobold_api:
  455. _set_badwords(tokenizer, engine_model_config.hf_config)
  456. available_port = find_available_port(args.port)
  457. try:
  458. uvicorn.run(app,
  459. host=args.host,
  460. port=available_port,
  461. log_level="info",
  462. timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
  463. ssl_keyfile=args.ssl_keyfile,
  464. ssl_certfile=args.ssl_certfile,
  465. log_config=UVICORN_LOG_CONFIG)
  466. except KeyboardInterrupt:
  467. logger.info("API server stopped by user. Exiting.")
  468. except asyncio.exceptions.CancelledError:
  469. logger.info("API server stopped due to a cancelled request. Exiting.")
  470. if __name__ == "__main__":
  471. # NOTE:
  472. # This section should be in sync with aphrodite/endpoints/cli.py
  473. # for CLI entrypoints.
  474. parser = make_arg_parser()
  475. args = parser.parse_args()
  476. run_server(args)