123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338 |
- import json
- import os
- import time
- from dataclasses import dataclass
- from typing import Optional
- import aiohttp
- from tqdm.asyncio import tqdm
- AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
- @dataclass
- class RequestFuncInput:
- prompt: str
- api_url: str
- prompt_len: int
- output_len: int
- model: str
- best_of: int = 1
- use_beam_search: bool = False
- @dataclass
- class RequestFuncOutput:
- generated_text: str = ""
- success: bool = False
- latency: float = 0
- ttft: float = 0
- prompt_len: int = 0
- async def async_request_tgi(
- request_func_input: RequestFuncInput,
- pbar: Optional[tqdm] = None,
- ) -> RequestFuncOutput:
- api_url = request_func_input.api_url
- assert api_url.endswith("generate_stream")
- async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
- assert not request_func_input.use_beam_search
- params = {
- "best_of": request_func_input.best_of,
- "max_new_tokens": request_func_input.output_len,
- "do_sample": True,
- "temperature": 0.01, # TGI does not accept 0.0 temperature.
- "top_p": 0.99, # TGI does not accept 1.0 top_p.
- }
- payload = {
- "inputs": request_func_input.prompt,
- "parameters": params,
- }
- output = RequestFuncOutput()
- output.prompt_len = request_func_input.prompt_len
- ttft = 0
- st = time.perf_counter()
- try:
- async with session.post(url=api_url, json=payload) as response:
- if response.status == 200:
- async for data in response.content.iter_any():
- if ttft == 0:
- ttft = time.perf_counter() - st
- output.ttft = ttft
- output.latency = time.perf_counter() - st
- body = data.decode("utf-8").lstrip("data:")
- output.generated_text = json.loads(body)["generated_text"]
- output.success = True
- else:
- output.success = False
- except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
- output.success = False
- if pbar:
- pbar.update(1)
- return output
- async def async_request_aphrodite(
- request_func_input: RequestFuncInput,
- pbar: Optional[tqdm] = None,
- ) -> RequestFuncOutput:
- api_url = request_func_input.api_url
- assert api_url.endswith("generate")
- async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
- payload = {
- "prompt": request_func_input.prompt,
- "n": 1,
- "best_of": request_func_input.best_of,
- "temperature": 0.0 if request_func_input.use_beam_search else 1.0,
- "top_p": 1.0,
- "min_p": 0.06,
- "seed": 42,
- "max_tokens": request_func_input.output_len,
- "ignore_eos": True,
- "stream": True,
- }
- output = RequestFuncOutput()
- output.prompt_len = request_func_input.prompt_len
- ttft = 0
- st = time.perf_counter()
- try:
- async with session.post(url=api_url, json=payload) as response:
- if response.status == 200:
- async for data in response.content.iter_any():
- if ttft == 0:
- ttft = time.perf_counter() - st
- output.ttft = ttft
- output.latency = time.perf_counter() - st
- # When streaming, '\0' is appended to the end of the
- # response.
- body = data.decode("utf-8").strip("\0")
- output.generated_text = json.loads(
- body)["text"][0][len(request_func_input.prompt):]
- output.success = True
- else:
- output.success = False
- except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
- output.success = False
- if pbar:
- pbar.update(1)
- return output
- async def async_request_vllm(
- request_func_input: RequestFuncInput,
- pbar: Optional[tqdm] = None,
- ) -> RequestFuncOutput:
- api_url = request_func_input.api_url
- assert api_url.endswith("generate")
- async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
- payload = {
- "prompt": request_func_input.prompt,
- "n": 1,
- "best_of": request_func_input.best_of,
- "use_beam_search": request_func_input.use_beam_search,
- "temperature": 0.0 if request_func_input.use_beam_search else 1.0,
- "top_p": 1.0,
- "max_tokens": request_func_input.output_len,
- "ignore_eos": True,
- "stream": True,
- }
- output = RequestFuncOutput()
- output.prompt_len = request_func_input.prompt_len
- ttft = 0
- st = time.perf_counter()
- try:
- async with session.post(url=api_url, json=payload) as response:
- if response.status == 200:
- async for data in response.content.iter_any():
- if ttft == 0:
- ttft = time.perf_counter() - st
- output.ttft = ttft
- output.latency = time.perf_counter() - st
- # When streaming, '\0' is appended to the end of the
- # response.
- body = data.decode("utf-8").strip("\0")
- output.generated_text = json.loads(
- body)["text"][0][len(request_func_input.prompt):]
- output.success = True
- else:
- output.success = False
- except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
- output.success = False
- if pbar:
- pbar.update(1)
- return output
- async def async_request_trt_llm(
- request_func_input: RequestFuncInput,
- pbar: Optional[tqdm] = None,
- ) -> RequestFuncOutput:
- api_url = request_func_input.api_url
- assert api_url.endswith("generate_stream")
- async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
- assert not request_func_input.use_beam_search
- assert request_func_input.best_of == 1
- payload = {
- "accumulate_tokens": True,
- "text_input": request_func_input.prompt,
- "temperature": 0.0,
- "top_p": 1.0,
- "max_tokens": request_func_input.output_len,
- "stream": True,
- }
- output = RequestFuncOutput()
- output.prompt_len = request_func_input.prompt_len
- ttft = 0
- st = time.perf_counter()
- try:
- async with session.post(url=api_url, json=payload) as resp:
- if resp.status == 200:
- async for data in resp.content.iter_any():
- if ttft == 0:
- ttft = time.perf_counter() - st
- output.ttft = ttft
- output.latency = time.perf_counter() - st
- body = data.decode("utf-8").lstrip("data:")
- output.generated_text = json.loads(body)["text_output"]
- output.success = True
- else:
- output.success = False
- except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
- output.success = False
- if pbar:
- pbar.update(1)
- return output
- async def async_request_deepspeed_mii(
- request_func_input: RequestFuncInput,
- pbar: Optional[tqdm] = None,
- ) -> RequestFuncOutput:
- async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
- assert request_func_input.best_of == 1
- assert not request_func_input.use_beam_search
- payload = {
- "prompts": request_func_input.prompt,
- "max_new_tokens": request_func_input.output_len,
- "ignore_eos": True,
- "do_sample": True,
- "temperature":
- 0.01, # deepspeed-mii does not accept 0.0 temperature.
- "top_p": 1.0,
- }
- output = RequestFuncOutput()
- output.prompt_len = request_func_input.prompt_len
- # DeepSpeed-MII doesn't support streaming as of Jan 28 2024, will use
- # 0 as placeholder.
- # https://github.com/microsoft/DeepSpeed-MII/pull/311
- output.ttft = 0
- st = time.perf_counter()
- try:
- async with session.post(url=request_func_input.api_url,
- json=payload) as resp:
- if resp.status == 200:
- parsed_resp = await resp.json()
- output.latency = time.perf_counter() - st
- output.generated_text = parsed_resp[0]["generated_text"]
- output.success = True
- else:
- output.success = False
- except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
- output.success = False
- if pbar:
- pbar.update(1)
- return output
- async def async_request_openai_completions(
- request_func_input: RequestFuncInput,
- pbar: Optional[tqdm] = None,
- ) -> RequestFuncOutput:
- api_url = request_func_input.api_url
- assert api_url.endswith("v1/completions")
- async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
- assert not request_func_input.use_beam_search
- payload = {
- "model": request_func_input.model,
- "prompt": request_func_input.prompt,
- "temperature": 0.0,
- "best_of": request_func_input.best_of,
- "max_tokens": request_func_input.output_len,
- "stream": True,
- }
- headers = {
- "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
- }
- output = RequestFuncOutput()
- output.prompt_len = request_func_input.prompt_len
- generated_text = ""
- ttft = 0
- st = time.perf_counter()
- try:
- async with session.post(url=api_url, json=payload,
- headers=headers) as response:
- if response.status == 200:
- async for chunk in response.content:
- if ttft == 0:
- ttft = time.perf_counter() - st
- output.ttft = ttft
- chunk = chunk.strip()
- if not chunk:
- continue
- chunk = chunk.decode("utf-8").lstrip("data: ")
- if chunk == "[DONE]":
- latency = time.perf_counter() - st
- else:
- body = json.loads(chunk)
- generated_text += body["choices"][0]["text"]
- output.generated_text = generated_text
- output.success = True
- output.latency = latency
- else:
- output.success = False
- except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
- output.success = False
- if pbar:
- pbar.update(1)
- return output
- ASYNC_REQUEST_FUNCS = {
- "tgi": async_request_tgi,
- "aphrodite": async_request_aphrodite,
- "vllm": async_request_vllm,
- "deepspeed-mii": async_request_deepspeed_mii,
- "openai": async_request_openai_completions,
- "tensorrt-llm": async_request_trt_llm,
- }
|