# Adapted from openai/api_server.py and tgi-kai-bridge import argparse import asyncio import json import os from http import HTTPStatus from typing import List, Tuple, AsyncGenerator from prometheus_client import make_asgi_app import uvicorn import fastapi from fastapi import APIRouter, Request, Response from fastapi.responses import JSONResponse, StreamingResponse, HTMLResponse from fastapi.middleware.cors import CORSMiddleware from loguru import logger from aphrodite.engine.args_tools import AsyncEngineArgs from aphrodite.engine.async_aphrodite import AsyncAphrodite from aphrodite.common.outputs import RequestOutput from aphrodite.common.sampling_params import SamplingParams, _SAMPLING_EPS from aphrodite.transformers_utils.tokenizer import get_tokenizer from aphrodite.common.utils import random_uuid from aphrodite.endpoints.kobold.protocol import KAIGenerationInputSchema TIMEOUT_KEEP_ALIVE = 5 # seconds served_model: str = "Read Only" engine: AsyncAphrodite = None gen_cache: dict = {} app = fastapi.FastAPI() badwordsids: List[int] = [] # Add prometheus asgi middleware to route /metrics/ requests metrics_app = make_asgi_app() app.mount("/metrics/", metrics_app) def _set_badwords(tokenizer, hf_config): # pylint: disable=redefined-outer-name global badwordsids if hf_config.bad_words_ids is not None: badwordsids = hf_config.bad_words_ids return badwordsids = [ v for k, v in tokenizer.get_vocab().items() if any(c in str(k) for c in "[]") ] if tokenizer.pad_token_id in badwordsids: badwordsids.remove(tokenizer.pad_token_id) badwordsids.append(tokenizer.eos_token_id) kai_api = APIRouter() extra_api = APIRouter() kobold_lite_ui = "" app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) def create_error_response(status_code: HTTPStatus, message: str) -> JSONResponse: return JSONResponse({ "msg": message, "type": "invalid_request_error" }, status_code=status_code.value) @app.exception_handler(ValueError) def validation_exception_handler(request, exc): # pylint: disable=unused-argument return create_error_response(HTTPStatus.UNPROCESSABLE_ENTITY, str(exc)) def prepare_engine_payload( kai_payload: KAIGenerationInputSchema ) -> Tuple[SamplingParams, List[int]]: """Create SamplingParams and truncated input tokens for AsyncEngine""" if not kai_payload.genkey: kai_payload.genkey = f"kai-{random_uuid()}" if kai_payload.max_context_length > max_model_len: raise ValueError( f"max_context_length ({kai_payload.max_context_length}) " "must be less than or equal to " f"max_model_len ({max_model_len})") # KAIspec: top_k == 0 means disabled, aphrodite: top_k == -1 means disabled # https://github.com/KoboldAI/KoboldAI-Client/wiki/Settings kai_payload.top_k = kai_payload.top_k if kai_payload.top_k != 0.0 else -1 kai_payload.tfs = max(_SAMPLING_EPS, kai_payload.tfs) if kai_payload.temperature < _SAMPLING_EPS: # temp < _SAMPLING_EPS: greedy sampling kai_payload.n = 1 kai_payload.top_p = 1.0 kai_payload.top_k = -1 if kai_payload.dynatemp_range is not None: dynatemp_min = kai_payload.temperature - kai_payload.dynatemp_range dynatemp_max = kai_payload.temperature + kai_payload.dynatemp_range sampling_params = SamplingParams( n=kai_payload.n, best_of=kai_payload.n, repetition_penalty=kai_payload.rep_pen, temperature=kai_payload.temperature, dynatemp_min=dynatemp_min if kai_payload.dynatemp_range > 0 else 0.0, dynatemp_max=dynatemp_max if kai_payload.dynatemp_range > 0 else 0.0, dynatemp_exponent=kai_payload.dynatemp_exponent, smoothing_factor=kai_payload.smoothing_factor, smoothing_curve=kai_payload.smoothing_curve, tfs=kai_payload.tfs, top_p=kai_payload.top_p, top_k=kai_payload.top_k, top_a=kai_payload.top_a, min_p=kai_payload.min_p, typical_p=kai_payload.typical, eta_cutoff=kai_payload.eta_cutoff, epsilon_cutoff=kai_payload.eps_cutoff, mirostat_mode=kai_payload.mirostat, mirostat_tau=kai_payload.mirostat_tau, mirostat_eta=kai_payload.mirostat_eta, stop=kai_payload.stop_sequence, include_stop_str_in_output=kai_payload.include_stop_str_in_output, custom_token_bans=badwordsids if kai_payload.use_default_badwordsids else [], max_tokens=kai_payload.max_length, seed=kai_payload.sampler_seed, ) max_input_tokens = max( 1, kai_payload.max_context_length - kai_payload.max_length) input_tokens = tokenizer(kai_payload.prompt).input_ids[-max_input_tokens:] return sampling_params, input_tokens @kai_api.post("/generate") async def generate(kai_payload: KAIGenerationInputSchema) -> JSONResponse: """Generate text""" sampling_params, input_tokens = prepare_engine_payload(kai_payload) result_generator = engine.generate(None, sampling_params, kai_payload.genkey, input_tokens) final_res: RequestOutput = None previous_output = "" async for res in result_generator: final_res = res new_chunk = res.outputs[0].text[len(previous_output):] previous_output += new_chunk gen_cache[kai_payload.genkey] = previous_output assert final_res is not None del gen_cache[kai_payload.genkey] return JSONResponse( {"results": [{ "text": output.text } for output in final_res.outputs]}) @extra_api.post("/generate/stream") async def generate_stream( kai_payload: KAIGenerationInputSchema) -> StreamingResponse: """Generate text SSE streaming""" sampling_params, input_tokens = prepare_engine_payload(kai_payload) results_generator = engine.generate(None, sampling_params, kai_payload.genkey, input_tokens) async def stream_kobold() -> AsyncGenerator[bytes, None]: previous_output = "" async for res in results_generator: new_chunk = res.outputs[0].text[len(previous_output):] previous_output += new_chunk yield b"event: message\n" yield f"data: {json.dumps({'token': new_chunk})}\n\n".encode() return StreamingResponse(stream_kobold(), headers={ "Cache-Control": "no-cache", "Connection": "keep-alive" }, media_type="text/event-stream") @extra_api.post("/generate/check") @extra_api.get("/generate/check") async def check_generation(request: Request): """Check outputs in progress (poll streaming)""" text = "" try: request_dict = await request.json() if "genkey" in request_dict and request_dict["genkey"] in gen_cache: text = gen_cache[request_dict["genkey"]] except json.JSONDecodeError: pass return JSONResponse({"results": [{"text": text}]}) @extra_api.post("/abort") async def abort_generation(request: Request): """Abort running generation""" try: request_dict = await request.json() if "genkey" in request_dict: await engine.abort(request_dict["genkey"]) except json.JSONDecodeError: pass return JSONResponse({}) @extra_api.post("/tokencount") async def count_tokens(request: Request): """Tokenize string and return token count""" request_dict = await request.json() tokenizer_result = tokenizer(request_dict["prompt"]) return JSONResponse({"value": len(tokenizer_result.input_ids)}) @kai_api.get("/info/version") async def get_version(): """Impersonate KAI""" return JSONResponse({"result": "1.2.4"}) @kai_api.get("/model") async def get_model(): """Get current model""" return JSONResponse({"result": f"aphrodite/{served_model}"}) @kai_api.get("/config/soft_prompts_list") async def get_available_softprompts(): """Stub for compatibility""" return JSONResponse({"values": []}) @kai_api.get("/config/soft_prompt") async def get_current_softprompt(): """Stub for compatibility""" return JSONResponse({"value": ""}) @kai_api.put("/config/soft_prompt") async def set_current_softprompt(): """Stub for compatibility""" return JSONResponse({}) @kai_api.get("/config/max_length") async def get_max_length() -> JSONResponse: """Return the configured max output length""" max_length = args.max_length return JSONResponse({"value": max_length}) @kai_api.get("/config/max_context_length") @extra_api.get("/true_max_context_length") async def get_max_context_length() -> JSONResponse: """Return the max context length based on the EngineArgs configuration.""" max_context_length = engine_model_config.max_model_len return JSONResponse({"value": max_context_length}) @extra_api.get("/preloadstory") async def get_preloaded_story() -> JSONResponse: """Stub for compatibility""" return JSONResponse({}) @extra_api.get("/version") async def get_extra_version(): """Impersonate KoboldCpp""" return JSONResponse({"result": "KoboldCpp", "version": "1.55.1"}) @app.get("/") async def get_kobold_lite_ui(): """Serves a cached copy of the Kobold Lite UI, loading it from disk on demand if needed.""" # read and return embedded kobold lite global kobold_lite_ui if kobold_lite_ui == "": scriptpath = os.path.dirname(os.path.abspath(__file__)) klitepath = os.path.join(scriptpath, "klite.embd") if os.path.exists(klitepath): with open(klitepath, "r") as f: kobold_lite_ui = f.read() else: print("Embedded Kobold Lite not found") return HTMLResponse(content=kobold_lite_ui) @app.get("/health") async def health() -> Response: """Health check route for K8s""" return Response(status_code=200) app.include_router(kai_api, prefix="/api/v1") app.include_router(kai_api, prefix="/api/latest", include_in_schema=False) app.include_router(extra_api, prefix="/api/extra") if __name__ == "__main__": parser = argparse.ArgumentParser( description="Aphrodite KoboldAI-Compatible RESTful API server.") parser.add_argument("--host", type=str, default="localhost", help="host name") parser.add_argument("--port", type=int, default=2242, help="port number") parser.add_argument("--served-model-name", type=str, default=None, help="The model name used in the API. If not " "specified, the model name will be the same as " "the huggingface name.") parser.add_argument("--max-length", type=int, default=256, help="The maximum length of the generated text. " "For use with Kobold Horde.") parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() logger.debug(f"args: {args}") logger.warning("The standalone Kobold API is deprecated and will not " "receive updates. Please use the OpenAI API with the " "--launch-kobold-api flag instead.") if args.served_model_name is not None: served_model = args.served_model_name else: served_model = args.model engine_args = AsyncEngineArgs.from_cli_args(args) engine = AsyncAphrodite.from_engine_args(engine_args) engine_model_config = asyncio.run(engine.get_model_config()) max_model_len = engine_model_config.max_model_len # A separate tokenizer to map token IDs to strings. tokenizer = get_tokenizer(engine_args.tokenizer, tokenizer_mode=engine_args.tokenizer_mode, trust_remote_code=engine_args.trust_remote_code) _set_badwords(tokenizer, engine_model_config.hf_config) uvicorn.run(app, host=args.host, port=args.port, log_level="info", timeout_keep_alive=TIMEOUT_KEEP_ALIVE)