test_cache.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. import random
  2. import pytest
  3. import torch
  4. from aphrodite._C import cache_ops
  5. DTYPES = [torch.half, torch.bfloat16, torch.float]
  6. NUM_TOKENS = [83] # Arbitrary values for testing
  7. NUM_LAYERS = [1] # Arbitrary values for testing
  8. NUM_HEADS = [8] # Arbitrary values for testing
  9. HEAD_SIZES = [64, 80, 96, 112, 128, 256]
  10. BLOCK_SIZES = [8, 16, 32]
  11. NUM_BLOCKS = [1024, 36000] # Arbitrary values for testing
  12. NUM_MAPPINGS = [256] # Arbitrary values for testing
  13. SEEDS = [0]
  14. DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
  15. @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
  16. @pytest.mark.parametrize("num_layers", NUM_LAYERS)
  17. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  18. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  19. @pytest.mark.parametrize("block_size", BLOCK_SIZES)
  20. @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
  21. @pytest.mark.parametrize("dtype", DTYPES)
  22. @pytest.mark.parametrize("seed", SEEDS)
  23. @pytest.mark.parametrize("device", DEVICES)
  24. @torch.inference_mode()
  25. def test_copy_blocks(
  26. kv_cache_factory,
  27. num_mappings: int,
  28. num_layers: int,
  29. num_heads: int,
  30. head_size: int,
  31. block_size: int,
  32. num_blocks: int,
  33. dtype: torch.dtype,
  34. seed: int,
  35. device: int,
  36. ) -> None:
  37. random.seed(seed)
  38. torch.random.manual_seed(seed)
  39. torch.cuda.manual_seed(seed)
  40. gpu_id = f"cuda:{device}"
  41. # Generate random block mappings where each source block is mapped to two
  42. # destination blocks.
  43. assert 2 * num_mappings <= num_blocks
  44. src_blocks = random.sample(range(num_blocks), num_mappings)
  45. remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
  46. dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
  47. copy_src = []
  48. copy_dst = []
  49. for i in range(num_mappings):
  50. copy_src.append(src_blocks[i])
  51. copy_dst.append(dst_blocks[2 * i])
  52. copy_src.append(src_blocks[i])
  53. copy_dst.append(dst_blocks[2 * i + 1])
  54. # Create the KV caches.
  55. key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
  56. num_layers, num_heads,
  57. head_size, dtype, seed, gpu_id)
  58. # Clone the KV caches.
  59. cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
  60. cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
  61. # Call the copy blocks kernel.
  62. cache_ops.copy_blocks(key_caches, value_caches, copy_src, copy_dst)
  63. # Run the reference implementation.
  64. for src, dst in zip(copy_src, copy_dst):
  65. for cloned_key_cache in cloned_key_caches:
  66. cloned_key_cache[dst].copy_(cloned_key_cache[src])
  67. for cloned_value_cache in cloned_value_caches:
  68. cloned_value_cache[dst].copy_(cloned_value_cache[src])
  69. # Compare the results.
  70. for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
  71. assert torch.allclose(key_cache, cloned_key_cache)
  72. for value_cache, cloned_value_cache in zip(value_caches,
  73. cloned_value_caches):
  74. assert torch.allclose(value_cache, cloned_value_cache)
  75. @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
  76. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  77. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  78. @pytest.mark.parametrize("block_size", BLOCK_SIZES)
  79. @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
  80. @pytest.mark.parametrize("dtype", DTYPES)
  81. @pytest.mark.parametrize("seed", SEEDS)
  82. @pytest.mark.parametrize("device", DEVICES)
  83. @torch.inference_mode()
  84. def test_reshape_and_cache(
  85. kv_cache_factory,
  86. num_tokens: int,
  87. num_heads: int,
  88. head_size: int,
  89. block_size: int,
  90. num_blocks: int,
  91. dtype: torch.dtype,
  92. seed: int,
  93. device: int,
  94. ) -> None:
  95. random.seed(seed)
  96. torch.random.manual_seed(seed)
  97. torch.cuda.manual_seed(seed)
  98. gpu_id = f"cuda:{device}"
  99. # Create a random slot mapping.
  100. num_slots = block_size * num_blocks
  101. slot_mapping = random.sample(range(num_slots), num_tokens)
  102. slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=gpu_id)
  103. qkv = torch.randn(num_tokens,
  104. 3,
  105. num_heads,
  106. head_size,
  107. dtype=dtype,
  108. device=gpu_id)
  109. _, key, value = qkv.unbind(dim=1)
  110. # Create the KV caches.
  111. key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
  112. num_heads, head_size, dtype,
  113. seed, gpu_id)
  114. key_cache, value_cache = key_caches[0], value_caches[0]
  115. # Clone the KV caches.
  116. cloned_key_cache = key_cache.clone()
  117. cloned_value_cache = value_cache.clone()
  118. # Call the reshape_and_cache kernel.
  119. cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
  120. slot_mapping)
  121. # Run the reference implementation.
  122. reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
  123. block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
  124. block_indicies = block_indicies.cpu().tolist()
  125. block_offsets = slot_mapping % block_size
  126. block_offsets = block_offsets.cpu().tolist()
  127. for i in range(num_tokens):
  128. block_idx = block_indicies[i]
  129. block_offset = block_offsets[i]
  130. cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
  131. cloned_value_cache[block_idx, :, :, block_offset] = value[i]
  132. assert torch.allclose(key_cache, cloned_key_cache)
  133. assert torch.allclose(value_cache, cloned_value_cache)