|
@@ -19,27 +19,8 @@ TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
|
|
|
app = FastAPI()
|
|
|
engine = None
|
|
|
|
|
|
-# user_tokens: Dict[str, str] = {}
|
|
|
-
|
|
|
-# def get_token(authorization: str = Header(None)):
|
|
|
-# if authorization is None or not authorization.startswith("Bearer "):
|
|
|
-# raise HTTPException(status_code=401, detail="Unauthorized access.")
|
|
|
-# token = authorization.replace("Bearer ", "")
|
|
|
-
|
|
|
-# # Check if the token exists in the user_tokens dictionary
|
|
|
-# if token not in user_tokens:
|
|
|
-# raise HTTPException(status_code=401, detail="Unauthorized access.")
|
|
|
-
|
|
|
-# return True
|
|
|
-
|
|
|
-
|
|
|
-# def generate_user_token(user_id: str) -> str:
|
|
|
-# token = random_uuid()
|
|
|
-# user_tokens[token] = user_id
|
|
|
-# return token
|
|
|
|
|
|
@app.post("/api/v1/generate")
|
|
|
-# async def generate(request: Request, token: bool = Depends(get_token), params: SamplingParams) -> Response:
|
|
|
async def generate(request: Request) -> Response:
|
|
|
"""Generate completion for the request.
|
|
|
|
|
@@ -62,36 +43,14 @@ async def generate(request: Request) -> Response:
|
|
|
request_dict['frequency_penalty'] = request_dict.pop('repetition_penalty')
|
|
|
if 'ban_eos_token' in request_dict:
|
|
|
request_dict['ignore_eos'] = request_dict.pop('ban_eos_token')
|
|
|
+ if 'top_k' in request_dict and request_dict['top_k'] == 0:
|
|
|
+ request_dict['top_k'] = -1
|
|
|
+
|
|
|
|
|
|
for key, value in request_dict.items():
|
|
|
if hasattr(sampling_params, key):
|
|
|
setattr(sampling_params, key, value)
|
|
|
|
|
|
- # sampling_params = SamplingParams(**sampling_params_data)
|
|
|
-
|
|
|
- # param_aliases = {
|
|
|
- # 'stop_sequence': 'stop',
|
|
|
- # 'max_length': 'max_tokens',
|
|
|
- # 'rep_pen': 'frequency_penalty',
|
|
|
- # 'use_story': None,
|
|
|
- # 'use_memory': None,
|
|
|
- # 'use_authors_note': None,
|
|
|
- # 'use_world_info': None,
|
|
|
- # 'max_context_length': None,
|
|
|
- # 'rep_pen_range': None,
|
|
|
- # 'rep_pen_slope': None,
|
|
|
- # 'tfs': None,
|
|
|
- # 'top_a': None,
|
|
|
- # 'typical': None,
|
|
|
- # 'sampler_order': None,
|
|
|
- # 'singleline': None,
|
|
|
- # 'use_default_badwordsids': None,
|
|
|
- # 'mirostat': None,
|
|
|
- # 'mirostat_eta': None,
|
|
|
- # 'mirostat_tau': None,
|
|
|
- # }
|
|
|
-
|
|
|
- # sampling_params = SamplingParams(**request_dict)
|
|
|
request_id = random_uuid()
|
|
|
|
|
|
results_generator = engine.generate(prompt, sampling_params, request_id)
|
|
@@ -111,7 +70,6 @@ async def generate(request: Request) -> Response:
|
|
|
|
|
|
if stream:
|
|
|
background_tasks = BackgroundTasks()
|
|
|
- # Abort the request if the client disconnects.
|
|
|
background_tasks.add_task(abort_request)
|
|
|
return StreamingResponse(stream_results(), background=background_tasks)
|
|
|
|
|
@@ -132,7 +90,6 @@ async def generate(request: Request) -> Response:
|
|
|
|
|
|
|
|
|
@app.get("/api/v1/model")
|
|
|
-# async def get_model_name(token: bool = Depends(get_token)) -> JSONResponse:
|
|
|
async def get_model_name() -> JSONResponse:
|
|
|
"""Return the model name based on the EngineArgs configuration."""
|
|
|
if engine is not None:
|
|
@@ -142,10 +99,6 @@ async def get_model_name() -> JSONResponse:
|
|
|
else:
|
|
|
return JSONResponse(content={"result": "Read Only"}, status_code=500)
|
|
|
|
|
|
-# @app.post("/api/v1/get-token")
|
|
|
-# async def get_user_token(user_id: str) -> JSONResponse:
|
|
|
-# token = generate_user_token(user_id)
|
|
|
-# return JSONResponse(content={"token": token})
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
parser = argparse.ArgumentParser()
|