paged_attention.py 7.3 KB


  1. import random
  2. import time
  3. from typing import List, Optional
  4. import torch
  5. from aphrodite import _custom_ops as ops
  6. from aphrodite.common.utils import (STR_DTYPE_TO_TORCH_DTYPE,
  7. FlexibleArgumentParser,
  8. create_kv_caches_with_random)
  9. NUM_BLOCKS = 1024
  10. PARTITION_SIZE = 512
  11. @torch.inference_mode()
  12. def main(
  13. version: str,
  14. num_seqs: int,
  15. seq_len: int,
  16. num_query_heads: int,
  17. num_kv_heads: int,
  18. head_size: int,
  19. use_alibi: bool,
  20. block_size: int,
  21. dtype: torch.dtype,
  22. seed: int,
  23. do_profile: bool,
  24. device: str = "cuda",
  25. kv_cache_dtype: Optional[str] = None,
  26. ) -> None:
  27. random.seed(seed)
  28. torch.random.manual_seed(seed)
  29. if torch.cuda.is_available():
  30. torch.cuda.manual_seed(seed)
  31. scale = float(1.0 / (head_size**0.5))
  32. query = torch.empty(num_seqs,
  33. num_query_heads,
  34. head_size,
  35. dtype=dtype,
  36. device=device)
  37. query.uniform_(-scale, scale)
  38. assert num_query_heads % num_kv_heads == 0
  39. alibi_slopes = None
  40. if use_alibi:
  41. alibi_slopes = torch.randn(num_query_heads,
  42. dtype=torch.float,
  43. device=device)
  44. seq_lens = [seq_len for _ in range(num_seqs)]
  45. max_seq_len = max(seq_lens)
  46. seq_lens = torch.tensor(seq_lens, dtype=torch.int, device=device)
  47. # Create the block tables.
  48. max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
  49. block_tables_lst: List[List[int]] = []
  50. for _ in range(num_seqs):
  51. block_table = [
  52. random.randint(0, NUM_BLOCKS - 1)
  53. for _ in range(max_num_blocks_per_seq)
  54. ]
  55. block_tables_lst.append(block_table)
  56. block_tables = torch.tensor(block_tables_lst,
  57. dtype=torch.int,
  58. device=device)
  59. # Create the KV cache.
  60. key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS,
  61. block_size,
  62. 1,
  63. num_kv_heads,
  64. head_size,
  65. kv_cache_dtype,
  66. dtype,
  67. device=device)
  68. key_cache, value_cache = key_caches[0], value_caches[0]
  69. # Prepare for the paged attention kernel.
  70. output = torch.empty_like(query)
  71. if version == "v2":
  72. num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
  73. tmp_output = torch.empty(
  74. size=(num_seqs, num_query_heads, num_partitions, head_size),
  75. dtype=output.dtype,
  76. device=output.device,
  77. )
  78. exp_sums = torch.empty(
  79. size=(num_seqs, num_query_heads, num_partitions),
  80. dtype=torch.float32,
  81. device=output.device,
  82. )
  83. max_logits = torch.empty_like(exp_sums)
  84. def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
  85. torch.cuda.synchronize()
  86. if profile:
  87. torch.cuda.cudart().cudaProfilerStart()
  88. start_time = time.perf_counter()
  89. # Using default kv_scale
  90. k_scale = v_scale = 1.0
  91. for _ in range(num_iters):
  92. if version == "v1":
  93. ops.paged_attention_v1(
  94. output,
  95. query,
  96. key_cache,
  97. value_cache,
  98. num_kv_heads,
  99. scale,
  100. block_tables,
  101. seq_lens,
  102. block_size,
  103. max_seq_len,
  104. alibi_slopes,
  105. kv_cache_dtype,
  106. k_scale,
  107. v_scale,
  108. )
  109. elif version == "v2":
  110. ops.paged_attention_v2(
  111. output,
  112. exp_sums,
  113. max_logits,
  114. tmp_output,
  115. query,
  116. key_cache,
  117. value_cache,
  118. num_kv_heads,
  119. scale,
  120. block_tables,
  121. seq_lens,
  122. block_size,
  123. max_seq_len,
  124. alibi_slopes,
  125. kv_cache_dtype,
  126. k_scale,
  127. v_scale,
  128. )
  129. else:
  130. raise ValueError(f"Invalid version: {version}")
  131. torch.cuda.synchronize()
  132. end_time = time.perf_counter()
  133. if profile:
  134. torch.cuda.cudart().cudaProfilerStart()
  135. return (end_time - start_time) / num_iters
  136. # Warmup.
  137. print("Warming up...")
  138. run_benchmark = run_cuda_benchmark
  139. run_benchmark(num_iters=3, profile=False)
  140. # Benchmark.
  141. if do_profile:
  142. latency = run_benchmark(num_iters=1, profile=True)
  143. else:
  144. latency = run_benchmark(num_iters=100, profile=False)
  145. print(f"Kernel running time: {latency * 1000000:.3f} us")
  146. if __name__ == '__main__':
  147. parser = FlexibleArgumentParser(
  148. description="Benchmark the paged attention kernel.")
  149. parser.add_argument("--version",
  150. type=str,
  151. choices=["v1", "v2"],
  152. default="v2")
  153. parser.add_argument("--batch-size", type=int, default=8)
  154. parser.add_argument("--seq-len", type=int, default=4096)
  155. parser.add_argument("--num-query-heads", type=int, default=64)
  156. parser.add_argument("--num-kv-heads", type=int, default=8)
  157. parser.add_argument("--head-size",
  158. type=int,
  159. choices=[64, 80, 96, 112, 128, 192, 256],
  160. default=128)
  161. parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
  162. parser.add_argument("--use-alibi", action="store_true")
  163. parser.add_argument("--dtype",
  164. type=str,
  165. choices=["half", "bfloat16", "float"],
  166. default="half")
  167. parser.add_argument("--seed", type=int, default=0)
  168. parser.add_argument("--profile", action="store_true")
  169. parser.add_argument(
  170. "--kv-cache-dtype",
  171. type=str,
  172. choices=["auto", "fp8", "fp8_e5m2", "fp8_e4m3"],
  173. default="auto",
  174. help="Data type for kv cache storage. If 'auto', will use model "
  175. "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
  176. "ROCm (AMD GPU) supports fp8 (=fp8_e4m3)")
  177. args = parser.parse_args()
  178. print(args)
  179. if args.num_query_heads % args.num_kv_heads != 0:
  180. raise ValueError("num_query_heads must be divisible by num_kv_heads")
  181. main(
  182. version=args.version,
  183. num_seqs=args.batch_size,
  184. seq_len=args.seq_len,
  185. num_query_heads=args.num_query_heads,
  186. num_kv_heads=args.num_kv_heads,
  187. head_size=args.head_size,
  188. block_size=args.block_size,
  189. use_alibi=args.use_alibi,
  190. dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
  191. seed=args.seed,
  192. do_profile=args.profile,
  193. kv_cache_dtype=args.kv_cache_dtype,
  194. )