rope.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. from itertools import accumulate
  2. from typing import List, Optional
  3. import nvtx
  4. import torch
  5. from aphrodite.common.utils import FlexibleArgumentParser
  6. from aphrodite.modeling.layers.rotary_embedding import (RotaryEmbedding,
  7. get_rope)
  8. def benchmark_rope_kernels_multi_lora(
  9. is_neox_style: bool,
  10. batch_size: int,
  11. seq_len: int,
  12. num_heads: int,
  13. head_size: int,
  14. rotary_dim: Optional[int],
  15. dtype: torch.dtype,
  16. seed: int,
  17. device: str,
  18. max_position: int = 8192,
  19. base: int = 10000,
  20. ) -> None:
  21. torch.random.manual_seed(seed)
  22. if torch.cuda.is_available():
  23. torch.cuda.manual_seed(seed)
  24. torch.set_default_device(device)
  25. if rotary_dim is None:
  26. rotary_dim = head_size
  27. # silulating serving 4 LoRAs
  28. scaling_factors = [1, 2, 4, 8]
  29. # batched RoPE can take multiple scaling factors
  30. batched_rope = get_rope(head_size, rotary_dim, max_position, base,
  31. is_neox_style, {
  32. "type": "linear",
  33. "factor": tuple(scaling_factors)
  34. })
  35. # non-batched RoPE takes only one scaling factor, we create multiple
  36. # instances to simulate the same behavior
  37. non_batched_ropes: List[RotaryEmbedding] = []
  38. for scaling_factor in scaling_factors:
  39. non_batched_ropes.append(
  40. get_rope(head_size, rotary_dim, max_position, base, is_neox_style,
  41. {
  42. "type": "linear",
  43. "factor": (scaling_factor, )
  44. }))
  45. positions = torch.randint(0, max_position, (batch_size, seq_len))
  46. query = torch.randn(batch_size,
  47. seq_len,
  48. num_heads * head_size,
  49. dtype=dtype)
  50. key = torch.randn_like(query)
  51. # create query offsets for batched RoPE, we concat multiple kv cache
  52. # together and each query needs to find the right kv cache of its type
  53. offset_map = torch.tensor(
  54. list(
  55. accumulate([0] + [
  56. max_position * scaling_factor * 2
  57. for scaling_factor in scaling_factors[:-1]
  58. ])))
  59. query_types = torch.randint(0,
  60. len(scaling_factors), (batch_size, seq_len),
  61. device=device)
  62. # map query types to offsets
  63. query_offsets = offset_map[query_types]
  64. # the kernel takes flattened offsets
  65. flatten_offsets = query_offsets.flatten()
  66. # batched queries of the same type together for non-batched RoPE
  67. queries = [query[query_types == i] for i in range(len(scaling_factors))]
  68. keys = [key[query_types == i] for i in range(len(scaling_factors))]
  69. packed_qkr = zip(queries, keys, non_batched_ropes)
  70. # synchronize before start timing
  71. torch.cuda.synchronize()
  72. with nvtx.annotate("non-batched", color="yellow"):
  73. for q, k, r in packed_qkr:
  74. r.forward(positions, q, k)
  75. torch.cuda.synchronize()
  76. with nvtx.annotate("batched", color="green"):
  77. batched_rope.forward(positions, query, key, flatten_offsets)
  78. torch.cuda.synchronize()
  79. if __name__ == '__main__':
  80. parser = FlexibleArgumentParser(
  81. description="Benchmark the rotary embedding kernels.")
  82. parser.add_argument("--is-neox-style", type=bool, default=True)
  83. parser.add_argument("--batch-size", type=int, default=16)
  84. parser.add_argument("--seq-len", type=int, default=512)
  85. parser.add_argument("--num-heads", type=int, default=8)
  86. parser.add_argument("--head-size",
  87. type=int,
  88. choices=[64, 80, 96, 112, 128, 192, 256],
  89. default=128)
  90. parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
  91. parser.add_argument("--dtype",
  92. type=str,
  93. choices=["bfloat16", "float"],
  94. default="float")
  95. parser.add_argument("--seed", type=int, default=0)
  96. parser.add_argument("--device",
  97. type=str,
  98. choices=["cuda:0", "cuda:1"],
  99. default="cuda:0")
  100. args = parser.parse_args()
  101. print(args)
  102. benchmark_rope_kernels_multi_lora(
  103. is_neox_style=args.is_neox_style,
  104. batch_size=args.batch_size,
  105. seq_len=args.seq_len,
  106. num_heads=args.num_heads,
  107. head_size=args.head_size,
  108. rotary_dim=args.rotary_dim,
  109. dtype=getattr(torch, args.dtype),
  110. seed=args.seed,
  111. device=args.device,
  112. )