"""Compare the with and without prefix caching. Run `pytest tests/prefix_caching/test_prefix_caching.py`. """ from typing import List import pytest from aphrodite.common.block import PhysicalTokenBlock from aphrodite.common.utils import Device from aphrodite.processing.block_manager_v1 import CachedBlockAllocator from tests.kernels.utils import override_backend_env_variable from ..models.utils import check_outputs_equal MODELS = [ "facebook/opt-125m", ] @pytest.mark.parametrize("block_size", [16]) @pytest.mark.parametrize("num_blocks", [16]) def test_block_allocator( block_size: int, num_blocks: int, ): block_hash = 1 block_allocator = CachedBlockAllocator(Device.CPU, block_size, num_blocks) # Allocate two PysicalTokenBlocks with the same hash and check # that they are the same PhysicalTokenBlock first_block = block_allocator.allocate(block_hash, 0) second_block = block_allocator.allocate(block_hash, 0) assert (first_block == second_block) assert (second_block.ref_count == 2) # Check metric: 1 hit of 2 queries assert block_allocator.get_prefix_cache_hit_rate() == 0.5 # Free the first_block and confirm that the ref_count is correctly # decremented on the second block block_allocator.free(first_block) assert (second_block.ref_count == 1) # Free the second block block_allocator.free(second_block) # Reallocate the first block and confirm that, even after the block # had its ref_count go to 0, we still get the same block back first_block = block_allocator.allocate(block_hash, 0) assert (first_block == second_block) assert (first_block.block_hash == block_hash) # Allocate one more time to get 3/4 hit rate for easy checking block_allocator.allocate(block_hash, 0) assert block_allocator.get_prefix_cache_hit_rate() == 0.75 @pytest.mark.parametrize("num_blocks", [16]) def test_eviction(num_blocks: int, ): block_size = 16 block_allocator = CachedBlockAllocator(Device.CPU, block_size, num_blocks) blocks: List[PhysicalTokenBlock] = [] for i in range(num_blocks): # use i as the block_hash blocks.append(block_allocator.allocate(i, 0)) #Free all blocks for block in blocks: block_allocator.free(block) # Allocate a new block and confirm that it's the first block freed. # I.E The Least Recently Used block new_block_hash = block_size new_block = block_allocator.allocate(new_block_hash, 0) assert (new_block == blocks[0]) assert (new_block.block_hash == new_block_hash) # Reallocate the second in blocks to remove it from the free list realloc_block_hash = 1 realloc_block = block_allocator.allocate(realloc_block_hash, 0) assert (realloc_block == blocks[realloc_block_hash]) assert (realloc_block.block_hash == realloc_block_hash) # Allocate a new block and confirm that it's not the realloc_block, # since the realloc_block shouldn't be in the free list new_block_hash = block_size + 1 new_block = block_allocator.allocate(new_block_hash, 0) assert (realloc_block != new_block) assert (new_block.block_hash == new_block_hash) assert (new_block.block_number == 2) @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"]) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("cached_position", [0, 1]) @pytest.mark.parametrize("use_v2_block_manager", [False, True]) def test_mixed_requests( hf_runner, aphrodite_runner, example_prompts, model: str, backend: str, dtype: str, max_tokens: int, cached_position: int, use_v2_block_manager: bool, monkeypatch, ) -> None: """ Test the case when some sequences have the prefix cache hit and the others don't. The cached position determines where the sequence is at among the batch of prefills. """ override_backend_env_variable(monkeypatch, backend) with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) cached_prompt = example_prompts[cached_position] with aphrodite_runner( model, dtype=dtype, enable_prefix_caching=True, use_v2_block_manager=use_v2_block_manager, ) as aphrodite_model: # Run the first prompt so the cache is populated aphrodite_outputs = aphrodite_model.generate_greedy([cached_prompt], max_tokens) # Run all the promopts aphrodite_outputs = aphrodite_model.generate_greedy(example_prompts, max_tokens) check_outputs_equal( outputs_0_lst=hf_outputs, outputs_1_lst=aphrodite_outputs, name_0="hf", name_1="aphrodite", )