|
@@ -2,8 +2,8 @@
|
|
|
|
|
|
On the server side, run one of the following commands:
|
|
|
Aphrodite OpenAI API server
|
|
|
- python -m aphrodite.endpoints.openai.api_server \
|
|
|
- --model <your_model> --swap-space 16 \
|
|
|
+ aphrodite run <your_model> \
|
|
|
+ --swap-space 16 \
|
|
|
--disable-log-requests
|
|
|
|
|
|
(TGI backend)
|
|
@@ -17,7 +17,7 @@ On the client side, run:
|
|
|
--dataset-path <path to dataset> \
|
|
|
--request-rate <request_rate> \ # By default <request_rate> is inf
|
|
|
--num-prompts <num_prompts> # By default <num_prompts> is 1000
|
|
|
-
|
|
|
+
|
|
|
when using tgi backend, add
|
|
|
--endpoint /generate_stream
|
|
|
to the end of the command above.
|
|
@@ -31,7 +31,7 @@ import time
|
|
|
import warnings
|
|
|
from dataclasses import dataclass
|
|
|
from datetime import datetime
|
|
|
-from typing import AsyncGenerator, List, Optional, Tuple
|
|
|
+from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
|
|
|
|
|
import numpy as np
|
|
|
from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
|
|
@@ -39,7 +39,15 @@ from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
|
|
|
from tqdm.asyncio import tqdm
|
|
|
from transformers import PreTrainedTokenizerBase
|
|
|
|
|
|
-from aphrodite.transformers_utils.tokenizer import get_tokenizer
|
|
|
+try:
|
|
|
+ from aphrodite.transformers_utils.tokenizer import get_tokenizer
|
|
|
+except ImportError:
|
|
|
+ from backend_request_func import get_tokenizer
|
|
|
+
|
|
|
+try:
|
|
|
+ from aphrodite.common.utils import FlexibleArgumentParser
|
|
|
+except ImportError:
|
|
|
+ from argparse import ArgumentParser as FlexibleArgumentParser
|
|
|
|
|
|
|
|
|
@dataclass
|
|
@@ -52,10 +60,16 @@ class BenchmarkMetrics:
|
|
|
output_throughput: float
|
|
|
mean_ttft_ms: float
|
|
|
median_ttft_ms: float
|
|
|
+ std_ttft_ms: float
|
|
|
p99_ttft_ms: float
|
|
|
mean_tpot_ms: float
|
|
|
median_tpot_ms: float
|
|
|
+ std_tpot_ms: float
|
|
|
p99_tpot_ms: float
|
|
|
+ mean_itl_ms: float
|
|
|
+ median_itl_ms: float
|
|
|
+ std_itl_ms: float
|
|
|
+ p99_itl_ms: float
|
|
|
|
|
|
|
|
|
def sample_sharegpt_requests(
|
|
@@ -66,7 +80,6 @@ def sample_sharegpt_requests(
|
|
|
) -> List[Tuple[str, int, int]]:
|
|
|
if fixed_output_len is not None and fixed_output_len < 4:
|
|
|
raise ValueError("output_len too small")
|
|
|
-
|
|
|
# Load the dataset.
|
|
|
with open(dataset_path) as f:
|
|
|
dataset = json.load(f)
|
|
@@ -174,6 +187,31 @@ def sample_sonnet_requests(
|
|
|
return sampled_requests
|
|
|
|
|
|
|
|
|
+def sample_random_requests(
|
|
|
+ input_len: int, output_len: int, num_prompts: int, range_ratio: float,
|
|
|
+ tokenizer: PreTrainedTokenizerBase) -> List[Tuple[str, int, int]]:
|
|
|
+
|
|
|
+ input_lens = np.random.randint(
|
|
|
+ int(input_len * range_ratio),
|
|
|
+ input_len + 1,
|
|
|
+ size=num_prompts,
|
|
|
+ )
|
|
|
+ output_lens = np.random.randint(
|
|
|
+ int(output_len * range_ratio),
|
|
|
+ output_len + 1,
|
|
|
+ size=num_prompts,
|
|
|
+ )
|
|
|
+ offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
|
|
|
+ input_requests = []
|
|
|
+ for i in range(num_prompts):
|
|
|
+ prompt = tokenizer.decode([(offsets[i] + i + j) % tokenizer.vocab_size
|
|
|
+ for j in range(input_lens[i])])
|
|
|
+ input_requests.append(
|
|
|
+ (prompt, int(input_lens[i]), int(output_lens[i])))
|
|
|
+
|
|
|
+ return input_requests
|
|
|
+
|
|
|
+
|
|
|
async def get_request(
|
|
|
input_requests: List[Tuple[str, int, int]],
|
|
|
request_rate: float,
|
|
@@ -185,6 +223,7 @@ async def get_request(
|
|
|
if request_rate == float("inf"):
|
|
|
# If the request rate is infinity, then we don't need to wait.
|
|
|
continue
|
|
|
+
|
|
|
# Sample the request interval from the exponential distribution.
|
|
|
interval = np.random.exponential(1.0 / request_rate)
|
|
|
# The next request will be sent after the interval.
|
|
@@ -197,19 +236,27 @@ def calculate_metrics(
|
|
|
dur_s: float,
|
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
|
) -> Tuple[BenchmarkMetrics, List[int]]:
|
|
|
- actual_output_lens = []
|
|
|
+ actual_output_lens: List[int] = []
|
|
|
total_input = 0
|
|
|
completed = 0
|
|
|
- tpots = []
|
|
|
- ttfts = []
|
|
|
+ itls: List[float] = []
|
|
|
+ tpots: List[float] = []
|
|
|
+ ttfts: List[float] = []
|
|
|
for i in range(len(outputs)):
|
|
|
if outputs[i].success:
|
|
|
- output_len = len(tokenizer(outputs[i].generated_text).input_ids)
|
|
|
+ # We use the tokenizer to count the number of output tokens for all
|
|
|
+ # serving backends instead of looking at len(outputs[i].itl) since
|
|
|
+ # multiple output tokens may be bundled together
|
|
|
+ # Note : this may inflate the output token count slightly
|
|
|
+ output_len = len(
|
|
|
+ tokenizer(outputs[i].generated_text,
|
|
|
+ add_special_tokens=False).input_ids)
|
|
|
actual_output_lens.append(output_len)
|
|
|
total_input += input_requests[i][1]
|
|
|
if output_len > 1:
|
|
|
tpots.append(
|
|
|
(outputs[i].latency - outputs[i].ttft) / (output_len - 1))
|
|
|
+ itls += outputs[i].itl
|
|
|
ttfts.append(outputs[i].ttft)
|
|
|
completed += 1
|
|
|
else:
|
|
@@ -230,10 +277,16 @@ def calculate_metrics(
|
|
|
mean_ttft_ms=np.mean(ttfts or 0) *
|
|
|
1000, # ttfts is empty if streaming is not supported by backend
|
|
|
median_ttft_ms=np.median(ttfts or 0) * 1000,
|
|
|
+ std_ttft_ms=np.std(ttfts or 0) * 1000,
|
|
|
p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
|
|
|
mean_tpot_ms=np.mean(tpots or 0) * 1000,
|
|
|
median_tpot_ms=np.median(tpots or 0) * 1000,
|
|
|
+ std_tpot_ms=np.std(tpots or 0) * 1000,
|
|
|
p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
|
|
|
+ mean_itl_ms=np.mean(itls or 0) * 1000,
|
|
|
+ median_itl_ms=np.median(itls or 0) * 1000,
|
|
|
+ std_itl_ms=np.std(itls or 0) * 1000,
|
|
|
+ p99_itl_ms=np.percentile(itls or 0, 99) * 1000,
|
|
|
)
|
|
|
|
|
|
return metrics, actual_output_lens
|
|
@@ -251,7 +304,7 @@ async def benchmark(
|
|
|
disable_tqdm: bool,
|
|
|
):
|
|
|
if backend in ASYNC_REQUEST_FUNCS:
|
|
|
- request_func = ASYNC_REQUEST_FUNCS.get(backend)
|
|
|
+ request_func = ASYNC_REQUEST_FUNCS[backend]
|
|
|
else:
|
|
|
raise ValueError(f"Unknown backend: {backend}")
|
|
|
|
|
@@ -278,7 +331,7 @@ async def benchmark(
|
|
|
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
|
|
|
|
|
benchmark_start_time = time.perf_counter()
|
|
|
- tasks = []
|
|
|
+ tasks: List[asyncio.Task] = []
|
|
|
async for request in get_request(input_requests, request_rate):
|
|
|
prompt, prompt_len, output_len = request
|
|
|
request_func_input = RequestFuncInput(
|
|
@@ -296,7 +349,7 @@ async def benchmark(
|
|
|
pbar=pbar)))
|
|
|
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
|
|
|
|
|
|
- if not disable_tqdm:
|
|
|
+ if pbar is not None:
|
|
|
pbar.close()
|
|
|
|
|
|
benchmark_duration = time.perf_counter() - benchmark_start_time
|
|
@@ -333,6 +386,10 @@ async def benchmark(
|
|
|
print("{:<40} {:<10.2f}".format("Median TPOT (ms):",
|
|
|
metrics.median_tpot_ms))
|
|
|
print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
|
|
|
+ print("{s:{c}^{n}}".format(s='Inter-token Latency', n=50, c='-'))
|
|
|
+ print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
|
|
|
+ print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
|
|
|
+ print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
|
|
|
print("=" * 50)
|
|
|
|
|
|
result = {
|
|
@@ -345,10 +402,16 @@ async def benchmark(
|
|
|
"output_throughput": metrics.output_throughput,
|
|
|
"mean_ttft_ms": metrics.mean_ttft_ms,
|
|
|
"median_ttft_ms": metrics.median_ttft_ms,
|
|
|
+ "std_ttft_ms": metrics.std_ttft_ms,
|
|
|
"p99_ttft_ms": metrics.p99_ttft_ms,
|
|
|
"mean_tpot_ms": metrics.mean_tpot_ms,
|
|
|
"median_tpot_ms": metrics.median_tpot_ms,
|
|
|
+ "std_tpot_ms": metrics.std_tpot_ms,
|
|
|
"p99_tpot_ms": metrics.p99_tpot_ms,
|
|
|
+ "mean_itl_ms": metrics.mean_itl_ms,
|
|
|
+ "median_itl_ms": metrics.median_itl_ms,
|
|
|
+ "std_itl_ms": metrics.std_itl_ms,
|
|
|
+ "p99_itl_ms": metrics.p99_itl_ms,
|
|
|
"input_lens": [output.prompt_len for output in outputs],
|
|
|
"output_lens": actual_output_lens,
|
|
|
"ttfts": [output.ttft for output in outputs],
|
|
@@ -427,6 +490,15 @@ def main(args: argparse.Namespace):
|
|
|
for prompt, prompt_formatted, prompt_len,
|
|
|
output_len in input_requests]
|
|
|
|
|
|
+ elif args.dataset_name == "random":
|
|
|
+ input_requests = sample_random_requests(
|
|
|
+ input_len=args.random_input_len,
|
|
|
+ output_len=args.random_output_len,
|
|
|
+ num_prompts=args.num_prompts,
|
|
|
+ range_ratio=args.random_range_ratio,
|
|
|
+ tokenizer=tokenizer,
|
|
|
+ )
|
|
|
+
|
|
|
else:
|
|
|
raise ValueError(f"Unknown dataset: {args.dataset_name}")
|
|
|
|
|
@@ -445,7 +517,7 @@ def main(args: argparse.Namespace):
|
|
|
|
|
|
# Save config and results to json
|
|
|
if args.save_result:
|
|
|
- result_json = {}
|
|
|
+ result_json: Dict[str, Any] = {}
|
|
|
|
|
|
# Setup
|
|
|
current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
|
|
@@ -478,6 +550,8 @@ def main(args: argparse.Namespace):
|
|
|
# Save to file
|
|
|
base_model_id = model_id.split("/")[-1]
|
|
|
file_name = f"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" #noqa
|
|
|
+ if args.result_filename:
|
|
|
+ file_name = args.result_filename
|
|
|
if args.result_dir:
|
|
|
file_name = os.path.join(args.result_dir, file_name)
|
|
|
with open(file_name, "w") as outfile:
|
|
@@ -485,7 +559,7 @@ def main(args: argparse.Namespace):
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
- parser = argparse.ArgumentParser(
|
|
|
+ parser = FlexibleArgumentParser(
|
|
|
description="Benchmark the online serving throughput.")
|
|
|
parser.add_argument(
|
|
|
"--backend",
|
|
@@ -518,7 +592,7 @@ if __name__ == "__main__":
|
|
|
"--dataset-name",
|
|
|
type=str,
|
|
|
default="sharegpt",
|
|
|
- choices=["sharegpt", "sonnet"],
|
|
|
+ choices=["sharegpt", "sonnet", "random"],
|
|
|
help="Name of the dataset to benchmark on.",
|
|
|
)
|
|
|
parser.add_argument("--dataset-path",
|
|
@@ -535,7 +609,7 @@ if __name__ == "__main__":
|
|
|
"--tokenizer",
|
|
|
type=str,
|
|
|
help=
|
|
|
- "Name or path of the tokenizer, if not using the default tokenizer.",
|
|
|
+ "Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--best-of",
|
|
@@ -578,6 +652,27 @@ if __name__ == "__main__":
|
|
|
help=
|
|
|
"Number of prefix tokens per request, used only for sonnet dataset.",
|
|
|
)
|
|
|
+ parser.add_argument(
|
|
|
+ "--random-input-len",
|
|
|
+ type=int,
|
|
|
+ default=1024,
|
|
|
+ help=
|
|
|
+ "Number of input tokens per request, used only for random sampling.",
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--random-output-len",
|
|
|
+ type=int,
|
|
|
+ default=128,
|
|
|
+ help=
|
|
|
+ "Number of output tokens per request, used only for random sampling.",
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--random-range-ratio",
|
|
|
+ type=float,
|
|
|
+ default=1.0,
|
|
|
+ help="Range of sampled ratio of input/output length, "
|
|
|
+ "used only for random sampling.",
|
|
|
+ )
|
|
|
parser.add_argument(
|
|
|
"--request-rate",
|
|
|
type=float,
|
|
@@ -618,6 +713,15 @@ if __name__ == "__main__":
|
|
|
help="Specify directory to save benchmark json results."
|
|
|
"If not specified, results are saved in the current directory.",
|
|
|
)
|
|
|
+ parser.add_argument(
|
|
|
+ "--result-filename",
|
|
|
+ type=str,
|
|
|
+ default=None,
|
|
|
+ help="Specify the filename to save benchmark json results."
|
|
|
+ "If not specified, results will be saved in "
|
|
|
+ "{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json"
|
|
|
+ " format.",
|
|
|
+ )
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
main(args)
|