|
@@ -6,18 +6,22 @@ import time
|
|
|
from typing import List, Optional, Tuple
|
|
|
|
|
|
import torch
|
|
|
-from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
|
|
|
from tqdm import tqdm
|
|
|
+from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
|
|
+ PreTrainedTokenizerBase)
|
|
|
|
|
|
-from aphrodite import LLM, SamplingParams
|
|
|
-from aphrodite.transformers_utils.tokenizer import get_tokenizer
|
|
|
+from aphrodite.quantization import QUANTIZATION_METHODS
|
|
|
|
|
|
|
|
|
def sample_requests(
|
|
|
dataset_path: str,
|
|
|
num_requests: int,
|
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
|
+ fixed_output_len: Optional[int],
|
|
|
) -> List[Tuple[str, int, int]]:
|
|
|
+ if fixed_output_len is not None and fixed_output_len < 4:
|
|
|
+ raise ValueError("output_len too small")
|
|
|
+
|
|
|
# Load the dataset.
|
|
|
with open(dataset_path) as f:
|
|
|
dataset = json.load(f)
|
|
@@ -27,20 +31,23 @@ def sample_requests(
|
|
|
dataset = [(data["conversations"][0]["value"],
|
|
|
data["conversations"][1]["value"]) for data in dataset]
|
|
|
|
|
|
- # Tokenize the prompts and completions.
|
|
|
- prompts = [prompt for prompt, _ in dataset]
|
|
|
- prompt_token_ids = tokenizer(prompts).input_ids
|
|
|
- completions = [completion for _, completion in dataset]
|
|
|
- completion_token_ids = tokenizer(completions).input_ids
|
|
|
- tokenized_dataset = []
|
|
|
- for i in range(len(dataset)):
|
|
|
- output_len = len(completion_token_ids[i])
|
|
|
- tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
|
|
|
+ # Shuffle the dataset.
|
|
|
+ random.shuffle(dataset)
|
|
|
|
|
|
- # Filter out too long sequences.
|
|
|
+ # Filter out sequences that are too long or too short
|
|
|
filtered_dataset: List[Tuple[str, int, int]] = []
|
|
|
- for prompt, prompt_token_ids, output_len in tokenized_dataset:
|
|
|
+ for i in range(len(dataset)):
|
|
|
+ if len(filtered_dataset) == num_requests:
|
|
|
+ break
|
|
|
+
|
|
|
+ # Tokenize the prompts and completions.
|
|
|
+ prompt = dataset[i][0]
|
|
|
+ prompt_token_ids = tokenizer(prompt).input_ids
|
|
|
+ completion = dataset[i][1]
|
|
|
+ completion_token_ids = tokenizer(completion).input_ids
|
|
|
prompt_len = len(prompt_token_ids)
|
|
|
+ output_len = len(completion_token_ids
|
|
|
+ ) if fixed_output_len is None else fixed_output_len
|
|
|
if prompt_len < 4 or output_len < 4:
|
|
|
# Prune too short sequences.
|
|
|
continue
|
|
@@ -49,9 +56,7 @@ def sample_requests(
|
|
|
continue
|
|
|
filtered_dataset.append((prompt, prompt_len, output_len))
|
|
|
|
|
|
- # Sample the requests.
|
|
|
- sampled_requests = random.sample(filtered_dataset, num_requests)
|
|
|
- return sampled_requests
|
|
|
+ return filtered_dataset
|
|
|
|
|
|
|
|
|
def run_aphrodite(
|
|
@@ -65,16 +70,18 @@ def run_aphrodite(
|
|
|
use_beam_search: bool,
|
|
|
trust_remote_code: bool,
|
|
|
dtype: str,
|
|
|
+ max_model_len: Optional[int],
|
|
|
+ enforce_eager: bool,
|
|
|
kv_cache_dtype: str,
|
|
|
- disable_custom_all_reduce: bool,
|
|
|
+ quantization_param_path: Optional[str],
|
|
|
+ device: str,
|
|
|
enable_prefix_caching: bool,
|
|
|
- enforce_eager: bool,
|
|
|
enable_chunked_prefill: bool,
|
|
|
max_num_batched_tokens: int,
|
|
|
- speculative_model: Optional[str] = None,
|
|
|
- num_speculative_tokens: Optional[int] = None,
|
|
|
- use_v2_block_manager: bool = False,
|
|
|
+ gpu_memory_utilization: float = 0.9,
|
|
|
+ download_dir: Optional[str] = None,
|
|
|
) -> float:
|
|
|
+ from aphrodite import LLM, SamplingParams
|
|
|
llm = LLM(
|
|
|
model=model,
|
|
|
tokenizer=tokenizer,
|
|
@@ -83,37 +90,35 @@ def run_aphrodite(
|
|
|
seed=seed,
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
dtype=dtype,
|
|
|
+ max_model_len=max_model_len,
|
|
|
+ gpu_memory_utilization=gpu_memory_utilization,
|
|
|
+ enforce_eager=enforce_eager,
|
|
|
kv_cache_dtype=kv_cache_dtype,
|
|
|
- disable_custom_all_reduce=disable_custom_all_reduce,
|
|
|
+ quantization_param_path=quantization_param_path,
|
|
|
+ device=device,
|
|
|
enable_prefix_caching=enable_prefix_caching,
|
|
|
- enforce_eager=enforce_eager,
|
|
|
+ download_dir=download_dir,
|
|
|
enable_chunked_prefill=enable_chunked_prefill,
|
|
|
max_num_batched_tokens=max_num_batched_tokens,
|
|
|
- speculative_model=speculative_model,
|
|
|
- num_speculative_tokens=num_speculative_tokens,
|
|
|
- use_v2_block_manager=use_v2_block_manager,
|
|
|
)
|
|
|
|
|
|
# Add the requests to the engine.
|
|
|
+ prompts = []
|
|
|
+ sampling_params = []
|
|
|
for prompt, _, output_len in requests:
|
|
|
- sampling_params = SamplingParams(
|
|
|
- n=n,
|
|
|
- temperature=0.0 if use_beam_search else 1.0,
|
|
|
- top_p=1.0,
|
|
|
- use_beam_search=use_beam_search,
|
|
|
- ignore_eos=True,
|
|
|
- max_tokens=output_len,
|
|
|
- )
|
|
|
- # FIXME: Do not use internal method.
|
|
|
- llm._add_request( # pylint: disable=protected-access
|
|
|
- prompt=prompt,
|
|
|
- prompt_token_ids=None,
|
|
|
- params=sampling_params,
|
|
|
- )
|
|
|
+ prompts.append(prompt)
|
|
|
+ sampling_params.append(
|
|
|
+ SamplingParams(
|
|
|
+ n=n,
|
|
|
+ temperature=0.0 if use_beam_search else 1.0,
|
|
|
+ top_p=1.0,
|
|
|
+ use_beam_search=use_beam_search,
|
|
|
+ ignore_eos=True,
|
|
|
+ max_tokens=output_len,
|
|
|
+ ))
|
|
|
|
|
|
start = time.perf_counter()
|
|
|
- # FIXME Do use internal method.
|
|
|
- llm._run_engine(use_tqdm=True) # pylint: disable=protected-access
|
|
|
+ llm.generate(prompts, sampling_params, use_tqdm=True)
|
|
|
end = time.perf_counter()
|
|
|
return end - start
|
|
|
|
|
@@ -178,28 +183,58 @@ def run_hf(
|
|
|
return end - start
|
|
|
|
|
|
|
|
|
-def main(args: argparse.Namespace): # pylint: disable=redefined-outer-name
|
|
|
+def run_mii(
|
|
|
+ requests: List[Tuple[str, int, int]],
|
|
|
+ model: str,
|
|
|
+ tensor_parallel_size: int,
|
|
|
+ output_len: int,
|
|
|
+) -> float:
|
|
|
+ from mii import client, serve
|
|
|
+ llm = serve(model, tensor_parallel=tensor_parallel_size)
|
|
|
+ prompts = [prompt for prompt, _, _ in requests]
|
|
|
+
|
|
|
+ start = time.perf_counter()
|
|
|
+ llm.generate(prompts, max_new_tokens=output_len)
|
|
|
+ end = time.perf_counter()
|
|
|
+ client = client(model)
|
|
|
+ client.terminate_server()
|
|
|
+ return end - start
|
|
|
+
|
|
|
+
|
|
|
+def main(args: argparse.Namespace):
|
|
|
print(args)
|
|
|
random.seed(args.seed)
|
|
|
|
|
|
# Sample the requests.
|
|
|
- tokenizer = get_tokenizer(args.tokenizer,
|
|
|
- trust_remote_code=args.trust_remote_code)
|
|
|
- requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
|
|
+ tokenizer = AutoTokenizer.from_pretrained(
|
|
|
+ args.tokenizer, trust_remote_code=args.trust_remote_code)
|
|
|
+ if args.dataset is None:
|
|
|
+ # Synthesize a prompt with the given input length.
|
|
|
+ prompt = "hi" * (args.input_len - 1)
|
|
|
+ requests = [(prompt, args.input_len, args.output_len)
|
|
|
+ for _ in range(args.num_prompts)]
|
|
|
+ else:
|
|
|
+ requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
|
|
|
+ args.output_len)
|
|
|
|
|
|
if args.backend == "aphrodite":
|
|
|
elapsed_time = run_aphrodite(
|
|
|
requests, args.model, args.tokenizer, args.quantization,
|
|
|
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
|
|
|
- args.trust_remote_code, args.dtype, args.kv_cache_dtype,
|
|
|
- args.disable_custom_all_reduce, args.enable_prefix_caching,
|
|
|
- args.enforce_eager, args.enable_chunked_prefill,
|
|
|
- args.max_num_batched_tokens)
|
|
|
+ args.trust_remote_code, args.dtype, args.max_model_len,
|
|
|
+ args.enforce_eager, args.kv_cache_dtype,
|
|
|
+ args.quantization_param_path, args.device,
|
|
|
+ args.enable_prefix_caching, args.enable_chunked_prefill,
|
|
|
+ args.max_num_batched_tokens, args.gpu_memory_utilization,
|
|
|
+ args.download_dir)
|
|
|
elif args.backend == "hf":
|
|
|
assert args.tensor_parallel_size == 1
|
|
|
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
|
|
args.use_beam_search, args.hf_max_batch_size,
|
|
|
args.trust_remote_code)
|
|
|
+ elif args.backend == "mii":
|
|
|
+ elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
|
|
|
+ args.output_len)
|
|
|
else:
|
|
|
raise ValueError(f"Unknown backend: {args.backend}")
|
|
|
total_input_tokens = sum(prompt_len for _, prompt_len, _ in requests)
|
|
@@ -214,22 +249,27 @@ if __name__ == "__main__":
|
|
|
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
|
|
|
parser.add_argument("--backend",
|
|
|
type=str,
|
|
|
- choices=["aphrodite", "hf"],
|
|
|
+ choices=["aphrodite", "hf", "mii"],
|
|
|
default="aphrodite")
|
|
|
parser.add_argument("--dataset",
|
|
|
type=str,
|
|
|
- required=True,
|
|
|
+ default=None,
|
|
|
help="Path to the dataset.")
|
|
|
- parser.add_argument("--model",
|
|
|
- type=str,
|
|
|
- default="EleutherAI/pythia-70m-deduped")
|
|
|
+ parser.add_argument("--input-len",
|
|
|
+ type=int,
|
|
|
+ default=None,
|
|
|
+ help="Input prompt length for each request")
|
|
|
+ parser.add_argument("--output-len",
|
|
|
+ type=int,
|
|
|
+ default=None,
|
|
|
+ help="Output length for each request. Overrides the "
|
|
|
+ "output length from the dataset.")
|
|
|
+ parser.add_argument("--model", type=str, default="facebook/opt-125m")
|
|
|
parser.add_argument("--tokenizer", type=str, default=None)
|
|
|
- parser.add_argument(
|
|
|
- "--quantization",
|
|
|
- "-q",
|
|
|
- choices=["awq", "gguf", "bnb", "gptq", "squeezellm", "marlin", None],
|
|
|
- default=None)
|
|
|
- parser.add_argument("--gpu-memory-utilization", type=float, default=0.88)
|
|
|
+ parser.add_argument('--quantization',
|
|
|
+ '-q',
|
|
|
+ choices=[*QUANTIZATION_METHODS, None],
|
|
|
+ default=None)
|
|
|
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
|
|
|
parser.add_argument("--n",
|
|
|
type=int,
|
|
@@ -245,53 +285,84 @@ if __name__ == "__main__":
|
|
|
type=int,
|
|
|
default=None,
|
|
|
help="Maximum batch size for HF backend.")
|
|
|
- parser.add_argument("--trust-remote-code",
|
|
|
- action="store_true",
|
|
|
- help="trust remote code from huggingface")
|
|
|
+ parser.add_argument('--trust-remote-code',
|
|
|
+ action='store_true',
|
|
|
+ help='trust remote code from huggingface')
|
|
|
parser.add_argument(
|
|
|
- "--dtype",
|
|
|
- type=str,
|
|
|
- default="auto",
|
|
|
- choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
|
|
|
- help="data type for model weights and activations. "
|
|
|
- "The 'auto' option will use FP16 precision "
|
|
|
- "for FP32 and FP16 models, and BF16 precision "
|
|
|
- "for BF16 models.")
|
|
|
- parser.add_argument("--kv-cache-dtype",
|
|
|
- type=str,
|
|
|
- default="auto",
|
|
|
- choices=["auto", "fp8_e5m2"],
|
|
|
- help="The Data Type for the KV cache.")
|
|
|
+ '--max-model-len',
|
|
|
+ type=int,
|
|
|
+ default=None,
|
|
|
+ help='Maximum length of a sequence (including prompt and output). '
|
|
|
+ 'If None, will be derived from the model.')
|
|
|
parser.add_argument(
|
|
|
- "--disable-custom-all-reduce",
|
|
|
- action="store_true",
|
|
|
- help="disable custom all reduce for the Aphrodite backend")
|
|
|
- parser.add_argument("--enable-prefix-caching",
|
|
|
- action="store_true",
|
|
|
- help="enable prefix caching for the Aphrodite backend")
|
|
|
+ '--dtype',
|
|
|
+ type=str,
|
|
|
+ default='auto',
|
|
|
+ choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
|
|
|
+ help='data type for model weights and activations. '
|
|
|
+ 'The "auto" option will use FP16 precision '
|
|
|
+ 'for FP32 and FP16 models, and BF16 precision '
|
|
|
+ 'for BF16 models.')
|
|
|
+ parser.add_argument('--gpu-memory-utilization',
|
|
|
+ type=float,
|
|
|
+ default=0.9,
|
|
|
+ help='the fraction of GPU memory to be used for '
|
|
|
+ 'the model executor, which can range from 0 to 1.'
|
|
|
+ 'If unspecified, will use the default value of 0.9.')
|
|
|
parser.add_argument("--enforce-eager",
|
|
|
type=lambda x: (str(x).lower() == 'true'),
|
|
|
- default=True,
|
|
|
- help="enforce eager mode for the Aphrodite backend")
|
|
|
+ help="enforce eager execution")
|
|
|
+ parser.add_argument(
|
|
|
+ "--kv-cache-dtype",
|
|
|
+ type=str,
|
|
|
+ choices=["auto", "fp8"],
|
|
|
+ default="auto",
|
|
|
+ help=
|
|
|
+ 'Data type for kv cache storage. If "auto", will use model data type. '
|
|
|
+ 'FP8_E5M2 (without scaling) is only supported on cuda version greater '
|
|
|
+ 'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
|
|
|
+ 'common inference criteria.')
|
|
|
+ parser.add_argument(
|
|
|
+ '--quantization-param-path',
|
|
|
+ type=str,
|
|
|
+ default=None,
|
|
|
+ help='Path to the JSON file containing the KV cache scaling factors. '
|
|
|
+ 'This should generally be supplied, when KV cache dtype is FP8. '
|
|
|
+ 'Otherwise, KV cache scaling factors default to 1.0, which may cause '
|
|
|
+ 'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
|
|
|
+ 'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
|
|
|
+ 'instead supported for common inference criteria.')
|
|
|
+ parser.add_argument(
|
|
|
+ "--device",
|
|
|
+ type=str,
|
|
|
+ default="cuda",
|
|
|
+ choices=["cuda", "cpu"],
|
|
|
+ help='device type for Aphrodite execution, supporting CUDA and CPU.')
|
|
|
parser.add_argument(
|
|
|
- "--enable-chunked-prefill",
|
|
|
- action="store_true",
|
|
|
- help="enable chunked prefill for the Aphrodite backend")
|
|
|
- parser.add_argument("--max-num-batched-tokens",
|
|
|
+ "--enable-prefix-caching",
|
|
|
+ action='store_true',
|
|
|
+ help="enable automatic prefix caching for Aphrodite backend.")
|
|
|
+ parser.add_argument("--enable-chunked-prefill",
|
|
|
+ action='store_true',
|
|
|
+ help="enable chunked prefill for Aphrodite backend.")
|
|
|
+ parser.add_argument('--max-num-batched-tokens',
|
|
|
type=int,
|
|
|
- help="maximum number of batched tokens for the "
|
|
|
- "Aphrodite backend")
|
|
|
- parser.add_argument("--speculative-model",
|
|
|
+ default=None,
|
|
|
+ help='maximum number of batched tokens per '
|
|
|
+ 'iteration')
|
|
|
+ parser.add_argument('--download-dir',
|
|
|
type=str,
|
|
|
- help="speculative model for the Aphrodite backend")
|
|
|
- parser.add_argument("--num-speculative-tokens",
|
|
|
- type=int,
|
|
|
- help="number of speculative tokens for the "
|
|
|
- "Aphrodite backend")
|
|
|
- parser.add_argument("--use-v2-block-manager",
|
|
|
- action="store_true",
|
|
|
- help="use v2 block manager for the Aphrodite backend")
|
|
|
+ default=None,
|
|
|
+ help='directory to download and load the weights, '
|
|
|
+ 'default to the default cache dir of huggingface')
|
|
|
args = parser.parse_args()
|
|
|
+ if args.tokenizer is None:
|
|
|
+ args.tokenizer = args.model
|
|
|
+ if args.dataset is None:
|
|
|
+ assert args.input_len is not None
|
|
|
+ assert args.output_len is not None
|
|
|
+ else:
|
|
|
+ assert args.input_len is None
|
|
|
|
|
|
if args.backend == "aphrodite":
|
|
|
if args.hf_max_batch_size is not None:
|
|
@@ -300,8 +371,19 @@ if __name__ == "__main__":
|
|
|
if args.hf_max_batch_size is None:
|
|
|
raise ValueError("HF max batch size is required for HF backend.")
|
|
|
if args.quantization is not None:
|
|
|
- raise ValueError("Quantization is only for aphrodite backend.")
|
|
|
- if args.tokenizer is None:
|
|
|
- args.tokenizer = args.model
|
|
|
-
|
|
|
+ raise ValueError("Quantization is only for Aphrodite backend.")
|
|
|
+ elif args.backend == "mii":
|
|
|
+ if args.dtype != "auto":
|
|
|
+ raise ValueError("dtype must be auto for MII backend.")
|
|
|
+ if args.n != 1:
|
|
|
+ raise ValueError("n must be 1 for MII backend.")
|
|
|
+ if args.use_beam_search:
|
|
|
+ raise ValueError("Beam search is not supported for MII backend.")
|
|
|
+ if args.quantization is not None:
|
|
|
+ raise ValueError("Quantization is only for Aphrodite backend.")
|
|
|
+ if args.hf_max_batch_size is not None:
|
|
|
+ raise ValueError("HF max batch size is only for HF backend.")
|
|
|
+ if args.tokenizer != args.model:
|
|
|
+ raise ValueError("Tokenizer must be the same as the model for MII "
|
|
|
+ "backend.")
|
|
|
main(args)
|