serving.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727
  1. """Benchmark online serving throughput.
  2. On the server side, run one of the following commands:
  3. Aphrodite OpenAI API server
  4. aphrodite run <your_model> \
  5. --swap-space 16 \
  6. --disable-log-requests
  7. (TGI backend)
  8. ./launch_tgi_server.sh <your_model> <max_batch_total_tokens>
  9. On the client side, run:
  10. python tests/benchmarks/serving.py \
  11. --backend <backend> \
  12. --model <your_model> \
  13. --dataset-name sharegpt \
  14. --dataset-path <path to dataset> \
  15. --request-rate <request_rate> \ # By default <request_rate> is inf
  16. --num-prompts <num_prompts> # By default <num_prompts> is 1000
  17. when using tgi backend, add
  18. --endpoint /generate_stream
  19. to the end of the command above.
  20. """
  21. import argparse
  22. import asyncio
  23. import json
  24. import os
  25. import random
  26. import time
  27. import warnings
  28. from dataclasses import dataclass
  29. from datetime import datetime
  30. from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
  31. import numpy as np
  32. from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
  33. RequestFuncOutput)
  34. from tqdm.asyncio import tqdm
  35. from transformers import PreTrainedTokenizerBase
  36. try:
  37. from aphrodite.transformers_utils.tokenizer import get_tokenizer
  38. except ImportError:
  39. from backend_request_func import get_tokenizer
  40. try:
  41. from aphrodite.common.utils import FlexibleArgumentParser
  42. except ImportError:
  43. from argparse import ArgumentParser as FlexibleArgumentParser
  44. @dataclass
  45. class BenchmarkMetrics:
  46. completed: int
  47. total_input: int
  48. total_output: int
  49. request_throughput: float
  50. input_throughput: float
  51. output_throughput: float
  52. mean_ttft_ms: float
  53. median_ttft_ms: float
  54. std_ttft_ms: float
  55. p99_ttft_ms: float
  56. mean_tpot_ms: float
  57. median_tpot_ms: float
  58. std_tpot_ms: float
  59. p99_tpot_ms: float
  60. mean_itl_ms: float
  61. median_itl_ms: float
  62. std_itl_ms: float
  63. p99_itl_ms: float
  64. def sample_sharegpt_requests(
  65. dataset_path: str,
  66. num_requests: int,
  67. tokenizer: PreTrainedTokenizerBase,
  68. fixed_output_len: Optional[int] = None,
  69. ) -> List[Tuple[str, int, int]]:
  70. if fixed_output_len is not None and fixed_output_len < 4:
  71. raise ValueError("output_len too small")
  72. # Load the dataset.
  73. with open(dataset_path) as f:
  74. dataset = json.load(f)
  75. # Filter out the conversations with less than 2 turns.
  76. dataset = [data for data in dataset if len(data["conversations"]) >= 2]
  77. # Only keep the first two turns of each conversation.
  78. dataset = [(data["conversations"][0]["value"],
  79. data["conversations"][1]["value"]) for data in dataset]
  80. # Shuffle the dataset.
  81. random.shuffle(dataset)
  82. # Filter out sequences that are too long or too short
  83. filtered_dataset: List[Tuple[str, int, int]] = []
  84. for i in range(len(dataset)):
  85. if len(filtered_dataset) == num_requests:
  86. break
  87. # Tokenize the prompts and completions.
  88. prompt = dataset[i][0]
  89. prompt_token_ids = tokenizer(prompt).input_ids
  90. completion = dataset[i][1]
  91. completion_token_ids = tokenizer(completion).input_ids
  92. prompt_len = len(prompt_token_ids)
  93. output_len = len(completion_token_ids
  94. ) if fixed_output_len is None else fixed_output_len
  95. if prompt_len < 4 or output_len < 4:
  96. # Prune too short sequences.
  97. continue
  98. if prompt_len > 1024 or prompt_len + output_len > 2048:
  99. # Prune too long sequences.
  100. continue
  101. filtered_dataset.append((prompt, prompt_len, output_len))
  102. return filtered_dataset
  103. def sample_sonnet_requests(
  104. dataset_path: str,
  105. num_requests: int,
  106. input_len: int,
  107. output_len: int,
  108. prefix_len: int,
  109. tokenizer: PreTrainedTokenizerBase,
  110. ) -> List[Tuple[str, str, int, int]]:
  111. assert (
  112. input_len > prefix_len
  113. ), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'."
  114. # Load the dataset.
  115. with open(dataset_path) as f:
  116. poem_lines = f.readlines()
  117. # Tokenize the poem lines.
  118. poem_token_ids = tokenizer(poem_lines).input_ids
  119. average_poem_len = sum(
  120. len(token_ids) for token_ids in poem_token_ids) / len(poem_token_ids)
  121. # Base prefix for all requests.
  122. base_prompt = "Pick as many lines as you can from these poem lines:\n"
  123. base_message = [{
  124. "role": "user",
  125. "content": base_prompt,
  126. }]
  127. base_prompt_formatted = tokenizer.apply_chat_template(
  128. base_message, add_generation_prompt=True, tokenize=False)
  129. base_prompt_offset = len(tokenizer(base_prompt_formatted).input_ids)
  130. assert (
  131. input_len > base_prompt_offset
  132. ), f"Please set 'args.sonnet-input-len' higher than {base_prompt_offset}."
  133. num_input_lines = round(
  134. (input_len - base_prompt_offset) / average_poem_len)
  135. # First approximately `prefix_len` number of tokens in the
  136. # prompt are fixed poem lines.
  137. assert (
  138. prefix_len > base_prompt_offset
  139. ), f"Please set 'args.sonnet-prefix-len' higher than {base_prompt_offset}."
  140. num_prefix_lines = round(
  141. (prefix_len - base_prompt_offset) / average_poem_len)
  142. prefix_lines = poem_lines[:num_prefix_lines]
  143. # Sample the rest of lines per request.
  144. sampled_requests: List[Tuple[str, int, int]] = []
  145. for _ in range(num_requests):
  146. sampled_lines = "".join(
  147. prefix_lines +
  148. random.sample(poem_lines, num_input_lines - num_prefix_lines))
  149. prompt = f"{base_prompt}{sampled_lines}"
  150. message = [
  151. {
  152. "role": "user",
  153. "content": prompt,
  154. },
  155. ]
  156. prompt_formatted = tokenizer.apply_chat_template(
  157. message, add_generation_prompt=True, tokenize=False)
  158. prompt_len = len(tokenizer(prompt_formatted).input_ids)
  159. sampled_requests.append(
  160. (prompt, prompt_formatted, prompt_len, output_len))
  161. return sampled_requests
  162. def sample_random_requests(
  163. input_len: int, output_len: int, num_prompts: int, range_ratio: float,
  164. tokenizer: PreTrainedTokenizerBase) -> List[Tuple[str, int, int]]:
  165. input_lens = np.random.randint(
  166. int(input_len * range_ratio),
  167. input_len + 1,
  168. size=num_prompts,
  169. )
  170. output_lens = np.random.randint(
  171. int(output_len * range_ratio),
  172. output_len + 1,
  173. size=num_prompts,
  174. )
  175. offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
  176. input_requests = []
  177. for i in range(num_prompts):
  178. prompt = tokenizer.decode([(offsets[i] + i + j) % tokenizer.vocab_size
  179. for j in range(input_lens[i])])
  180. input_requests.append(
  181. (prompt, int(input_lens[i]), int(output_lens[i])))
  182. return input_requests
  183. async def get_request(
  184. input_requests: List[Tuple[str, int, int]],
  185. request_rate: float,
  186. ) -> AsyncGenerator[Tuple[str, int, int], None]:
  187. input_requests = iter(input_requests)
  188. for request in input_requests:
  189. yield request
  190. if request_rate == float("inf"):
  191. # If the request rate is infinity, then we don't need to wait.
  192. continue
  193. # Sample the request interval from the exponential distribution.
  194. interval = np.random.exponential(1.0 / request_rate)
  195. # The next request will be sent after the interval.
  196. await asyncio.sleep(interval)
  197. def calculate_metrics(
  198. input_requests: List[Tuple[str, int, int]],
  199. outputs: List[RequestFuncOutput],
  200. dur_s: float,
  201. tokenizer: PreTrainedTokenizerBase,
  202. ) -> Tuple[BenchmarkMetrics, List[int]]:
  203. actual_output_lens: List[int] = []
  204. total_input = 0
  205. completed = 0
  206. itls: List[float] = []
  207. tpots: List[float] = []
  208. ttfts: List[float] = []
  209. for i in range(len(outputs)):
  210. if outputs[i].success:
  211. # We use the tokenizer to count the number of output tokens for all
  212. # serving backends instead of looking at len(outputs[i].itl) since
  213. # multiple output tokens may be bundled together
  214. # Note : this may inflate the output token count slightly
  215. output_len = len(
  216. tokenizer(outputs[i].generated_text,
  217. add_special_tokens=False).input_ids)
  218. actual_output_lens.append(output_len)
  219. total_input += input_requests[i][1]
  220. if output_len > 1:
  221. tpots.append(
  222. (outputs[i].latency - outputs[i].ttft) / (output_len - 1))
  223. itls += outputs[i].itl
  224. ttfts.append(outputs[i].ttft)
  225. completed += 1
  226. else:
  227. actual_output_lens.append(0)
  228. if completed == 0:
  229. warnings.warn(
  230. "All requests failed. This is likely due to a misconfiguration "
  231. "on the benchmark arguments.",
  232. stacklevel=2)
  233. metrics = BenchmarkMetrics(
  234. completed=completed,
  235. total_input=total_input,
  236. total_output=sum(actual_output_lens),
  237. request_throughput=completed / dur_s,
  238. input_throughput=total_input / dur_s,
  239. output_throughput=sum(actual_output_lens) / dur_s,
  240. mean_ttft_ms=np.mean(ttfts or 0) *
  241. 1000, # ttfts is empty if streaming is not supported by backend
  242. median_ttft_ms=np.median(ttfts or 0) * 1000,
  243. std_ttft_ms=np.std(ttfts or 0) * 1000,
  244. p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
  245. mean_tpot_ms=np.mean(tpots or 0) * 1000,
  246. median_tpot_ms=np.median(tpots or 0) * 1000,
  247. std_tpot_ms=np.std(tpots or 0) * 1000,
  248. p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
  249. mean_itl_ms=np.mean(itls or 0) * 1000,
  250. median_itl_ms=np.median(itls or 0) * 1000,
  251. std_itl_ms=np.std(itls or 0) * 1000,
  252. p99_itl_ms=np.percentile(itls or 0, 99) * 1000,
  253. )
  254. return metrics, actual_output_lens
  255. async def benchmark(
  256. backend: str,
  257. api_url: str,
  258. model_id: str,
  259. tokenizer: PreTrainedTokenizerBase,
  260. input_requests: List[Tuple[str, int, int]],
  261. best_of: int,
  262. use_beam_search: bool,
  263. request_rate: float,
  264. disable_tqdm: bool,
  265. ):
  266. if backend in ASYNC_REQUEST_FUNCS:
  267. request_func = ASYNC_REQUEST_FUNCS[backend]
  268. else:
  269. raise ValueError(f"Unknown backend: {backend}")
  270. print("Starting initial single prompt test run...")
  271. test_prompt, test_prompt_len, test_output_len = input_requests[0]
  272. test_input = RequestFuncInput(
  273. model=model_id,
  274. prompt=test_prompt,
  275. api_url=api_url,
  276. prompt_len=test_prompt_len,
  277. output_len=test_output_len,
  278. best_of=best_of,
  279. use_beam_search=use_beam_search,
  280. )
  281. test_output = await request_func(request_func_input=test_input)
  282. if not test_output.success:
  283. raise ValueError(
  284. "Initial test run failed - Please make sure benchmark arguments "
  285. f"are correctly specified. Error: {test_output.error}")
  286. else:
  287. print("Initial test run completed. Starting main benchmark run...")
  288. print(f"Traffic request rate: {request_rate}")
  289. pbar = None if disable_tqdm else tqdm(total=len(input_requests))
  290. benchmark_start_time = time.perf_counter()
  291. tasks: List[asyncio.Task] = []
  292. async for request in get_request(input_requests, request_rate):
  293. prompt, prompt_len, output_len = request
  294. request_func_input = RequestFuncInput(
  295. model=model_id,
  296. prompt=prompt,
  297. api_url=api_url,
  298. prompt_len=prompt_len,
  299. output_len=output_len,
  300. best_of=best_of,
  301. use_beam_search=use_beam_search,
  302. )
  303. tasks.append(
  304. asyncio.create_task(
  305. request_func(request_func_input=request_func_input,
  306. pbar=pbar)))
  307. outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
  308. if pbar is not None:
  309. pbar.close()
  310. benchmark_duration = time.perf_counter() - benchmark_start_time
  311. metrics, actual_output_lens = calculate_metrics(
  312. input_requests=input_requests,
  313. outputs=outputs,
  314. dur_s=benchmark_duration,
  315. tokenizer=tokenizer,
  316. )
  317. print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
  318. print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
  319. print("{:<40} {:<10.2f}".format("Benchmark duration (s):",
  320. benchmark_duration))
  321. print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
  322. print("{:<40} {:<10}".format("Total generated tokens:",
  323. metrics.total_output))
  324. print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
  325. metrics.request_throughput))
  326. print("{:<40} {:<10.2f}".format("Input token throughput (tok/s):",
  327. metrics.input_throughput))
  328. print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
  329. metrics.output_throughput))
  330. print("{s:{c}^{n}}".format(s='Time to First Token', n=50, c='-'))
  331. print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
  332. print("{:<40} {:<10.2f}".format("Median TTFT (ms):",
  333. metrics.median_ttft_ms))
  334. print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
  335. print("{s:{c}^{n}}".format(s='Time per Output Token (excl. 1st token)',
  336. n=50,
  337. c='-'))
  338. print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
  339. print("{:<40} {:<10.2f}".format("Median TPOT (ms):",
  340. metrics.median_tpot_ms))
  341. print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
  342. print("{s:{c}^{n}}".format(s='Inter-token Latency', n=50, c='-'))
  343. print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
  344. print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
  345. print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
  346. print("=" * 50)
  347. result = {
  348. "duration": benchmark_duration,
  349. "completed": metrics.completed,
  350. "total_input_tokens": metrics.total_input,
  351. "total_output_tokens": metrics.total_output,
  352. "request_throughput": metrics.request_throughput,
  353. "input_throughput": metrics.input_throughput,
  354. "output_throughput": metrics.output_throughput,
  355. "mean_ttft_ms": metrics.mean_ttft_ms,
  356. "median_ttft_ms": metrics.median_ttft_ms,
  357. "std_ttft_ms": metrics.std_ttft_ms,
  358. "p99_ttft_ms": metrics.p99_ttft_ms,
  359. "mean_tpot_ms": metrics.mean_tpot_ms,
  360. "median_tpot_ms": metrics.median_tpot_ms,
  361. "std_tpot_ms": metrics.std_tpot_ms,
  362. "p99_tpot_ms": metrics.p99_tpot_ms,
  363. "mean_itl_ms": metrics.mean_itl_ms,
  364. "median_itl_ms": metrics.median_itl_ms,
  365. "std_itl_ms": metrics.std_itl_ms,
  366. "p99_itl_ms": metrics.p99_itl_ms,
  367. "input_lens": [output.prompt_len for output in outputs],
  368. "output_lens": actual_output_lens,
  369. "ttfts": [output.ttft for output in outputs],
  370. "itls": [output.itl for output in outputs],
  371. "generated_texts": [output.generated_text for output in outputs],
  372. "errors": [output.error for output in outputs],
  373. }
  374. return result
  375. def main(args: argparse.Namespace):
  376. print(args)
  377. random.seed(args.seed)
  378. np.random.seed(args.seed)
  379. backend = args.backend
  380. model_id = args.model
  381. tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
  382. if args.base_url is not None:
  383. api_url = f"{args.base_url}{args.endpoint}"
  384. else:
  385. api_url = f"http://{args.host}:{args.port}{args.endpoint}"
  386. tokenizer = get_tokenizer(tokenizer_id,
  387. trust_remote_code=args.trust_remote_code)
  388. if args.dataset is not None:
  389. warnings.warn(
  390. "The '--dataset' argument will be deprecated in the next "
  391. "release. Please use '--dataset-name' and "
  392. "'--dataset-path' in the future runs.",
  393. stacklevel=2)
  394. input_requests = sample_sharegpt_requests(
  395. dataset_path=args.dataset,
  396. num_requests=args.num_prompts,
  397. tokenizer=tokenizer,
  398. fixed_output_len=args.sharegpt_output_len,
  399. )
  400. elif args.dataset_name == "sharegpt":
  401. input_requests = sample_sharegpt_requests(
  402. dataset_path=args.dataset_path,
  403. num_requests=args.num_prompts,
  404. tokenizer=tokenizer,
  405. fixed_output_len=args.sharegpt_output_len,
  406. )
  407. elif args.dataset_name == "sonnet":
  408. # Do not format the prompt, pass to message directly
  409. if args.backend == "openai-chat":
  410. input_requests = sample_sonnet_requests(
  411. dataset_path=args.dataset_path,
  412. num_requests=args.num_prompts,
  413. input_len=args.sonnet_input_len,
  414. output_len=args.sonnet_output_len,
  415. prefix_len=args.sonnet_prefix_len,
  416. tokenizer=tokenizer,
  417. )
  418. input_requests = [(prompt, prompt_len, output_len)
  419. for prompt, prompt_formatted, prompt_len,
  420. output_len in input_requests]
  421. else:
  422. assert (
  423. tokenizer.chat_template or tokenizer.default_chat_template
  424. ), "Tokenizer/model must have chat template for sonnet dataset."
  425. input_requests = sample_sonnet_requests(
  426. dataset_path=args.dataset_path,
  427. num_requests=args.num_prompts,
  428. input_len=args.sonnet_input_len,
  429. output_len=args.sonnet_output_len,
  430. prefix_len=args.sonnet_prefix_len,
  431. tokenizer=tokenizer,
  432. )
  433. input_requests = [(prompt_formatted, prompt_len, output_len)
  434. for prompt, prompt_formatted, prompt_len,
  435. output_len in input_requests]
  436. elif args.dataset_name == "random":
  437. input_requests = sample_random_requests(
  438. input_len=args.random_input_len,
  439. output_len=args.random_output_len,
  440. num_prompts=args.num_prompts,
  441. range_ratio=args.random_range_ratio,
  442. tokenizer=tokenizer,
  443. )
  444. else:
  445. raise ValueError(f"Unknown dataset: {args.dataset_name}")
  446. benchmark_result = asyncio.run(
  447. benchmark(
  448. backend=backend,
  449. api_url=api_url,
  450. model_id=model_id,
  451. tokenizer=tokenizer,
  452. input_requests=input_requests,
  453. best_of=args.best_of,
  454. use_beam_search=args.use_beam_search,
  455. request_rate=args.request_rate,
  456. disable_tqdm=args.disable_tqdm,
  457. ))
  458. # Save config and results to json
  459. if args.save_result:
  460. result_json: Dict[str, Any] = {}
  461. # Setup
  462. current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
  463. result_json["date"] = current_dt
  464. result_json["backend"] = backend
  465. result_json["model_id"] = model_id
  466. result_json["tokenizer_id"] = tokenizer_id
  467. result_json["best_of"] = args.best_of
  468. result_json["use_beam_search"] = args.use_beam_search
  469. result_json["num_prompts"] = args.num_prompts
  470. # Metadata
  471. if args.metadata:
  472. for item in args.metadata:
  473. if "=" in item:
  474. kvstring = item.split("=")
  475. result_json[kvstring[0].strip()] = kvstring[1].strip()
  476. else:
  477. raise ValueError(
  478. "Invalid metadata format. Please use KEY=VALUE format."
  479. )
  480. # Traffic
  481. result_json["request_rate"] = (
  482. args.request_rate if args.request_rate < float("inf") else "inf")
  483. # Merge with benchmark result
  484. result_json = {**result_json, **benchmark_result}
  485. # Save to file
  486. base_model_id = model_id.split("/")[-1]
  487. file_name = f"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" #noqa
  488. if args.result_filename:
  489. file_name = args.result_filename
  490. if args.result_dir:
  491. file_name = os.path.join(args.result_dir, file_name)
  492. with open(file_name, "w") as outfile:
  493. json.dump(result_json, outfile)
  494. if __name__ == "__main__":
  495. parser = FlexibleArgumentParser(
  496. description="Benchmark the online serving throughput.")
  497. parser.add_argument(
  498. "--backend",
  499. type=str,
  500. default="aphrodite",
  501. choices=list(ASYNC_REQUEST_FUNCS.keys()),
  502. )
  503. parser.add_argument(
  504. "--base-url",
  505. type=str,
  506. default=None,
  507. help="Server or API base url if not using http host and port.",
  508. )
  509. parser.add_argument("--host", type=str, default="localhost")
  510. parser.add_argument("--port", type=int, default=8000)
  511. parser.add_argument(
  512. "--endpoint",
  513. type=str,
  514. default="/v1/completions",
  515. help="API endpoint.",
  516. )
  517. parser.add_argument(
  518. "--dataset",
  519. type=str,
  520. default=None,
  521. help="Path to the ShareGPT dataset, will be deprecated in the "
  522. "next release.",
  523. )
  524. parser.add_argument(
  525. "--dataset-name",
  526. type=str,
  527. default="sharegpt",
  528. choices=["sharegpt", "sonnet", "random"],
  529. help="Name of the dataset to benchmark on.",
  530. )
  531. parser.add_argument("--dataset-path",
  532. type=str,
  533. default=None,
  534. help="Path to the dataset.")
  535. parser.add_argument(
  536. "--model",
  537. type=str,
  538. required=True,
  539. help="Name of the model.",
  540. )
  541. parser.add_argument(
  542. "--tokenizer",
  543. type=str,
  544. help=
  545. "Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
  546. )
  547. parser.add_argument(
  548. "--best-of",
  549. type=int,
  550. default=1,
  551. help="Generates `best_of` sequences per prompt and "
  552. "returns the best one.",
  553. )
  554. parser.add_argument("--use-beam-search", action="store_true")
  555. parser.add_argument(
  556. "--num-prompts",
  557. type=int,
  558. default=1000,
  559. help="Number of prompts to process.",
  560. )
  561. parser.add_argument(
  562. "--sharegpt-output-len",
  563. type=int,
  564. default=None,
  565. help="Output length for each request. Overrides the output length "
  566. "from the ShareGPT dataset.")
  567. parser.add_argument(
  568. "--sonnet-input-len",
  569. type=int,
  570. default=550,
  571. help=
  572. "Number of input tokens per request, used only for sonnet dataset.",
  573. )
  574. parser.add_argument(
  575. "--sonnet-output-len",
  576. type=int,
  577. default=150,
  578. help=
  579. "Number of output tokens per request, used only for sonnet dataset.",
  580. )
  581. parser.add_argument(
  582. "--sonnet-prefix-len",
  583. type=int,
  584. default=200,
  585. help=
  586. "Number of prefix tokens per request, used only for sonnet dataset.",
  587. )
  588. parser.add_argument(
  589. "--random-input-len",
  590. type=int,
  591. default=1024,
  592. help=
  593. "Number of input tokens per request, used only for random sampling.",
  594. )
  595. parser.add_argument(
  596. "--random-output-len",
  597. type=int,
  598. default=128,
  599. help=
  600. "Number of output tokens per request, used only for random sampling.",
  601. )
  602. parser.add_argument(
  603. "--random-range-ratio",
  604. type=float,
  605. default=1.0,
  606. help="Range of sampled ratio of input/output length, "
  607. "used only for random sampling.",
  608. )
  609. parser.add_argument(
  610. "--request-rate",
  611. type=float,
  612. default=float("inf"),
  613. help="Number of requests per second. If this is inf, "
  614. "then all the requests are sent at time 0. "
  615. "Otherwise, we use Poisson process to synthesize "
  616. "the request arrival times.",
  617. )
  618. parser.add_argument("--seed", type=int, default=0)
  619. parser.add_argument(
  620. "--trust-remote-code",
  621. action="store_true",
  622. help="Trust remote code from huggingface",
  623. )
  624. parser.add_argument(
  625. "--disable-tqdm",
  626. action="store_true",
  627. help="Specify to disable tqdm progress bar.",
  628. )
  629. parser.add_argument(
  630. "--save-result",
  631. action="store_true",
  632. help="Specify to save benchmark results to a json file",
  633. )
  634. parser.add_argument(
  635. "--metadata",
  636. metavar="KEY=VALUE",
  637. nargs="*",
  638. help="Key-value pairs (e.g, --metadata version=0.3.3 tp=1) "
  639. "for metadata of this run to be saved in the result JSON file "
  640. "for record keeping purposes.",
  641. )
  642. parser.add_argument(
  643. "--result-dir",
  644. type=str,
  645. default=None,
  646. help="Specify directory to save benchmark json results."
  647. "If not specified, results are saved in the current directory.",
  648. )
  649. parser.add_argument(
  650. "--result-filename",
  651. type=str,
  652. default=None,
  653. help="Specify the filename to save benchmark json results."
  654. "If not specified, results will be saved in "
  655. "{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json"
  656. " format.",
  657. )
  658. args = parser.parse_args()
  659. main(args)