123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143 |
- """Compare the with and without prefix caching.
- Run `pytest tests/prefix_caching/test_prefix_caching.py`.
- """
- from typing import List
- import pytest
- from aphrodite.common.block import PhysicalTokenBlock
- from aphrodite.common.utils import Device
- from aphrodite.processing.block_manager_v1 import CachedBlockAllocator
- from tests.kernels.utils import override_backend_env_variable
- from ..models.utils import check_outputs_equal
- MODELS = [
- "facebook/opt-125m",
- ]
- @pytest.mark.parametrize("block_size", [16])
- @pytest.mark.parametrize("num_blocks", [16])
- def test_block_allocator(
- block_size: int,
- num_blocks: int,
- ):
- block_hash = 1
- block_allocator = CachedBlockAllocator(Device.CPU, block_size, num_blocks)
- # Allocate two PysicalTokenBlocks with the same hash and check
- # that they are the same PhysicalTokenBlock
- first_block = block_allocator.allocate(block_hash, 0)
- second_block = block_allocator.allocate(block_hash, 0)
- assert (first_block == second_block)
- assert (second_block.ref_count == 2)
- # Check metric: 1 hit of 2 queries
- assert block_allocator.get_prefix_cache_hit_rate() == 0.5
- # Free the first_block and confirm that the ref_count is correctly
- # decremented on the second block
- block_allocator.free(first_block)
- assert (second_block.ref_count == 1)
- # Free the second block
- block_allocator.free(second_block)
- # Reallocate the first block and confirm that, even after the block
- # had its ref_count go to 0, we still get the same block back
- first_block = block_allocator.allocate(block_hash, 0)
- assert (first_block == second_block)
- assert (first_block.block_hash == block_hash)
- # Allocate one more time to get 3/4 hit rate for easy checking
- block_allocator.allocate(block_hash, 0)
- assert block_allocator.get_prefix_cache_hit_rate() == 0.75
- @pytest.mark.parametrize("num_blocks", [16])
- def test_eviction(num_blocks: int, ):
- block_size = 16
- block_allocator = CachedBlockAllocator(Device.CPU, block_size, num_blocks)
- blocks: List[PhysicalTokenBlock] = []
- for i in range(num_blocks):
- # use i as the block_hash
- blocks.append(block_allocator.allocate(i, 0))
- #Free all blocks
- for block in blocks:
- block_allocator.free(block)
- # Allocate a new block and confirm that it's the first block freed.
- # I.E The Least Recently Used block
- new_block_hash = block_size
- new_block = block_allocator.allocate(new_block_hash, 0)
- assert (new_block == blocks[0])
- assert (new_block.block_hash == new_block_hash)
- # Reallocate the second in blocks to remove it from the free list
- realloc_block_hash = 1
- realloc_block = block_allocator.allocate(realloc_block_hash, 0)
- assert (realloc_block == blocks[realloc_block_hash])
- assert (realloc_block.block_hash == realloc_block_hash)
- # Allocate a new block and confirm that it's not the realloc_block,
- # since the realloc_block shouldn't be in the free list
- new_block_hash = block_size + 1
- new_block = block_allocator.allocate(new_block_hash, 0)
- assert (realloc_block != new_block)
- assert (new_block.block_hash == new_block_hash)
- assert (new_block.block_number == 2)
- @pytest.mark.parametrize("model", MODELS)
- @pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"])
- @pytest.mark.parametrize("dtype", ["half"])
- @pytest.mark.parametrize("max_tokens", [5])
- @pytest.mark.parametrize("cached_position", [0, 1])
- @pytest.mark.parametrize("use_v2_block_manager", [False, True])
- def test_mixed_requests(
- hf_runner,
- aphrodite_runner,
- example_prompts,
- model: str,
- backend: str,
- dtype: str,
- max_tokens: int,
- cached_position: int,
- use_v2_block_manager: bool,
- monkeypatch,
- ) -> None:
- """
- Test the case when some sequences have the prefix cache hit
- and the others don't. The cached position determines where
- the sequence is at among the batch of prefills.
- """
- override_backend_env_variable(monkeypatch, backend)
- with hf_runner(model, dtype=dtype) as hf_model:
- hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
- cached_prompt = example_prompts[cached_position]
- with aphrodite_runner(
- model,
- dtype=dtype,
- enable_prefix_caching=True,
- use_v2_block_manager=use_v2_block_manager,
- ) as aphrodite_model:
- # Run the first prompt so the cache is populated
- aphrodite_outputs = aphrodite_model.generate_greedy([cached_prompt],
- max_tokens)
- # Run all the promopts
- aphrodite_outputs = aphrodite_model.generate_greedy(example_prompts,
- max_tokens)
- check_outputs_equal(
- outputs_0_lst=hf_outputs,
- outputs_1_lst=aphrodite_outputs,
- name_0="hf",
- name_1="aphrodite",
- )
|