소스 검색

api: support aphrodite_config.yaml with inline loading (#929)

* api: support aphrodite_config.yaml with inline loading

* support .yml too
AlpinDale 2 달 전
부모
커밋
59d1d59028
2개의 변경된 파일68개의 추가작업 그리고 5개의 파일을 삭제
  1. 21 5
      aphrodite/endpoints/openai/api_server.py
  2. 47 0
      aphrodite/modeling/model_loader/weight_utils.py

+ 21 - 5
aphrodite/endpoints/openai/api_server.py

@@ -57,6 +57,7 @@ from aphrodite.endpoints.openai.serving_tokenization import (
 from aphrodite.engine.args_tools import AsyncEngineArgs
 from aphrodite.engine.async_aphrodite import AsyncAphrodite
 from aphrodite.engine.protocol import AsyncEngineClient
+from aphrodite.modeling.model_loader.weight_utils import get_model_config_yaml
 from aphrodite.server import serve_http
 from aphrodite.transformers_utils.tokenizer import get_tokenizer
 from aphrodite.version import __version__ as APHRODITE_VERSION
@@ -266,21 +267,36 @@ async def _maybe_switch_model(
                 status_code=401
             )  # type: ignore
     
-    # Need to switch models
     logger.info(f"Switching from {served_model_names[0]} to {request_model}")
-
+    
     try:
         args = app_state.args
         if not args.disable_frontend_multiprocessing:
             await async_engine_client.kill()
         else:
             await async_engine_client.shutdown_background_loop()
-
+            
         model_is_loaded = False
 
-        engine_args = AsyncEngineArgs(model=request_model)
+        yaml_config = get_model_config_yaml(request_model, args.download_dir)
+
+        if yaml_config:
+            parser = FlexibleArgumentParser()
+            parser = make_arg_parser(parser)
+            engine_args = parser.parse_args([])  # empty args
 
-        if args.disable_frontend_multiprocessing:
+            for key, value in yaml_config.items():
+                if hasattr(engine_args, key):
+                    setattr(engine_args, key, value)
+
+            engine_args.model = request_model
+            engine_args = AsyncEngineArgs.from_cli_args(engine_args)
+        else:
+            # Fallback to minimal config
+            engine_args = AsyncEngineArgs(model=request_model)
+        
+        if (model_is_embedding(engine_args.model, engine_args.trust_remote_code)
+                or args.disable_frontend_multiprocessing):
             async_engine_client = AsyncAphrodite.from_engine_args(engine_args)
             await async_engine_client.setup()
         else:

+ 47 - 0
aphrodite/modeling/model_loader/weight_utils.py

@@ -406,6 +406,53 @@ def pt_weights_iterator(
         torch.cuda.empty_cache()
 
 
+def get_model_config_yaml(
+        model_name_or_path: str,
+        cache_dir: Optional[str] = None) -> Optional[dict]:
+    """Look for aphrodite_config.yaml in model directory or HF repo.
+
+    Args:
+        model_name_or_path: Local path or HF model name
+        cache_dir: Optional cache directory for HF downloads
+
+    Returns:
+        Dict containing the config if found, None otherwise
+    """
+    is_local = os.path.isdir(model_name_or_path)
+    config_path = None
+
+    if is_local:
+        config_path = os.path.join(model_name_or_path, "aphrodite_config.yaml")
+        if not os.path.exists(config_path):
+            return None
+    else:
+        try:
+            with get_lock(model_name_or_path, cache_dir):
+                valid_names = ["aphrodite_config.yaml",
+                               "aphrodite_config.yml"]
+                for name in valid_names:
+                    config_path = hf_hub_download(
+                        model_name_or_path,
+                        filename=name,
+                        cache_dir=cache_dir,
+                        local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
+                    )
+                    if os.path.exists(config_path):
+                        break
+        except (huggingface_hub.utils.EntryNotFoundError,
+                huggingface_hub.utils.LocalEntryNotFoundError):
+            return None
+
+    try:
+        import yaml
+        with open(config_path, 'r') as f:
+            config = yaml.safe_load(f)
+        return config
+    except Exception as e:
+        logger.warning(f"Failed to load aphrodite_config.yaml: {e}")
+        return None
+
+
 def get_gguf_extra_tensor_names(
         gguf_file: str, gguf_to_hf_name_map: Dict[str, str]) -> List[str]:
     reader = gguf.GGUFReader(gguf_file)