123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248 |
- # Adapted from openai/api_server.py and tgi-kai-bridge
- import argparse
- import asyncio
- import json
- import os
- from http import HTTPStatus
- from typing import List, Tuple, AsyncGenerator
- import uvicorn
- from fastapi import FastAPI, APIRouter
- from fastapi.responses import JSONResponse, StreamingResponse, HTMLResponse
- from fastapi.middleware.cors import CORSMiddleware
- from aphrodite.engine.args_tools import AsyncEngineArgs
- from aphrodite.engine.async_aphrodite import AsyncAphrodite
- from aphrodite.common.logger import init_logger
- from aphrodite.common.outputs import RequestOutput
- from aphrodite.common.sampling_params import SamplingParams, _SAMPLING_EPS
- from aphrodite.transformers_utils.tokenizer import get_tokenizer
- from aphrodite.common.utils import random_uuid
- from aphrodite.endpoints.protocol import KAIGenerationInputSchema
- TIMEOUT_KEEP_ALIVE = 5 # seconds
- logger = init_logger(__name__)
- served_model: str = "Read Only"
- engine: AsyncAphrodite = None
- app = FastAPI()
- kai_api = APIRouter()
- extra_api = APIRouter()
- kobold_lite_ui = ""
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- def create_error_response(status_code: HTTPStatus, message: str) -> JSONResponse:
- return JSONResponse({"msg": message, "type": "invalid_request_error"},
- status_code=status_code.value)
- @app.exception_handler(ValueError)
- def validation_exception_handler(request, exc): # pylint: disable=unused-argument
- return create_error_response(HTTPStatus.UNPROCESSABLE_ENTITY, str(exc))
- def prepare_engine_payload(kai_payload: KAIGenerationInputSchema) -> Tuple[SamplingParams, List[int]]:
- """Create SamplingParams and truncated input tokens for AsyncEngine"""
- if kai_payload.max_context_length > max_model_len:
- raise ValueError(
- f"max_context_length ({kai_payload.max_context_length}) must be less than or equal to "
- f"max_model_len ({max_model_len})"
- )
- sampling_params = SamplingParams(max_tokens=kai_payload.max_length)
- # KAI spec: top_k == 0 means disabled, aphrodite: top_k == -1 means disabled
- # https://github.com/KoboldAI/KoboldAI-Client/wiki/Settings
- kai_payload.top_k = kai_payload.top_k if kai_payload.top_k != 0.0 else -1
- kai_payload.tfs = max(_SAMPLING_EPS, kai_payload.tfs)
- if kai_payload.temperature < _SAMPLING_EPS:
- # temp < _SAMPLING_EPS: greedy sampling
- kai_payload.n = 1
- kai_payload.top_p = 1.0
- kai_payload.top_k = -1
- sampling_params = SamplingParams(
- n=kai_payload.n,
- best_of=kai_payload.n,
- repetition_penalty=kai_payload.rep_pen,
- temperature=kai_payload.temperature,
- tfs=kai_payload.tfs,
- top_p=kai_payload.top_p,
- top_k=kai_payload.top_k,
- top_a=kai_payload.top_a,
- typical_p=kai_payload.typical,
- eta_cutoff=kai_payload.eta_cutoff,
- epsilon_cutoff=kai_payload.eps_cutoff,
- stop=kai_payload.stop_sequence,
- # ignore_eos=kai_payload.use_default_badwordsids, # TODO ban instead
- max_tokens=kai_payload.max_length,
- )
- max_input_tokens = max(1, kai_payload.max_context_length - kai_payload.max_length)
- input_tokens = tokenizer(kai_payload.prompt).input_ids[-max_input_tokens:]
- return sampling_params, input_tokens
- @kai_api.post("/generate")
- async def generate(kai_payload: KAIGenerationInputSchema) -> JSONResponse:
- """ Generate text """
- req_id = f"kai-{random_uuid()}"
- sampling_params, input_tokens = prepare_engine_payload(kai_payload)
- result_generator = engine.generate(None, sampling_params, req_id, input_tokens)
- final_res: RequestOutput = None
- async for res in result_generator:
- final_res = res
- assert final_res is not None
- return JSONResponse({"results": [{"text": output.text} for output in final_res.outputs]})
- @extra_api.post("/generate/stream")
- async def generate_stream(kai_payload: KAIGenerationInputSchema) -> StreamingResponse:
- """ Generate text SSE streaming """
- req_id = f"kai-{random_uuid()}"
- sampling_params, input_tokens = prepare_engine_payload(kai_payload)
- results_generator = engine.generate(None, sampling_params, req_id, input_tokens)
- async def stream_kobold() -> AsyncGenerator[bytes, None]:
- previous_output = ""
- async for res in results_generator:
- new_chunk = res.outputs[0].text[len(previous_output):]
- previous_output += new_chunk
- yield b"event: message\n"
- yield f"data: {json.dumps({'token': new_chunk})}\n\n".encode()
- return StreamingResponse(stream_kobold(),
- headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
- media_type='text/event-stream')
- @extra_api.post("/generate/check")
- async def check_generation():
- """ stub for compatibility """
- return JSONResponse({"results": [{"text": ""}]})
- @kai_api.get("/info/version")
- async def get_version():
- """ Impersonate KAI """
- return JSONResponse({"result": "1.2.4"})
- @kai_api.get("/model")
- async def get_model():
- """ Get current model """
- return JSONResponse({"result": f"aphrodite/{served_model}"})
- @kai_api.get("/config/soft_prompts_list")
- async def get_available_softprompts():
- """ stub for compatibility """
- return JSONResponse({"values":[]})
- @kai_api.get("/config/soft_prompt")
- async def get_current_softprompt():
- """ stub for compatibility """
- return JSONResponse({"value": ""})
- @kai_api.put("/config/soft_prompt")
- async def set_current_softprompt():
- """ stub for compatibility """
- return JSONResponse({})
- @app.get("/api/latest/config/max_context_length")
- async def get_max_context_length() -> JSONResponse:
- """Return the max context length based on the EngineArgs configuration."""
- max_context_length = engine_model_config.max_model_len
- return JSONResponse({"value": max_context_length })
- @app.get("/api/latest/config/max_length")
- async def get_max_length() -> JSONResponse:
- """Why do we need this twice?"""
- max_length = args.max_length
- return JSONResponse({"value": max_length})
- @extra_api.post("/abort")
- async def abort_generation():
- """ stub for compatibility """
- return JSONResponse({})
- @extra_api.get("/version")
- async def get_extra_version():
- """ Impersonate KoboldCpp with streaming support """
- return JSONResponse({"result": "KoboldCpp", "version": "1.30"})
- @app.get("/")
- async def get_kobold_lite_ui():
- """Serves a cached copy of the Kobold Lite UI, loading it from disk on demand if needed."""
- #read and return embedded kobold lite
- global kobold_lite_ui
- if kobold_lite_ui=="":
- scriptpath = os.path.dirname(os.path.abspath(__file__))
- klitepath = os.path.join(scriptpath, "klite.embd")
- if os.path.exists(klitepath):
- with open(klitepath, "r") as f:
- kobold_lite_ui = f.read()
- else:
- print("Embedded Kobold Lite not found")
- return HTMLResponse(content=kobold_lite_ui)
- app.include_router(kai_api, prefix="/api/v1")
- app.include_router(kai_api, prefix="/api/latest", include_in_schema=False)
- app.include_router(extra_api, prefix="/api/extra")
- if __name__ == "__main__":
- parser = argparse.ArgumentParser(
- description="Aphrodite KoboldAI-Compatible RESTful API server.")
- parser.add_argument("--host",
- type=str,
- default="localhost",
- help="host name")
- parser.add_argument("--port", type=int, default=2242, help="port number")
- parser.add_argument("--served-model-name",
- type=str,
- default=None,
- help="The model name used in the API. If not "
- "specified, the model name will be the same as "
- "the huggingface name.")
- parser.add_argument("--max-length",
- type=int,
- default=256,
- help="The maximum length of the generated text. "
- "For use with Kobold Horde.")
- parser = AsyncEngineArgs.add_cli_args(parser)
- global args
- args = parser.parse_args()
- logger.info(f"args: {args}")
- if args.served_model_name is not None:
- served_model = args.served_model_name
- else:
- served_model = args.model
- engine_args = AsyncEngineArgs.from_cli_args(args)
- engine = AsyncAphrodite.from_engine_args(engine_args)
- engine_model_config = asyncio.run(engine.get_model_config())
- max_model_len = engine_model_config.get_max_model_len()
- # A separate tokenizer to map token IDs to strings.
- tokenizer = get_tokenizer(engine_args.tokenizer,
- tokenizer_mode=engine_args.tokenizer_mode,
- trust_remote_code=engine_args.trust_remote_code)
- uvicorn.run(app,
- host=args.host,
- port=args.port,
- log_level="info",
- timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
|