test_prefix_caching.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. """Compare the with and without prefix caching.
  2. Run `pytest tests/prefix_caching/test_prefix_caching.py`.
  3. """
  4. from typing import List
  5. import pytest
  6. from aphrodite.common.block import PhysicalTokenBlock
  7. from aphrodite.common.utils import Device
  8. from aphrodite.processing.block_manager_v1 import CachedBlockAllocator
  9. from tests.kernels.utils import override_backend_env_variable
  10. from ..models.utils import check_outputs_equal
  11. MODELS = [
  12. "facebook/opt-125m",
  13. ]
  14. @pytest.mark.parametrize("block_size", [16])
  15. @pytest.mark.parametrize("num_blocks", [16])
  16. def test_block_allocator(
  17. block_size: int,
  18. num_blocks: int,
  19. ):
  20. block_hash = 1
  21. block_allocator = CachedBlockAllocator(Device.CPU, block_size, num_blocks)
  22. # Allocate two PysicalTokenBlocks with the same hash and check
  23. # that they are the same PhysicalTokenBlock
  24. first_block = block_allocator.allocate(block_hash, 0)
  25. second_block = block_allocator.allocate(block_hash, 0)
  26. assert (first_block == second_block)
  27. assert (second_block.ref_count == 2)
  28. # Free the first_block and confirm that the ref_count is correctly
  29. # decremented on the second block
  30. block_allocator.free(first_block)
  31. assert (second_block.ref_count == 1)
  32. # Free the second block
  33. block_allocator.free(second_block)
  34. # Reallocate the first block and confirm that, even after the block
  35. # had its ref_count go to 0, we still get the same block back
  36. first_block = block_allocator.allocate(block_hash, 0)
  37. assert (first_block == second_block)
  38. assert (first_block.block_hash == block_hash)
  39. @pytest.mark.parametrize("num_blocks", [16])
  40. def test_eviction(num_blocks: int, ):
  41. block_size = 16
  42. block_allocator = CachedBlockAllocator(Device.CPU, block_size, num_blocks)
  43. blocks: List[PhysicalTokenBlock] = []
  44. for i in range(num_blocks):
  45. # use i as the block_hash
  46. blocks.append(block_allocator.allocate(i, 0))
  47. #Free all blocks
  48. for block in blocks:
  49. block_allocator.free(block)
  50. # Allocate a new block and confirm that it's the first block freed.
  51. # I.E The Least Recently Used block
  52. new_block_hash = block_size
  53. new_block = block_allocator.allocate(new_block_hash, 0)
  54. assert (new_block == blocks[0])
  55. assert (new_block.block_hash == new_block_hash)
  56. # Reallocate the second in blocks to remove it from the free list
  57. realloc_block_hash = 1
  58. realloc_block = block_allocator.allocate(realloc_block_hash, 0)
  59. assert (realloc_block == blocks[realloc_block_hash])
  60. assert (realloc_block.block_hash == realloc_block_hash)
  61. # Allocate a new block and confirm that it's not the realloc_block,
  62. # since the realloc_block shouldn't be in the free list
  63. new_block_hash = block_size + 1
  64. new_block = block_allocator.allocate(new_block_hash, 0)
  65. assert (realloc_block != new_block)
  66. assert (new_block.block_hash == new_block_hash)
  67. assert (new_block.block_number == 2)
  68. @pytest.mark.parametrize("model", MODELS)
  69. @pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"])
  70. @pytest.mark.parametrize("dtype", ["half"])
  71. @pytest.mark.parametrize("max_tokens", [5])
  72. @pytest.mark.parametrize("cached_position", [0, 1])
  73. @pytest.mark.parametrize("use_v2_block_manager", [False, True])
  74. def test_mixed_requests(
  75. hf_runner,
  76. aphrodite_runner,
  77. example_prompts,
  78. model: str,
  79. backend: str,
  80. dtype: str,
  81. max_tokens: int,
  82. cached_position: int,
  83. use_v2_block_manager: bool,
  84. monkeypatch,
  85. ) -> None:
  86. """
  87. Test the case when some sequences have the prefix cache hit
  88. and the others don't. The cached position determines where
  89. the sequence is at among the batch of prefills.
  90. """
  91. override_backend_env_variable(monkeypatch, backend)
  92. with hf_runner(model, dtype=dtype) as hf_model:
  93. hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
  94. cached_prompt = example_prompts[cached_position]
  95. with aphrodite_runner(
  96. model,
  97. dtype=dtype,
  98. enable_prefix_caching=True,
  99. use_v2_block_manager=use_v2_block_manager,
  100. ) as aphrodite_model:
  101. # Run the first prompt so the cache is populated
  102. aphrodite_outputs = aphrodite_model.generate_greedy([cached_prompt],
  103. max_tokens)
  104. # Run all the promopts
  105. aphrodite_outputs = aphrodite_model.generate_greedy(example_prompts,
  106. max_tokens)
  107. check_outputs_equal(
  108. outputs_0_lst=hf_outputs,
  109. outputs_1_lst=aphrodite_outputs,
  110. name_0="hf",
  111. name_1="aphrodite",
  112. )