attention.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. import argparse
  2. import random
  3. import time
  4. import torch
  5. from aphrodite import _custom_ops as attention_ops
  6. NUM_BLOCKS = 1024
  7. PARTITION_SIZE = 512
  8. @torch.inference_mode()
  9. def main(
  10. version: str,
  11. num_seqs: int,
  12. context_len: int,
  13. num_query_heads: int,
  14. num_kv_heads: int,
  15. head_size: int,
  16. use_alibi: bool,
  17. block_size: int,
  18. dtype: torch.dtype,
  19. seed: int,
  20. do_profile: bool,
  21. ) -> None:
  22. random.seed(seed)
  23. torch.random.manual_seed(seed)
  24. torch.cuda.manual_seed(seed)
  25. scale = float(1.0 / (head_size**0.5))
  26. query = torch.empty(num_seqs,
  27. num_query_heads,
  28. head_size,
  29. dtype=dtype,
  30. device="cuda")
  31. query.uniform_(-scale, scale)
  32. assert num_query_heads % num_kv_heads == 0
  33. num_queries_per_kv = num_query_heads // num_kv_heads
  34. head_mapping = torch.repeat_interleave(
  35. torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
  36. num_queries_per_kv)
  37. alibi_slopes = None
  38. if use_alibi:
  39. alibi_slopes = torch.randn(num_query_heads,
  40. dtype=torch.float,
  41. device="cuda")
  42. context_lens = [context_len for _ in range(num_seqs)]
  43. max_context_len = max(context_lens)
  44. context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
  45. # Create the block tables.
  46. max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
  47. block_tables = []
  48. for _ in range(num_seqs):
  49. block_table = [
  50. random.randint(0, NUM_BLOCKS - 1)
  51. for _ in range(max_num_blocks_per_seq)
  52. ]
  53. block_tables.append(block_table)
  54. block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
  55. # Create the KV cache.
  56. x = 16 // torch.tensor([], dtype=dtype).element_size()
  57. key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x)
  58. key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device="cuda")
  59. key_cache.uniform_(-scale, scale)
  60. value_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size, block_size)
  61. value_cache = torch.empty(size=value_cache_shape,
  62. dtype=dtype,
  63. device="cuda")
  64. value_cache.uniform_(-scale, scale)
  65. # Prepare for the paged attention kernel.
  66. output = torch.empty_like(query)
  67. if version == "v2":
  68. num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
  69. PARTITION_SIZE)
  70. tmp_output = torch.empty(
  71. size=(num_seqs, num_query_heads, num_partitions, head_size),
  72. dtype=output.dtype,
  73. device=output.device,
  74. )
  75. exp_sums = torch.empty(
  76. size=(num_seqs, num_query_heads, num_partitions),
  77. dtype=torch.float32,
  78. device=output.device,
  79. )
  80. max_logits = torch.empty_like(exp_sums)
  81. def run_benchmark(num_iters: int, profile: bool = False) -> float:
  82. torch.cuda.synchronize()
  83. if profile:
  84. torch.cuda.cudart().cudaProfilerStart()
  85. start_time = time.perf_counter()
  86. for _ in range(num_iters):
  87. if version == "v1":
  88. attention_ops.paged_attention_v1(
  89. output,
  90. query,
  91. key_cache,
  92. value_cache,
  93. head_mapping,
  94. scale,
  95. block_tables,
  96. context_lens,
  97. block_size,
  98. max_context_len,
  99. alibi_slopes,
  100. )
  101. elif version == "v2":
  102. attention_ops.paged_attention_v2(
  103. output,
  104. exp_sums,
  105. max_logits,
  106. tmp_output,
  107. query,
  108. key_cache,
  109. value_cache,
  110. head_mapping,
  111. scale,
  112. block_tables,
  113. context_lens,
  114. block_size,
  115. max_context_len,
  116. alibi_slopes,
  117. )
  118. else:
  119. raise ValueError(f"Invalid version: {version}")
  120. torch.cuda.synchronize()
  121. end_time = time.perf_counter()
  122. if profile:
  123. torch.cuda.cudart().cudaProfilerStart()
  124. return (end_time - start_time) / num_iters
  125. # Warmup.
  126. print("Warming up...")
  127. run_benchmark(num_iters=3, profile=False)
  128. # Benchmark.
  129. if do_profile:
  130. latency = run_benchmark(num_iters=1, profile=True)
  131. else:
  132. latency = run_benchmark(num_iters=100, profile=False)
  133. print(f"Kernel running time: {latency * 1000000:.3f} us")
  134. if __name__ == "__main__":
  135. parser = argparse.ArgumentParser(
  136. description="Benchmark the paged attention kernel.")
  137. parser.add_argument("--version",
  138. type=str,
  139. choices=["v1", "v2"],
  140. default="v2")
  141. parser.add_argument("--batch-size", type=int, default=8)
  142. parser.add_argument("--context-len", type=int, default=4096)
  143. parser.add_argument("--num-query-heads", type=int, default=64)
  144. parser.add_argument("--num-kv-heads", type=int, default=8)
  145. parser.add_argument("--head-size",
  146. type=int,
  147. choices=[64, 80, 96, 112, 128, 256],
  148. default=128)
  149. parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
  150. parser.add_argument("--use-alibi", action="store_true")
  151. parser.add_argument("--dtype",
  152. type=str,
  153. choices=["half", "bfloat16", "float"],
  154. default="half")
  155. parser.add_argument("--seed", type=int, default=0)
  156. parser.add_argument("--profile", action="store_true")
  157. args = parser.parse_args()
  158. print(args)
  159. if args.num_query_heads % args.num_kv_heads != 0:
  160. raise ValueError("num_query_heads must be divisible by num_kv_heads")
  161. dtype_to_torch_dtype = {
  162. "half": torch.half,
  163. "bfloat16": torch.bfloat16,
  164. "float": torch.float,
  165. }
  166. main(
  167. version=args.version,
  168. num_seqs=args.batch_size,
  169. context_len=args.context_len,
  170. num_query_heads=args.num_query_heads,
  171. num_kv_heads=args.num_kv_heads,
  172. head_size=args.head_size,
  173. block_size=args.block_size,
  174. use_alibi=args.use_alibi,
  175. dtype=dtype_to_torch_dtype[args.dtype],
  176. seed=args.seed,
  177. do_profile=args.profile,
  178. )