|
@@ -10,9 +10,11 @@ from argparse import Namespace
|
|
|
from contextlib import asynccontextmanager
|
|
|
from distutils.util import strtobool
|
|
|
from http import HTTPStatus
|
|
|
-from typing import AsyncGenerator, AsyncIterator, List, Optional, Set, Tuple
|
|
|
+from typing import (Any, AsyncGenerator, AsyncIterator, Dict, List, Optional,
|
|
|
+ Set, Tuple)
|
|
|
|
|
|
-from fastapi import APIRouter, FastAPI, Request
|
|
|
+import yaml
|
|
|
+from fastapi import APIRouter, FastAPI, Request, UploadFile
|
|
|
from fastapi.exceptions import RequestValidationError
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
from fastapi.responses import (HTMLResponse, JSONResponse, Response,
|
|
@@ -79,6 +81,7 @@ kobold_lite_ui = ""
|
|
|
sampler_json = ""
|
|
|
gen_cache: dict = {}
|
|
|
prometheus_multiproc_dir: tempfile.TemporaryDirectory
|
|
|
+model_is_loaded = True
|
|
|
|
|
|
_running_tasks: Set[asyncio.Task] = set()
|
|
|
|
|
@@ -225,6 +228,143 @@ def mount_metrics(app: FastAPI):
|
|
|
app.routes.append(metrics_route)
|
|
|
|
|
|
|
|
|
+@router.delete("/v1/model/unload")
|
|
|
+async def unload_model(request: Request):
|
|
|
+ """Unload the current model and shut down the server."""
|
|
|
+ logger.info("Received request to unload model.")
|
|
|
+
|
|
|
+ try:
|
|
|
+ args = request.app.state.args
|
|
|
+ if not args.disable_frontend_multiprocessing:
|
|
|
+ await async_engine_client.kill()
|
|
|
+ else:
|
|
|
+ await async_engine_client.shutdown_background_loop()
|
|
|
+
|
|
|
+ global model_is_loaded
|
|
|
+ model_is_loaded = False
|
|
|
+ return JSONResponse(
|
|
|
+ content={
|
|
|
+ "status": "success",
|
|
|
+ "message": "Model unloaded successfully"
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ error_msg = f"Error while unloading model: {str(e)}"
|
|
|
+ logger.error(error_msg)
|
|
|
+ return JSONResponse(
|
|
|
+ content={"status": "error", "message": error_msg},
|
|
|
+ status_code=500
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+@router.post("/v1/model/load")
|
|
|
+async def load_model(config_file: UploadFile):
|
|
|
+ """Load a model using a YAML configuration file."""
|
|
|
+ global model_is_loaded, async_engine_client, engine_args
|
|
|
+
|
|
|
+ if model_is_loaded:
|
|
|
+ return JSONResponse(
|
|
|
+ content={
|
|
|
+ "error": {
|
|
|
+ "message": "A model is already loaded. "
|
|
|
+ "Please unload it first.",
|
|
|
+ "type": "invalid_request_error",
|
|
|
+ "code": "model_already_loaded"
|
|
|
+ }
|
|
|
+ },
|
|
|
+ status_code=400
|
|
|
+ )
|
|
|
+
|
|
|
+ try:
|
|
|
+ # basically the same logic as the one in aphrodite.endpoints.cli
|
|
|
+ config_text = await config_file.read()
|
|
|
+ config: Dict[Any, Any] = yaml.safe_load(config_text)
|
|
|
+
|
|
|
+ args = []
|
|
|
+ for key, value in config.items():
|
|
|
+ key = key.replace('_', '-')
|
|
|
+
|
|
|
+ if isinstance(value, bool):
|
|
|
+ if value:
|
|
|
+ args.append(f"--{key}")
|
|
|
+ elif isinstance(value, (list, tuple)):
|
|
|
+ if key in ['lora-modules', 'prompt-adapters']:
|
|
|
+ for item in value:
|
|
|
+ args.append(f"--{key}")
|
|
|
+ args.append(f"{item['name']}={item['path']}")
|
|
|
+ else:
|
|
|
+ for item in value:
|
|
|
+ args.append(f"--{key}")
|
|
|
+ args.append(str(item))
|
|
|
+ else:
|
|
|
+ args.append(f"--{key}")
|
|
|
+ args.append(str(value))
|
|
|
+
|
|
|
+ parser = FlexibleArgumentParser()
|
|
|
+ parser = make_arg_parser(parser)
|
|
|
+ parsed_args = parser.parse_args(args)
|
|
|
+
|
|
|
+ if (model_is_embedding(parsed_args.model, parsed_args.trust_remote_code)
|
|
|
+ or parsed_args.disable_frontend_multiprocessing):
|
|
|
+ async_engine_client = AsyncAphrodite.from_engine_args(engine_args)
|
|
|
+ await async_engine_client.setup()
|
|
|
+ else:
|
|
|
+ if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
|
|
|
+ global prometheus_multiproc_dir
|
|
|
+ prometheus_multiproc_dir = tempfile.TemporaryDirectory()
|
|
|
+ os.environ[
|
|
|
+ "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
|
|
|
+
|
|
|
+ rpc_path = get_open_zmq_ipc_path()
|
|
|
+ logger.info(
|
|
|
+ f"Multiprocessing frontend to use {rpc_path} for RPC Path.")
|
|
|
+
|
|
|
+ rpc_client = AsyncEngineRPCClient(rpc_path)
|
|
|
+ async_engine_client = rpc_client
|
|
|
+
|
|
|
+ context = multiprocessing.get_context("spawn")
|
|
|
+ rpc_server_process = context.Process(
|
|
|
+ target=run_rpc_server,
|
|
|
+ args=(engine_args, rpc_path))
|
|
|
+ rpc_server_process.start()
|
|
|
+ logger.info(
|
|
|
+ f"Started engine process with PID {rpc_server_process.pid}")
|
|
|
+
|
|
|
+ while True:
|
|
|
+ try:
|
|
|
+ await async_engine_client.setup()
|
|
|
+ break
|
|
|
+ except TimeoutError as e:
|
|
|
+ if not rpc_server_process.is_alive():
|
|
|
+ raise RuntimeError(
|
|
|
+ "RPC Server died before responding to readiness "
|
|
|
+ "probe") from e
|
|
|
+
|
|
|
+ app = await init_app(async_engine_client, parsed_args) # noqa: F841
|
|
|
+ model_is_loaded = True
|
|
|
+
|
|
|
+ return JSONResponse(
|
|
|
+ content={
|
|
|
+ "status": "success",
|
|
|
+ "message": "Model loaded successfully"
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ error_msg = f"Error while loading model: {str(e)}"
|
|
|
+ logger.error(error_msg)
|
|
|
+ return JSONResponse(
|
|
|
+ content={
|
|
|
+ "error": {
|
|
|
+ "message": error_msg,
|
|
|
+ "type": "invalid_request_error",
|
|
|
+ "code": "model_load_error"
|
|
|
+ }
|
|
|
+ },
|
|
|
+ status_code=500
|
|
|
+ )
|
|
|
+
|
|
|
@router.get("/health")
|
|
|
async def health() -> Response:
|
|
|
"""Health check."""
|
|
@@ -234,6 +374,14 @@ async def health() -> Response:
|
|
|
|
|
|
@router.post("/v1/tokenize")
|
|
|
async def tokenize(request: TokenizeRequest):
|
|
|
+ if not model_is_loaded:
|
|
|
+ return JSONResponse(
|
|
|
+ content={
|
|
|
+ "status": "error",
|
|
|
+ "message": "No model loaded."
|
|
|
+ },
|
|
|
+ status_code=500
|
|
|
+ )
|
|
|
generator = await openai_serving_tokenization.create_tokenize(request)
|
|
|
if isinstance(generator, ErrorResponse):
|
|
|
return JSONResponse(content=generator.model_dump(),
|
|
@@ -245,6 +393,14 @@ async def tokenize(request: TokenizeRequest):
|
|
|
|
|
|
@router.post("/v1/detokenize")
|
|
|
async def detokenize(request: DetokenizeRequest):
|
|
|
+ if not model_is_loaded:
|
|
|
+ return JSONResponse(
|
|
|
+ content={
|
|
|
+ "status": "error",
|
|
|
+ "message": "No model loaded."
|
|
|
+ },
|
|
|
+ status_code=500
|
|
|
+ )
|
|
|
generator = await openai_serving_tokenization.create_detokenize(request)
|
|
|
if isinstance(generator, ErrorResponse):
|
|
|
return JSONResponse(content=generator.model_dump(),
|
|
@@ -300,6 +456,14 @@ async def serviceinfo():
|
|
|
@router.post("/v1/chat/completions")
|
|
|
async def create_chat_completion(request: ChatCompletionRequest,
|
|
|
raw_request: Request):
|
|
|
+ if not model_is_loaded:
|
|
|
+ return JSONResponse(
|
|
|
+ content={
|
|
|
+ "status": "error",
|
|
|
+ "message": "No model loaded."
|
|
|
+ },
|
|
|
+ status_code=500
|
|
|
+ )
|
|
|
generator = await openai_serving_chat.create_chat_completion(
|
|
|
request, raw_request)
|
|
|
if isinstance(generator, ErrorResponse):
|
|
@@ -315,6 +479,14 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
|
|
|
|
|
@router.post("/v1/completions")
|
|
|
async def create_completion(request: CompletionRequest, raw_request: Request):
|
|
|
+ if not model_is_loaded:
|
|
|
+ return JSONResponse(
|
|
|
+ content={
|
|
|
+ "status": "error",
|
|
|
+ "message": "No model loaded."
|
|
|
+ },
|
|
|
+ status_code=500
|
|
|
+ )
|
|
|
generator = await openai_serving_completion.create_completion(
|
|
|
request, raw_request)
|
|
|
if isinstance(generator, ErrorResponse):
|
|
@@ -329,6 +501,14 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
|
|
|
|
|
@router.post("/v1/embeddings")
|
|
|
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
|
|
+ if not model_is_loaded:
|
|
|
+ return JSONResponse(
|
|
|
+ content={
|
|
|
+ "status": "error",
|
|
|
+ "message": "No model loaded."
|
|
|
+ },
|
|
|
+ status_code=500
|
|
|
+ )
|
|
|
generator = await openai_serving_embedding.create_embedding(
|
|
|
request, raw_request)
|
|
|
if isinstance(generator, ErrorResponse):
|
|
@@ -340,6 +520,14 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
|
|
|
|
|
@router.post("/v1/lora/load")
|
|
|
async def load_lora(lora: LoRAModulePath):
|
|
|
+ if not model_is_loaded:
|
|
|
+ return JSONResponse(
|
|
|
+ content={
|
|
|
+ "status": "error",
|
|
|
+ "message": "No model loaded."
|
|
|
+ },
|
|
|
+ status_code=500
|
|
|
+ )
|
|
|
openai_serving_completion.add_lora(lora)
|
|
|
if engine_args.enable_lora is False:
|
|
|
logger.error("LoRA is not enabled in the engine. "
|
|
@@ -350,12 +538,28 @@ async def load_lora(lora: LoRAModulePath):
|
|
|
|
|
|
@router.delete("/v1/lora/unload")
|
|
|
async def unload_lora(lora_name: str):
|
|
|
+ if not model_is_loaded:
|
|
|
+ return JSONResponse(
|
|
|
+ content={
|
|
|
+ "status": "error",
|
|
|
+ "message": "No model loaded."
|
|
|
+ },
|
|
|
+ status_code=500
|
|
|
+ )
|
|
|
openai_serving_completion.remove_lora(lora_name)
|
|
|
return JSONResponse(content={"status": "success"})
|
|
|
|
|
|
|
|
|
@router.post("/v1/soft_prompt/load")
|
|
|
async def load_soft_prompt(soft_prompt: PromptAdapterPath):
|
|
|
+ if not model_is_loaded:
|
|
|
+ return JSONResponse(
|
|
|
+ content={
|
|
|
+ "status": "error",
|
|
|
+ "message": "No model loaded."
|
|
|
+ },
|
|
|
+ status_code=500
|
|
|
+ )
|
|
|
openai_serving_completion.add_prompt_adapter(soft_prompt)
|
|
|
if engine_args.enable_prompt_adapter is False:
|
|
|
logger.error("Prompt Adapter is not enabled in the engine. "
|
|
@@ -365,6 +569,14 @@ async def load_soft_prompt(soft_prompt: PromptAdapterPath):
|
|
|
|
|
|
@router.delete("/v1/soft_prompt/unload")
|
|
|
async def unload_soft_prompt(soft_prompt_name: str):
|
|
|
+ if not model_is_loaded:
|
|
|
+ return JSONResponse(
|
|
|
+ content={
|
|
|
+ "status": "error",
|
|
|
+ "message": "No model loaded."
|
|
|
+ },
|
|
|
+ status_code=500
|
|
|
+ )
|
|
|
openai_serving_completion.remove_prompt_adapter(soft_prompt_name)
|
|
|
return JSONResponse(content={"status": "success"})
|
|
|
|
|
@@ -611,6 +823,7 @@ def build_app(args: Namespace) -> FastAPI:
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
app.include_router(router)
|
|
|
app.root_path = args.root_path
|
|
|
+ app.state.args = args
|
|
|
if args.launch_kobold_api:
|
|
|
logger.warning("Kobold API is now enabled by default. "
|
|
|
"This flag will be removed in the future.")
|
|
@@ -659,8 +872,12 @@ def build_app(args: Namespace) -> FastAPI:
|
|
|
auth_header = request.headers.get("Authorization")
|
|
|
api_key_header = request.headers.get("x-api-key")
|
|
|
|
|
|
- if request.url.path.startswith(("/v1/lora", "/v1/soft_prompt")):
|
|
|
- if admin_key is not None and api_key_header == admin_key:
|
|
|
+ if request.url.path.startswith(
|
|
|
+ ("/v1/lora", "/v1/soft_prompt", "/v1/model")):
|
|
|
+ if admin_key is not None and (
|
|
|
+ api_key_header == admin_key or
|
|
|
+ auth_header == "Bearer " + admin_key
|
|
|
+ ):
|
|
|
return await call_next(request)
|
|
|
return JSONResponse(content={"error": "Unauthorized"},
|
|
|
status_code=401)
|