Browse Source

refactor throughput benchmark script

AlpinDale 7 months ago
parent
commit
033797fd55
1 changed files with 187 additions and 105 deletions
  1. 187 105
      tests/benchmarks/throughput.py

+ 187 - 105
tests/benchmarks/throughput.py

@@ -6,18 +6,22 @@ import time
 from typing import List, Optional, Tuple
 
 import torch
-from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
 from tqdm import tqdm
+from transformers import (AutoModelForCausalLM, AutoTokenizer,
+                          PreTrainedTokenizerBase)
 
-from aphrodite import LLM, SamplingParams
-from aphrodite.transformers_utils.tokenizer import get_tokenizer
+from aphrodite.quantization import QUANTIZATION_METHODS
 
 
 def sample_requests(
     dataset_path: str,
     num_requests: int,
     tokenizer: PreTrainedTokenizerBase,
+    fixed_output_len: Optional[int],
 ) -> List[Tuple[str, int, int]]:
+    if fixed_output_len is not None and fixed_output_len < 4:
+        raise ValueError("output_len too small")
+
     # Load the dataset.
     with open(dataset_path) as f:
         dataset = json.load(f)
@@ -27,20 +31,23 @@ def sample_requests(
     dataset = [(data["conversations"][0]["value"],
                 data["conversations"][1]["value"]) for data in dataset]
 
-    # Tokenize the prompts and completions.
-    prompts = [prompt for prompt, _ in dataset]
-    prompt_token_ids = tokenizer(prompts).input_ids
-    completions = [completion for _, completion in dataset]
-    completion_token_ids = tokenizer(completions).input_ids
-    tokenized_dataset = []
-    for i in range(len(dataset)):
-        output_len = len(completion_token_ids[i])
-        tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
+    # Shuffle the dataset.
+    random.shuffle(dataset)
 
-    # Filter out too long sequences.
+    # Filter out sequences that are too long or too short
     filtered_dataset: List[Tuple[str, int, int]] = []
-    for prompt, prompt_token_ids, output_len in tokenized_dataset:
+    for i in range(len(dataset)):
+        if len(filtered_dataset) == num_requests:
+            break
+
+        # Tokenize the prompts and completions.
+        prompt = dataset[i][0]
+        prompt_token_ids = tokenizer(prompt).input_ids
+        completion = dataset[i][1]
+        completion_token_ids = tokenizer(completion).input_ids
         prompt_len = len(prompt_token_ids)
+        output_len = len(completion_token_ids
+                         ) if fixed_output_len is None else fixed_output_len
         if prompt_len < 4 or output_len < 4:
             # Prune too short sequences.
             continue
@@ -49,9 +56,7 @@ def sample_requests(
             continue
         filtered_dataset.append((prompt, prompt_len, output_len))
 
-    # Sample the requests.
-    sampled_requests = random.sample(filtered_dataset, num_requests)
-    return sampled_requests
+    return filtered_dataset
 
 
 def run_aphrodite(
@@ -65,16 +70,18 @@ def run_aphrodite(
     use_beam_search: bool,
     trust_remote_code: bool,
     dtype: str,
+    max_model_len: Optional[int],
+    enforce_eager: bool,
     kv_cache_dtype: str,
-    disable_custom_all_reduce: bool,
+    quantization_param_path: Optional[str],
+    device: str,
     enable_prefix_caching: bool,
-    enforce_eager: bool,
     enable_chunked_prefill: bool,
     max_num_batched_tokens: int,
-    speculative_model: Optional[str] = None,
-    num_speculative_tokens: Optional[int] = None,
-    use_v2_block_manager: bool = False,
+    gpu_memory_utilization: float = 0.9,
+    download_dir: Optional[str] = None,
 ) -> float:
+    from aphrodite import LLM, SamplingParams
     llm = LLM(
         model=model,
         tokenizer=tokenizer,
@@ -83,37 +90,35 @@ def run_aphrodite(
         seed=seed,
         trust_remote_code=trust_remote_code,
         dtype=dtype,
+        max_model_len=max_model_len,
+        gpu_memory_utilization=gpu_memory_utilization,
+        enforce_eager=enforce_eager,
         kv_cache_dtype=kv_cache_dtype,
-        disable_custom_all_reduce=disable_custom_all_reduce,
+        quantization_param_path=quantization_param_path,
+        device=device,
         enable_prefix_caching=enable_prefix_caching,
-        enforce_eager=enforce_eager,
+        download_dir=download_dir,
         enable_chunked_prefill=enable_chunked_prefill,
         max_num_batched_tokens=max_num_batched_tokens,
-        speculative_model=speculative_model,
-        num_speculative_tokens=num_speculative_tokens,
-        use_v2_block_manager=use_v2_block_manager,
     )
 
     # Add the requests to the engine.
+    prompts = []
+    sampling_params = []
     for prompt, _, output_len in requests:
-        sampling_params = SamplingParams(
-            n=n,
-            temperature=0.0 if use_beam_search else 1.0,
-            top_p=1.0,
-            use_beam_search=use_beam_search,
-            ignore_eos=True,
-            max_tokens=output_len,
-        )
-        # FIXME: Do not use internal method.
-        llm._add_request(  # pylint: disable=protected-access
-            prompt=prompt,
-            prompt_token_ids=None,
-            params=sampling_params,
-        )
+        prompts.append(prompt)
+        sampling_params.append(
+            SamplingParams(
+                n=n,
+                temperature=0.0 if use_beam_search else 1.0,
+                top_p=1.0,
+                use_beam_search=use_beam_search,
+                ignore_eos=True,
+                max_tokens=output_len,
+            ))
 
     start = time.perf_counter()
-    # FIXME Do use internal method.
-    llm._run_engine(use_tqdm=True)  # pylint: disable=protected-access
+    llm.generate(prompts, sampling_params, use_tqdm=True)
     end = time.perf_counter()
     return end - start
 
@@ -178,28 +183,58 @@ def run_hf(
     return end - start
 
 
-def main(args: argparse.Namespace):  # pylint: disable=redefined-outer-name
+def run_mii(
+    requests: List[Tuple[str, int, int]],
+    model: str,
+    tensor_parallel_size: int,
+    output_len: int,
+) -> float:
+    from mii import client, serve
+    llm = serve(model, tensor_parallel=tensor_parallel_size)
+    prompts = [prompt for prompt, _, _ in requests]
+
+    start = time.perf_counter()
+    llm.generate(prompts, max_new_tokens=output_len)
+    end = time.perf_counter()
+    client = client(model)
+    client.terminate_server()
+    return end - start
+
+
+def main(args: argparse.Namespace):
     print(args)
     random.seed(args.seed)
 
     # Sample the requests.
-    tokenizer = get_tokenizer(args.tokenizer,
-                              trust_remote_code=args.trust_remote_code)
-    requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
+    tokenizer = AutoTokenizer.from_pretrained(
+        args.tokenizer, trust_remote_code=args.trust_remote_code)
+    if args.dataset is None:
+        # Synthesize a prompt with the given input length.
+        prompt = "hi" * (args.input_len - 1)
+        requests = [(prompt, args.input_len, args.output_len)
+                    for _ in range(args.num_prompts)]
+    else:
+        requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
+                                   args.output_len)
 
     if args.backend == "aphrodite":
         elapsed_time = run_aphrodite(
             requests, args.model, args.tokenizer, args.quantization,
             args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
-            args.trust_remote_code, args.dtype, args.kv_cache_dtype,
-            args.disable_custom_all_reduce, args.enable_prefix_caching,
-            args.enforce_eager, args.enable_chunked_prefill,
-            args.max_num_batched_tokens)
+            args.trust_remote_code, args.dtype, args.max_model_len,
+            args.enforce_eager, args.kv_cache_dtype,
+            args.quantization_param_path, args.device,
+            args.enable_prefix_caching, args.enable_chunked_prefill,
+            args.max_num_batched_tokens, args.gpu_memory_utilization,
+            args.download_dir)
     elif args.backend == "hf":
         assert args.tensor_parallel_size == 1
         elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
                               args.use_beam_search, args.hf_max_batch_size,
                               args.trust_remote_code)
+    elif args.backend == "mii":
+        elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
+                               args.output_len)
     else:
         raise ValueError(f"Unknown backend: {args.backend}")
     total_input_tokens = sum(prompt_len for _, prompt_len, _ in requests)
@@ -214,22 +249,27 @@ if __name__ == "__main__":
     parser = argparse.ArgumentParser(description="Benchmark the throughput.")
     parser.add_argument("--backend",
                         type=str,
-                        choices=["aphrodite", "hf"],
+                        choices=["aphrodite", "hf", "mii"],
                         default="aphrodite")
     parser.add_argument("--dataset",
                         type=str,
-                        required=True,
+                        default=None,
                         help="Path to the dataset.")
-    parser.add_argument("--model",
-                        type=str,
-                        default="EleutherAI/pythia-70m-deduped")
+    parser.add_argument("--input-len",
+                        type=int,
+                        default=None,
+                        help="Input prompt length for each request")
+    parser.add_argument("--output-len",
+                        type=int,
+                        default=None,
+                        help="Output length for each request. Overrides the "
+                        "output length from the dataset.")
+    parser.add_argument("--model", type=str, default="facebook/opt-125m")
     parser.add_argument("--tokenizer", type=str, default=None)
-    parser.add_argument(
-        "--quantization",
-        "-q",
-        choices=["awq", "gguf", "bnb", "gptq", "squeezellm", "marlin", None],
-        default=None)
-    parser.add_argument("--gpu-memory-utilization", type=float, default=0.88)
+    parser.add_argument('--quantization',
+                        '-q',
+                        choices=[*QUANTIZATION_METHODS, None],
+                        default=None)
     parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
     parser.add_argument("--n",
                         type=int,
@@ -245,53 +285,84 @@ if __name__ == "__main__":
                         type=int,
                         default=None,
                         help="Maximum batch size for HF backend.")
-    parser.add_argument("--trust-remote-code",
-                        action="store_true",
-                        help="trust remote code from huggingface")
+    parser.add_argument('--trust-remote-code',
+                        action='store_true',
+                        help='trust remote code from huggingface')
     parser.add_argument(
-        "--dtype",
-        type=str,
-        default="auto",
-        choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
-        help="data type for model weights and activations. "
-        "The 'auto' option will use FP16 precision "
-        "for FP32 and FP16 models, and BF16 precision "
-        "for BF16 models.")
-    parser.add_argument("--kv-cache-dtype",
-                        type=str,
-                        default="auto",
-                        choices=["auto", "fp8_e5m2"],
-                        help="The Data Type for the KV cache.")
+        '--max-model-len',
+        type=int,
+        default=None,
+        help='Maximum length of a sequence (including prompt and output). '
+        'If None, will be derived from the model.')
     parser.add_argument(
-        "--disable-custom-all-reduce",
-        action="store_true",
-        help="disable custom all reduce for the Aphrodite backend")
-    parser.add_argument("--enable-prefix-caching",
-                        action="store_true",
-                        help="enable prefix caching for the Aphrodite backend")
+        '--dtype',
+        type=str,
+        default='auto',
+        choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
+        help='data type for model weights and activations. '
+        'The "auto" option will use FP16 precision '
+        'for FP32 and FP16 models, and BF16 precision '
+        'for BF16 models.')
+    parser.add_argument('--gpu-memory-utilization',
+                        type=float,
+                        default=0.9,
+                        help='the fraction of GPU memory to be used for '
+                        'the model executor, which can range from 0 to 1.'
+                        'If unspecified, will use the default value of 0.9.')
     parser.add_argument("--enforce-eager",
                         type=lambda x: (str(x).lower() == 'true'),
-                        default=True,
-                        help="enforce eager mode for the Aphrodite backend")
+                        help="enforce eager execution")
+    parser.add_argument(
+        "--kv-cache-dtype",
+        type=str,
+        choices=["auto", "fp8"],
+        default="auto",
+        help=
+        'Data type for kv cache storage. If "auto", will use model data type. '
+        'FP8_E5M2 (without scaling) is only supported on cuda version greater '
+        'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
+        'common inference criteria.')
+    parser.add_argument(
+        '--quantization-param-path',
+        type=str,
+        default=None,
+        help='Path to the JSON file containing the KV cache scaling factors. '
+        'This should generally be supplied, when KV cache dtype is FP8. '
+        'Otherwise, KV cache scaling factors default to 1.0, which may cause '
+        'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
+        'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
+        'instead supported for common inference criteria.')
+    parser.add_argument(
+        "--device",
+        type=str,
+        default="cuda",
+        choices=["cuda", "cpu"],
+        help='device type for Aphrodite execution, supporting CUDA and CPU.')
     parser.add_argument(
-        "--enable-chunked-prefill",
-        action="store_true",
-        help="enable chunked prefill for the Aphrodite backend")
-    parser.add_argument("--max-num-batched-tokens",
+        "--enable-prefix-caching",
+        action='store_true',
+        help="enable automatic prefix caching for Aphrodite backend.")
+    parser.add_argument("--enable-chunked-prefill",
+                        action='store_true',
+                        help="enable chunked prefill for Aphrodite backend.")
+    parser.add_argument('--max-num-batched-tokens',
                         type=int,
-                        help="maximum number of batched tokens for the "
-                        "Aphrodite backend")
-    parser.add_argument("--speculative-model",
+                        default=None,
+                        help='maximum number of batched tokens per '
+                        'iteration')
+    parser.add_argument('--download-dir',
                         type=str,
-                        help="speculative model for the Aphrodite backend")
-    parser.add_argument("--num-speculative-tokens",
-                        type=int,
-                        help="number of speculative tokens for the "
-                        "Aphrodite backend")
-    parser.add_argument("--use-v2-block-manager",
-                        action="store_true",
-                        help="use v2 block manager for the Aphrodite backend")
+                        default=None,
+                        help='directory to download and load the weights, '
+                        'default to the default cache dir of huggingface')
     args = parser.parse_args()
+    if args.tokenizer is None:
+        args.tokenizer = args.model
+    if args.dataset is None:
+        assert args.input_len is not None
+        assert args.output_len is not None
+    else:
+        assert args.input_len is None
 
     if args.backend == "aphrodite":
         if args.hf_max_batch_size is not None:
@@ -300,8 +371,19 @@ if __name__ == "__main__":
         if args.hf_max_batch_size is None:
             raise ValueError("HF max batch size is required for HF backend.")
         if args.quantization is not None:
-            raise ValueError("Quantization is only for aphrodite backend.")
-    if args.tokenizer is None:
-        args.tokenizer = args.model
-
+            raise ValueError("Quantization is only for Aphrodite backend.")
+    elif args.backend == "mii":
+        if args.dtype != "auto":
+            raise ValueError("dtype must be auto for MII backend.")
+        if args.n != 1:
+            raise ValueError("n must be 1 for MII backend.")
+        if args.use_beam_search:
+            raise ValueError("Beam search is not supported for MII backend.")
+        if args.quantization is not None:
+            raise ValueError("Quantization is only for Aphrodite backend.")
+        if args.hf_max_batch_size is not None:
+            raise ValueError("HF max batch size is only for HF backend.")
+        if args.tokenizer != args.model:
+            raise ValueError("Tokenizer must be the same as the model for MII "
+                             "backend.")
     main(args)