Browse Source

api: add endpoint for loading and unloading the model (#926)

* api: add endpoint for loading and unloading the model

* fix admin key issue
AlpinDale 2 months ago
parent
commit
53d0ba7c7c

+ 221 - 4
aphrodite/endpoints/openai/api_server.py

@@ -10,9 +10,11 @@ from argparse import Namespace
 from contextlib import asynccontextmanager
 from contextlib import asynccontextmanager
 from distutils.util import strtobool
 from distutils.util import strtobool
 from http import HTTPStatus
 from http import HTTPStatus
-from typing import AsyncGenerator, AsyncIterator, List, Optional, Set, Tuple
+from typing import (Any, AsyncGenerator, AsyncIterator, Dict, List, Optional,
+                    Set, Tuple)
 
 
-from fastapi import APIRouter, FastAPI, Request
+import yaml
+from fastapi import APIRouter, FastAPI, Request, UploadFile
 from fastapi.exceptions import RequestValidationError
 from fastapi.exceptions import RequestValidationError
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.responses import (HTMLResponse, JSONResponse, Response,
 from fastapi.responses import (HTMLResponse, JSONResponse, Response,
@@ -79,6 +81,7 @@ kobold_lite_ui = ""
 sampler_json = ""
 sampler_json = ""
 gen_cache: dict = {}
 gen_cache: dict = {}
 prometheus_multiproc_dir: tempfile.TemporaryDirectory
 prometheus_multiproc_dir: tempfile.TemporaryDirectory
+model_is_loaded = True
 
 
 _running_tasks: Set[asyncio.Task] = set()
 _running_tasks: Set[asyncio.Task] = set()
 
 
@@ -225,6 +228,143 @@ def mount_metrics(app: FastAPI):
     app.routes.append(metrics_route)
     app.routes.append(metrics_route)
 
 
 
 
+@router.delete("/v1/model/unload")
+async def unload_model(request: Request):
+    """Unload the current model and shut down the server."""
+    logger.info("Received request to unload model.")
+
+    try:
+        args = request.app.state.args
+        if not args.disable_frontend_multiprocessing:
+            await async_engine_client.kill()
+        else:
+            await async_engine_client.shutdown_background_loop()
+
+        global model_is_loaded
+        model_is_loaded = False
+        return JSONResponse(
+            content={
+                "status": "success",
+                "message": "Model unloaded successfully"
+            }
+        )
+
+    except Exception as e:
+        error_msg = f"Error while unloading model: {str(e)}"
+        logger.error(error_msg)
+        return JSONResponse(
+            content={"status": "error", "message": error_msg},
+            status_code=500
+        )
+
+
+@router.post("/v1/model/load")
+async def load_model(config_file: UploadFile):
+    """Load a model using a YAML configuration file."""
+    global model_is_loaded, async_engine_client, engine_args
+
+    if model_is_loaded:
+        return JSONResponse(
+            content={
+                "error": {
+                    "message": "A model is already loaded. "
+                    "Please unload it first.",
+                    "type": "invalid_request_error",
+                    "code": "model_already_loaded"
+                }
+            },
+            status_code=400
+        )
+
+    try:
+        # basically the same logic as the one in aphrodite.endpoints.cli
+        config_text = await config_file.read()
+        config: Dict[Any, Any] = yaml.safe_load(config_text)
+
+        args = []
+        for key, value in config.items():
+            key = key.replace('_', '-')
+
+            if isinstance(value, bool):
+                if value:
+                    args.append(f"--{key}")
+            elif isinstance(value, (list, tuple)):
+                if key in ['lora-modules', 'prompt-adapters']:
+                    for item in value:
+                        args.append(f"--{key}")
+                        args.append(f"{item['name']}={item['path']}")
+                else:
+                    for item in value:
+                        args.append(f"--{key}")
+                        args.append(str(item))
+            else:
+                args.append(f"--{key}")
+                args.append(str(value))
+
+        parser = FlexibleArgumentParser()
+        parser = make_arg_parser(parser)
+        parsed_args = parser.parse_args(args)
+
+        if (model_is_embedding(parsed_args.model, parsed_args.trust_remote_code)
+                or parsed_args.disable_frontend_multiprocessing):
+            async_engine_client = AsyncAphrodite.from_engine_args(engine_args)
+            await async_engine_client.setup()
+        else:
+            if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
+                global prometheus_multiproc_dir
+                prometheus_multiproc_dir = tempfile.TemporaryDirectory()
+                os.environ[
+                    "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
+
+            rpc_path = get_open_zmq_ipc_path()
+            logger.info(
+                f"Multiprocessing frontend to use {rpc_path} for RPC Path.")
+
+            rpc_client = AsyncEngineRPCClient(rpc_path)
+            async_engine_client = rpc_client
+
+            context = multiprocessing.get_context("spawn")
+            rpc_server_process = context.Process(
+                target=run_rpc_server,
+                args=(engine_args, rpc_path))
+            rpc_server_process.start()
+            logger.info(
+                f"Started engine process with PID {rpc_server_process.pid}")
+
+            while True:
+                try:
+                    await async_engine_client.setup()
+                    break
+                except TimeoutError as e:
+                    if not rpc_server_process.is_alive():
+                        raise RuntimeError(
+                            "RPC Server died before responding to readiness "
+                            "probe") from e
+
+        app = await init_app(async_engine_client, parsed_args)  # noqa: F841
+        model_is_loaded = True
+
+        return JSONResponse(
+            content={
+                "status": "success",
+                "message": "Model loaded successfully"
+            }
+        )
+
+    except Exception as e:
+        error_msg = f"Error while loading model: {str(e)}"
+        logger.error(error_msg)
+        return JSONResponse(
+            content={
+                "error": {
+                    "message": error_msg,
+                    "type": "invalid_request_error",
+                    "code": "model_load_error"
+                }
+            },
+            status_code=500
+        )
+
 @router.get("/health")
 @router.get("/health")
 async def health() -> Response:
 async def health() -> Response:
     """Health check."""
     """Health check."""
@@ -234,6 +374,14 @@ async def health() -> Response:
 
 
 @router.post("/v1/tokenize")
 @router.post("/v1/tokenize")
 async def tokenize(request: TokenizeRequest):
 async def tokenize(request: TokenizeRequest):
+    if not model_is_loaded:
+        return JSONResponse(
+            content={
+                "status": "error",
+                "message": "No model loaded."
+            },
+            status_code=500
+        )
     generator = await openai_serving_tokenization.create_tokenize(request)
     generator = await openai_serving_tokenization.create_tokenize(request)
     if isinstance(generator, ErrorResponse):
     if isinstance(generator, ErrorResponse):
         return JSONResponse(content=generator.model_dump(),
         return JSONResponse(content=generator.model_dump(),
@@ -245,6 +393,14 @@ async def tokenize(request: TokenizeRequest):
 
 
 @router.post("/v1/detokenize")
 @router.post("/v1/detokenize")
 async def detokenize(request: DetokenizeRequest):
 async def detokenize(request: DetokenizeRequest):
+    if not model_is_loaded:
+        return JSONResponse(
+            content={
+                "status": "error",
+                "message": "No model loaded."
+            },
+            status_code=500
+        )
     generator = await openai_serving_tokenization.create_detokenize(request)
     generator = await openai_serving_tokenization.create_detokenize(request)
     if isinstance(generator, ErrorResponse):
     if isinstance(generator, ErrorResponse):
         return JSONResponse(content=generator.model_dump(),
         return JSONResponse(content=generator.model_dump(),
@@ -300,6 +456,14 @@ async def serviceinfo():
 @router.post("/v1/chat/completions")
 @router.post("/v1/chat/completions")
 async def create_chat_completion(request: ChatCompletionRequest,
 async def create_chat_completion(request: ChatCompletionRequest,
                                  raw_request: Request):
                                  raw_request: Request):
+    if not model_is_loaded:
+        return JSONResponse(
+            content={
+                "status": "error",
+                "message": "No model loaded."
+            },
+            status_code=500
+        )
     generator = await openai_serving_chat.create_chat_completion(
     generator = await openai_serving_chat.create_chat_completion(
         request, raw_request)
         request, raw_request)
     if isinstance(generator, ErrorResponse):
     if isinstance(generator, ErrorResponse):
@@ -315,6 +479,14 @@ async def create_chat_completion(request: ChatCompletionRequest,
 
 
 @router.post("/v1/completions")
 @router.post("/v1/completions")
 async def create_completion(request: CompletionRequest, raw_request: Request):
 async def create_completion(request: CompletionRequest, raw_request: Request):
+    if not model_is_loaded:
+        return JSONResponse(
+            content={
+                "status": "error",
+                "message": "No model loaded."
+            },
+            status_code=500
+        )
     generator = await openai_serving_completion.create_completion(
     generator = await openai_serving_completion.create_completion(
         request, raw_request)
         request, raw_request)
     if isinstance(generator, ErrorResponse):
     if isinstance(generator, ErrorResponse):
@@ -329,6 +501,14 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
 
 
 @router.post("/v1/embeddings")
 @router.post("/v1/embeddings")
 async def create_embedding(request: EmbeddingRequest, raw_request: Request):
 async def create_embedding(request: EmbeddingRequest, raw_request: Request):
+    if not model_is_loaded:
+        return JSONResponse(
+            content={
+                "status": "error",
+                "message": "No model loaded."
+            },
+            status_code=500
+        )
     generator = await openai_serving_embedding.create_embedding(
     generator = await openai_serving_embedding.create_embedding(
         request, raw_request)
         request, raw_request)
     if isinstance(generator, ErrorResponse):
     if isinstance(generator, ErrorResponse):
@@ -340,6 +520,14 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
 
 
 @router.post("/v1/lora/load")
 @router.post("/v1/lora/load")
 async def load_lora(lora: LoRAModulePath):
 async def load_lora(lora: LoRAModulePath):
+    if not model_is_loaded:
+        return JSONResponse(
+            content={
+                "status": "error",
+                "message": "No model loaded."
+            },
+            status_code=500
+        )
     openai_serving_completion.add_lora(lora)
     openai_serving_completion.add_lora(lora)
     if engine_args.enable_lora is False:
     if engine_args.enable_lora is False:
         logger.error("LoRA is not enabled in the engine. "
         logger.error("LoRA is not enabled in the engine. "
@@ -350,12 +538,28 @@ async def load_lora(lora: LoRAModulePath):
 
 
 @router.delete("/v1/lora/unload")
 @router.delete("/v1/lora/unload")
 async def unload_lora(lora_name: str):
 async def unload_lora(lora_name: str):
+    if not model_is_loaded:
+        return JSONResponse(
+            content={
+                "status": "error",
+                "message": "No model loaded."
+            },
+            status_code=500
+        )
     openai_serving_completion.remove_lora(lora_name)
     openai_serving_completion.remove_lora(lora_name)
     return JSONResponse(content={"status": "success"})
     return JSONResponse(content={"status": "success"})
 
 
 
 
 @router.post("/v1/soft_prompt/load")
 @router.post("/v1/soft_prompt/load")
 async def load_soft_prompt(soft_prompt: PromptAdapterPath):
 async def load_soft_prompt(soft_prompt: PromptAdapterPath):
+    if not model_is_loaded:
+        return JSONResponse(
+            content={
+                "status": "error",
+                "message": "No model loaded."
+            },
+            status_code=500
+        )
     openai_serving_completion.add_prompt_adapter(soft_prompt)
     openai_serving_completion.add_prompt_adapter(soft_prompt)
     if engine_args.enable_prompt_adapter is False:
     if engine_args.enable_prompt_adapter is False:
         logger.error("Prompt Adapter is not enabled in the engine. "
         logger.error("Prompt Adapter is not enabled in the engine. "
@@ -365,6 +569,14 @@ async def load_soft_prompt(soft_prompt: PromptAdapterPath):
 
 
 @router.delete("/v1/soft_prompt/unload")
 @router.delete("/v1/soft_prompt/unload")
 async def unload_soft_prompt(soft_prompt_name: str):
 async def unload_soft_prompt(soft_prompt_name: str):
+    if not model_is_loaded:
+        return JSONResponse(
+            content={
+                "status": "error",
+                "message": "No model loaded."
+            },
+            status_code=500
+        )
     openai_serving_completion.remove_prompt_adapter(soft_prompt_name)
     openai_serving_completion.remove_prompt_adapter(soft_prompt_name)
     return JSONResponse(content={"status": "success"})
     return JSONResponse(content={"status": "success"})
 
 
@@ -611,6 +823,7 @@ def build_app(args: Namespace) -> FastAPI:
     app = FastAPI(lifespan=lifespan)
     app = FastAPI(lifespan=lifespan)
     app.include_router(router)
     app.include_router(router)
     app.root_path = args.root_path
     app.root_path = args.root_path
+    app.state.args = args
     if args.launch_kobold_api:
     if args.launch_kobold_api:
         logger.warning("Kobold API is now enabled by default. "
         logger.warning("Kobold API is now enabled by default. "
                        "This flag will be removed in the future.")
                        "This flag will be removed in the future.")
@@ -659,8 +872,12 @@ def build_app(args: Namespace) -> FastAPI:
             auth_header = request.headers.get("Authorization")
             auth_header = request.headers.get("Authorization")
             api_key_header = request.headers.get("x-api-key")
             api_key_header = request.headers.get("x-api-key")
 
 
-            if request.url.path.startswith(("/v1/lora", "/v1/soft_prompt")):
-                if admin_key is not None and api_key_header == admin_key:
+            if request.url.path.startswith(
+                ("/v1/lora", "/v1/soft_prompt", "/v1/model")):
+                if admin_key is not None and (
+                    api_key_header == admin_key or 
+                    auth_header == "Bearer " + admin_key
+                ):
                     return await call_next(request)
                     return await call_next(request)
                 return JSONResponse(content={"error": "Unauthorized"},
                 return JSONResponse(content={"error": "Unauthorized"},
                                     status_code=401)
                                     status_code=401)

+ 1 - 1
aphrodite/endpoints/openai/rpc/__init__.py

@@ -38,7 +38,7 @@ class RPCUtilityRequest(Enum):
     GET_LORA_CONFIG = 6
     GET_LORA_CONFIG = 6
     DO_LOG_STATS = 7
     DO_LOG_STATS = 7
     IS_SERVER_HEALTHY = 8
     IS_SERVER_HEALTHY = 8
-
+    SHUTDOWN_SERVER = 9
 
 
 RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest,
 RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest,
                          RPCUtilityRequest]
                          RPCUtilityRequest]

+ 14 - 0
aphrodite/endpoints/openai/rpc/client.py

@@ -407,3 +407,17 @@ class AsyncEngineRPCClient:
                      **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]:
                      **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]:
         raise NotImplementedError(
         raise NotImplementedError(
             "Embeddings not supported with multiprocessing backend")
             "Embeddings not supported with multiprocessing backend")
+
+    async def kill(self):
+        """Cleanly shut down the RPC client and engine."""
+        try:
+            # Send shutdown signal to RPC server
+            await self._send_one_way_rpc_request(
+                request=RPCUtilityRequest.SHUTDOWN_SERVER,
+                error_message="Failed to send shutdown signal to RPC server"
+            )
+        except Exception as e:
+            logger.error(f"Error while shutting down RPC server: {str(e)}")
+        finally:
+            # Close local resources
+            self.close()

+ 25 - 0
aphrodite/endpoints/openai/rpc/server.py

@@ -1,4 +1,5 @@
 import asyncio
 import asyncio
+import os
 import signal
 import signal
 from typing import Any, Coroutine, Union
 from typing import Any, Coroutine, Union
 
 
@@ -146,6 +147,8 @@ class AsyncEngineRPCServer:
                 return self.is_server_ready(identity)
                 return self.is_server_ready(identity)
             elif request == RPCUtilityRequest.IS_SERVER_HEALTHY:
             elif request == RPCUtilityRequest.IS_SERVER_HEALTHY:
                 return self.check_health(identity)
                 return self.check_health(identity)
+            elif request == RPCUtilityRequest.SHUTDOWN_SERVER:
+                return self.shutdown(identity)
             else:
             else:
                 raise ValueError(f"Unknown RPCUtilityRequest type: {request}")
                 raise ValueError(f"Unknown RPCUtilityRequest type: {request}")
 
 
@@ -171,6 +174,28 @@ class AsyncEngineRPCServer:
             running_tasks.add(task)
             running_tasks.add(task)
             task.add_done_callback(running_tasks.discard)
             task.add_done_callback(running_tasks.discard)
 
 
+    async def shutdown(self, identity):
+        """Handle shutdown request from client."""
+        try:
+            # Clean shutdown of engine
+            self.engine.shutdown_background_loop()
+            await self.socket.send_multipart(
+                [identity, cloudpickle.dumps(APHRODITE_RPC_SUCCESS_STR)]
+            )
+        except Exception as e:
+            await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
+        finally:
+            # Schedule server shutdown
+            asyncio.create_task(self._delayed_shutdown())
+    
+    async def _delayed_shutdown(self):
+        """Helper to shut down server after response is sent"""
+        await asyncio.sleep(1)
+        self.cleanup()
+        # Force exit the process
+        os._exit(0)
+
+
 
 
 async def run_server(server: AsyncEngineRPCServer):
 async def run_server(server: AsyncEngineRPCServer):
     # Put the server task into the asyncio loop.
     # Put the server task into the asyncio loop.

+ 1 - 0
requirements-common.txt

@@ -32,3 +32,4 @@ mistral_common >= 1.5.0
 protobuf
 protobuf
 pandas
 pandas
 msgspec
 msgspec
+python-multipart