throughput.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. """Benchmark offline inference throughput."""
  2. import argparse
  3. import json
  4. import random
  5. import time
  6. from typing import List, Optional, Tuple
  7. import torch
  8. from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
  9. from tqdm import tqdm
  10. from aphrodite import LLM, SamplingParams
  11. from aphrodite.transformers_utils.tokenizer import get_tokenizer
  12. def sample_requests(
  13. dataset_path: str,
  14. num_requests: int,
  15. tokenizer: PreTrainedTokenizerBase,
  16. ) -> List[Tuple[str, int, int]]:
  17. # Load the dataset.
  18. with open(dataset_path) as f:
  19. dataset = json.load(f)
  20. # Filter out the conversations with less than 2 turns.
  21. dataset = [data for data in dataset if len(data["conversations"]) >= 2]
  22. # Only keep the first two turns of each conversation.
  23. dataset = [(data["conversations"][0]["value"],
  24. data["conversations"][1]["value"]) for data in dataset]
  25. # Tokenize the prompts and completions.
  26. prompts = [prompt for prompt, _ in dataset]
  27. prompt_token_ids = tokenizer(prompts).input_ids
  28. completions = [completion for _, completion in dataset]
  29. completion_token_ids = tokenizer(completions).input_ids
  30. tokenized_dataset = []
  31. for i in range(len(dataset)):
  32. output_len = len(completion_token_ids[i])
  33. tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
  34. # Filter out too long sequences.
  35. filtered_dataset: List[Tuple[str, int, int]] = []
  36. for prompt, prompt_token_ids, output_len in tokenized_dataset:
  37. prompt_len = len(prompt_token_ids)
  38. if prompt_len < 4 or output_len < 4:
  39. # Prune too short sequences.
  40. continue
  41. if prompt_len > 1024 or prompt_len + output_len > 2048:
  42. # Prune too long sequences.
  43. continue
  44. filtered_dataset.append((prompt, prompt_len, output_len))
  45. # Sample the requests.
  46. sampled_requests = random.sample(filtered_dataset, num_requests)
  47. return sampled_requests
  48. def run_aphrodite(
  49. requests: List[Tuple[str, int, int]],
  50. model: str,
  51. tokenizer: str,
  52. quantization: Optional[str],
  53. tensor_parallel_size: int,
  54. seed: int,
  55. n: int,
  56. use_beam_search: bool,
  57. trust_remote_code: bool,
  58. dtype: str,
  59. kv_cache_dtype: str,
  60. disable_custom_all_reduce: bool,
  61. context_shift: bool,
  62. enforce_eager: bool,
  63. enable_chunked_prefill: bool,
  64. max_num_batched_tokens: int,
  65. speculative_model: Optional[str] = None,
  66. num_speculative_tokens: Optional[int] = None,
  67. use_v2_block_manager: bool = False,
  68. ) -> float:
  69. llm = LLM(
  70. model=model,
  71. tokenizer=tokenizer,
  72. quantization=quantization,
  73. tensor_parallel_size=tensor_parallel_size,
  74. seed=seed,
  75. trust_remote_code=trust_remote_code,
  76. dtype=dtype,
  77. kv_cache_dtype=kv_cache_dtype,
  78. disable_custom_all_reduce=disable_custom_all_reduce,
  79. context_shift=context_shift,
  80. enforce_eager=enforce_eager,
  81. enable_chunked_prefill=enable_chunked_prefill,
  82. max_num_batched_tokens=max_num_batched_tokens,
  83. speculative_model=speculative_model,
  84. num_speculative_tokens=num_speculative_tokens,
  85. use_v2_block_manager=use_v2_block_manager,
  86. )
  87. # Add the requests to the engine.
  88. for prompt, _, output_len in requests:
  89. sampling_params = SamplingParams(
  90. n=n,
  91. temperature=0.0 if use_beam_search else 1.0,
  92. top_p=1.0,
  93. use_beam_search=use_beam_search,
  94. ignore_eos=True,
  95. max_tokens=output_len,
  96. )
  97. # FIXME: Do not use internal method.
  98. llm._add_request( # pylint: disable=protected-access
  99. prompt=prompt,
  100. prompt_token_ids=None,
  101. sampling_params=sampling_params,
  102. )
  103. start = time.perf_counter()
  104. # FIXME Do use internal method.
  105. llm._run_engine(use_tqdm=True) # pylint: disable=protected-access
  106. end = time.perf_counter()
  107. return end - start
  108. def run_hf(
  109. requests: List[Tuple[str, int, int]],
  110. model: str,
  111. tokenizer: PreTrainedTokenizerBase,
  112. n: int,
  113. use_beam_search: bool,
  114. max_batch_size: int,
  115. trust_remote_code: bool,
  116. ) -> float:
  117. assert not use_beam_search
  118. llm = AutoModelForCausalLM.from_pretrained(
  119. model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
  120. if llm.config.model_type == "llama":
  121. # To enable padding in the HF backend.
  122. tokenizer.pad_token = tokenizer.eos_token
  123. llm = llm.cuda()
  124. pbar = tqdm(total=len(requests))
  125. start = time.perf_counter()
  126. batch: List[str] = []
  127. max_prompt_len = 0
  128. max_output_len = 0
  129. for i in range(len(requests)):
  130. prompt, prompt_len, output_len = requests[i]
  131. # Add the prompt to the batch.
  132. batch.append(prompt)
  133. max_prompt_len = max(max_prompt_len, prompt_len)
  134. max_output_len = max(max_output_len, output_len)
  135. if len(batch) < max_batch_size and i != len(requests) - 1:
  136. # Check if we can add more requests to the batch.
  137. _, next_prompt_len, next_output_len = requests[i + 1]
  138. if (max(max_prompt_len, next_prompt_len) +
  139. max(max_output_len, next_output_len)) <= 2048:
  140. # We can add more requests to the batch.
  141. continue
  142. # Generate the sequences.
  143. input_ids = tokenizer(batch, return_tensors="pt",
  144. padding=True).input_ids
  145. llm_outputs = llm.generate(
  146. input_ids=input_ids.cuda(),
  147. do_sample=not use_beam_search,
  148. num_return_sequences=n,
  149. temperature=1.0,
  150. top_p=1.0,
  151. use_cache=True,
  152. max_new_tokens=max_output_len,
  153. )
  154. # Include the decoding time.
  155. tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
  156. pbar.update(len(batch))
  157. # Clear the batch.
  158. batch = []
  159. max_prompt_len = 0
  160. max_output_len = 0
  161. end = time.perf_counter()
  162. return end - start
  163. def main(args: argparse.Namespace): # pylint: disable=redefined-outer-name
  164. print(args)
  165. random.seed(args.seed)
  166. # Sample the requests.
  167. tokenizer = get_tokenizer(args.tokenizer,
  168. trust_remote_code=args.trust_remote_code)
  169. requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
  170. if args.backend == "aphrodite":
  171. elapsed_time = run_aphrodite(
  172. requests, args.model, args.tokenizer, args.quantization,
  173. args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
  174. args.trust_remote_code, args.dtype, args.kv_cache_dtype,
  175. args.disable_custom_all_reduce, args.context_shift,
  176. args.enforce_eager, args.enable_chunked_prefill,
  177. args.max_num_batched_tokens)
  178. elif args.backend == "hf":
  179. assert args.tensor_parallel_size == 1
  180. elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
  181. args.use_beam_search, args.hf_max_batch_size,
  182. args.trust_remote_code)
  183. else:
  184. raise ValueError(f"Unknown backend: {args.backend}")
  185. total_input_tokens = sum(prompt_len for _, prompt_len, _ in requests)
  186. total_output_tokens = sum(output_len for _, _, output_len in requests)
  187. print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
  188. f"Input tokens/s: {total_input_tokens / elapsed_time:.2f}, "
  189. f"Output tokens/s: {total_output_tokens / elapsed_time:.2f}")
  190. if __name__ == "__main__":
  191. parser = argparse.ArgumentParser(description="Benchmark the throughput.")
  192. parser.add_argument("--backend",
  193. type=str,
  194. choices=["aphrodite", "hf"],
  195. default="aphrodite")
  196. parser.add_argument("--dataset",
  197. type=str,
  198. required=True,
  199. help="Path to the dataset.")
  200. parser.add_argument("--model",
  201. type=str,
  202. default="EleutherAI/pythia-70m-deduped")
  203. parser.add_argument("--tokenizer", type=str, default=None)
  204. parser.add_argument(
  205. "--quantization",
  206. "-q",
  207. choices=["awq", "gguf", "bnb", "gptq", "squeezellm", "marlin", None],
  208. default=None)
  209. parser.add_argument("--gpu-memory-utilization", type=float, default=0.88)
  210. parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
  211. parser.add_argument("--n",
  212. type=int,
  213. default=1,
  214. help="Number of generated sequences per prompt.")
  215. parser.add_argument("--use-beam-search", action="store_true")
  216. parser.add_argument("--num-prompts",
  217. type=int,
  218. default=1000,
  219. help="Number of prompts to process.")
  220. parser.add_argument("--seed", type=int, default=0)
  221. parser.add_argument("--hf-max-batch-size",
  222. type=int,
  223. default=None,
  224. help="Maximum batch size for HF backend.")
  225. parser.add_argument("--trust-remote-code",
  226. action="store_true",
  227. help="trust remote code from huggingface")
  228. parser.add_argument(
  229. "--dtype",
  230. type=str,
  231. default="auto",
  232. choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
  233. help="data type for model weights and activations. "
  234. "The 'auto' option will use FP16 precision "
  235. "for FP32 and FP16 models, and BF16 precision "
  236. "for BF16 models.")
  237. parser.add_argument("--kv-cache-dtype",
  238. type=str,
  239. default="auto",
  240. choices=["auto", "fp8_e5m2"],
  241. help="The Data Type for the KV cache.")
  242. parser.add_argument(
  243. "--disable-custom-all-reduce",
  244. action="store_true",
  245. help="disable custom all reduce for the Aphrodite backend")
  246. parser.add_argument(
  247. "--context-shift",
  248. action="store_true",
  249. help="enable context shifting for the Aphrodite backend")
  250. parser.add_argument("--enforce-eager",
  251. type=lambda x: (str(x).lower() == 'true'),
  252. default=True,
  253. help="enforce eager mode for the Aphrodite backend")
  254. parser.add_argument(
  255. "--enable-chunked-prefill",
  256. action="store_true",
  257. help="enable chunked prefill for the Aphrodite backend")
  258. parser.add_argument("--max-num-batched-tokens",
  259. type=int,
  260. help="maximum number of batched tokens for the "
  261. "Aphrodite backend")
  262. parser.add_argument("--speculative-model",
  263. type=str,
  264. help="speculative model for the Aphrodite backend")
  265. parser.add_argument("--num-speculative-tokens",
  266. type=int,
  267. help="number of speculative tokens for the "
  268. "Aphrodite backend")
  269. parser.add_argument("--use-v2-block-manager",
  270. action="store_true",
  271. help="use v2 block manager for the Aphrodite backend")
  272. args = parser.parse_args()
  273. if args.backend == "aphrodite":
  274. if args.hf_max_batch_size is not None:
  275. raise ValueError("HF max batch size is only for HF backend.")
  276. elif args.backend == "hf":
  277. if args.hf_max_batch_size is None:
  278. raise ValueError("HF max batch size is required for HF backend.")
  279. if args.quantization is not None:
  280. raise ValueError("Quantization is only for aphrodite backend.")
  281. if args.tokenizer is None:
  282. args.tokenizer = args.model
  283. main(args)