Răsfoiți Sursa

feat: add load/unload endpoints for soft-prompts (#694)

* feat: add load/unload endpoints for soft-prompts

* fix
AlpinDale 6 luni în urmă
părinte
comite
1d3a1fec47

+ 18 - 2
aphrodite/endpoints/openai/api_server.py

@@ -45,7 +45,8 @@ from aphrodite.endpoints.openai.serving_chat import OpenAIServingChat
 from aphrodite.endpoints.openai.serving_completions import (
     OpenAIServingCompletion)
 from aphrodite.endpoints.openai.serving_embedding import OpenAIServingEmbedding
-from aphrodite.endpoints.openai.serving_engine import LoRAModulePath
+from aphrodite.endpoints.openai.serving_engine import (LoRAModulePath,
+                                                       PromptAdapterPath)
 from aphrodite.endpoints.openai.serving_tokenization import (
     OpenAIServingTokenization)
 from aphrodite.engine.args_tools import AsyncEngineArgs
@@ -258,6 +259,21 @@ async def unload_lora(lora_name: str):
     return JSONResponse(content={"status": "success"})
 
 
+@router.post("/v1/soft_prompt/load")
+async def load_soft_prompt(soft_prompt: PromptAdapterPath):
+    openai_serving_completion.add_prompt_adapter(soft_prompt)
+    if engine_args.enable_prompt_adapter is False:
+        logger.error("Prompt Adapter is not enabled in the engine. "
+                     "Please start the server with the "
+                     "--enable-prompt-adapter flag!")
+    return JSONResponse(content={"status": "success"})
+
+@router.delete("/v1/soft_prompt/unload")
+async def unload_soft_prompt(soft_prompt_name: str):
+    openai_serving_completion.remove_prompt_adapter(soft_prompt_name)
+    return JSONResponse(content={"status": "success"})
+
+
 # ============ KoboldAI API ============ #
 
 
@@ -556,7 +572,7 @@ def build_app(args: Namespace) -> FastAPI:
             auth_header = request.headers.get("Authorization")
             api_key_header = request.headers.get("x-api-key")
 
-            if request.url.path.startswith("/v1/lora"):
+            if request.url.path.startswith(("/v1/lora", "/v1/soft_prompt")):
                 if admin_key is not None and api_key_header == admin_key:
                     return await call_next(request)
                 return JSONResponse(content={"error": "Unauthorized"},

+ 25 - 0
aphrodite/endpoints/openai/serving_engine.py

@@ -217,6 +217,31 @@ class OpenAIServing:
             lora for lora in self.lora_requests if lora.lora_name != lora_name
         ]
 
+    def add_prompt_adapter(self, prompt_adapter: PromptAdapterPath):
+        if prompt_adapter.name in [
+                prompt_adapter.prompt_adapter_name
+                for prompt_adapter in self.prompt_adapter_requests
+        ]:
+            logger.error(
+                f"Prompt adapter {prompt_adapter.name} already exists.")
+            return
+        with pathlib.Path(prompt_adapter.local_path,
+                          "adapter_config.json").open() as f:
+            adapter_config = json.load(f)
+            num_virtual_tokens = adapter_config["num_virtual_tokens"]
+        self.prompt_adapter_requests.append(
+            PromptAdapterRequest(
+                prompt_adapter_name=prompt_adapter.name,
+                prompt_adapter_id=len(self.prompt_adapter_requests) + 1,
+                prompt_adapter_local_path=prompt_adapter.local_path,
+                prompt_adapter_num_virtual_tokens=num_virtual_tokens))
+        
+    def remove_prompt_adapter(self, prompt_adapter_name: str):
+        self.prompt_adapter_requests = [
+            prompt_adapter for prompt_adapter in self.prompt_adapter_requests
+            if prompt_adapter.prompt_adapter_name != prompt_adapter_name
+        ]
+
     def _normalize_prompt_text_to_input(
         self,
         request: AnyRequest,