throughput.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. """Benchmark offline inference throughput."""
  2. import argparse
  3. import json
  4. import random
  5. import time
  6. from typing import List, Optional, Tuple
  7. import torch
  8. from tqdm import tqdm
  9. from transformers import (AutoModelForCausalLM, AutoTokenizer,
  10. PreTrainedTokenizerBase)
  11. from aphrodite.common.utils import FlexibleArgumentParser
  12. from aphrodite.engine.args_tools import EngineArgs
  13. from aphrodite.quantization import QUANTIZATION_METHODS
  14. def sample_requests(
  15. dataset_path: str,
  16. num_requests: int,
  17. tokenizer: PreTrainedTokenizerBase,
  18. fixed_output_len: Optional[int],
  19. ) -> List[Tuple[str, int, int]]:
  20. if fixed_output_len is not None and fixed_output_len < 4:
  21. raise ValueError("output_len too small")
  22. # Load the dataset.
  23. with open(dataset_path) as f:
  24. dataset = json.load(f)
  25. # Filter out the conversations with less than 2 turns.
  26. dataset = [data for data in dataset if len(data["conversations"]) >= 2]
  27. # Only keep the first two turns of each conversation.
  28. dataset = [(data["conversations"][0]["value"],
  29. data["conversations"][1]["value"]) for data in dataset]
  30. # Shuffle the dataset.
  31. random.shuffle(dataset)
  32. # Filter out sequences that are too long or too short
  33. filtered_dataset: List[Tuple[str, int, int]] = []
  34. for i in range(len(dataset)):
  35. if len(filtered_dataset) == num_requests:
  36. break
  37. # Tokenize the prompts and completions.
  38. prompt = dataset[i][0]
  39. prompt_token_ids = tokenizer(prompt).input_ids
  40. completion = dataset[i][1]
  41. completion_token_ids = tokenizer(completion).input_ids
  42. prompt_len = len(prompt_token_ids)
  43. output_len = len(completion_token_ids
  44. ) if fixed_output_len is None else fixed_output_len
  45. if prompt_len < 4 or output_len < 4:
  46. # Prune too short sequences.
  47. continue
  48. if prompt_len > 1024 or prompt_len + output_len > 2048:
  49. # Prune too long sequences.
  50. continue
  51. filtered_dataset.append((prompt, prompt_len, output_len))
  52. return filtered_dataset
  53. def run_aphrodite(
  54. requests: List[Tuple[str, int, int]],
  55. model: str,
  56. tokenizer: str,
  57. quantization: Optional[str],
  58. quant_llm_fp_bits: Optional[int],
  59. tensor_parallel_size: int,
  60. seed: int,
  61. n: int,
  62. use_beam_search: bool,
  63. trust_remote_code: bool,
  64. dtype: str,
  65. max_model_len: Optional[int],
  66. enforce_eager: bool,
  67. max_seq_len_to_capture: int,
  68. kv_cache_dtype: str,
  69. quantization_param_path: Optional[str],
  70. device: str,
  71. enable_prefix_caching: bool,
  72. enable_chunked_prefill: bool,
  73. max_num_batched_tokens: int,
  74. distributed_executor_backend: Optional[str],
  75. gpu_memory_utilization: float = 0.9,
  76. download_dir: Optional[str] = None,
  77. load_format: str = EngineArgs.load_format,
  78. max_num_seqs: Optional[int] = None,
  79. num_scheduler_steps: Optional[int] = None,
  80. ) -> float:
  81. from aphrodite import LLM, SamplingParams
  82. llm = LLM(
  83. model=model,
  84. tokenizer=tokenizer,
  85. quantization=quantization,
  86. quant_llm_fp_bits=quant_llm_fp_bits,
  87. tensor_parallel_size=tensor_parallel_size,
  88. seed=seed,
  89. trust_remote_code=trust_remote_code,
  90. dtype=dtype,
  91. max_model_len=max_model_len,
  92. gpu_memory_utilization=gpu_memory_utilization,
  93. enforce_eager=enforce_eager,
  94. max_seq_len_to_capture=max_seq_len_to_capture,
  95. kv_cache_dtype=kv_cache_dtype,
  96. quantization_param_path=quantization_param_path,
  97. device=device,
  98. enable_prefix_caching=enable_prefix_caching,
  99. download_dir=download_dir,
  100. enable_chunked_prefill=enable_chunked_prefill,
  101. max_num_batched_tokens=max_num_batched_tokens,
  102. distributed_executor_backend=distributed_executor_backend,
  103. load_format=load_format,
  104. max_num_seqs=max_num_seqs,
  105. num_scheduler_steps=num_scheduler_steps,
  106. )
  107. # Add the requests to the engine.
  108. prompts: List[str] = []
  109. sampling_params: List[SamplingParams] = []
  110. for prompt, _, output_len in requests:
  111. prompts.append(prompt)
  112. sampling_params.append(
  113. SamplingParams(
  114. n=n,
  115. temperature=0.0 if use_beam_search else 1.0,
  116. top_p=1.0,
  117. use_beam_search=use_beam_search,
  118. ignore_eos=True,
  119. max_tokens=output_len,
  120. ))
  121. start = time.perf_counter()
  122. llm.generate(prompts, sampling_params, use_tqdm=True)
  123. end = time.perf_counter()
  124. return end - start
  125. def run_hf(
  126. requests: List[Tuple[str, int, int]],
  127. model: str,
  128. tokenizer: PreTrainedTokenizerBase,
  129. n: int,
  130. use_beam_search: bool,
  131. max_batch_size: int,
  132. trust_remote_code: bool,
  133. ) -> float:
  134. assert not use_beam_search
  135. llm = AutoModelForCausalLM.from_pretrained(
  136. model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
  137. if llm.config.model_type == "llama":
  138. # To enable padding in the HF backend.
  139. tokenizer.pad_token = tokenizer.eos_token
  140. llm = llm.cuda()
  141. pbar = tqdm(total=len(requests))
  142. start = time.perf_counter()
  143. batch: List[str] = []
  144. max_prompt_len = 0
  145. max_output_len = 0
  146. for i in range(len(requests)):
  147. prompt, prompt_len, output_len = requests[i]
  148. # Add the prompt to the batch.
  149. batch.append(prompt)
  150. max_prompt_len = max(max_prompt_len, prompt_len)
  151. max_output_len = max(max_output_len, output_len)
  152. if len(batch) < max_batch_size and i != len(requests) - 1:
  153. # Check if we can add more requests to the batch.
  154. _, next_prompt_len, next_output_len = requests[i + 1]
  155. if (max(max_prompt_len, next_prompt_len) +
  156. max(max_output_len, next_output_len)) <= 2048:
  157. # We can add more requests to the batch.
  158. continue
  159. # Generate the sequences.
  160. input_ids = tokenizer(batch, return_tensors="pt",
  161. padding=True).input_ids
  162. llm_outputs = llm.generate(
  163. input_ids=input_ids.cuda(),
  164. do_sample=not use_beam_search,
  165. num_return_sequences=n,
  166. temperature=1.0,
  167. top_p=1.0,
  168. use_cache=True,
  169. max_new_tokens=max_output_len,
  170. )
  171. # Include the decoding time.
  172. tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
  173. pbar.update(len(batch))
  174. # Clear the batch.
  175. batch = []
  176. max_prompt_len = 0
  177. max_output_len = 0
  178. end = time.perf_counter()
  179. return end - start
  180. def run_mii(
  181. requests: List[Tuple[str, int, int]],
  182. model: str,
  183. tensor_parallel_size: int,
  184. output_len: int,
  185. ) -> float:
  186. from mii import client, serve
  187. llm = serve(model, tensor_parallel=tensor_parallel_size)
  188. prompts = [prompt for prompt, _, _ in requests]
  189. start = time.perf_counter()
  190. llm.generate(prompts, max_new_tokens=output_len)
  191. end = time.perf_counter()
  192. client = client(model)
  193. client.terminate_server()
  194. return end - start
  195. def main(args: argparse.Namespace):
  196. print(args)
  197. random.seed(args.seed)
  198. # Sample the requests.
  199. tokenizer = AutoTokenizer.from_pretrained(
  200. args.tokenizer, trust_remote_code=args.trust_remote_code)
  201. if args.dataset is None:
  202. # Synthesize a prompt with the given input length.
  203. prompt = "hi" * (args.input_len - 1)
  204. requests = [(prompt, args.input_len, args.output_len)
  205. for _ in range(args.num_prompts)]
  206. else:
  207. requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
  208. args.output_len)
  209. if args.backend == "aphrodite":
  210. elapsed_time = run_aphrodite(
  211. requests, args.model, args.tokenizer, args.quantization,
  212. args.quant_llm_fp_bits,
  213. args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
  214. args.trust_remote_code, args.dtype, args.max_model_len,
  215. args.enforce_eager, args.max_seq_len_to_capture,
  216. args.kv_cache_dtype, args.quantization_param_path, args.device,
  217. args.enable_prefix_caching, args.enable_chunked_prefill,
  218. args.max_num_batched_tokens, args.distributed_executor_backend,
  219. args.gpu_memory_utilization, args.download_dir, args.load_format,
  220. args.max_num_seqs, args.num_scheduler_steps)
  221. elif args.backend == "hf":
  222. assert args.tensor_parallel_size == 1
  223. elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
  224. args.use_beam_search, args.hf_max_batch_size,
  225. args.trust_remote_code)
  226. elif args.backend == "mii":
  227. elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
  228. args.output_len)
  229. else:
  230. raise ValueError(f"Unknown backend: {args.backend}")
  231. total_num_tokens = sum(prompt_len + output_len
  232. for _, prompt_len, output_len in requests)
  233. print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
  234. f"{total_num_tokens / elapsed_time:.2f} tokens/s")
  235. # Output JSON results if specified
  236. if args.output_json:
  237. results = {
  238. "elapsed_time": elapsed_time,
  239. "num_requests": len(requests),
  240. "total_num_tokens": total_num_tokens,
  241. "requests_per_second": len(requests) / elapsed_time,
  242. "tokens_per_second": total_num_tokens / elapsed_time,
  243. }
  244. with open(args.output_json, "w") as f:
  245. json.dump(results, f, indent=4)
  246. if __name__ == "__main__":
  247. parser = FlexibleArgumentParser(description="Benchmark the throughput.")
  248. parser.add_argument("--backend",
  249. type=str,
  250. choices=["aphrodite", "hf", "mii"],
  251. default="aphrodite")
  252. parser.add_argument("--dataset",
  253. type=str,
  254. default=None,
  255. help="Path to the dataset.")
  256. parser.add_argument("--input-len",
  257. type=int,
  258. default=None,
  259. help="Input prompt length for each request")
  260. parser.add_argument("--output-len",
  261. type=int,
  262. default=None,
  263. help="Output length for each request. Overrides the "
  264. "output length from the dataset.")
  265. parser.add_argument("--model", type=str, default="facebook/opt-125m")
  266. parser.add_argument("--tokenizer", type=str, default=None)
  267. parser.add_argument('--quantization',
  268. '-q',
  269. choices=[*QUANTIZATION_METHODS, None],
  270. default=None)
  271. parser.add_argument('--quant-llm-fp-bits',
  272. type=int,
  273. default=None,
  274. choices=[4, 5, 6, 7],
  275. help="Number of bits for the FP quantization in "
  276. "QuantLLM")
  277. parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
  278. parser.add_argument("--n",
  279. type=int,
  280. default=1,
  281. help="Number of generated sequences per prompt.")
  282. parser.add_argument("--use-beam-search", action="store_true")
  283. parser.add_argument('--max-num-seqs',
  284. type=int,
  285. default=256,
  286. help='maximum number of batched requests per iteration')
  287. parser.add_argument('--num-scheduler-steps',
  288. type=int,
  289. default=1,
  290. help='number of scheduler steps for multi-step.')
  291. parser.add_argument("--num-prompts",
  292. type=int,
  293. default=1000,
  294. help="Number of prompts to process.")
  295. parser.add_argument("--seed", type=int, default=0)
  296. parser.add_argument("--hf-max-batch-size",
  297. type=int,
  298. default=None,
  299. help="Maximum batch size for HF backend.")
  300. parser.add_argument('--trust-remote-code',
  301. action='store_true',
  302. help='trust remote code from huggingface')
  303. parser.add_argument(
  304. '--max-model-len',
  305. type=int,
  306. default=None,
  307. help='Maximum length of a sequence (including prompt and output). '
  308. 'If None, will be derived from the model.')
  309. parser.add_argument(
  310. '--dtype',
  311. type=str,
  312. default='auto',
  313. choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
  314. help='data type for model weights and activations. '
  315. 'The "auto" option will use FP16 precision '
  316. 'for FP32 and FP16 models, and BF16 precision '
  317. 'for BF16 models.')
  318. parser.add_argument('--gpu-memory-utilization',
  319. type=float,
  320. default=0.9,
  321. help='the fraction of GPU memory to be used for '
  322. 'the model executor, which can range from 0 to 1.'
  323. 'If unspecified, will use the default value of 0.9.')
  324. parser.add_argument("--enforce-eager",
  325. action="store_true",
  326. help="enforce eager execution")
  327. parser.add_argument("--max-seq-len-to-capture",
  328. type=int,
  329. default=None,
  330. help="The maximum sequence length to capture for "
  331. "CUDA graphs.")
  332. parser.add_argument(
  333. '--kv-cache-dtype',
  334. type=str,
  335. choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
  336. default="auto",
  337. help='Data type for kv cache storage. If "auto", will use model '
  338. 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
  339. 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
  340. parser.add_argument(
  341. '--quantization-param-path',
  342. type=str,
  343. default=None,
  344. help='Path to the JSON file containing the KV cache scaling factors. '
  345. 'This should generally be supplied, when KV cache dtype is FP8. '
  346. 'Otherwise, KV cache scaling factors default to 1.0, which may cause '
  347. 'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
  348. 'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
  349. 'instead supported for common inference criteria.')
  350. parser.add_argument(
  351. "--device",
  352. type=str,
  353. default="auto",
  354. choices=["auto", "cuda", "cpu", "openvino", "tpu", "xpu"],
  355. help=
  356. 'device type for Aphrodite execution, supporting CUDA, OpenVINO and '
  357. 'CPU.')
  358. parser.add_argument(
  359. "--enable-prefix-caching",
  360. action='store_true',
  361. help="enable automatic prefix caching for Aphrodite backend.")
  362. parser.add_argument("--enable-chunked-prefill",
  363. action='store_true',
  364. help="enable chunked prefill for Aphrodite backend.")
  365. parser.add_argument('--max-num-batched-tokens',
  366. type=int,
  367. default=None,
  368. help='maximum number of batched tokens per '
  369. 'iteration')
  370. parser.add_argument('--download-dir',
  371. type=str,
  372. default=None,
  373. help='directory to download and load the weights, '
  374. 'default to the default cache dir of huggingface')
  375. parser.add_argument(
  376. '--output-json',
  377. type=str,
  378. default=None,
  379. help='Path to save the throughput results in JSON format.')
  380. parser.add_argument(
  381. '--distributed-executor-backend',
  382. choices=['ray', 'mp'],
  383. default=None,
  384. help='Backend to use for distributed serving. When more than 1 GPU '
  385. 'is used, will be automatically set to "ray" if installed '
  386. 'or "mp" (multiprocessing) otherwise.')
  387. parser.add_argument(
  388. '--load-format',
  389. type=str,
  390. default=EngineArgs.load_format,
  391. choices=[
  392. 'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer',
  393. 'bitsandbytes'
  394. ],
  395. help='The format of the model weights to load.\n\n'
  396. '* "auto" will try to load the weights in the safetensors format '
  397. 'and fall back to the pytorch bin format if safetensors format '
  398. 'is not available.\n'
  399. '* "pt" will load the weights in the pytorch bin format.\n'
  400. '* "safetensors" will load the weights in the safetensors format.\n'
  401. '* "npcache" will load the weights in pytorch format and store '
  402. 'a numpy cache to speed up the loading.\n'
  403. '* "dummy" will initialize the weights with random values, '
  404. 'which is mainly for profiling.\n'
  405. '* "tensorizer" will load the weights using tensorizer from '
  406. 'CoreWeave. See the Tensorize Aphrodite Model script in the Examples'
  407. 'section for more information.\n'
  408. '* "bitsandbytes" will load the weights using bitsandbytes '
  409. 'quantization.\n')
  410. args = parser.parse_args()
  411. if args.tokenizer is None:
  412. args.tokenizer = args.model
  413. if args.dataset is None:
  414. assert args.input_len is not None
  415. assert args.output_len is not None
  416. else:
  417. assert args.input_len is None
  418. if args.backend == "aphrodite":
  419. if args.hf_max_batch_size is not None:
  420. raise ValueError("HF max batch size is only for HF backend.")
  421. elif args.backend == "hf":
  422. if args.hf_max_batch_size is None:
  423. raise ValueError("HF max batch size is required for HF backend.")
  424. if args.quantization is not None:
  425. raise ValueError("Quantization is only for Aphrodite backend.")
  426. elif args.backend == "mii":
  427. if args.dtype != "auto":
  428. raise ValueError("dtype must be auto for MII backend.")
  429. if args.n != 1:
  430. raise ValueError("n must be 1 for MII backend.")
  431. if args.use_beam_search:
  432. raise ValueError("Beam search is not supported for MII backend.")
  433. if args.quantization is not None:
  434. raise ValueError("Quantization is only for Aphrodite backend.")
  435. if args.hf_max_batch_size is not None:
  436. raise ValueError("HF max batch size is only for HF backend.")
  437. if args.tokenizer != args.model:
  438. raise ValueError("Tokenizer must be the same as the model for MII "
  439. "backend.")
  440. main(args)