123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197 |
- import argparse
- import random
- import time
- import torch
- from aphrodite import _custom_ops as attention_ops
- NUM_BLOCKS = 1024
- PARTITION_SIZE = 512
- @torch.inference_mode()
- def main(
- version: str,
- num_seqs: int,
- context_len: int,
- num_query_heads: int,
- num_kv_heads: int,
- head_size: int,
- use_alibi: bool,
- block_size: int,
- dtype: torch.dtype,
- seed: int,
- do_profile: bool,
- ) -> None:
- random.seed(seed)
- torch.random.manual_seed(seed)
- torch.cuda.manual_seed(seed)
- scale = float(1.0 / (head_size**0.5))
- query = torch.empty(num_seqs,
- num_query_heads,
- head_size,
- dtype=dtype,
- device="cuda")
- query.uniform_(-scale, scale)
- assert num_query_heads % num_kv_heads == 0
- num_queries_per_kv = num_query_heads // num_kv_heads
- head_mapping = torch.repeat_interleave(
- torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
- num_queries_per_kv)
- alibi_slopes = None
- if use_alibi:
- alibi_slopes = torch.randn(num_query_heads,
- dtype=torch.float,
- device="cuda")
- context_lens = [context_len for _ in range(num_seqs)]
- max_context_len = max(context_lens)
- context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
- # Create the block tables.
- max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
- block_tables = []
- for _ in range(num_seqs):
- block_table = [
- random.randint(0, NUM_BLOCKS - 1)
- for _ in range(max_num_blocks_per_seq)
- ]
- block_tables.append(block_table)
- block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
- # Create the KV cache.
- x = 16 // torch.tensor([], dtype=dtype).element_size()
- key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x)
- key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device="cuda")
- key_cache.uniform_(-scale, scale)
- value_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size, block_size)
- value_cache = torch.empty(size=value_cache_shape,
- dtype=dtype,
- device="cuda")
- value_cache.uniform_(-scale, scale)
- # Prepare for the paged attention kernel.
- output = torch.empty_like(query)
- if version == "v2":
- num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
- PARTITION_SIZE)
- tmp_output = torch.empty(
- size=(num_seqs, num_query_heads, num_partitions, head_size),
- dtype=output.dtype,
- device=output.device,
- )
- exp_sums = torch.empty(
- size=(num_seqs, num_query_heads, num_partitions),
- dtype=torch.float32,
- device=output.device,
- )
- max_logits = torch.empty_like(exp_sums)
- def run_benchmark(num_iters: int, profile: bool = False) -> float:
- torch.cuda.synchronize()
- if profile:
- torch.cuda.cudart().cudaProfilerStart()
- start_time = time.perf_counter()
- for _ in range(num_iters):
- if version == "v1":
- attention_ops.paged_attention_v1(
- output,
- query,
- key_cache,
- value_cache,
- head_mapping,
- scale,
- block_tables,
- context_lens,
- block_size,
- max_context_len,
- alibi_slopes,
- )
- elif version == "v2":
- attention_ops.paged_attention_v2(
- output,
- exp_sums,
- max_logits,
- tmp_output,
- query,
- key_cache,
- value_cache,
- head_mapping,
- scale,
- block_tables,
- context_lens,
- block_size,
- max_context_len,
- alibi_slopes,
- )
- else:
- raise ValueError(f"Invalid version: {version}")
- torch.cuda.synchronize()
- end_time = time.perf_counter()
- if profile:
- torch.cuda.cudart().cudaProfilerStart()
- return (end_time - start_time) / num_iters
- # Warmup.
- print("Warming up...")
- run_benchmark(num_iters=3, profile=False)
- # Benchmark.
- if do_profile:
- latency = run_benchmark(num_iters=1, profile=True)
- else:
- latency = run_benchmark(num_iters=100, profile=False)
- print(f"Kernel running time: {latency * 1000000:.3f} us")
- if __name__ == "__main__":
- parser = argparse.ArgumentParser(
- description="Benchmark the paged attention kernel.")
- parser.add_argument("--version",
- type=str,
- choices=["v1", "v2"],
- default="v2")
- parser.add_argument("--batch-size", type=int, default=8)
- parser.add_argument("--context-len", type=int, default=4096)
- parser.add_argument("--num-query-heads", type=int, default=64)
- parser.add_argument("--num-kv-heads", type=int, default=8)
- parser.add_argument("--head-size",
- type=int,
- choices=[64, 80, 96, 112, 128, 256],
- default=128)
- parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
- parser.add_argument("--use-alibi", action="store_true")
- parser.add_argument("--dtype",
- type=str,
- choices=["half", "bfloat16", "float"],
- default="half")
- parser.add_argument("--seed", type=int, default=0)
- parser.add_argument("--profile", action="store_true")
- args = parser.parse_args()
- print(args)
- if args.num_query_heads % args.num_kv_heads != 0:
- raise ValueError("num_query_heads must be divisible by num_kv_heads")
- dtype_to_torch_dtype = {
- "half": torch.half,
- "bfloat16": torch.bfloat16,
- "float": torch.float,
- }
- main(
- version=args.version,
- num_seqs=args.batch_size,
- context_len=args.context_len,
- num_query_heads=args.num_query_heads,
- num_kv_heads=args.num_kv_heads,
- head_size=args.head_size,
- block_size=args.block_size,
- use_alibi=args.use_alibi,
- dtype=dtype_to_torch_dtype[args.dtype],
- seed=args.seed,
- do_profile=args.profile,
- )
|