import json import os import sys import time import traceback from dataclasses import dataclass, field from typing import List, Optional, Union import aiohttp import huggingface_hub.constants from tqdm.asyncio import tqdm from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) @dataclass class RequestFuncInput: prompt: str api_url: str prompt_len: int output_len: int model: str best_of: int = 1 use_beam_search: bool = False @dataclass class RequestFuncOutput: generated_text: str = "" success: bool = False latency: float = 0.0 ttft: float = 0.0 # Time to first token itl: List[float] = field( default_factory=list) # List of inter-token latencies prompt_len: int = 0 error: str = "" async def async_request_tgi( request_func_input: RequestFuncInput, pbar: Optional[tqdm] = None, ) -> RequestFuncOutput: api_url = request_func_input.api_url assert api_url.endswith("generate_stream") async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: assert not request_func_input.use_beam_search params = { "best_of": request_func_input.best_of, "max_new_tokens": request_func_input.output_len, "do_sample": True, "temperature": 0.01, # TGI does not accept 0.0 temperature. "top_p": 0.99, # TGI does not accept 1.0 top_p. } payload = { "inputs": request_func_input.prompt, "parameters": params, } output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len ttft = 0.0 st = time.perf_counter() most_recent_timestamp = st try: async with session.post(url=api_url, json=payload) as response: if response.status == 200: async for chunk_bytes in response.content: chunk_bytes = chunk_bytes.strip() if not chunk_bytes: continue chunk_bytes = chunk_bytes.decode("utf-8") #NOTE: Sometimes TGI returns a ping response without # any data, we should skip it. if chunk_bytes.startswith(":"): continue chunk = remove_prefix(chunk_bytes, "data:") data = json.loads(chunk) timestamp = time.perf_counter() # First token if ttft == 0.0: ttft = time.perf_counter() - st output.ttft = ttft # Decoding phase else: output.itl.append(timestamp - most_recent_timestamp) most_recent_timestamp = timestamp output.latency = most_recent_timestamp - st output.success = True output.generated_text = data["generated_text"] else: output.error = response.reason or "" output.success = False except Exception: output.success = False exc_info = sys.exc_info() output.error = "".join(traceback.format_exception(*exc_info)) if pbar: pbar.update(1) return output async def async_request_trt_llm( request_func_input: RequestFuncInput, pbar: Optional[tqdm] = None, ) -> RequestFuncOutput: api_url = request_func_input.api_url assert api_url.endswith("generate_stream") async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: assert not request_func_input.use_beam_search assert request_func_input.best_of == 1 payload = { "accumulate_tokens": True, "text_input": request_func_input.prompt, "temperature": 0.0, "top_p": 1.0, "max_tokens": request_func_input.output_len, "stream": True, } output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len ttft = 0.0 st = time.perf_counter() most_recent_timestamp = st try: async with session.post(url=api_url, json=payload) as response: if response.status == 200: async for chunk_bytes in response.content: chunk_bytes = chunk_bytes.strip() if not chunk_bytes: continue chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data:") data = json.loads(chunk) output.generated_text += data["text_output"] timestamp = time.perf_counter() # First token if ttft == 0.0: ttft = time.perf_counter() - st output.ttft = ttft # Decoding phase else: output.itl.append(timestamp - most_recent_timestamp) most_recent_timestamp = timestamp output.latency = most_recent_timestamp - st output.success = True else: output.error = response.reason or "" output.success = False except Exception: output.success = False exc_info = sys.exc_info() output.error = "".join(traceback.format_exception(*exc_info)) if pbar: pbar.update(1) return output async def async_request_deepspeed_mii( request_func_input: RequestFuncInput, pbar: Optional[tqdm] = None, ) -> RequestFuncOutput: async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: assert request_func_input.best_of == 1 assert not request_func_input.use_beam_search payload = { "prompt": request_func_input.prompt, "max_tokens": request_func_input.output_len, "temperature": 0.01, # deepspeed-mii does not accept 0.0 temp. "top_p": 1.0, } output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len # NOTE: DeepSpeed-MII doesn't support streaming as of Jan 28 2024, # will use 0 as placeholder. # See https://github.com/microsoft/DeepSpeed-MII/pull/311 output.ttft = 0 st = time.perf_counter() try: async with session.post(url=request_func_input.api_url, json=payload) as response: if response.status == 200: parsed_resp = await response.json() output.latency = time.perf_counter() - st output.generated_text = parsed_resp["text"][0] output.success = True else: output.error = response.reason or "" output.success = False except Exception: output.success = False exc_info = sys.exc_info() output.error = "".join(traceback.format_exception(*exc_info)) if pbar: pbar.update(1) return output async def async_request_openai_completions( request_func_input: RequestFuncInput, pbar: Optional[tqdm] = None, ) -> RequestFuncOutput: api_url = request_func_input.api_url assert api_url.endswith( "completions" ), "OpenAI Completions API URL must end with 'completions'." async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: assert not request_func_input.use_beam_search payload = { "model": request_func_input.model, "prompt": request_func_input.prompt, "temperature": 0.0, "best_of": request_func_input.best_of, "max_tokens": request_func_input.output_len, "stream": True, } headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" } output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len generated_text = "" ttft = 0.0 st = time.perf_counter() most_recent_timestamp = st try: async with session.post(url=api_url, json=payload, headers=headers) as response: if response.status == 200: async for chunk_bytes in response.content: chunk_bytes = chunk_bytes.strip() if not chunk_bytes: continue chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ") if chunk == "[DONE]": latency = time.perf_counter() - st else: data = json.loads(chunk) # NOTE: Some completion API might have a last # usage summary response without a token so we # want to check a token was generated if data["choices"][0]["text"]: timestamp = time.perf_counter() # First token if ttft == 0.0: ttft = time.perf_counter() - st output.ttft = ttft # Decoding phase output.itl.append(timestamp - most_recent_timestamp) most_recent_timestamp = timestamp generated_text += data["choices"][0]["text"] output.generated_text = generated_text output.success = True output.latency = latency else: output.error = response.reason or "" output.success = False except Exception: output.success = False exc_info = sys.exc_info() output.error = "".join(traceback.format_exception(*exc_info)) if pbar: pbar.update(1) return output async def async_request_openai_chat_completions( request_func_input: RequestFuncInput, pbar: Optional[tqdm] = None, ) -> RequestFuncOutput: api_url = request_func_input.api_url assert api_url.endswith( "chat/completions" ), "OpenAI Chat Completions API URL must end with 'chat/completions'." async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: assert not request_func_input.use_beam_search payload = { "model": request_func_input.model, "messages": [ { "role": "user", "content": request_func_input.prompt, }, ], "temperature": 0.0, "max_tokens": request_func_input.output_len, "stream": True, } headers = { "Content-Type": "application/json", "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", } output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len generated_text = "" ttft = 0.0 st = time.perf_counter() most_recent_timestamp = st try: async with session.post(url=api_url, json=payload, headers=headers) as response: if response.status == 200: async for chunk_bytes in response.content: chunk_bytes = chunk_bytes.strip() if not chunk_bytes: continue chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ") if chunk == "[DONE]": latency = time.perf_counter() - st else: timestamp = time.perf_counter() data = json.loads(chunk) delta = data["choices"][0]["delta"] if delta.get("content", None): # First token if ttft == 0.0: ttft = time.perf_counter() - st output.ttft = ttft # Decoding phase else: output.itl.append(timestamp - most_recent_timestamp) generated_text += delta["content"] most_recent_timestamp = timestamp output.generated_text = generated_text output.success = True output.latency = latency else: output.error = response.reason or "" output.success = False except Exception: output.success = False exc_info = sys.exc_info() output.error = "".join(traceback.format_exception(*exc_info)) if pbar: pbar.update(1) return output # Since aphrodite must support Python 3.8, we can't use str.removeprefix(prefix) # introduced in Python 3.9 def remove_prefix(text: str, prefix: str) -> str: if text.startswith(prefix): return text[len(prefix):] return text def get_model(pretrained_model_name_or_path: str) -> str: if os.getenv('APHRODITE_USE_MODELSCOPE', 'False').lower() == 'true': from modelscope import snapshot_download model_path = snapshot_download( model_id=pretrained_model_name_or_path, local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"]) return model_path return pretrained_model_name_or_path def get_tokenizer( pretrained_model_name_or_path: str, trust_remote_code: bool ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: if pretrained_model_name_or_path is not None and not os.path.exists( pretrained_model_name_or_path): pretrained_model_name_or_path = get_model( pretrained_model_name_or_path) return AutoTokenizer.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code) ASYNC_REQUEST_FUNCS = { "tgi": async_request_tgi, "aphrodite": async_request_openai_completions, "vllm": async_request_openai_completions, "lmdeploy": async_request_openai_completions, "deepspeed-mii": async_request_deepspeed_mii, "openai": async_request_openai_completions, "openai-chat": async_request_openai_chat_completions, "tensorrt-llm": async_request_trt_llm, "scalellm": async_request_openai_completions, }