xqa_attn.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import torch
  2. import pytest
  3. from aphrodite._custom_ops import xqa_paged_attention
  4. def print_cuda_info():
  5. print(f"CUDA Device: {torch.cuda.get_device_name()}")
  6. print(f"CUDA Capability: {torch.cuda.get_device_capability()}")
  7. print(f"CUDA Version: {torch.version.cuda}")
  8. def reset_cuda():
  9. torch.cuda.empty_cache()
  10. torch.cuda.synchronize()
  11. torch.cuda.init()
  12. def test_single_config(
  13. batch_size=1,
  14. num_heads=32,
  15. num_kv_heads=4,
  16. head_size=128,
  17. block_size=16,
  18. max_seq_len=128
  19. ):
  20. """Test a single XQA configuration with proper cleanup"""
  21. if not torch.cuda.is_available():
  22. pytest.skip("CUDA is required for XQA paged attention")
  23. print("\nTesting configuration:")
  24. print(f" batch_size: {batch_size}")
  25. print(f" num_heads: {num_heads}")
  26. print(f" num_kv_heads: {num_kv_heads}")
  27. print(f" head_size: {head_size}")
  28. print(f" block_size: {block_size}")
  29. print(f" max_seq_len: {max_seq_len}")
  30. reset_cuda()
  31. print(
  32. f"CUDA memory after reset: "
  33. f"{torch.cuda.memory_allocated() / 1024**2:.2f} MB")
  34. rotary_embedding_dim = head_size // 2
  35. scale = 1.0 / (head_size ** 0.5)
  36. query = torch.randn(batch_size, num_heads, head_size,
  37. dtype=torch.float16, device="cuda")
  38. torch.cuda.synchronize()
  39. num_blocks = (max_seq_len + block_size - 1) // block_size * batch_size
  40. kv_cache = torch.randn(num_blocks, num_kv_heads, block_size, 2, head_size,
  41. dtype=torch.float16, device="cuda")
  42. torch.cuda.synchronize()
  43. max_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
  44. block_tables = torch.arange(num_blocks, dtype=torch.int32, device="cuda")
  45. block_tables = block_tables.reshape(batch_size, max_blocks_per_seq)
  46. torch.cuda.synchronize()
  47. seq_lens = torch.full((batch_size,), max_seq_len,
  48. dtype=torch.int32, device="cuda")
  49. torch.cuda.synchronize()
  50. out = torch.empty_like(query)
  51. torch.cuda.synchronize()
  52. print("\nTensor shapes:")
  53. print(f" query: {query.shape}")
  54. print(f" kv_cache: {kv_cache.shape}")
  55. print(f" block_tables: {block_tables.shape}")
  56. print(f" seq_lens: {seq_lens.shape}")
  57. try:
  58. xqa_paged_attention(
  59. out=out,
  60. query=query,
  61. kv_cache=kv_cache,
  62. num_heads=num_heads,
  63. num_kv_heads=num_kv_heads,
  64. rotary_embedding_dim=rotary_embedding_dim,
  65. scale=scale,
  66. block_tables=block_tables,
  67. seq_lens=seq_lens,
  68. block_size=block_size,
  69. max_seq_len=max_seq_len,
  70. kv_cache_dtype="auto",
  71. k_scale=1.0,
  72. v_scale=1.0,
  73. )
  74. torch.cuda.synchronize()
  75. print("✓ Configuration succeeded")
  76. return True
  77. except Exception as e:
  78. print(f"✗ Configuration failed: {str(e)}")
  79. return False
  80. finally:
  81. reset_cuda()
  82. if __name__ == "__main__":
  83. import os
  84. os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
  85. # Known working config
  86. test_single_config(
  87. batch_size=1,
  88. num_heads=32,
  89. num_kv_heads=4,
  90. head_size=128,
  91. block_size=16,
  92. max_seq_len=128
  93. )