test_prefix_caching.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. """Compare the with and without prefix caching.
  2. Run `pytest tests/prefix_caching/test_prefix_caching.py`.
  3. """
  4. from typing import List
  5. import pytest
  6. from aphrodite.common.block import PhysicalTokenBlock
  7. from aphrodite.common.utils import Device
  8. from aphrodite.processing.block_manager_v1 import CachedBlockAllocator
  9. from tests.kernels.utils import override_backend_env_variable
  10. from ..models.utils import check_outputs_equal
  11. MODELS = [
  12. "facebook/opt-125m",
  13. ]
  14. @pytest.mark.parametrize("block_size", [16])
  15. @pytest.mark.parametrize("num_blocks", [16])
  16. def test_block_allocator(
  17. block_size: int,
  18. num_blocks: int,
  19. ):
  20. block_hash = 1
  21. block_allocator = CachedBlockAllocator(Device.CPU, block_size, num_blocks)
  22. # Allocate two PysicalTokenBlocks with the same hash and check
  23. # that they are the same PhysicalTokenBlock
  24. first_block = block_allocator.allocate(block_hash, 0)
  25. second_block = block_allocator.allocate(block_hash, 0)
  26. assert (first_block == second_block)
  27. assert (second_block.ref_count == 2)
  28. # Check metric: 1 hit of 2 queries
  29. assert block_allocator.get_prefix_cache_hit_rate() == 0.5
  30. # Free the first_block and confirm that the ref_count is correctly
  31. # decremented on the second block
  32. block_allocator.free(first_block)
  33. assert (second_block.ref_count == 1)
  34. # Free the second block
  35. block_allocator.free(second_block)
  36. # Reallocate the first block and confirm that, even after the block
  37. # had its ref_count go to 0, we still get the same block back
  38. first_block = block_allocator.allocate(block_hash, 0)
  39. assert (first_block == second_block)
  40. assert (first_block.block_hash == block_hash)
  41. # Allocate one more time to get 3/4 hit rate for easy checking
  42. block_allocator.allocate(block_hash, 0)
  43. assert block_allocator.get_prefix_cache_hit_rate() == 0.75
  44. @pytest.mark.parametrize("num_blocks", [16])
  45. def test_eviction(num_blocks: int, ):
  46. block_size = 16
  47. block_allocator = CachedBlockAllocator(Device.CPU, block_size, num_blocks)
  48. blocks: List[PhysicalTokenBlock] = []
  49. for i in range(num_blocks):
  50. # use i as the block_hash
  51. blocks.append(block_allocator.allocate(i, 0))
  52. #Free all blocks
  53. for block in blocks:
  54. block_allocator.free(block)
  55. # Allocate a new block and confirm that it's the first block freed.
  56. # I.E The Least Recently Used block
  57. new_block_hash = block_size
  58. new_block = block_allocator.allocate(new_block_hash, 0)
  59. assert (new_block == blocks[0])
  60. assert (new_block.block_hash == new_block_hash)
  61. # Reallocate the second in blocks to remove it from the free list
  62. realloc_block_hash = 1
  63. realloc_block = block_allocator.allocate(realloc_block_hash, 0)
  64. assert (realloc_block == blocks[realloc_block_hash])
  65. assert (realloc_block.block_hash == realloc_block_hash)
  66. # Allocate a new block and confirm that it's not the realloc_block,
  67. # since the realloc_block shouldn't be in the free list
  68. new_block_hash = block_size + 1
  69. new_block = block_allocator.allocate(new_block_hash, 0)
  70. assert (realloc_block != new_block)
  71. assert (new_block.block_hash == new_block_hash)
  72. assert (new_block.block_number == 2)
  73. @pytest.mark.parametrize("model", MODELS)
  74. @pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"])
  75. @pytest.mark.parametrize("dtype", ["half"])
  76. @pytest.mark.parametrize("max_tokens", [5])
  77. @pytest.mark.parametrize("cached_position", [0, 1])
  78. @pytest.mark.parametrize("use_v2_block_manager", [False, True])
  79. def test_mixed_requests(
  80. hf_runner,
  81. aphrodite_runner,
  82. example_prompts,
  83. model: str,
  84. backend: str,
  85. dtype: str,
  86. max_tokens: int,
  87. cached_position: int,
  88. use_v2_block_manager: bool,
  89. monkeypatch,
  90. ) -> None:
  91. """
  92. Test the case when some sequences have the prefix cache hit
  93. and the others don't. The cached position determines where
  94. the sequence is at among the batch of prefills.
  95. """
  96. override_backend_env_variable(monkeypatch, backend)
  97. with hf_runner(model, dtype=dtype) as hf_model:
  98. hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
  99. cached_prompt = example_prompts[cached_position]
  100. with aphrodite_runner(
  101. model,
  102. dtype=dtype,
  103. enable_prefix_caching=True,
  104. use_v2_block_manager=use_v2_block_manager,
  105. ) as aphrodite_model:
  106. # Run the first prompt so the cache is populated
  107. aphrodite_outputs = aphrodite_model.generate_greedy([cached_prompt],
  108. max_tokens)
  109. # Run all the promopts
  110. aphrodite_outputs = aphrodite_model.generate_greedy(example_prompts,
  111. max_tokens)
  112. check_outputs_equal(
  113. outputs_0_lst=hf_outputs,
  114. outputs_1_lst=aphrodite_outputs,
  115. name_0="hf",
  116. name_1="aphrodite",
  117. )