serving.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623
  1. """Benchmark online serving throughput.
  2. On the server side, run one of the following commands:
  3. Aphrodite OpenAI API server
  4. python -m aphrodite.endpoints.openai.api_server \
  5. --model <your_model> --swap-space 16 \
  6. --disable-log-requests
  7. (TGI backend)
  8. ./launch_tgi_server.sh <your_model> <max_batch_total_tokens>
  9. On the client side, run:
  10. python tests/benchmarks/serving.py \
  11. --backend <backend> \
  12. --model <your_model> \
  13. --dataset-name sharegpt \
  14. --dataset-path <path to dataset> \
  15. --request-rate <request_rate> \ # By default <request_rate> is inf
  16. --num-prompts <num_prompts> # By default <num_prompts> is 1000
  17. when using tgi backend, add
  18. --endpoint /generate_stream
  19. to the end of the command above.
  20. """
  21. import argparse
  22. import asyncio
  23. import json
  24. import os
  25. import random
  26. import time
  27. import warnings
  28. from dataclasses import dataclass
  29. from datetime import datetime
  30. from typing import AsyncGenerator, List, Optional, Tuple
  31. import numpy as np
  32. from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
  33. RequestFuncOutput)
  34. from tqdm.asyncio import tqdm
  35. from transformers import PreTrainedTokenizerBase
  36. from aphrodite.transformers_utils.tokenizer import get_tokenizer
  37. @dataclass
  38. class BenchmarkMetrics:
  39. completed: int
  40. total_input: int
  41. total_output: int
  42. request_throughput: float
  43. input_throughput: float
  44. output_throughput: float
  45. mean_ttft_ms: float
  46. median_ttft_ms: float
  47. p99_ttft_ms: float
  48. mean_tpot_ms: float
  49. median_tpot_ms: float
  50. p99_tpot_ms: float
  51. def sample_sharegpt_requests(
  52. dataset_path: str,
  53. num_requests: int,
  54. tokenizer: PreTrainedTokenizerBase,
  55. fixed_output_len: Optional[int] = None,
  56. ) -> List[Tuple[str, int, int]]:
  57. if fixed_output_len is not None and fixed_output_len < 4:
  58. raise ValueError("output_len too small")
  59. # Load the dataset.
  60. with open(dataset_path) as f:
  61. dataset = json.load(f)
  62. # Filter out the conversations with less than 2 turns.
  63. dataset = [data for data in dataset if len(data["conversations"]) >= 2]
  64. # Only keep the first two turns of each conversation.
  65. dataset = [(data["conversations"][0]["value"],
  66. data["conversations"][1]["value"]) for data in dataset]
  67. # Shuffle the dataset.
  68. random.shuffle(dataset)
  69. # Filter out sequences that are too long or too short
  70. filtered_dataset: List[Tuple[str, int, int]] = []
  71. for i in range(len(dataset)):
  72. if len(filtered_dataset) == num_requests:
  73. break
  74. # Tokenize the prompts and completions.
  75. prompt = dataset[i][0]
  76. prompt_token_ids = tokenizer(prompt).input_ids
  77. completion = dataset[i][1]
  78. completion_token_ids = tokenizer(completion).input_ids
  79. prompt_len = len(prompt_token_ids)
  80. output_len = len(completion_token_ids
  81. ) if fixed_output_len is None else fixed_output_len
  82. if prompt_len < 4 or output_len < 4:
  83. # Prune too short sequences.
  84. continue
  85. if prompt_len > 1024 or prompt_len + output_len > 2048:
  86. # Prune too long sequences.
  87. continue
  88. filtered_dataset.append((prompt, prompt_len, output_len))
  89. return filtered_dataset
  90. def sample_sonnet_requests(
  91. dataset_path: str,
  92. num_requests: int,
  93. input_len: int,
  94. output_len: int,
  95. prefix_len: int,
  96. tokenizer: PreTrainedTokenizerBase,
  97. ) -> List[Tuple[str, str, int, int]]:
  98. assert (
  99. input_len > prefix_len
  100. ), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'."
  101. # Load the dataset.
  102. with open(dataset_path) as f:
  103. poem_lines = f.readlines()
  104. # Tokenize the poem lines.
  105. poem_token_ids = tokenizer(poem_lines).input_ids
  106. average_poem_len = sum(
  107. len(token_ids) for token_ids in poem_token_ids) / len(poem_token_ids)
  108. # Base prefix for all requests.
  109. base_prompt = "Pick as many lines as you can from these poem lines:\n"
  110. base_message = [{
  111. "role": "user",
  112. "content": base_prompt,
  113. }]
  114. base_prompt_formatted = tokenizer.apply_chat_template(
  115. base_message, add_generation_prompt=True, tokenize=False)
  116. base_prompt_offset = len(tokenizer(base_prompt_formatted).input_ids)
  117. assert (
  118. input_len > base_prompt_offset
  119. ), f"Please set 'args.sonnet-input-len' higher than {base_prompt_offset}."
  120. num_input_lines = round(
  121. (input_len - base_prompt_offset) / average_poem_len)
  122. # First approximately `prefix_len` number of tokens in the
  123. # prompt are fixed poem lines.
  124. assert (
  125. prefix_len > base_prompt_offset
  126. ), f"Please set 'args.sonnet-prefix-len' higher than {base_prompt_offset}."
  127. num_prefix_lines = round(
  128. (prefix_len - base_prompt_offset) / average_poem_len)
  129. prefix_lines = poem_lines[:num_prefix_lines]
  130. # Sample the rest of lines per request.
  131. sampled_requests: List[Tuple[str, int, int]] = []
  132. for _ in range(num_requests):
  133. sampled_lines = "".join(
  134. prefix_lines +
  135. random.sample(poem_lines, num_input_lines - num_prefix_lines))
  136. prompt = f"{base_prompt}{sampled_lines}"
  137. message = [
  138. {
  139. "role": "user",
  140. "content": prompt,
  141. },
  142. ]
  143. prompt_formatted = tokenizer.apply_chat_template(
  144. message, add_generation_prompt=True, tokenize=False)
  145. prompt_len = len(tokenizer(prompt_formatted).input_ids)
  146. sampled_requests.append(
  147. (prompt, prompt_formatted, prompt_len, output_len))
  148. return sampled_requests
  149. async def get_request(
  150. input_requests: List[Tuple[str, int, int]],
  151. request_rate: float,
  152. ) -> AsyncGenerator[Tuple[str, int, int], None]:
  153. input_requests = iter(input_requests)
  154. for request in input_requests:
  155. yield request
  156. if request_rate == float("inf"):
  157. # If the request rate is infinity, then we don't need to wait.
  158. continue
  159. # Sample the request interval from the exponential distribution.
  160. interval = np.random.exponential(1.0 / request_rate)
  161. # The next request will be sent after the interval.
  162. await asyncio.sleep(interval)
  163. def calculate_metrics(
  164. input_requests: List[Tuple[str, int, int]],
  165. outputs: List[RequestFuncOutput],
  166. dur_s: float,
  167. tokenizer: PreTrainedTokenizerBase,
  168. ) -> Tuple[BenchmarkMetrics, List[int]]:
  169. actual_output_lens = []
  170. total_input = 0
  171. completed = 0
  172. tpots = []
  173. ttfts = []
  174. for i in range(len(outputs)):
  175. if outputs[i].success:
  176. output_len = len(tokenizer(outputs[i].generated_text).input_ids)
  177. actual_output_lens.append(output_len)
  178. total_input += input_requests[i][1]
  179. if output_len > 1:
  180. tpots.append(
  181. (outputs[i].latency - outputs[i].ttft) / (output_len - 1))
  182. ttfts.append(outputs[i].ttft)
  183. completed += 1
  184. else:
  185. actual_output_lens.append(0)
  186. if completed == 0:
  187. warnings.warn(
  188. "All requests failed. This is likely due to a misconfiguration "
  189. "on the benchmark arguments.",
  190. stacklevel=2)
  191. metrics = BenchmarkMetrics(
  192. completed=completed,
  193. total_input=total_input,
  194. total_output=sum(actual_output_lens),
  195. request_throughput=completed / dur_s,
  196. input_throughput=total_input / dur_s,
  197. output_throughput=sum(actual_output_lens) / dur_s,
  198. mean_ttft_ms=np.mean(ttfts or 0) *
  199. 1000, # ttfts is empty if streaming is not supported by backend
  200. median_ttft_ms=np.median(ttfts or 0) * 1000,
  201. p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
  202. mean_tpot_ms=np.mean(tpots or 0) * 1000,
  203. median_tpot_ms=np.median(tpots or 0) * 1000,
  204. p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
  205. )
  206. return metrics, actual_output_lens
  207. async def benchmark(
  208. backend: str,
  209. api_url: str,
  210. model_id: str,
  211. tokenizer: PreTrainedTokenizerBase,
  212. input_requests: List[Tuple[str, int, int]],
  213. best_of: int,
  214. use_beam_search: bool,
  215. request_rate: float,
  216. disable_tqdm: bool,
  217. ):
  218. if backend in ASYNC_REQUEST_FUNCS:
  219. request_func = ASYNC_REQUEST_FUNCS.get(backend)
  220. else:
  221. raise ValueError(f"Unknown backend: {backend}")
  222. print("Starting initial single prompt test run...")
  223. test_prompt, test_prompt_len, test_output_len = input_requests[0]
  224. test_input = RequestFuncInput(
  225. model=model_id,
  226. prompt=test_prompt,
  227. api_url=api_url,
  228. prompt_len=test_prompt_len,
  229. output_len=test_output_len,
  230. best_of=best_of,
  231. use_beam_search=use_beam_search,
  232. )
  233. test_output = await request_func(request_func_input=test_input)
  234. if not test_output.success:
  235. raise ValueError(
  236. "Initial test run failed - Please make sure benchmark arguments "
  237. f"are correctly specified. Error: {test_output.error}")
  238. else:
  239. print("Initial test run completed. Starting main benchmark run...")
  240. print(f"Traffic request rate: {request_rate}")
  241. pbar = None if disable_tqdm else tqdm(total=len(input_requests))
  242. benchmark_start_time = time.perf_counter()
  243. tasks = []
  244. async for request in get_request(input_requests, request_rate):
  245. prompt, prompt_len, output_len = request
  246. request_func_input = RequestFuncInput(
  247. model=model_id,
  248. prompt=prompt,
  249. api_url=api_url,
  250. prompt_len=prompt_len,
  251. output_len=output_len,
  252. best_of=best_of,
  253. use_beam_search=use_beam_search,
  254. )
  255. tasks.append(
  256. asyncio.create_task(
  257. request_func(request_func_input=request_func_input,
  258. pbar=pbar)))
  259. outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
  260. if not disable_tqdm:
  261. pbar.close()
  262. benchmark_duration = time.perf_counter() - benchmark_start_time
  263. metrics, actual_output_lens = calculate_metrics(
  264. input_requests=input_requests,
  265. outputs=outputs,
  266. dur_s=benchmark_duration,
  267. tokenizer=tokenizer,
  268. )
  269. print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
  270. print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
  271. print("{:<40} {:<10.2f}".format("Benchmark duration (s):",
  272. benchmark_duration))
  273. print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
  274. print("{:<40} {:<10}".format("Total generated tokens:",
  275. metrics.total_output))
  276. print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
  277. metrics.request_throughput))
  278. print("{:<40} {:<10.2f}".format("Input token throughput (tok/s):",
  279. metrics.input_throughput))
  280. print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
  281. metrics.output_throughput))
  282. print("{s:{c}^{n}}".format(s='Time to First Token', n=50, c='-'))
  283. print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
  284. print("{:<40} {:<10.2f}".format("Median TTFT (ms):",
  285. metrics.median_ttft_ms))
  286. print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
  287. print("{s:{c}^{n}}".format(s='Time per Output Token (excl. 1st token)',
  288. n=50,
  289. c='-'))
  290. print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
  291. print("{:<40} {:<10.2f}".format("Median TPOT (ms):",
  292. metrics.median_tpot_ms))
  293. print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
  294. print("=" * 50)
  295. result = {
  296. "duration": benchmark_duration,
  297. "completed": metrics.completed,
  298. "total_input_tokens": metrics.total_input,
  299. "total_output_tokens": metrics.total_output,
  300. "request_throughput": metrics.request_throughput,
  301. "input_throughput": metrics.input_throughput,
  302. "output_throughput": metrics.output_throughput,
  303. "mean_ttft_ms": metrics.mean_ttft_ms,
  304. "median_ttft_ms": metrics.median_ttft_ms,
  305. "p99_ttft_ms": metrics.p99_ttft_ms,
  306. "mean_tpot_ms": metrics.mean_tpot_ms,
  307. "median_tpot_ms": metrics.median_tpot_ms,
  308. "p99_tpot_ms": metrics.p99_tpot_ms,
  309. "input_lens": [output.prompt_len for output in outputs],
  310. "output_lens": actual_output_lens,
  311. "ttfts": [output.ttft for output in outputs],
  312. "itls": [output.itl for output in outputs],
  313. "generated_texts": [output.generated_text for output in outputs],
  314. "errors": [output.error for output in outputs],
  315. }
  316. return result
  317. def main(args: argparse.Namespace):
  318. print(args)
  319. random.seed(args.seed)
  320. np.random.seed(args.seed)
  321. backend = args.backend
  322. model_id = args.model
  323. tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
  324. if args.base_url is not None:
  325. api_url = f"{args.base_url}{args.endpoint}"
  326. else:
  327. api_url = f"http://{args.host}:{args.port}{args.endpoint}"
  328. tokenizer = get_tokenizer(tokenizer_id,
  329. trust_remote_code=args.trust_remote_code)
  330. if args.dataset is not None:
  331. warnings.warn(
  332. "The '--dataset' argument will be deprecated in the next "
  333. "release. Please use '--dataset-name' and "
  334. "'--dataset-path' in the future runs.",
  335. stacklevel=2)
  336. input_requests = sample_sharegpt_requests(
  337. dataset_path=args.dataset,
  338. num_requests=args.num_prompts,
  339. tokenizer=tokenizer,
  340. fixed_output_len=args.sharegpt_output_len,
  341. )
  342. elif args.dataset_name == "sharegpt":
  343. input_requests = sample_sharegpt_requests(
  344. dataset_path=args.dataset_path,
  345. num_requests=args.num_prompts,
  346. tokenizer=tokenizer,
  347. fixed_output_len=args.sharegpt_output_len,
  348. )
  349. elif args.dataset_name == "sonnet":
  350. # Do not format the prompt, pass to message directly
  351. if args.backend == "openai-chat":
  352. input_requests = sample_sonnet_requests(
  353. dataset_path=args.dataset_path,
  354. num_requests=args.num_prompts,
  355. input_len=args.sonnet_input_len,
  356. output_len=args.sonnet_output_len,
  357. prefix_len=args.sonnet_prefix_len,
  358. tokenizer=tokenizer,
  359. )
  360. input_requests = [(prompt, prompt_len, output_len)
  361. for prompt, prompt_formatted, prompt_len,
  362. output_len in input_requests]
  363. else:
  364. assert (
  365. tokenizer.chat_template or tokenizer.default_chat_template
  366. ), "Tokenizer/model must have chat template for sonnet dataset."
  367. input_requests = sample_sonnet_requests(
  368. dataset_path=args.dataset_path,
  369. num_requests=args.num_prompts,
  370. input_len=args.sonnet_input_len,
  371. output_len=args.sonnet_output_len,
  372. prefix_len=args.sonnet_prefix_len,
  373. tokenizer=tokenizer,
  374. )
  375. input_requests = [(prompt_formatted, prompt_len, output_len)
  376. for prompt, prompt_formatted, prompt_len,
  377. output_len in input_requests]
  378. else:
  379. raise ValueError(f"Unknown dataset: {args.dataset_name}")
  380. benchmark_result = asyncio.run(
  381. benchmark(
  382. backend=backend,
  383. api_url=api_url,
  384. model_id=model_id,
  385. tokenizer=tokenizer,
  386. input_requests=input_requests,
  387. best_of=args.best_of,
  388. use_beam_search=args.use_beam_search,
  389. request_rate=args.request_rate,
  390. disable_tqdm=args.disable_tqdm,
  391. ))
  392. # Save config and results to json
  393. if args.save_result:
  394. result_json = {}
  395. # Setup
  396. current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
  397. result_json["date"] = current_dt
  398. result_json["backend"] = backend
  399. result_json["model_id"] = model_id
  400. result_json["tokenizer_id"] = tokenizer_id
  401. result_json["best_of"] = args.best_of
  402. result_json["use_beam_search"] = args.use_beam_search
  403. result_json["num_prompts"] = args.num_prompts
  404. # Metadata
  405. if args.metadata:
  406. for item in args.metadata:
  407. if "=" in item:
  408. kvstring = item.split("=")
  409. result_json[kvstring[0].strip()] = kvstring[1].strip()
  410. else:
  411. raise ValueError(
  412. "Invalid metadata format. Please use KEY=VALUE format."
  413. )
  414. # Traffic
  415. result_json["request_rate"] = (
  416. args.request_rate if args.request_rate < float("inf") else "inf")
  417. # Merge with benchmark result
  418. result_json = {**result_json, **benchmark_result}
  419. # Save to file
  420. base_model_id = model_id.split("/")[-1]
  421. file_name = f"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" #noqa
  422. if args.result_dir:
  423. file_name = os.path.join(args.result_dir, file_name)
  424. with open(file_name, "w") as outfile:
  425. json.dump(result_json, outfile)
  426. if __name__ == "__main__":
  427. parser = argparse.ArgumentParser(
  428. description="Benchmark the online serving throughput.")
  429. parser.add_argument(
  430. "--backend",
  431. type=str,
  432. default="aphrodite",
  433. choices=list(ASYNC_REQUEST_FUNCS.keys()),
  434. )
  435. parser.add_argument(
  436. "--base-url",
  437. type=str,
  438. default=None,
  439. help="Server or API base url if not using http host and port.",
  440. )
  441. parser.add_argument("--host", type=str, default="localhost")
  442. parser.add_argument("--port", type=int, default=8000)
  443. parser.add_argument(
  444. "--endpoint",
  445. type=str,
  446. default="/v1/completions",
  447. help="API endpoint.",
  448. )
  449. parser.add_argument(
  450. "--dataset",
  451. type=str,
  452. default=None,
  453. help="Path to the ShareGPT dataset, will be deprecated in the "
  454. "next release.",
  455. )
  456. parser.add_argument(
  457. "--dataset-name",
  458. type=str,
  459. default="sharegpt",
  460. choices=["sharegpt", "sonnet"],
  461. help="Name of the dataset to benchmark on.",
  462. )
  463. parser.add_argument("--dataset-path",
  464. type=str,
  465. default=None,
  466. help="Path to the dataset.")
  467. parser.add_argument(
  468. "--model",
  469. type=str,
  470. required=True,
  471. help="Name of the model.",
  472. )
  473. parser.add_argument(
  474. "--tokenizer",
  475. type=str,
  476. help=
  477. "Name or path of the tokenizer, if not using the default tokenizer.",
  478. )
  479. parser.add_argument(
  480. "--best-of",
  481. type=int,
  482. default=1,
  483. help="Generates `best_of` sequences per prompt and "
  484. "returns the best one.",
  485. )
  486. parser.add_argument("--use-beam-search", action="store_true")
  487. parser.add_argument(
  488. "--num-prompts",
  489. type=int,
  490. default=1000,
  491. help="Number of prompts to process.",
  492. )
  493. parser.add_argument(
  494. "--sharegpt-output-len",
  495. type=int,
  496. default=None,
  497. help="Output length for each request. Overrides the output length "
  498. "from the ShareGPT dataset.")
  499. parser.add_argument(
  500. "--sonnet-input-len",
  501. type=int,
  502. default=550,
  503. help=
  504. "Number of input tokens per request, used only for sonnet dataset.",
  505. )
  506. parser.add_argument(
  507. "--sonnet-output-len",
  508. type=int,
  509. default=150,
  510. help=
  511. "Number of output tokens per request, used only for sonnet dataset.",
  512. )
  513. parser.add_argument(
  514. "--sonnet-prefix-len",
  515. type=int,
  516. default=200,
  517. help=
  518. "Number of prefix tokens per request, used only for sonnet dataset.",
  519. )
  520. parser.add_argument(
  521. "--request-rate",
  522. type=float,
  523. default=float("inf"),
  524. help="Number of requests per second. If this is inf, "
  525. "then all the requests are sent at time 0. "
  526. "Otherwise, we use Poisson process to synthesize "
  527. "the request arrival times.",
  528. )
  529. parser.add_argument("--seed", type=int, default=0)
  530. parser.add_argument(
  531. "--trust-remote-code",
  532. action="store_true",
  533. help="Trust remote code from huggingface",
  534. )
  535. parser.add_argument(
  536. "--disable-tqdm",
  537. action="store_true",
  538. help="Specify to disable tqdm progress bar.",
  539. )
  540. parser.add_argument(
  541. "--save-result",
  542. action="store_true",
  543. help="Specify to save benchmark results to a json file",
  544. )
  545. parser.add_argument(
  546. "--metadata",
  547. metavar="KEY=VALUE",
  548. nargs="*",
  549. help="Key-value pairs (e.g, --metadata version=0.3.3 tp=1) "
  550. "for metadata of this run to be saved in the result JSON file "
  551. "for record keeping purposes.",
  552. )
  553. parser.add_argument(
  554. "--result-dir",
  555. type=str,
  556. default=None,
  557. help="Specify directory to save benchmark json results."
  558. "If not specified, results are saved in the current directory.",
  559. )
  560. args = parser.parse_args()
  561. main(args)