Browse Source

fix: kobold api generation

AlpinDale 6 months ago
parent
commit
2ed49cdc4c
1 changed files with 16 additions and 4 deletions
  1. 16 4
      aphrodite/endpoints/openai/api_server.py

+ 16 - 4
aphrodite/endpoints/openai/api_server.py

@@ -285,8 +285,14 @@ def prepare_engine_payload(
 @kai_api.post("/generate")
 async def generate(kai_payload: KAIGenerationInputSchema) -> JSONResponse:
     sampling_params, input_tokens = prepare_engine_payload(kai_payload)
-    result_generator = engine.generate(None, sampling_params,
-                                       kai_payload.genkey, input_tokens)
+    result_generator = engine.generate(
+        {
+            "prompt": kai_payload.prompt,
+            "prompt_token_ids": input_tokens,
+        },
+        sampling_params,
+        kai_payload.genkey,
+    )
 
     final_res: RequestOutput = None
     previous_output = ""
@@ -310,8 +316,14 @@ async def generate_stream(
         kai_payload: KAIGenerationInputSchema) -> StreamingResponse:
 
     sampling_params, input_tokens = prepare_engine_payload(kai_payload)
-    results_generator = engine.generate(None, sampling_params,
-                                        kai_payload.genkey, input_tokens)
+    results_generator = engine.generate(
+        {
+            "prompt": kai_payload.prompt,
+            "prompt_token_ids": input_tokens,
+        },
+        sampling_params,
+        kai_payload.genkey,
+    )
 
     async def stream_kobold() -> AsyncGenerator[bytes, None]:
         previous_output = ""