latency.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. """Benchmark the latency of processing a single batch of requests."""
  2. import argparse
  3. import json
  4. import time
  5. from pathlib import Path
  6. from typing import List, Optional
  7. import numpy as np
  8. import torch
  9. from tqdm import tqdm
  10. from aphrodite import LLM, SamplingParams
  11. from aphrodite.common.utils import FlexibleArgumentParser
  12. from aphrodite.engine.args_tools import DEVICE_OPTIONS, EngineArgs
  13. from aphrodite.inputs import PromptType
  14. from aphrodite.quantization import QUANTIZATION_METHODS
  15. def main(args: argparse.Namespace):
  16. print(args)
  17. # NOTE: If the request cannot be processed in a single batch,
  18. # the engine will automatically process the request in multiple batches.
  19. llm = LLM(
  20. model=args.model,
  21. speculative_model=args.speculative_model,
  22. num_speculative_tokens=args.num_speculative_tokens,
  23. speculative_draft_tensor_parallel_size=\
  24. args.speculative_draft_tensor_parallel_size,
  25. tokenizer=args.tokenizer,
  26. quantization=args.quantization,
  27. tensor_parallel_size=args.tensor_parallel_size,
  28. trust_remote_code=args.trust_remote_code,
  29. dtype=args.dtype,
  30. max_model_len=args.max_model_len,
  31. enforce_eager=args.enforce_eager,
  32. kv_cache_dtype=args.kv_cache_dtype,
  33. quantization_param_path=args.quantization_param_path,
  34. device=args.device,
  35. ray_workers_use_nsight=args.ray_workers_use_nsight,
  36. use_v2_block_manager=args.use_v2_block_manager,
  37. enable_chunked_prefill=args.enable_chunked_prefill,
  38. download_dir=args.download_dir,
  39. block_size=args.block_size,
  40. gpu_memory_utilization=args.gpu_memory_utilization,
  41. load_format=args.load_format,
  42. distributed_executor_backend=args.distributed_executor_backend,
  43. enable_prefix_caching=args.enable_prefix_caching,
  44. )
  45. sampling_params = SamplingParams(
  46. n=args.n,
  47. temperature=0.0 if args.use_beam_search else 1.0,
  48. top_p=1.0,
  49. use_beam_search=args.use_beam_search,
  50. ignore_eos=True,
  51. max_tokens=args.output_len,
  52. )
  53. print(sampling_params)
  54. dummy_prompt_token_ids = np.random.randint(10000,
  55. size=(args.batch_size,
  56. args.input_len))
  57. dummy_prompts: List[PromptType] = [{
  58. "prompt_token_ids": batch
  59. } for batch in dummy_prompt_token_ids.tolist()]
  60. def run_to_completion(profile_dir: Optional[str] = None):
  61. if profile_dir:
  62. with torch.profiler.profile(
  63. activities=[
  64. torch.profiler.ProfilerActivity.CPU,
  65. torch.profiler.ProfilerActivity.CUDA,
  66. ],
  67. on_trace_ready=torch.profiler.tensorboard_trace_handler(
  68. str(profile_dir))) as p:
  69. llm.generate(dummy_prompts,
  70. sampling_params=sampling_params,
  71. use_tqdm=False)
  72. print(p.key_averages())
  73. else:
  74. start_time = time.perf_counter()
  75. llm.generate(dummy_prompts,
  76. sampling_params=sampling_params,
  77. use_tqdm=False)
  78. end_time = time.perf_counter()
  79. latency = end_time - start_time
  80. return latency
  81. print("Warming up...")
  82. for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
  83. run_to_completion(profile_dir=None)
  84. if args.profile:
  85. profile_dir = args.profile_result_dir
  86. if not profile_dir:
  87. profile_dir = Path(
  88. "."
  89. ) / "aphrodite_benchmark_result" / f"latency_result_{time.time()}"
  90. print(f"Profiling (results will be saved to '{profile_dir}')...")
  91. run_to_completion(profile_dir=profile_dir)
  92. return
  93. # Benchmark.
  94. latencies = []
  95. for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
  96. latencies.append(run_to_completion(profile_dir=None))
  97. latencies = np.array(latencies)
  98. percentages = [10, 25, 50, 75, 90, 99]
  99. percentiles = np.percentile(latencies, percentages)
  100. print(f'Avg latency: {np.mean(latencies)} seconds')
  101. for percentage, percentile in zip(percentages, percentiles):
  102. print(f'{percentage}% percentile latency: {percentile} seconds')
  103. # Output JSON results if specified
  104. if args.output_json:
  105. results = {
  106. "avg_latency": np.mean(latencies),
  107. "latencies": latencies.tolist(),
  108. "percentiles": dict(zip(percentages, percentiles.tolist())),
  109. }
  110. with open(args.output_json, "w") as f:
  111. json.dump(results, f, indent=4)
  112. if __name__ == '__main__':
  113. parser = FlexibleArgumentParser(
  114. description='Benchmark the latency of processing a single batch of '
  115. 'requests till completion.')
  116. parser.add_argument('--model', type=str, default='facebook/opt-125m')
  117. parser.add_argument('--speculative-model', type=str, default=None)
  118. parser.add_argument('--num-speculative-tokens', type=int, default=None)
  119. parser.add_argument('--speculative-draft-tensor-parallel-size',
  120. '-spec-draft-tp',
  121. type=int,
  122. default=None)
  123. parser.add_argument('--tokenizer', type=str, default=None)
  124. parser.add_argument('--quantization',
  125. '-q',
  126. choices=[*QUANTIZATION_METHODS, None],
  127. default=None)
  128. parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
  129. parser.add_argument('--input-len', type=int, default=32)
  130. parser.add_argument('--output-len', type=int, default=128)
  131. parser.add_argument('--batch-size', type=int, default=8)
  132. parser.add_argument('--n',
  133. type=int,
  134. default=1,
  135. help='Number of generated sequences per prompt.')
  136. parser.add_argument('--use-beam-search', action='store_true')
  137. parser.add_argument('--num-iters-warmup',
  138. type=int,
  139. default=10,
  140. help='Number of iterations to run for warmup.')
  141. parser.add_argument('--num-iters',
  142. type=int,
  143. default=30,
  144. help='Number of iterations to run.')
  145. parser.add_argument('--trust-remote-code',
  146. action='store_true',
  147. help='trust remote code from huggingface')
  148. parser.add_argument(
  149. '--max-model-len',
  150. type=int,
  151. default=None,
  152. help='Maximum length of a sequence (including prompt and output). '
  153. 'If None, will be derived from the model.')
  154. parser.add_argument(
  155. '--dtype',
  156. type=str,
  157. default='auto',
  158. choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
  159. help='data type for model weights and activations. '
  160. 'The "auto" option will use FP16 precision '
  161. 'for FP32 and FP16 models, and BF16 precision '
  162. 'for BF16 models.')
  163. parser.add_argument('--enforce-eager',
  164. action='store_true',
  165. help='enforce eager mode and disable CUDA graph')
  166. parser.add_argument(
  167. '--kv-cache-dtype',
  168. type=str,
  169. choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
  170. default="auto",
  171. help='Data type for kv cache storage. If "auto", will use model '
  172. 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
  173. 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
  174. parser.add_argument(
  175. '--quantization-param-path',
  176. type=str,
  177. default=None,
  178. help='Path to the JSON file containing the KV cache scaling factors. '
  179. 'This should generally be supplied, when KV cache dtype is FP8. '
  180. 'Otherwise, KV cache scaling factors default to 1.0, which may cause '
  181. 'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
  182. 'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
  183. 'instead supported for common inference criteria.')
  184. parser.add_argument(
  185. '--profile',
  186. action='store_true',
  187. help='profile the generation process of a single batch')
  188. parser.add_argument(
  189. '--profile-result-dir',
  190. type=str,
  191. default=None,
  192. help=('path to save the pytorch profiler output. Can be visualized '
  193. 'with ui.perfetto.dev or Tensorboard.'))
  194. parser.add_argument("--device",
  195. type=str,
  196. default="auto",
  197. choices=DEVICE_OPTIONS,
  198. help='device type for vLLM execution')
  199. parser.add_argument('--block-size',
  200. type=int,
  201. default=16,
  202. help='block size of key/value cache')
  203. parser.add_argument(
  204. '--enable-chunked-prefill',
  205. action='store_true',
  206. help='If True, the prefill requests can be chunked based on the '
  207. 'max_num_batched_tokens')
  208. parser.add_argument("--enable-prefix-caching",
  209. action='store_true',
  210. help="Enable automatic prefix caching")
  211. parser.add_argument('--use-v2-block-manager', action='store_true')
  212. parser.add_argument(
  213. "--ray-workers-use-nsight",
  214. action='store_true',
  215. help="If specified, use nsight to profile ray workers",
  216. )
  217. parser.add_argument('--download-dir',
  218. type=str,
  219. default=None,
  220. help='directory to download and load the weights, '
  221. 'default to the default cache dir of huggingface')
  222. parser.add_argument(
  223. '--output-json',
  224. type=str,
  225. default=None,
  226. help='Path to save the latency results in JSON format.')
  227. parser.add_argument('--gpu-memory-utilization',
  228. type=float,
  229. default=0.9,
  230. help='the fraction of GPU memory to be used for '
  231. 'the model executor, which can range from 0 to 1.'
  232. 'If unspecified, will use the default value of 0.9.')
  233. parser.add_argument(
  234. '--load-format',
  235. type=str,
  236. default=EngineArgs.load_format,
  237. choices=[
  238. 'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer',
  239. 'bitsandbytes'
  240. ],
  241. help='The format of the model weights to load.\n\n'
  242. '* "auto" will try to load the weights in the safetensors format '
  243. 'and fall back to the pytorch bin format if safetensors format '
  244. 'is not available.\n'
  245. '* "pt" will load the weights in the pytorch bin format.\n'
  246. '* "safetensors" will load the weights in the safetensors format.\n'
  247. '* "npcache" will load the weights in pytorch format and store '
  248. 'a numpy cache to speed up the loading.\n'
  249. '* "dummy" will initialize the weights with random values, '
  250. 'which is mainly for profiling.\n'
  251. '* "tensorizer" will load the weights using tensorizer from '
  252. 'CoreWeave. See the Tensorize Aphrodite Model script in the Examples'
  253. 'section for more information.\n'
  254. '* "bitsandbytes" will load the weights using bitsandbytes '
  255. 'quantization.\n')
  256. parser.add_argument(
  257. '--distributed-executor-backend',
  258. choices=['ray', 'mp'],
  259. default=None,
  260. help='Backend to use for distributed serving. When more than 1 GPU '
  261. 'is used, will be automatically set to "ray" if installed '
  262. 'or "mp" (multiprocessing) otherwise.')
  263. args = parser.parse_args()
  264. main(args)