فهرست منبع

fix top_k in server

AlpinDale 1 سال پیش
والد
کامیت
ac61b31879
1فایلهای تغییر یافته به همراه3 افزوده شده و 50 حذف شده
  1. 3 50
      aphrodite/endpoints/api_server_ooba.py

+ 3 - 50
aphrodite/endpoints/api_server_ooba.py

@@ -19,27 +19,8 @@ TIMEOUT_TO_PREVENT_DEADLOCK = 1  # seconds.
 app = FastAPI()
 engine = None
 
-# user_tokens: Dict[str, str] = {}
-
-# def get_token(authorization: str = Header(None)):
-#     if authorization is None or not authorization.startswith("Bearer "):
-#         raise HTTPException(status_code=401, detail="Unauthorized access.")
-#     token = authorization.replace("Bearer ", "")
-    
-#     # Check if the token exists in the user_tokens dictionary
-#     if token not in user_tokens:
-#         raise HTTPException(status_code=401, detail="Unauthorized access.")
-    
-#     return True
-
-
-# def generate_user_token(user_id: str) -> str:
-#     token = random_uuid()
-#     user_tokens[token] = user_id
-#     return token
 
 @app.post("/api/v1/generate")
-# async def generate(request: Request, token: bool = Depends(get_token), params: SamplingParams) -> Response:
 async def generate(request: Request) -> Response:
     """Generate completion for the request.
 
@@ -62,36 +43,14 @@ async def generate(request: Request) -> Response:
         request_dict['frequency_penalty'] = request_dict.pop('repetition_penalty')
     if 'ban_eos_token' in request_dict:
         request_dict['ignore_eos'] = request_dict.pop('ban_eos_token')
+    if 'top_k' in request_dict and request_dict['top_k'] == 0:
+        request_dict['top_k'] = -1
+
 
     for key, value in request_dict.items():
         if hasattr(sampling_params, key):
             setattr(sampling_params, key, value)
 
-    # sampling_params = SamplingParams(**sampling_params_data)
-
-    # param_aliases = {
-    #     'stop_sequence': 'stop',
-    #     'max_length': 'max_tokens',
-    #     'rep_pen': 'frequency_penalty',
-    #     'use_story': None,
-    #     'use_memory': None,
-    #     'use_authors_note': None,
-    #     'use_world_info': None,
-    #     'max_context_length': None,
-    #     'rep_pen_range': None,
-    #     'rep_pen_slope': None,
-    #     'tfs': None,
-    #     'top_a': None,
-    #     'typical': None,
-    #     'sampler_order': None,
-    #     'singleline': None,
-    #     'use_default_badwordsids': None,
-    #     'mirostat': None,
-    #     'mirostat_eta': None,
-    #     'mirostat_tau': None,
-    # }
-
-    # sampling_params = SamplingParams(**request_dict)
     request_id = random_uuid()
 
     results_generator = engine.generate(prompt, sampling_params, request_id)
@@ -111,7 +70,6 @@ async def generate(request: Request) -> Response:
 
     if stream:
         background_tasks = BackgroundTasks()
-        # Abort the request if the client disconnects.
         background_tasks.add_task(abort_request)
         return StreamingResponse(stream_results(), background=background_tasks)
 
@@ -132,7 +90,6 @@ async def generate(request: Request) -> Response:
 
 
 @app.get("/api/v1/model")
-# async def get_model_name(token: bool = Depends(get_token)) -> JSONResponse:
 async def get_model_name() -> JSONResponse:
     """Return the model name based on the EngineArgs configuration."""
     if engine is not None:
@@ -142,10 +99,6 @@ async def get_model_name() -> JSONResponse:
     else:
         return JSONResponse(content={"result": "Read Only"}, status_code=500)
 
-# @app.post("/api/v1/get-token")
-# async def get_user_token(user_id: str) -> JSONResponse:
-#     token = generate_user_token(user_id)
-#     return JSONResponse(content={"token": token})
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()