Browse Source

api: add inline model loading (#928)

* api: add inline model loading

* add `--allow-inline-model-loading` flag

* guard with admin key
AlpinDale 2 months ago
parent
commit
d46e70ac98
3 changed files with 158 additions and 20 deletions
  1. 5 0
      aphrodite/common/envs.py
  2. 148 20
      aphrodite/endpoints/openai/api_server.py
  3. 5 0
      aphrodite/endpoints/openai/args.py

+ 5 - 0
aphrodite/common/envs.py

@@ -16,6 +16,7 @@ if TYPE_CHECKING:
     CUDA_VISIBLE_DEVICES: Optional[str] = None
     APHRODITE_ENGINE_ITERATION_TIMEOUT_S: int = 60
     APHRODITE_API_KEY: Optional[str] = None
+    APHRODITE_ADMIN_KEY: Optional[str] = None
     S3_ACCESS_KEY_ID: Optional[str] = None
     S3_SECRET_ACCESS_KEY: Optional[str] = None
     S3_ENDPOINT_URL: Optional[str] = None
@@ -211,6 +212,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
     "APHRODITE_API_KEY":
     lambda: os.environ.get("APHRODITE_API_KEY", None),
 
+    # Admin API key for APHRODITE API server
+    "APHRODITE_ADMIN_KEY":
+    lambda: os.environ.get("APHRODITE_ADMIN_KEY", None),
+
     # S3 access information, used for tensorizer to load model from S3
     "S3_ACCESS_KEY_ID":
     lambda: os.environ.get("S3_ACCESS_KEY_ID", None),

+ 148 - 20
aphrodite/endpoints/openai/api_server.py

@@ -1,4 +1,5 @@
 import asyncio
+import copy
 import importlib
 import inspect
 import json
@@ -205,6 +206,137 @@ async def build_async_engine_client(
             multiprocess.mark_process_dead(rpc_server_process.pid)
 
 
+async def _maybe_switch_model(
+        request_model: str, app_state,
+        raw_request: Request) -> Optional[ErrorResponse]:
+    """Switch to requested model if different from currently loaded one."""
+    global model_is_loaded, async_engine_client, engine_args, served_model_names
+    
+    if not model_is_loaded:
+        return None
+        
+    if request_model in served_model_names:
+        return None
+
+    if not app_state.args.allow_inline_model_loading:
+        return JSONResponse(
+            content={
+                "error": {
+                    "message": "Requested model is not currently loaded. "
+                    "Inline model loading is disabled. Enable it with "
+                    "--allow-inline-model-loading.",
+                    "type": "invalid_request_error",
+                    "code": "model_not_loaded"
+                }
+            },
+            status_code=400
+        )  # type: ignore
+
+    api_key = envs.APHRODITE_API_KEY or app_state.args.api_keys
+    admin_key = envs.APHRODITE_ADMIN_KEY or app_state.args.admin_key
+
+    if api_key:
+        api_key_header = raw_request.headers.get("x-api-key")
+        auth_header = raw_request.headers.get("Authorization")
+
+        if not admin_key:
+            return JSONResponse(
+                content={
+                    "error": {
+                        "message": "Admin key not configured. "
+                        "Inline model loading is disabled.",
+                        "type": "invalid_request_error",
+                        "code": "admin_key_required"
+                    }
+                },
+                status_code=401
+            )  # type: ignore
+
+        if not (api_key_header == admin_key or
+                auth_header == f"Bearer {admin_key}"):
+            return JSONResponse(
+                content={
+                    "error": {
+                        "message": "Admin privileges required for inline "
+                        "model loading.",
+                        "type": "invalid_request_error",
+                        "code": "unauthorized"
+                    }
+                },
+                status_code=401
+            )  # type: ignore
+    
+    # Need to switch models
+    logger.info(f"Switching from {served_model_names[0]} to {request_model}")
+
+    try:
+        args = app_state.args
+        if not args.disable_frontend_multiprocessing:
+            await async_engine_client.kill()
+        else:
+            await async_engine_client.shutdown_background_loop()
+
+        model_is_loaded = False
+
+        engine_args = AsyncEngineArgs(model=request_model)
+
+        if args.disable_frontend_multiprocessing:
+            async_engine_client = AsyncAphrodite.from_engine_args(engine_args)
+            await async_engine_client.setup()
+        else:
+            if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
+                global prometheus_multiproc_dir
+                prometheus_multiproc_dir = tempfile.TemporaryDirectory()
+                os.environ[
+                    "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
+
+            rpc_path = get_open_zmq_ipc_path()
+            logger.info(
+                f"Multiprocessing frontend to use {rpc_path} for RPC Path.")
+
+            rpc_client = AsyncEngineRPCClient(rpc_path)
+            async_engine_client = rpc_client
+
+            context = multiprocessing.get_context("spawn")
+            rpc_server_process = context.Process(
+                target=run_rpc_server,
+                args=(engine_args, rpc_path))
+            rpc_server_process.start()
+            logger.info(
+                f"Started engine process with PID {rpc_server_process.pid}")
+
+            while True:
+                try:
+                    await async_engine_client.setup()
+                    break
+                except TimeoutError as e:
+                    if not rpc_server_process.is_alive():
+                        raise RuntimeError(
+                            "RPC Server died before responding to "
+                            "readiness probe") from e
+
+        new_args = copy.deepcopy(args)
+        new_args.model = request_model
+
+        app = await init_app(async_engine_client, new_args)  # noqa: F841
+        served_model_names = [request_model]
+        model_is_loaded = True
+        return None
+
+    except Exception as e:
+        error_msg = f"Error while switching models: {str(e)}"
+        logger.error(error_msg)
+        return JSONResponse(
+            content={
+                "error": {
+                    "message": error_msg,
+                    "type": "invalid_request_error",
+                    "code": "model_load_error"
+                }
+            },
+            status_code=500
+        )  # type: ignore
+
 def mount_metrics(app: FastAPI):
     # Lazy import for prometheus multiprocessing.
     # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
@@ -456,14 +588,10 @@ async def serviceinfo():
 @router.post("/v1/chat/completions")
 async def create_chat_completion(request: ChatCompletionRequest,
                                  raw_request: Request):
-    if not model_is_loaded:
-        return JSONResponse(
-            content={
-                "status": "error",
-                "message": "No model loaded."
-            },
-            status_code=500
-        )
+    error_check_ret = await _maybe_switch_model(
+        request.model, raw_request.app.state, raw_request)
+    if error_check_ret is not None:
+        return error_check_ret
     generator = await openai_serving_chat.create_chat_completion(
         request, raw_request)
     if isinstance(generator, ErrorResponse):
@@ -479,14 +607,10 @@ async def create_chat_completion(request: ChatCompletionRequest,
 
 @router.post("/v1/completions")
 async def create_completion(request: CompletionRequest, raw_request: Request):
-    if not model_is_loaded:
-        return JSONResponse(
-            content={
-                "status": "error",
-                "message": "No model loaded."
-            },
-            status_code=500
-        )
+    error_check_ret = await _maybe_switch_model(
+        request.model, raw_request.app.state, raw_request)
+    if error_check_ret is not None:
+        return error_check_ret
     generator = await openai_serving_completion.create_completion(
         request, raw_request)
     if isinstance(generator, ErrorResponse):
@@ -882,10 +1006,14 @@ def build_app(args: Namespace) -> FastAPI:
                 return JSONResponse(content={"error": "Unauthorized"},
                                     status_code=401)
 
-            if auth_header != "Bearer " + token and api_key_header != token:
-                return JSONResponse(content={"error": "Unauthorized"},
-                                    status_code=401)
-            return await call_next(request)
+            if (auth_header == f"Bearer {token}" or api_key_header == token or
+                (admin_key is not None and
+                 (api_key_header == admin_key or
+                  auth_header == f"Bearer {admin_key}"))):
+                return await call_next(request)
+
+            return JSONResponse(
+                content={"error": "Unauthorized"}, status_code=401)
 
     for middleware in args.middleware:
         module_path, object_name = middleware.rsplit(".", 1)

+ 5 - 0
aphrodite/endpoints/openai/args.py

@@ -151,6 +151,11 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
         action="store_true",
         help="If specified, will run the OpenAI frontend server in the same "
         "process as the model serving engine.")
+    parser.add_argument(
+        "--allow-inline-model-loading",
+        action="store_true",
+        help="If specified, will allow the model to be switched inline "
+        "in the same process as the OpenAI frontend server.")
 
     parser = AsyncEngineArgs.add_cli_args(parser)
     return parser