123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122 |
- from itertools import accumulate
- from typing import List, Optional
- import nvtx
- import torch
- from aphrodite.common.utils import FlexibleArgumentParser
- from aphrodite.modeling.layers.rotary_embedding import (RotaryEmbedding,
- get_rope)
- def benchmark_rope_kernels_multi_lora(
- is_neox_style: bool,
- batch_size: int,
- seq_len: int,
- num_heads: int,
- head_size: int,
- rotary_dim: Optional[int],
- dtype: torch.dtype,
- seed: int,
- device: str,
- max_position: int = 8192,
- base: int = 10000,
- ) -> None:
- torch.random.manual_seed(seed)
- if torch.cuda.is_available():
- torch.cuda.manual_seed(seed)
- torch.set_default_device(device)
- if rotary_dim is None:
- rotary_dim = head_size
- # silulating serving 4 LoRAs
- scaling_factors = [1, 2, 4, 8]
- # batched RoPE can take multiple scaling factors
- batched_rope = get_rope(head_size, rotary_dim, max_position, base,
- is_neox_style, {
- "type": "linear",
- "factor": tuple(scaling_factors)
- })
- # non-batched RoPE takes only one scaling factor, we create multiple
- # instances to simulate the same behavior
- non_batched_ropes: List[RotaryEmbedding] = []
- for scaling_factor in scaling_factors:
- non_batched_ropes.append(
- get_rope(head_size, rotary_dim, max_position, base, is_neox_style,
- {
- "type": "linear",
- "factor": (scaling_factor, )
- }))
- positions = torch.randint(0, max_position, (batch_size, seq_len))
- query = torch.randn(batch_size,
- seq_len,
- num_heads * head_size,
- dtype=dtype)
- key = torch.randn_like(query)
- # create query offsets for batched RoPE, we concat multiple kv cache
- # together and each query needs to find the right kv cache of its type
- offset_map = torch.tensor(
- list(
- accumulate([0] + [
- max_position * scaling_factor * 2
- for scaling_factor in scaling_factors[:-1]
- ])))
- query_types = torch.randint(0,
- len(scaling_factors), (batch_size, seq_len),
- device=device)
- # map query types to offsets
- query_offsets = offset_map[query_types]
- # the kernel takes flattened offsets
- flatten_offsets = query_offsets.flatten()
- # batched queries of the same type together for non-batched RoPE
- queries = [query[query_types == i] for i in range(len(scaling_factors))]
- keys = [key[query_types == i] for i in range(len(scaling_factors))]
- packed_qkr = zip(queries, keys, non_batched_ropes)
- # synchronize before start timing
- torch.cuda.synchronize()
- with nvtx.annotate("non-batched", color="yellow"):
- for q, k, r in packed_qkr:
- r.forward(positions, q, k)
- torch.cuda.synchronize()
- with nvtx.annotate("batched", color="green"):
- batched_rope.forward(positions, query, key, flatten_offsets)
- torch.cuda.synchronize()
- if __name__ == '__main__':
- parser = FlexibleArgumentParser(
- description="Benchmark the rotary embedding kernels.")
- parser.add_argument("--is-neox-style", type=bool, default=True)
- parser.add_argument("--batch-size", type=int, default=16)
- parser.add_argument("--seq-len", type=int, default=512)
- parser.add_argument("--num-heads", type=int, default=8)
- parser.add_argument("--head-size",
- type=int,
- choices=[64, 80, 96, 112, 128, 192, 256],
- default=128)
- parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
- parser.add_argument("--dtype",
- type=str,
- choices=["bfloat16", "float"],
- default="float")
- parser.add_argument("--seed", type=int, default=0)
- parser.add_argument("--device",
- type=str,
- choices=["cuda:0", "cuda:1"],
- default="cuda:0")
- args = parser.parse_args()
- print(args)
- benchmark_rope_kernels_multi_lora(
- is_neox_style=args.is_neox_style,
- batch_size=args.batch_size,
- seq_len=args.seq_len,
- num_heads=args.num_heads,
- head_size=args.head_size,
- rotary_dim=args.rotary_dim,
- dtype=getattr(torch, args.dtype),
- seed=args.seed,
- device=args.device,
- )
|