throughput.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566
  1. """Benchmark offline inference throughput."""
  2. import argparse
  3. import json
  4. import random
  5. import time
  6. from typing import List, Optional, Tuple
  7. import torch
  8. import uvloop
  9. from tqdm import tqdm
  10. from transformers import (AutoModelForCausalLM, AutoTokenizer,
  11. PreTrainedTokenizerBase)
  12. from aphrodite.common.utils import (FlexibleArgumentParser,
  13. merge_async_iterators)
  14. from aphrodite.endpoints.openai.api_server import (
  15. build_engine_client_from_engine_args)
  16. from aphrodite.engine.args_tools import (DEVICE_OPTIONS, AsyncEngineArgs,
  17. EngineArgs)
  18. from aphrodite.quantization import QUANTIZATION_METHODS
  19. def sample_requests(
  20. dataset_path: str,
  21. num_requests: int,
  22. tokenizer: PreTrainedTokenizerBase,
  23. fixed_output_len: Optional[int],
  24. ) -> List[Tuple[str, int, int]]:
  25. if fixed_output_len is not None and fixed_output_len < 4:
  26. raise ValueError("output_len too small")
  27. # Load the dataset.
  28. with open(dataset_path) as f:
  29. dataset = json.load(f)
  30. # Filter out the conversations with less than 2 turns.
  31. dataset = [data for data in dataset if len(data["conversations"]) >= 2]
  32. # Only keep the first two turns of each conversation.
  33. dataset = [(data["conversations"][0]["value"],
  34. data["conversations"][1]["value"]) for data in dataset]
  35. # Shuffle the dataset.
  36. random.shuffle(dataset)
  37. # Filter out sequences that are too long or too short
  38. filtered_dataset: List[Tuple[str, int, int]] = []
  39. for i in range(len(dataset)):
  40. if len(filtered_dataset) == num_requests:
  41. break
  42. # Tokenize the prompts and completions.
  43. prompt = dataset[i][0]
  44. prompt_token_ids = tokenizer(prompt).input_ids
  45. completion = dataset[i][1]
  46. completion_token_ids = tokenizer(completion).input_ids
  47. prompt_len = len(prompt_token_ids)
  48. output_len = len(completion_token_ids
  49. ) if fixed_output_len is None else fixed_output_len
  50. if prompt_len < 4 or output_len < 4:
  51. # Prune too short sequences.
  52. continue
  53. if prompt_len > 1024 or prompt_len + output_len > 2048:
  54. # Prune too long sequences.
  55. continue
  56. filtered_dataset.append((prompt, prompt_len, output_len))
  57. return filtered_dataset
  58. def run_aphrodite(
  59. requests: List[Tuple[str, int, int]],
  60. model: str,
  61. tokenizer: str,
  62. quantization: Optional[str],
  63. tensor_parallel_size: int,
  64. seed: int,
  65. n: int,
  66. use_beam_search: bool,
  67. trust_remote_code: bool,
  68. dtype: str,
  69. max_model_len: Optional[int],
  70. enforce_eager: bool,
  71. kv_cache_dtype: str,
  72. quantization_param_path: Optional[str],
  73. device: str,
  74. enable_prefix_caching: bool,
  75. enable_chunked_prefill: bool,
  76. max_num_batched_tokens: int,
  77. distributed_executor_backend: Optional[str],
  78. gpu_memory_utilization: float = 0.9,
  79. num_scheduler_steps: int = 1,
  80. use_v2_block_manager: bool = False,
  81. download_dir: Optional[str] = None,
  82. load_format: str = EngineArgs.load_format,
  83. disable_async_output_proc: bool = False,
  84. ) -> float:
  85. from aphrodite import LLM, SamplingParams
  86. llm = LLM(
  87. model=model,
  88. tokenizer=tokenizer,
  89. quantization=quantization,
  90. tensor_parallel_size=tensor_parallel_size,
  91. seed=seed,
  92. trust_remote_code=trust_remote_code,
  93. dtype=dtype,
  94. max_model_len=max_model_len,
  95. gpu_memory_utilization=gpu_memory_utilization,
  96. enforce_eager=enforce_eager,
  97. kv_cache_dtype=kv_cache_dtype,
  98. quantization_param_path=quantization_param_path,
  99. device=device,
  100. enable_prefix_caching=enable_prefix_caching,
  101. download_dir=download_dir,
  102. enable_chunked_prefill=enable_chunked_prefill,
  103. max_num_batched_tokens=max_num_batched_tokens,
  104. distributed_executor_backend=distributed_executor_backend,
  105. load_format=load_format,
  106. num_scheduler_steps=num_scheduler_steps,
  107. use_v2_block_manager=use_v2_block_manager,
  108. disable_async_output_proc=disable_async_output_proc,
  109. )
  110. # Add the requests to the engine.
  111. prompts: List[str] = []
  112. sampling_params: List[SamplingParams] = []
  113. for prompt, _, output_len in requests:
  114. prompts.append(prompt)
  115. sampling_params.append(
  116. SamplingParams(
  117. n=n,
  118. temperature=0.0 if use_beam_search else 1.0,
  119. top_p=1.0,
  120. use_beam_search=use_beam_search,
  121. ignore_eos=True,
  122. max_tokens=output_len,
  123. ))
  124. start = time.perf_counter()
  125. llm.generate(prompts, sampling_params, use_tqdm=True)
  126. end = time.perf_counter()
  127. return end - start
  128. async def run_aphrodite_async(
  129. requests: List[Tuple[str, int, int]],
  130. model: str,
  131. tokenizer: str,
  132. quantization: Optional[str],
  133. tensor_parallel_size: int,
  134. seed: int,
  135. n: int,
  136. use_beam_search: bool,
  137. trust_remote_code: bool,
  138. dtype: str,
  139. max_model_len: Optional[int],
  140. enforce_eager: bool,
  141. kv_cache_dtype: str,
  142. quantization_param_path: Optional[str],
  143. device: str,
  144. enable_prefix_caching: bool,
  145. enable_chunked_prefill: bool,
  146. max_num_batched_tokens: int,
  147. distributed_executor_backend: Optional[str],
  148. gpu_memory_utilization: float = 0.9,
  149. num_scheduler_steps: int = 1,
  150. use_v2_block_manager: bool = False,
  151. download_dir: Optional[str] = None,
  152. load_format: str = EngineArgs.load_format,
  153. disable_async_output_proc: bool = False,
  154. disable_frontend_multiprocessing: bool = False,
  155. ) -> float:
  156. from aphrodite import SamplingParams
  157. engine_args = AsyncEngineArgs(
  158. model=model,
  159. tokenizer=tokenizer,
  160. quantization=quantization,
  161. tensor_parallel_size=tensor_parallel_size,
  162. seed=seed,
  163. trust_remote_code=trust_remote_code,
  164. dtype=dtype,
  165. max_model_len=max_model_len,
  166. gpu_memory_utilization=gpu_memory_utilization,
  167. enforce_eager=enforce_eager,
  168. kv_cache_dtype=kv_cache_dtype,
  169. quantization_param_path=quantization_param_path,
  170. device=device,
  171. enable_prefix_caching=enable_prefix_caching,
  172. download_dir=download_dir,
  173. enable_chunked_prefill=enable_chunked_prefill,
  174. max_num_batched_tokens=max_num_batched_tokens,
  175. distributed_executor_backend=distributed_executor_backend,
  176. load_format=load_format,
  177. num_scheduler_steps=num_scheduler_steps,
  178. use_v2_block_manager=use_v2_block_manager,
  179. disable_async_output_proc=disable_async_output_proc,
  180. disable_log_requests=True,
  181. )
  182. async with build_engine_client_from_engine_args(
  183. engine_args, disable_frontend_multiprocessing) as llm:
  184. # Add the requests to the engine.
  185. prompts: List[str] = []
  186. sampling_params: List[SamplingParams] = []
  187. for prompt, _, output_len in requests:
  188. prompts.append(prompt)
  189. sampling_params.append(
  190. SamplingParams(
  191. n=n,
  192. temperature=0.0 if use_beam_search else 1.0,
  193. top_p=1.0,
  194. use_beam_search=use_beam_search,
  195. ignore_eos=True,
  196. max_tokens=output_len,
  197. ))
  198. generators = []
  199. start = time.perf_counter()
  200. for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
  201. generator = llm.generate(prompt, sp, request_id=f"test{i}")
  202. generators.append(generator)
  203. all_gens = merge_async_iterators(*generators)
  204. async for i, res in all_gens:
  205. pass
  206. end = time.perf_counter()
  207. return end - start
  208. def run_hf(
  209. requests: List[Tuple[str, int, int]],
  210. model: str,
  211. tokenizer: PreTrainedTokenizerBase,
  212. n: int,
  213. use_beam_search: bool,
  214. max_batch_size: int,
  215. trust_remote_code: bool,
  216. ) -> float:
  217. assert not use_beam_search
  218. llm = AutoModelForCausalLM.from_pretrained(
  219. model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
  220. if llm.config.model_type == "llama":
  221. # To enable padding in the HF backend.
  222. tokenizer.pad_token = tokenizer.eos_token
  223. llm = llm.cuda()
  224. pbar = tqdm(total=len(requests))
  225. start = time.perf_counter()
  226. batch: List[str] = []
  227. max_prompt_len = 0
  228. max_output_len = 0
  229. for i in range(len(requests)):
  230. prompt, prompt_len, output_len = requests[i]
  231. # Add the prompt to the batch.
  232. batch.append(prompt)
  233. max_prompt_len = max(max_prompt_len, prompt_len)
  234. max_output_len = max(max_output_len, output_len)
  235. if len(batch) < max_batch_size and i != len(requests) - 1:
  236. # Check if we can add more requests to the batch.
  237. _, next_prompt_len, next_output_len = requests[i + 1]
  238. if (max(max_prompt_len, next_prompt_len) +
  239. max(max_output_len, next_output_len)) <= 2048:
  240. # We can add more requests to the batch.
  241. continue
  242. # Generate the sequences.
  243. input_ids = tokenizer(batch, return_tensors="pt",
  244. padding=True).input_ids
  245. llm_outputs = llm.generate(
  246. input_ids=input_ids.cuda(),
  247. do_sample=not use_beam_search,
  248. num_return_sequences=n,
  249. temperature=1.0,
  250. top_p=1.0,
  251. use_cache=True,
  252. max_new_tokens=max_output_len,
  253. )
  254. # Include the decoding time.
  255. tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
  256. pbar.update(len(batch))
  257. # Clear the batch.
  258. batch = []
  259. max_prompt_len = 0
  260. max_output_len = 0
  261. end = time.perf_counter()
  262. return end - start
  263. def run_mii(
  264. requests: List[Tuple[str, int, int]],
  265. model: str,
  266. tensor_parallel_size: int,
  267. output_len: int,
  268. ) -> float:
  269. from mii import client, serve
  270. llm = serve(model, tensor_parallel=tensor_parallel_size)
  271. prompts = [prompt for prompt, _, _ in requests]
  272. start = time.perf_counter()
  273. llm.generate(prompts, max_new_tokens=output_len)
  274. end = time.perf_counter()
  275. client = client(model)
  276. client.terminate_server()
  277. return end - start
  278. def main(args: argparse.Namespace):
  279. print(args)
  280. random.seed(args.seed)
  281. # Sample the requests.
  282. tokenizer = AutoTokenizer.from_pretrained(
  283. args.tokenizer, trust_remote_code=args.trust_remote_code)
  284. if args.dataset is None:
  285. # Synthesize a prompt with the given input length.
  286. prompt = "hi" * (args.input_len - 1)
  287. requests = [(prompt, args.input_len, args.output_len)
  288. for _ in range(args.num_prompts)]
  289. else:
  290. requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
  291. args.output_len)
  292. if args.backend == "aphrodite":
  293. run_args = [
  294. requests, args.model, args.tokenizer, args.quantization,
  295. args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
  296. args.trust_remote_code, args.dtype, args.max_model_len,
  297. args.enforce_eager, args.kv_cache_dtype,
  298. args.quantization_param_path, args.device,
  299. args.enable_prefix_caching, args.enable_chunked_prefill,
  300. args.max_num_batched_tokens, args.distributed_executor_backend,
  301. args.gpu_memory_utilization, args.num_scheduler_steps,
  302. args.use_v2_block_manager, args.download_dir, args.load_format,
  303. args.disable_async_output_proc
  304. ]
  305. if args.async_engine:
  306. run_args.append(args.disable_frontend_multiprocessing)
  307. elapsed_time = uvloop.run(run_aphrodite_async(*run_args))
  308. else:
  309. elapsed_time = run_aphrodite(*run_args)
  310. elif args.backend == "hf":
  311. assert args.tensor_parallel_size == 1
  312. elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
  313. args.use_beam_search, args.hf_max_batch_size,
  314. args.trust_remote_code)
  315. elif args.backend == "mii":
  316. elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
  317. args.output_len)
  318. else:
  319. raise ValueError(f"Unknown backend: {args.backend}")
  320. total_num_tokens = sum(prompt_len + output_len
  321. for _, prompt_len, output_len in requests)
  322. print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
  323. f"{total_num_tokens / elapsed_time:.2f} tokens/s")
  324. # Output JSON results if specified
  325. if args.output_json:
  326. results = {
  327. "elapsed_time": elapsed_time,
  328. "num_requests": len(requests),
  329. "total_num_tokens": total_num_tokens,
  330. "requests_per_second": len(requests) / elapsed_time,
  331. "tokens_per_second": total_num_tokens / elapsed_time,
  332. }
  333. with open(args.output_json, "w") as f:
  334. json.dump(results, f, indent=4)
  335. if __name__ == "__main__":
  336. parser = FlexibleArgumentParser(description="Benchmark the throughput.")
  337. parser.add_argument("--backend",
  338. type=str,
  339. choices=["aphrodite", "hf", "mii"],
  340. default="aphrodite")
  341. parser.add_argument("--dataset",
  342. type=str,
  343. default=None,
  344. help="Path to the dataset.")
  345. parser.add_argument("--input-len",
  346. type=int,
  347. default=None,
  348. help="Input prompt length for each request")
  349. parser.add_argument("--output-len",
  350. type=int,
  351. default=None,
  352. help="Output length for each request. Overrides the "
  353. "output length from the dataset.")
  354. parser.add_argument("--model", type=str, default="facebook/opt-125m")
  355. parser.add_argument("--tokenizer", type=str, default=None)
  356. parser.add_argument('--quantization',
  357. '-q',
  358. choices=[*QUANTIZATION_METHODS, None],
  359. default=None)
  360. parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
  361. parser.add_argument("--n",
  362. type=int,
  363. default=1,
  364. help="Number of generated sequences per prompt.")
  365. parser.add_argument("--use-beam-search", action="store_true")
  366. parser.add_argument("--num-prompts",
  367. type=int,
  368. default=1000,
  369. help="Number of prompts to process.")
  370. parser.add_argument("--seed", type=int, default=0)
  371. parser.add_argument("--hf-max-batch-size",
  372. type=int,
  373. default=None,
  374. help="Maximum batch size for HF backend.")
  375. parser.add_argument('--trust-remote-code',
  376. action='store_true',
  377. help='trust remote code from huggingface')
  378. parser.add_argument(
  379. '--max-model-len',
  380. type=int,
  381. default=None,
  382. help='Maximum length of a sequence (including prompt and output). '
  383. 'If None, will be derived from the model.')
  384. parser.add_argument(
  385. '--dtype',
  386. type=str,
  387. default='auto',
  388. choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
  389. help='data type for model weights and activations. '
  390. 'The "auto" option will use FP16 precision '
  391. 'for FP32 and FP16 models, and BF16 precision '
  392. 'for BF16 models.')
  393. parser.add_argument('--gpu-memory-utilization',
  394. type=float,
  395. default=0.9,
  396. help='the fraction of GPU memory to be used for '
  397. 'the model executor, which can range from 0 to 1.'
  398. 'If unspecified, will use the default value of 0.9.')
  399. parser.add_argument("--enforce-eager",
  400. action="store_true",
  401. help="enforce eager execution")
  402. parser.add_argument(
  403. '--kv-cache-dtype',
  404. type=str,
  405. choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
  406. default="auto",
  407. help='Data type for kv cache storage. If "auto", will use model '
  408. 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
  409. 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
  410. parser.add_argument(
  411. '--quantization-param-path',
  412. type=str,
  413. default=None,
  414. help='Path to the JSON file containing the KV cache scaling factors. '
  415. 'This should generally be supplied, when KV cache dtype is FP8. '
  416. 'Otherwise, KV cache scaling factors default to 1.0, which may cause '
  417. 'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
  418. 'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
  419. 'instead supported for common inference criteria.')
  420. parser.add_argument(
  421. "--device",
  422. type=str,
  423. default="auto",
  424. choices=DEVICE_OPTIONS,
  425. help='device type for Aphrodite execution, supporting CUDA, OpenVINO '
  426. 'and CPU.')
  427. parser.add_argument(
  428. "--num-scheduler-steps",
  429. type=int,
  430. default=1,
  431. help="Maximum number of forward steps per scheduler call.")
  432. parser.add_argument("--use-v2-block-manager",
  433. action='store_true',
  434. help="Enable block manager v2.")
  435. parser.add_argument(
  436. "--enable-prefix-caching",
  437. action='store_true',
  438. help="Enable automatic prefix caching for Aphrodite backend.")
  439. parser.add_argument("--enable-chunked-prefill",
  440. action='store_true',
  441. help="enable chunked prefill for Aphrodite backend.")
  442. parser.add_argument('--max-num-batched-tokens',
  443. type=int,
  444. default=None,
  445. help='maximum number of batched tokens per '
  446. 'iteration')
  447. parser.add_argument('--download-dir',
  448. type=str,
  449. default=None,
  450. help='directory to download and load the weights, '
  451. 'default to the default cache dir of huggingface')
  452. parser.add_argument(
  453. '--output-json',
  454. type=str,
  455. default=None,
  456. help='Path to save the throughput results in JSON format.')
  457. parser.add_argument(
  458. '--distributed-executor-backend',
  459. choices=['ray', 'mp'],
  460. default=None,
  461. help='Backend to use for distributed serving. When more than 1 GPU '
  462. 'is used, will be automatically set to "ray" if installed '
  463. 'or "mp" (multiprocessing) otherwise.')
  464. parser.add_argument(
  465. '--load-format',
  466. type=str,
  467. default=EngineArgs.load_format,
  468. choices=[
  469. 'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer',
  470. 'bitsandbytes'
  471. ],
  472. help='The format of the model weights to load.\n\n'
  473. '* "auto" will try to load the weights in the safetensors format '
  474. 'and fall back to the pytorch bin format if safetensors format '
  475. 'is not available.\n'
  476. '* "pt" will load the weights in the pytorch bin format.\n'
  477. '* "safetensors" will load the weights in the safetensors format.\n'
  478. '* "npcache" will load the weights in pytorch format and store '
  479. 'a numpy cache to speed up the loading.\n'
  480. '* "dummy" will initialize the weights with random values, '
  481. 'which is mainly for profiling.\n'
  482. '* "tensorizer" will load the weights using tensorizer from '
  483. 'CoreWeave. See the Tensorize Aphrodite Model script in the Examples'
  484. 'section for more information.\n'
  485. '* "bitsandbytes" will load the weights using bitsandbytes '
  486. 'quantization.\n')
  487. parser.add_argument(
  488. "--disable-async-output-proc",
  489. action='store_true',
  490. default=False,
  491. help="Disable async output processor for Aphrodite backend.")
  492. parser.add_argument("--async-engine",
  493. action='store_true',
  494. default=False,
  495. help="Use Aphrodite async engine rather than LLM class."
  496. )
  497. parser.add_argument("--disable-frontend-multiprocessing",
  498. action='store_true',
  499. default=False,
  500. help="Disable decoupled async engine frontend.")
  501. args = parser.parse_args()
  502. if args.tokenizer is None:
  503. args.tokenizer = args.model
  504. if args.dataset is None:
  505. assert args.input_len is not None
  506. assert args.output_len is not None
  507. else:
  508. assert args.input_len is None
  509. if args.backend == "aphrodite":
  510. if args.hf_max_batch_size is not None:
  511. raise ValueError("HF max batch size is only for HF backend.")
  512. elif args.backend == "hf":
  513. if args.hf_max_batch_size is None:
  514. raise ValueError("HF max batch size is required for HF backend.")
  515. if args.quantization is not None:
  516. raise ValueError("Quantization is only for Aphrodite backend.")
  517. elif args.backend == "mii":
  518. if args.dtype != "auto":
  519. raise ValueError("dtype must be auto for MII backend.")
  520. if args.n != 1:
  521. raise ValueError("n must be 1 for MII backend.")
  522. if args.use_beam_search:
  523. raise ValueError("Beam search is not supported for MII backend.")
  524. if args.quantization is not None:
  525. raise ValueError("Quantization is only for Aphrodite backend.")
  526. if args.hf_max_batch_size is not None:
  527. raise ValueError("HF max batch size is only for HF backend.")
  528. if args.tokenizer != args.model:
  529. raise ValueError("Tokenizer must be the same as the model for MII "
  530. "backend.")
  531. main(args)