|
@@ -1,4 +1,5 @@
|
|
|
import asyncio
|
|
|
+import copy
|
|
|
import importlib
|
|
|
import inspect
|
|
|
import json
|
|
@@ -205,6 +206,137 @@ async def build_async_engine_client(
|
|
|
multiprocess.mark_process_dead(rpc_server_process.pid)
|
|
|
|
|
|
|
|
|
+async def _maybe_switch_model(
|
|
|
+ request_model: str, app_state,
|
|
|
+ raw_request: Request) -> Optional[ErrorResponse]:
|
|
|
+ """Switch to requested model if different from currently loaded one."""
|
|
|
+ global model_is_loaded, async_engine_client, engine_args, served_model_names
|
|
|
+
|
|
|
+ if not model_is_loaded:
|
|
|
+ return None
|
|
|
+
|
|
|
+ if request_model in served_model_names:
|
|
|
+ return None
|
|
|
+
|
|
|
+ if not app_state.args.allow_inline_model_loading:
|
|
|
+ return JSONResponse(
|
|
|
+ content={
|
|
|
+ "error": {
|
|
|
+ "message": "Requested model is not currently loaded. "
|
|
|
+ "Inline model loading is disabled. Enable it with "
|
|
|
+ "--allow-inline-model-loading.",
|
|
|
+ "type": "invalid_request_error",
|
|
|
+ "code": "model_not_loaded"
|
|
|
+ }
|
|
|
+ },
|
|
|
+ status_code=400
|
|
|
+ ) # type: ignore
|
|
|
+
|
|
|
+ api_key = envs.APHRODITE_API_KEY or app_state.args.api_keys
|
|
|
+ admin_key = envs.APHRODITE_ADMIN_KEY or app_state.args.admin_key
|
|
|
+
|
|
|
+ if api_key:
|
|
|
+ api_key_header = raw_request.headers.get("x-api-key")
|
|
|
+ auth_header = raw_request.headers.get("Authorization")
|
|
|
+
|
|
|
+ if not admin_key:
|
|
|
+ return JSONResponse(
|
|
|
+ content={
|
|
|
+ "error": {
|
|
|
+ "message": "Admin key not configured. "
|
|
|
+ "Inline model loading is disabled.",
|
|
|
+ "type": "invalid_request_error",
|
|
|
+ "code": "admin_key_required"
|
|
|
+ }
|
|
|
+ },
|
|
|
+ status_code=401
|
|
|
+ ) # type: ignore
|
|
|
+
|
|
|
+ if not (api_key_header == admin_key or
|
|
|
+ auth_header == f"Bearer {admin_key}"):
|
|
|
+ return JSONResponse(
|
|
|
+ content={
|
|
|
+ "error": {
|
|
|
+ "message": "Admin privileges required for inline "
|
|
|
+ "model loading.",
|
|
|
+ "type": "invalid_request_error",
|
|
|
+ "code": "unauthorized"
|
|
|
+ }
|
|
|
+ },
|
|
|
+ status_code=401
|
|
|
+ ) # type: ignore
|
|
|
+
|
|
|
+ # Need to switch models
|
|
|
+ logger.info(f"Switching from {served_model_names[0]} to {request_model}")
|
|
|
+
|
|
|
+ try:
|
|
|
+ args = app_state.args
|
|
|
+ if not args.disable_frontend_multiprocessing:
|
|
|
+ await async_engine_client.kill()
|
|
|
+ else:
|
|
|
+ await async_engine_client.shutdown_background_loop()
|
|
|
+
|
|
|
+ model_is_loaded = False
|
|
|
+
|
|
|
+ engine_args = AsyncEngineArgs(model=request_model)
|
|
|
+
|
|
|
+ if args.disable_frontend_multiprocessing:
|
|
|
+ async_engine_client = AsyncAphrodite.from_engine_args(engine_args)
|
|
|
+ await async_engine_client.setup()
|
|
|
+ else:
|
|
|
+ if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
|
|
|
+ global prometheus_multiproc_dir
|
|
|
+ prometheus_multiproc_dir = tempfile.TemporaryDirectory()
|
|
|
+ os.environ[
|
|
|
+ "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
|
|
|
+
|
|
|
+ rpc_path = get_open_zmq_ipc_path()
|
|
|
+ logger.info(
|
|
|
+ f"Multiprocessing frontend to use {rpc_path} for RPC Path.")
|
|
|
+
|
|
|
+ rpc_client = AsyncEngineRPCClient(rpc_path)
|
|
|
+ async_engine_client = rpc_client
|
|
|
+
|
|
|
+ context = multiprocessing.get_context("spawn")
|
|
|
+ rpc_server_process = context.Process(
|
|
|
+ target=run_rpc_server,
|
|
|
+ args=(engine_args, rpc_path))
|
|
|
+ rpc_server_process.start()
|
|
|
+ logger.info(
|
|
|
+ f"Started engine process with PID {rpc_server_process.pid}")
|
|
|
+
|
|
|
+ while True:
|
|
|
+ try:
|
|
|
+ await async_engine_client.setup()
|
|
|
+ break
|
|
|
+ except TimeoutError as e:
|
|
|
+ if not rpc_server_process.is_alive():
|
|
|
+ raise RuntimeError(
|
|
|
+ "RPC Server died before responding to "
|
|
|
+ "readiness probe") from e
|
|
|
+
|
|
|
+ new_args = copy.deepcopy(args)
|
|
|
+ new_args.model = request_model
|
|
|
+
|
|
|
+ app = await init_app(async_engine_client, new_args) # noqa: F841
|
|
|
+ served_model_names = [request_model]
|
|
|
+ model_is_loaded = True
|
|
|
+ return None
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ error_msg = f"Error while switching models: {str(e)}"
|
|
|
+ logger.error(error_msg)
|
|
|
+ return JSONResponse(
|
|
|
+ content={
|
|
|
+ "error": {
|
|
|
+ "message": error_msg,
|
|
|
+ "type": "invalid_request_error",
|
|
|
+ "code": "model_load_error"
|
|
|
+ }
|
|
|
+ },
|
|
|
+ status_code=500
|
|
|
+ ) # type: ignore
|
|
|
+
|
|
|
def mount_metrics(app: FastAPI):
|
|
|
# Lazy import for prometheus multiprocessing.
|
|
|
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
|
|
@@ -456,14 +588,10 @@ async def serviceinfo():
|
|
|
@router.post("/v1/chat/completions")
|
|
|
async def create_chat_completion(request: ChatCompletionRequest,
|
|
|
raw_request: Request):
|
|
|
- if not model_is_loaded:
|
|
|
- return JSONResponse(
|
|
|
- content={
|
|
|
- "status": "error",
|
|
|
- "message": "No model loaded."
|
|
|
- },
|
|
|
- status_code=500
|
|
|
- )
|
|
|
+ error_check_ret = await _maybe_switch_model(
|
|
|
+ request.model, raw_request.app.state, raw_request)
|
|
|
+ if error_check_ret is not None:
|
|
|
+ return error_check_ret
|
|
|
generator = await openai_serving_chat.create_chat_completion(
|
|
|
request, raw_request)
|
|
|
if isinstance(generator, ErrorResponse):
|
|
@@ -479,14 +607,10 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
|
|
|
|
|
@router.post("/v1/completions")
|
|
|
async def create_completion(request: CompletionRequest, raw_request: Request):
|
|
|
- if not model_is_loaded:
|
|
|
- return JSONResponse(
|
|
|
- content={
|
|
|
- "status": "error",
|
|
|
- "message": "No model loaded."
|
|
|
- },
|
|
|
- status_code=500
|
|
|
- )
|
|
|
+ error_check_ret = await _maybe_switch_model(
|
|
|
+ request.model, raw_request.app.state, raw_request)
|
|
|
+ if error_check_ret is not None:
|
|
|
+ return error_check_ret
|
|
|
generator = await openai_serving_completion.create_completion(
|
|
|
request, raw_request)
|
|
|
if isinstance(generator, ErrorResponse):
|
|
@@ -882,10 +1006,14 @@ def build_app(args: Namespace) -> FastAPI:
|
|
|
return JSONResponse(content={"error": "Unauthorized"},
|
|
|
status_code=401)
|
|
|
|
|
|
- if auth_header != "Bearer " + token and api_key_header != token:
|
|
|
- return JSONResponse(content={"error": "Unauthorized"},
|
|
|
- status_code=401)
|
|
|
- return await call_next(request)
|
|
|
+ if (auth_header == f"Bearer {token}" or api_key_header == token or
|
|
|
+ (admin_key is not None and
|
|
|
+ (api_key_header == admin_key or
|
|
|
+ auth_header == f"Bearer {admin_key}"))):
|
|
|
+ return await call_next(request)
|
|
|
+
|
|
|
+ return JSONResponse(
|
|
|
+ content={"error": "Unauthorized"}, status_code=401)
|
|
|
|
|
|
for middleware in args.middleware:
|
|
|
module_path, object_name = middleware.rsplit(".", 1)
|