Prechádzať zdrojové kódy

fix: another one missed

AlpinDale 6 mesiacov pred
rodič
commit
a3b56353fa

+ 3 - 2
aphrodite/endpoints/openai/api_server.py

@@ -33,7 +33,7 @@ from aphrodite.endpoints.openai.serving_chat import OpenAIServingChat
 from aphrodite.endpoints.openai.serving_completions import \
     OpenAIServingCompletion
 from aphrodite.endpoints.openai.serving_embedding import OpenAIServingEmbedding
-from aphrodite.endpoints.openai.serving_engine import LoRA
+from aphrodite.endpoints.openai.serving_engine import LoRAModulePath
 from aphrodite.engine.args_tools import AsyncEngineArgs
 from aphrodite.engine.async_aphrodite import AsyncAphrodite
 from aphrodite.transformers_utils.tokenizer import get_tokenizer
@@ -154,7 +154,8 @@ async def show_samplers(x_api_key: Optional[str] = Header(None)):
 
 
 @router.post("/v1/lora/load")
-async def load_lora(lora: LoRA, x_api_key: Optional[str] = Header(None)):
+async def load_lora(lora: LoRAModulePath,
+                    x_api_key: Optional[str] = Header(None)):
     openai_serving_chat.add_lora(lora)
     openai_serving_completion.add_lora(lora)
     if engine_args.enable_lora is False:

+ 2 - 1
aphrodite/endpoints/openai/serving_embedding.py

@@ -12,7 +12,8 @@ from aphrodite.endpoints.openai.protocol import (EmbeddingRequest,
                                                  EmbeddingResponseData,
                                                  UsageInfo)
 from aphrodite.endpoints.openai.serving_completions import parse_prompt_format
-from aphrodite.endpoints.openai.serving_engine import LoRAModulePath, OpenAIServing
+from aphrodite.endpoints.openai.serving_engine import (LoRAModulePath,
+                                                       OpenAIServing)
 from aphrodite.engine.async_aphrodite import AsyncAphrodite
 
 TypeTokenIDs = List[int]

+ 13 - 13
examples/offline_inference/soft_prompt_inference.py

@@ -5,6 +5,7 @@ from aphrodite.prompt_adapter.request import PromptAdapterRequest
 MODEL_PATH = "bigscience/bloomz-560m"
 PA_PATH = 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM'
 
+
 def do_sample(llm, pa_name: str, pa_id: int):
     # Sample prompts
     prompts = [
@@ -13,16 +14,15 @@ def do_sample(llm, pa_name: str, pa_id: int):
         "Tweet text : @nationalgridus Looks good thanks! Label : "
     ]
     # Define sampling parameters
-    sampling_params = SamplingParams(temperature=0.0, max_tokens=3,
+    sampling_params = SamplingParams(temperature=0.0,
+                                     max_tokens=3,
                                      stop_token_ids=[3])
 
     # Generate outputs using the LLM
-    outputs = llm.generate(
-        prompts,
-        sampling_params,
-        prompt_adapter_request=PromptAdapterRequest(pa_name, pa_id, PA_PATH,
-                                                    8) if pa_id else None
-    )
+    outputs = llm.generate(prompts,
+                           sampling_params,
+                           prompt_adapter_request=PromptAdapterRequest(
+                               pa_name, pa_id, PA_PATH, 8) if pa_id else None)
 
     # Print the outputs
     for output in outputs:
@@ -30,17 +30,17 @@ def do_sample(llm, pa_name: str, pa_id: int):
         generated_text = output.outputs[0].text.strip()
         print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
 
+
 def main():
     # Create an LLM with prompt adapter enabled
-    llm = LLM(
-        MODEL_PATH,
-        enforce_eager=True,
-        enable_prompt_adapter=True,
-        max_prompt_adapter_token=8
-    )
+    llm = LLM(MODEL_PATH,
+              enforce_eager=True,
+              enable_prompt_adapter=True,
+              max_prompt_adapter_token=8)
 
     # Run the sampling function
     do_sample(llm, "twitter_pa", pa_id=1)
 
+
 if __name__ == "__main__":
     main()