api_server_kobold.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. # Adapted from openai/api_server.py and tgi-kai-bridge
  2. import argparse
  3. import asyncio
  4. import json
  5. import os
  6. from http import HTTPStatus
  7. from typing import List, Tuple, AsyncGenerator
  8. import uvicorn
  9. from fastapi import FastAPI, APIRouter
  10. from fastapi.responses import JSONResponse, StreamingResponse, HTMLResponse
  11. from fastapi.middleware.cors import CORSMiddleware
  12. from aphrodite.engine.args_tools import AsyncEngineArgs
  13. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  14. from aphrodite.common.logger import init_logger
  15. from aphrodite.common.outputs import RequestOutput
  16. from aphrodite.common.sampling_params import SamplingParams, _SAMPLING_EPS
  17. from aphrodite.transformers_utils.tokenizer import get_tokenizer
  18. from aphrodite.common.utils import random_uuid
  19. from aphrodite.endpoints.protocol import KAIGenerationInputSchema
  20. TIMEOUT_KEEP_ALIVE = 5 # seconds
  21. logger = init_logger(__name__)
  22. served_model: str = "Read Only"
  23. engine: AsyncAphrodite = None
  24. app = FastAPI()
  25. kai_api = APIRouter()
  26. extra_api = APIRouter()
  27. kobold_lite_ui = ""
  28. app.add_middleware(
  29. CORSMiddleware,
  30. allow_origins=["*"],
  31. allow_credentials=True,
  32. allow_methods=["*"],
  33. allow_headers=["*"],
  34. )
  35. def create_error_response(status_code: HTTPStatus, message: str) -> JSONResponse:
  36. return JSONResponse({"msg": message, "type": "invalid_request_error"},
  37. status_code=status_code.value)
  38. @app.exception_handler(ValueError)
  39. def validation_exception_handler(request, exc): # pylint: disable=unused-argument
  40. return create_error_response(HTTPStatus.UNPROCESSABLE_ENTITY, str(exc))
  41. def prepare_engine_payload(kai_payload: KAIGenerationInputSchema) -> Tuple[SamplingParams, List[int]]:
  42. """Create SamplingParams and truncated input tokens for AsyncEngine"""
  43. if kai_payload.max_context_length > max_model_len:
  44. raise ValueError(
  45. f"max_context_length ({kai_payload.max_context_length}) must be less than or equal to "
  46. f"max_model_len ({max_model_len})"
  47. )
  48. sampling_params = SamplingParams(max_tokens=kai_payload.max_length)
  49. # KAI spec: top_k == 0 means disabled, aphrodite: top_k == -1 means disabled
  50. # https://github.com/KoboldAI/KoboldAI-Client/wiki/Settings
  51. kai_payload.top_k = kai_payload.top_k if kai_payload.top_k != 0.0 else -1
  52. kai_payload.tfs = max(_SAMPLING_EPS, kai_payload.tfs)
  53. if kai_payload.temperature < _SAMPLING_EPS:
  54. # temp < _SAMPLING_EPS: greedy sampling
  55. kai_payload.n = 1
  56. kai_payload.top_p = 1.0
  57. kai_payload.top_k = -1
  58. sampling_params = SamplingParams(
  59. n=kai_payload.n,
  60. best_of=kai_payload.n,
  61. repetition_penalty=kai_payload.rep_pen,
  62. temperature=kai_payload.temperature,
  63. tfs=kai_payload.tfs,
  64. top_p=kai_payload.top_p,
  65. top_k=kai_payload.top_k,
  66. top_a=kai_payload.top_a,
  67. typical_p=kai_payload.typical,
  68. eta_cutoff=kai_payload.eta_cutoff,
  69. epsilon_cutoff=kai_payload.eps_cutoff,
  70. stop=kai_payload.stop_sequence,
  71. # ignore_eos=kai_payload.use_default_badwordsids, # TODO ban instead
  72. max_tokens=kai_payload.max_length,
  73. )
  74. max_input_tokens = max(1, kai_payload.max_context_length - kai_payload.max_length)
  75. input_tokens = tokenizer(kai_payload.prompt).input_ids[-max_input_tokens:]
  76. return sampling_params, input_tokens
  77. @kai_api.post("/generate")
  78. async def generate(kai_payload: KAIGenerationInputSchema) -> JSONResponse:
  79. """ Generate text """
  80. req_id = f"kai-{random_uuid()}"
  81. sampling_params, input_tokens = prepare_engine_payload(kai_payload)
  82. result_generator = engine.generate(None, sampling_params, req_id, input_tokens)
  83. final_res: RequestOutput = None
  84. async for res in result_generator:
  85. final_res = res
  86. assert final_res is not None
  87. return JSONResponse({"results": [{"text": output.text} for output in final_res.outputs]})
  88. @extra_api.post("/generate/stream")
  89. async def generate_stream(kai_payload: KAIGenerationInputSchema) -> StreamingResponse:
  90. """ Generate text SSE streaming """
  91. req_id = f"kai-{random_uuid()}"
  92. sampling_params, input_tokens = prepare_engine_payload(kai_payload)
  93. results_generator = engine.generate(None, sampling_params, req_id, input_tokens)
  94. async def stream_kobold() -> AsyncGenerator[bytes, None]:
  95. previous_output = ""
  96. async for res in results_generator:
  97. new_chunk = res.outputs[0].text[len(previous_output):]
  98. previous_output += new_chunk
  99. yield b"event: message\n"
  100. yield f"data: {json.dumps({'token': new_chunk})}\n\n".encode()
  101. return StreamingResponse(stream_kobold(),
  102. headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
  103. media_type='text/event-stream')
  104. @extra_api.post("/generate/check")
  105. async def check_generation():
  106. """ stub for compatibility """
  107. return JSONResponse({"results": [{"text": ""}]})
  108. @kai_api.get("/info/version")
  109. async def get_version():
  110. """ Impersonate KAI """
  111. return JSONResponse({"result": "1.2.4"})
  112. @kai_api.get("/model")
  113. async def get_model():
  114. """ Get current model """
  115. return JSONResponse({"result": f"aphrodite/{served_model}"})
  116. @kai_api.get("/config/soft_prompts_list")
  117. async def get_available_softprompts():
  118. """ stub for compatibility """
  119. return JSONResponse({"values":[]})
  120. @kai_api.get("/config/soft_prompt")
  121. async def get_current_softprompt():
  122. """ stub for compatibility """
  123. return JSONResponse({"value": ""})
  124. @kai_api.put("/config/soft_prompt")
  125. async def set_current_softprompt():
  126. """ stub for compatibility """
  127. return JSONResponse({})
  128. @app.get("/api/latest/config/max_context_length")
  129. async def get_max_context_length() -> JSONResponse:
  130. """Return the max context length based on the EngineArgs configuration."""
  131. max_context_length = engine_model_config.max_model_len
  132. return JSONResponse({"value": max_context_length })
  133. @app.get("/api/latest/config/max_length")
  134. async def get_max_length() -> JSONResponse:
  135. """Why do we need this twice?"""
  136. max_length = args.max_length
  137. return JSONResponse({"value": max_length})
  138. @extra_api.post("/abort")
  139. async def abort_generation():
  140. """ stub for compatibility """
  141. return JSONResponse({})
  142. @extra_api.get("/version")
  143. async def get_extra_version():
  144. """ Impersonate KoboldCpp with streaming support """
  145. return JSONResponse({"result": "KoboldCpp", "version": "1.30"})
  146. @app.get("/")
  147. async def get_kobold_lite_ui():
  148. """Serves a cached copy of the Kobold Lite UI, loading it from disk on demand if needed."""
  149. #read and return embedded kobold lite
  150. global kobold_lite_ui
  151. if kobold_lite_ui=="":
  152. scriptpath = os.path.dirname(os.path.abspath(__file__))
  153. klitepath = os.path.join(scriptpath, "klite.embd")
  154. if os.path.exists(klitepath):
  155. with open(klitepath, "r") as f:
  156. kobold_lite_ui = f.read()
  157. else:
  158. print("Embedded Kobold Lite not found")
  159. return HTMLResponse(content=kobold_lite_ui)
  160. app.include_router(kai_api, prefix="/api/v1")
  161. app.include_router(kai_api, prefix="/api/latest", include_in_schema=False)
  162. app.include_router(extra_api, prefix="/api/extra")
  163. if __name__ == "__main__":
  164. parser = argparse.ArgumentParser(
  165. description="Aphrodite KoboldAI-Compatible RESTful API server.")
  166. parser.add_argument("--host",
  167. type=str,
  168. default="localhost",
  169. help="host name")
  170. parser.add_argument("--port", type=int, default=2242, help="port number")
  171. parser.add_argument("--served-model-name",
  172. type=str,
  173. default=None,
  174. help="The model name used in the API. If not "
  175. "specified, the model name will be the same as "
  176. "the huggingface name.")
  177. parser.add_argument("--max-length",
  178. type=int,
  179. default=256,
  180. help="The maximum length of the generated text. "
  181. "For use with Kobold Horde.")
  182. parser = AsyncEngineArgs.add_cli_args(parser)
  183. global args
  184. args = parser.parse_args()
  185. logger.info(f"args: {args}")
  186. if args.served_model_name is not None:
  187. served_model = args.served_model_name
  188. else:
  189. served_model = args.model
  190. engine_args = AsyncEngineArgs.from_cli_args(args)
  191. engine = AsyncAphrodite.from_engine_args(engine_args)
  192. engine_model_config = asyncio.run(engine.get_model_config())
  193. max_model_len = engine_model_config.get_max_model_len()
  194. # A separate tokenizer to map token IDs to strings.
  195. tokenizer = get_tokenizer(engine_args.tokenizer,
  196. tokenizer_mode=engine_args.tokenizer_mode,
  197. trust_remote_code=engine_args.trust_remote_code)
  198. uvicorn.run(app,
  199. host=args.host,
  200. port=args.port,
  201. log_level="info",
  202. timeout_keep_alive=TIMEOUT_KEEP_ALIVE)