Преглед на файлове

api: add endpoint for loading and unloading the model (#926)

* api: add endpoint for loading and unloading the model

* fix admin key issue
AlpinDale преди 2 месеца
родител
ревизия
53d0ba7c7c

+ 221 - 4
aphrodite/endpoints/openai/api_server.py

@@ -10,9 +10,11 @@ from argparse import Namespace
 from contextlib import asynccontextmanager
 from distutils.util import strtobool
 from http import HTTPStatus
-from typing import AsyncGenerator, AsyncIterator, List, Optional, Set, Tuple
+from typing import (Any, AsyncGenerator, AsyncIterator, Dict, List, Optional,
+                    Set, Tuple)
 
-from fastapi import APIRouter, FastAPI, Request
+import yaml
+from fastapi import APIRouter, FastAPI, Request, UploadFile
 from fastapi.exceptions import RequestValidationError
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.responses import (HTMLResponse, JSONResponse, Response,
@@ -79,6 +81,7 @@ kobold_lite_ui = ""
 sampler_json = ""
 gen_cache: dict = {}
 prometheus_multiproc_dir: tempfile.TemporaryDirectory
+model_is_loaded = True
 
 _running_tasks: Set[asyncio.Task] = set()
 
@@ -225,6 +228,143 @@ def mount_metrics(app: FastAPI):
     app.routes.append(metrics_route)
 
 
+@router.delete("/v1/model/unload")
+async def unload_model(request: Request):
+    """Unload the current model and shut down the server."""
+    logger.info("Received request to unload model.")
+
+    try:
+        args = request.app.state.args
+        if not args.disable_frontend_multiprocessing:
+            await async_engine_client.kill()
+        else:
+            await async_engine_client.shutdown_background_loop()
+
+        global model_is_loaded
+        model_is_loaded = False
+        return JSONResponse(
+            content={
+                "status": "success",
+                "message": "Model unloaded successfully"
+            }
+        )
+
+    except Exception as e:
+        error_msg = f"Error while unloading model: {str(e)}"
+        logger.error(error_msg)
+        return JSONResponse(
+            content={"status": "error", "message": error_msg},
+            status_code=500
+        )
+
+
+@router.post("/v1/model/load")
+async def load_model(config_file: UploadFile):
+    """Load a model using a YAML configuration file."""
+    global model_is_loaded, async_engine_client, engine_args
+
+    if model_is_loaded:
+        return JSONResponse(
+            content={
+                "error": {
+                    "message": "A model is already loaded. "
+                    "Please unload it first.",
+                    "type": "invalid_request_error",
+                    "code": "model_already_loaded"
+                }
+            },
+            status_code=400
+        )
+
+    try:
+        # basically the same logic as the one in aphrodite.endpoints.cli
+        config_text = await config_file.read()
+        config: Dict[Any, Any] = yaml.safe_load(config_text)
+
+        args = []
+        for key, value in config.items():
+            key = key.replace('_', '-')
+
+            if isinstance(value, bool):
+                if value:
+                    args.append(f"--{key}")
+            elif isinstance(value, (list, tuple)):
+                if key in ['lora-modules', 'prompt-adapters']:
+                    for item in value:
+                        args.append(f"--{key}")
+                        args.append(f"{item['name']}={item['path']}")
+                else:
+                    for item in value:
+                        args.append(f"--{key}")
+                        args.append(str(item))
+            else:
+                args.append(f"--{key}")
+                args.append(str(value))
+
+        parser = FlexibleArgumentParser()
+        parser = make_arg_parser(parser)
+        parsed_args = parser.parse_args(args)
+
+        if (model_is_embedding(parsed_args.model, parsed_args.trust_remote_code)
+                or parsed_args.disable_frontend_multiprocessing):
+            async_engine_client = AsyncAphrodite.from_engine_args(engine_args)
+            await async_engine_client.setup()
+        else:
+            if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
+                global prometheus_multiproc_dir
+                prometheus_multiproc_dir = tempfile.TemporaryDirectory()
+                os.environ[
+                    "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
+
+            rpc_path = get_open_zmq_ipc_path()
+            logger.info(
+                f"Multiprocessing frontend to use {rpc_path} for RPC Path.")
+
+            rpc_client = AsyncEngineRPCClient(rpc_path)
+            async_engine_client = rpc_client
+
+            context = multiprocessing.get_context("spawn")
+            rpc_server_process = context.Process(
+                target=run_rpc_server,
+                args=(engine_args, rpc_path))
+            rpc_server_process.start()
+            logger.info(
+                f"Started engine process with PID {rpc_server_process.pid}")
+
+            while True:
+                try:
+                    await async_engine_client.setup()
+                    break
+                except TimeoutError as e:
+                    if not rpc_server_process.is_alive():
+                        raise RuntimeError(
+                            "RPC Server died before responding to readiness "
+                            "probe") from e
+
+        app = await init_app(async_engine_client, parsed_args)  # noqa: F841
+        model_is_loaded = True
+
+        return JSONResponse(
+            content={
+                "status": "success",
+                "message": "Model loaded successfully"
+            }
+        )
+
+    except Exception as e:
+        error_msg = f"Error while loading model: {str(e)}"
+        logger.error(error_msg)
+        return JSONResponse(
+            content={
+                "error": {
+                    "message": error_msg,
+                    "type": "invalid_request_error",
+                    "code": "model_load_error"
+                }
+            },
+            status_code=500
+        )
+
 @router.get("/health")
 async def health() -> Response:
     """Health check."""
@@ -234,6 +374,14 @@ async def health() -> Response:
 
 @router.post("/v1/tokenize")
 async def tokenize(request: TokenizeRequest):
+    if not model_is_loaded:
+        return JSONResponse(
+            content={
+                "status": "error",
+                "message": "No model loaded."
+            },
+            status_code=500
+        )
     generator = await openai_serving_tokenization.create_tokenize(request)
     if isinstance(generator, ErrorResponse):
         return JSONResponse(content=generator.model_dump(),
@@ -245,6 +393,14 @@ async def tokenize(request: TokenizeRequest):
 
 @router.post("/v1/detokenize")
 async def detokenize(request: DetokenizeRequest):
+    if not model_is_loaded:
+        return JSONResponse(
+            content={
+                "status": "error",
+                "message": "No model loaded."
+            },
+            status_code=500
+        )
     generator = await openai_serving_tokenization.create_detokenize(request)
     if isinstance(generator, ErrorResponse):
         return JSONResponse(content=generator.model_dump(),
@@ -300,6 +456,14 @@ async def serviceinfo():
 @router.post("/v1/chat/completions")
 async def create_chat_completion(request: ChatCompletionRequest,
                                  raw_request: Request):
+    if not model_is_loaded:
+        return JSONResponse(
+            content={
+                "status": "error",
+                "message": "No model loaded."
+            },
+            status_code=500
+        )
     generator = await openai_serving_chat.create_chat_completion(
         request, raw_request)
     if isinstance(generator, ErrorResponse):
@@ -315,6 +479,14 @@ async def create_chat_completion(request: ChatCompletionRequest,
 
 @router.post("/v1/completions")
 async def create_completion(request: CompletionRequest, raw_request: Request):
+    if not model_is_loaded:
+        return JSONResponse(
+            content={
+                "status": "error",
+                "message": "No model loaded."
+            },
+            status_code=500
+        )
     generator = await openai_serving_completion.create_completion(
         request, raw_request)
     if isinstance(generator, ErrorResponse):
@@ -329,6 +501,14 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
 
 @router.post("/v1/embeddings")
 async def create_embedding(request: EmbeddingRequest, raw_request: Request):
+    if not model_is_loaded:
+        return JSONResponse(
+            content={
+                "status": "error",
+                "message": "No model loaded."
+            },
+            status_code=500
+        )
     generator = await openai_serving_embedding.create_embedding(
         request, raw_request)
     if isinstance(generator, ErrorResponse):
@@ -340,6 +520,14 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
 
 @router.post("/v1/lora/load")
 async def load_lora(lora: LoRAModulePath):
+    if not model_is_loaded:
+        return JSONResponse(
+            content={
+                "status": "error",
+                "message": "No model loaded."
+            },
+            status_code=500
+        )
     openai_serving_completion.add_lora(lora)
     if engine_args.enable_lora is False:
         logger.error("LoRA is not enabled in the engine. "
@@ -350,12 +538,28 @@ async def load_lora(lora: LoRAModulePath):
 
 @router.delete("/v1/lora/unload")
 async def unload_lora(lora_name: str):
+    if not model_is_loaded:
+        return JSONResponse(
+            content={
+                "status": "error",
+                "message": "No model loaded."
+            },
+            status_code=500
+        )
     openai_serving_completion.remove_lora(lora_name)
     return JSONResponse(content={"status": "success"})
 
 
 @router.post("/v1/soft_prompt/load")
 async def load_soft_prompt(soft_prompt: PromptAdapterPath):
+    if not model_is_loaded:
+        return JSONResponse(
+            content={
+                "status": "error",
+                "message": "No model loaded."
+            },
+            status_code=500
+        )
     openai_serving_completion.add_prompt_adapter(soft_prompt)
     if engine_args.enable_prompt_adapter is False:
         logger.error("Prompt Adapter is not enabled in the engine. "
@@ -365,6 +569,14 @@ async def load_soft_prompt(soft_prompt: PromptAdapterPath):
 
 @router.delete("/v1/soft_prompt/unload")
 async def unload_soft_prompt(soft_prompt_name: str):
+    if not model_is_loaded:
+        return JSONResponse(
+            content={
+                "status": "error",
+                "message": "No model loaded."
+            },
+            status_code=500
+        )
     openai_serving_completion.remove_prompt_adapter(soft_prompt_name)
     return JSONResponse(content={"status": "success"})
 
@@ -611,6 +823,7 @@ def build_app(args: Namespace) -> FastAPI:
     app = FastAPI(lifespan=lifespan)
     app.include_router(router)
     app.root_path = args.root_path
+    app.state.args = args
     if args.launch_kobold_api:
         logger.warning("Kobold API is now enabled by default. "
                        "This flag will be removed in the future.")
@@ -659,8 +872,12 @@ def build_app(args: Namespace) -> FastAPI:
             auth_header = request.headers.get("Authorization")
             api_key_header = request.headers.get("x-api-key")
 
-            if request.url.path.startswith(("/v1/lora", "/v1/soft_prompt")):
-                if admin_key is not None and api_key_header == admin_key:
+            if request.url.path.startswith(
+                ("/v1/lora", "/v1/soft_prompt", "/v1/model")):
+                if admin_key is not None and (
+                    api_key_header == admin_key or 
+                    auth_header == "Bearer " + admin_key
+                ):
                     return await call_next(request)
                 return JSONResponse(content={"error": "Unauthorized"},
                                     status_code=401)

+ 1 - 1
aphrodite/endpoints/openai/rpc/__init__.py

@@ -38,7 +38,7 @@ class RPCUtilityRequest(Enum):
     GET_LORA_CONFIG = 6
     DO_LOG_STATS = 7
     IS_SERVER_HEALTHY = 8
-
+    SHUTDOWN_SERVER = 9
 
 RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest,
                          RPCUtilityRequest]

+ 14 - 0
aphrodite/endpoints/openai/rpc/client.py

@@ -407,3 +407,17 @@ class AsyncEngineRPCClient:
                      **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]:
         raise NotImplementedError(
             "Embeddings not supported with multiprocessing backend")
+
+    async def kill(self):
+        """Cleanly shut down the RPC client and engine."""
+        try:
+            # Send shutdown signal to RPC server
+            await self._send_one_way_rpc_request(
+                request=RPCUtilityRequest.SHUTDOWN_SERVER,
+                error_message="Failed to send shutdown signal to RPC server"
+            )
+        except Exception as e:
+            logger.error(f"Error while shutting down RPC server: {str(e)}")
+        finally:
+            # Close local resources
+            self.close()

+ 25 - 0
aphrodite/endpoints/openai/rpc/server.py

@@ -1,4 +1,5 @@
 import asyncio
+import os
 import signal
 from typing import Any, Coroutine, Union
 
@@ -146,6 +147,8 @@ class AsyncEngineRPCServer:
                 return self.is_server_ready(identity)
             elif request == RPCUtilityRequest.IS_SERVER_HEALTHY:
                 return self.check_health(identity)
+            elif request == RPCUtilityRequest.SHUTDOWN_SERVER:
+                return self.shutdown(identity)
             else:
                 raise ValueError(f"Unknown RPCUtilityRequest type: {request}")
 
@@ -171,6 +174,28 @@ class AsyncEngineRPCServer:
             running_tasks.add(task)
             task.add_done_callback(running_tasks.discard)
 
+    async def shutdown(self, identity):
+        """Handle shutdown request from client."""
+        try:
+            # Clean shutdown of engine
+            self.engine.shutdown_background_loop()
+            await self.socket.send_multipart(
+                [identity, cloudpickle.dumps(APHRODITE_RPC_SUCCESS_STR)]
+            )
+        except Exception as e:
+            await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
+        finally:
+            # Schedule server shutdown
+            asyncio.create_task(self._delayed_shutdown())
+    
+    async def _delayed_shutdown(self):
+        """Helper to shut down server after response is sent"""
+        await asyncio.sleep(1)
+        self.cleanup()
+        # Force exit the process
+        os._exit(0)
+
+
 
 async def run_server(server: AsyncEngineRPCServer):
     # Put the server task into the asyncio loop.

+ 1 - 0
requirements-common.txt

@@ -32,3 +32,4 @@ mistral_common >= 1.5.0
 protobuf
 pandas
 msgspec
+python-multipart