Quellcode durchsuchen

feat: add metrics for prefix cache hit rate (#829)

AlpinDale vor 3 Monaten
Ursprung
Commit
3d83e64f8e

+ 12 - 1
aphrodite/engine/aphrodite_engine.py

@@ -23,7 +23,7 @@ from aphrodite.common.sequence import (EmbeddingSequenceGroupOutput,
                                        ExecuteModelRequest, PoolerOutput,
                                        SamplerOutput, Sequence, SequenceGroup,
                                        SequenceGroupMetadata, SequenceStatus)
-from aphrodite.common.utils import Counter
+from aphrodite.common.utils import Counter, Device
 from aphrodite.engine.args_tools import EngineArgs
 from aphrodite.engine.metrics_types import StatLoggerBase, Stats
 from aphrodite.engine.output_processor.interfaces import (
@@ -1290,6 +1290,13 @@ class AphroditeEngine:
             else:
                 cpu_cache_usage_sys = 0.0
 
+        # Prefix Cache Hit Rate. Note that we always use
+        # the cache hit rate of the first virtual engine.
+        cpu_prefix_cache_hit_rate = self.scheduler[
+            0].get_prefix_cache_hit_rate(Device.CPU)
+        gpu_prefix_cache_hit_rate = self.scheduler[
+            0].get_prefix_cache_hit_rate(Device.GPU)
+
         # Iteration stats
         num_prompt_tokens_iter = 0
         num_generation_tokens_iter = 0
@@ -1400,6 +1407,10 @@ class AphroditeEngine:
             gpu_cache_usage_sys=gpu_cache_usage_sys,
             cpu_cache_usage_sys=cpu_cache_usage_sys,
 
+            #   Prefix Cache Hit Rate
+            cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate,
+            gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate,
+
             # Iteration stats
             num_prompt_tokens_iter=num_prompt_tokens_iter,
             num_generation_tokens_iter=num_generation_tokens_iter,

+ 23 - 0
aphrodite/engine/metrics.py

@@ -70,6 +70,18 @@ class Metrics:
             documentation="CPU KV-cache usage. 1 means 100 percent usage.",
             labelnames=labelnames,
             multiprocess_mode="sum")
+        
+        #   Prefix caching block hit rate
+        self.gauge_cpu_prefix_cache_hit_rate = self._gauge_cls(
+            name="aphrodite:cpu_prefix_cache_hit_rate",
+            documentation="CPU prefix cache block hit rate.",
+            labelnames=labelnames,
+            multiprocess_mode="sum")
+        self.gauge_gpu_prefix_cache_hit_rate = self._gauge_cls(
+            name="aphrodite:gpu_prefix_cache_hit_rate",
+            documentation="GPU prefix cache block hit rate.",
+            labelnames=labelnames,
+            multiprocess_mode="sum")
 
         # Iteration stats
         self.counter_num_preemption = self._counter_cls(
@@ -347,6 +359,13 @@ class LoggingStatLogger(StatLoggerBase):
                 f"CPU KV cache usage: {stats.cpu_cache_usage_sys * 100:.1f}%."
             )
 
+            if (stats.cpu_prefix_cache_hit_rate >= 0
+                    or stats.gpu_prefix_cache_hit_rate >= 0):
+                logger.info(
+                    "Prefix cache hit rate: "
+                    f"GPU: {stats.gpu_prefix_cache_hit_rate * 100:.2f}%, "
+                    f"CPU: {stats.cpu_prefix_cache_hit_rate * 100:.2f}%")
+
             if self.spec_decode_metrics is not None:
                 logger.info(
                     self._format_spec_decode_metrics_str(
@@ -418,6 +437,10 @@ class PrometheusStatLogger(StatLoggerBase):
                         stats.gpu_cache_usage_sys)
         self._log_gauge(self.metrics.gauge_cpu_cache_usage,
                         stats.cpu_cache_usage_sys)
+        self._log_gauge(self.metrics.gauge_cpu_prefix_cache_hit_rate,
+                        stats.cpu_prefix_cache_hit_rate)
+        self._log_gauge(self.metrics.gauge_gpu_prefix_cache_hit_rate,
+                        stats.gpu_prefix_cache_hit_rate)
 
         # Iteration level data
         self._log_counter(self.metrics.counter_num_preemption,

+ 3 - 0
aphrodite/engine/metrics_types.py

@@ -28,6 +28,9 @@ class Stats:
     #   KV Cache Usage in %
     gpu_cache_usage_sys: float
     cpu_cache_usage_sys: float
+    #   Prefix caching block hit rate
+    cpu_prefix_cache_hit_rate: float
+    gpu_prefix_cache_hit_rate: float
     # Iteration stats (should have _iter suffix)
     num_prompt_tokens_iter: int
     num_generation_tokens_iter: int

+ 49 - 0
aphrodite/processing/block/common.py

@@ -1,4 +1,5 @@
 from collections import deque
+from dataclasses import dataclass
 from typing import Deque, Dict, Iterable, List, Optional, Protocol, Tuple
 
 from aphrodite.processing.block.interfaces import Block, BlockAllocator
@@ -282,6 +283,54 @@ class BlockList:
         return self._block_ids
 
 
+@dataclass
+class CacheMetricData:
+    """A utility dataclass to maintain cache metric.
+    To avoid overflow, we maintain the hit rate in block granularity, so that
+    we can maintain a single hit rate for n_completed_block x block_size,
+    and calculate the real time hit rate by the following:
+    BS = The number of queries per block.
+    nB = The number of completed blocks.
+    HR = hit rate of (nB x BS) queries.
+    Q = current number of queries (< BS).
+    H = current number of hits (< BS).
+    hit rate = ((HR x nB) + (H / Q) x (Q / BS)) / (nB + Q / BS)
+    """
+    num_completed_blocks: int = 0
+    completed_block_cache_hit_rate: float = 0.0
+    num_incompleted_block_queries: int = 0
+    num_incompleted_block_hit: int = 0
+    block_size: int = 1000
+    def query(self, hit: bool):
+        self.num_incompleted_block_queries += 1
+        self.num_incompleted_block_hit += 1 if hit else 0
+        # When a block is completed, update the cache hit rate
+        # and reset the incomplete numbers.
+        if self.num_incompleted_block_queries == self.block_size:
+            hit_rate = (self.num_incompleted_block_hit /
+                        self.num_incompleted_block_queries)
+            self.completed_block_cache_hit_rate = (
+                self.completed_block_cache_hit_rate * self.num_completed_blocks
+                + hit_rate) / (self.num_completed_blocks + 1)
+            self.num_incompleted_block_queries = 0
+            self.num_incompleted_block_hit = 0
+            self.num_completed_blocks += 1
+    def get_hit_rate(self):
+        incomplete_ratio = self.num_incompleted_block_queries / self.block_size
+        total_blocks = self.num_completed_blocks + incomplete_ratio
+        if total_blocks == 0:
+            return 0.0
+        completed_block_hit, incompleted_block_hit = 0.0, 0.0
+        if self.num_completed_blocks > 0:
+            completed_block_hit = (self.completed_block_cache_hit_rate *
+                                   self.num_completed_blocks)
+        if self.num_incompleted_block_queries > 0:
+            incompleted_hit_rate = (self.num_incompleted_block_hit /
+                                    self.num_incompleted_block_queries)
+            incompleted_block_hit = (incompleted_hit_rate * incomplete_ratio)
+        return (completed_block_hit + incompleted_block_hit) / total_blocks
+
+
 def get_all_blocks_recursively(last_block: Block) -> List[Block]:
     """Retrieves all the blocks in a sequence starting from the last block.
 

+ 5 - 0
aphrodite/processing/block/cpu_gpu_block_allocator.py

@@ -326,6 +326,11 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
     def all_block_ids(self) -> FrozenSet[int]:
         return frozenset(self._block_ids_to_allocator.keys())
 
+    def get_prefix_cache_hit_rate(self, device: Device) -> float:
+        """Prefix cache hit rate. -1 means not supported or disabled."""
+        assert device in self._allocators
+        return self._allocators[device].get_prefix_cache_hit_rate()
+
     def get_and_reset_swaps(self) -> List[Tuple[int, int]]:
         """Returns and clears the mapping of source to destination block IDs.
         Will be called after every swapping operations for now, and after every

+ 10 - 0
aphrodite/processing/block/interfaces.py

@@ -186,6 +186,11 @@ class BlockAllocator(ABC):
                                num_lookahead_slots: int = 0) -> int:
         pass
 
+    @abstractmethod
+    def get_prefix_cache_hit_rate(self) -> float:
+        """Prefix cache hit rate. -1 means not supported or disabled."""
+        pass
+
     class NoFreeBlocksError(ValueError):
         pass
 
@@ -278,3 +283,8 @@ class DeviceAwareBlockAllocator(ABC):
         There is at most one null block per allocator.
         """
         pass
+
+    @abstractmethod
+    def get_prefix_cache_hit_rate(self, device: Device) -> float:
+        """Prefix cache hit rate. -1 means not supported or disabled."""
+        pass

+ 3 - 0
aphrodite/processing/block/naive_block.py

@@ -343,6 +343,9 @@ class NaiveBlockAllocator(BlockAllocator):
 
             block.block_id = block_id  # Assign block_id
 
+    def get_prefix_cache_hit_rate(self) -> float:
+        return -1
+
 
 class NaiveBlock(Block):
     """An implementation of the Block class that does not support prefix

+ 9 - 1
aphrodite/processing/block/prefix_caching_block.py

@@ -4,7 +4,8 @@ from os.path import commonprefix
 from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple
 
 from aphrodite.common.utils import cdiv
-from aphrodite.processing.block.common import (CopyOnWriteTracker,
+from aphrodite.processing.block.common import (CacheMetricData,
+                                               CopyOnWriteTracker,
                                                get_all_blocks_recursively)
 from aphrodite.processing.block.interfaces import (Block, BlockAllocator,
                                                    BlockId, Device)
@@ -109,6 +110,8 @@ class PrefixCachingBlockAllocator(BlockAllocator):
         self._cow_tracker = CopyOnWriteTracker(
             refcounter=self._refcounter.as_readonly())
 
+        self.metric_data = CacheMetricData()
+
     # Implements Block.Factory.
     def _create_block(
         self,
@@ -157,9 +160,11 @@ class PrefixCachingBlockAllocator(BlockAllocator):
 
         cached_block_id = self._cached_blocks.get(block.content_hash, None)
         if cached_block_id is not None:
+            self.metric_data.query(hit=True)
             block.block_id = cached_block_id
             self._incr_refcount_cached_block(block)
             return block
+        self.metric_data.query(hit=False)
         self._block_pool.free_block(block)
 
         # No cached block => Allocate a new block
@@ -406,6 +411,9 @@ class PrefixCachingBlockAllocator(BlockAllocator):
     def all_block_ids(self) -> FrozenSet[int]:
         return self._hashless_allocator.all_block_ids
 
+    def get_prefix_cache_hit_rate(self) -> float:
+        return self.metric_data.get_hit_rate()
+
     def is_block_cached(self, block: Block) -> bool:
         assert block.content_hash is not None
         if block.content_hash in self._cached_blocks:

+ 26 - 4
aphrodite/processing/block_manager_v1.py

@@ -12,6 +12,7 @@ from loguru import logger
 from aphrodite.common.block import BlockTable, PhysicalTokenBlock
 from aphrodite.common.sequence import Sequence, SequenceGroup, SequenceStatus
 from aphrodite.common.utils import Device
+from aphrodite.processing.block.common import CacheMetricData
 from aphrodite.processing.block.utils import (
     check_no_caching_or_swa_for_blockmgr_encdec)
 from aphrodite.processing.evictor_v1 import (EvictionPolicy, Evictor,
@@ -62,6 +63,12 @@ class BlockAllocatorBase(ABC):
         pass
 
 
+    @abstractmethod
+    def get_prefix_cache_hit_rate(self) -> float:
+        """Prefix cache hit rate. -1 means not supported or disabled."""
+        pass
+
+
 class CachedBlockAllocator(BlockAllocatorBase):
     """Manages free physical token blocks for a device.
 
@@ -86,6 +93,8 @@ class CachedBlockAllocator(BlockAllocatorBase):
 
         self.default_hash_ctr = count()
 
+        self.cache_metric_data = CacheMetricData()
+
     def allocate_block(self, block_hash: int,
                        num_hashed_tokens: int) -> PhysicalTokenBlock:
         if self.current_num_blocks == self.num_blocks:
@@ -111,10 +120,10 @@ class CachedBlockAllocator(BlockAllocatorBase):
             block = self.evictor.remove(block_hash)
             assert block.ref_count == 0
             self.cached_blocks[block_hash] = block
-            block.ref_count += 1
-            assert block.block_hash == block_hash
-            return block
-        if block_hash not in self.cached_blocks:
+        if block_hash in self.cached_blocks:
+            self.cache_metric_data.query(hit=True)
+        else:
+            self.cache_metric_data.query(hit=False)
             self.cached_blocks[block_hash] = self.allocate_block(
                 block_hash, num_hashed_tokens)
         block = self.cached_blocks[block_hash]
@@ -151,6 +160,9 @@ class CachedBlockAllocator(BlockAllocatorBase):
         del self.cached_blocks[old_hash]
         self.cached_blocks[block_hash] = block
 
+    def get_prefix_cache_hit_rate(self) -> float:
+        return self.cache_metric_data.get_hit_rate()
+
 
 class UncachedBlockAllocator(BlockAllocatorBase):
     """Manages free physical token blocks for a device.
@@ -210,6 +222,9 @@ class UncachedBlockAllocator(BlockAllocatorBase):
         raise NotImplementedError(
             "Invalid codepath for uncached block allocator.")
 
+    def get_prefix_cache_hit_rate(self) -> float:
+        return -1
+
 
 class BlockSpaceManagerV1(BlockSpaceManager):
     """Manages the mapping between logical and physical token blocks."""
@@ -706,3 +721,10 @@ class BlockSpaceManagerV1(BlockSpaceManager):
         if self.enable_caching:
             for seq in seq_group.get_seqs():
                 self.compute_full_blocks_in_seq(seq)
+
+    def get_prefix_cache_hit_rate(self, device: Device) -> float:
+        if device == Device.GPU:
+            return self.gpu_allocator.get_prefix_cache_hit_rate()
+        if device == Device.CPU:
+            return self.cpu_allocator.get_prefix_cache_hit_rate()
+        raise ValueError(f"Invalid device: {device}")

+ 3 - 0
aphrodite/processing/block_manager_v2.py

@@ -439,6 +439,9 @@ class BlockSpaceManagerV2(BlockSpaceManager):
     def get_num_free_cpu_blocks(self) -> int:
         return self.block_allocator.get_num_free_blocks(Device.CPU)
 
+    def get_prefix_cache_hit_rate(self, device: Device) -> float:
+        return self.block_allocator.get_prefix_cache_hit_rate(device)
+
     def _can_swap(self,
                   seq_group: SequenceGroup,
                   device: Device,

+ 9 - 7
aphrodite/processing/evictor_v2.py

@@ -85,18 +85,21 @@ class LRUEvictor(Evictor):
         if len(self.free_table) == 0:
             raise ValueError("No usable cache memory left")
 
-        evicted_block = next(iter(self.free_table.values()))
-        evicted_block_id = next(iter(self.free_table.keys()))
+        evicted_block, evicted_block_id = None, None
         # The blocks with the lowest timestamps should be placed consecutively
         # at the start of OrderedDict. Loop through all these blocks to
         # find the one with maximum number of hashed tokens.
         for _id, block in self.free_table.items():
+            if evicted_block is None:
+                evicted_block, evicted_block_id = block, _id
+                continue
             if evicted_block.last_accessed < block.last_accessed:
                 break
-            if (evicted_block.last_accessed == block.last_accessed and
-                    evicted_block.num_hashed_tokens < block.num_hashed_tokens):
-                evicted_block = block
-                evicted_block_id = _id
+            if evicted_block.num_hashed_tokens < block.num_hashed_tokens:
+                evicted_block, evicted_block_id = block, _id
+
+        assert evicted_block is not None
+        assert evicted_block_id is not None
 
         self.free_table.pop(evicted_block_id)
 
@@ -110,7 +113,6 @@ class LRUEvictor(Evictor):
 
     def update(self, block_id: int, last_accessed: float):
         self.free_table[block_id].last_accessed = last_accessed
-        self.free_table.move_to_end(block_id)
 
     def remove(self, block_id: int):
         if block_id not in self.free_table:

+ 6 - 0
aphrodite/processing/interfaces.py

@@ -5,6 +5,7 @@ from typing import Sequence as GenericSequence
 from typing import Tuple
 
 from aphrodite.common.sequence import Sequence, SequenceGroup
+from aphrodite.common.utils import Device
 
 
 class AllocStatus(enum.Enum):
@@ -118,3 +119,8 @@ class BlockSpaceManager(ABC):
     @abstractmethod
     def mark_blocks_as_computed(self, seq_group: SequenceGroup):
         pass
+
+    @abstractmethod
+    def get_prefix_cache_hit_rate(self, device: Device) -> float:
+        """Prefix cache hit rate. -1 means not supported or disabled."""
+        pass

+ 4 - 0
aphrodite/processing/placeholder_block_space_manager.py

@@ -1,6 +1,7 @@
 from typing import List, Tuple
 
 from aphrodite.common.sequence import Sequence, SequenceGroup
+from aphrodite.common.utils import Device
 from aphrodite.processing.interfaces import AllocStatus, BlockSpaceManager
 
 
@@ -81,3 +82,6 @@ class PlaceholderBlockSpaceManager(BlockSpaceManager):
 
     def mark_blocks_as_computed(self, seq_group: SequenceGroup):
         pass
+
+    def get_prefix_cache_hit_rate(self, device: Device) -> float:
+        return -1

+ 4 - 1
aphrodite/processing/scheduler.py

@@ -13,7 +13,7 @@ from aphrodite.common.sequence import (Sequence, SequenceData, SequenceGroup,
                                        SequenceGroupMetadata,
                                        SequenceGroupMetadataDelta,
                                        SequenceStatus)
-from aphrodite.common.utils import PyObjectCache
+from aphrodite.common.utils import Device, PyObjectCache
 from aphrodite.lora.request import LoRARequest
 from aphrodite.processing.interfaces import AllocStatus, BlockSpaceManager
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
@@ -457,6 +457,9 @@ class Scheduler:
         return len(self.waiting) != 0 or len(self.running) != 0 or len(
             self.swapped) != 0
 
+    def get_prefix_cache_hit_rate(self, device: Device) -> float:
+        return self.block_manager.get_prefix_cache_hit_rate(device)
+
     def get_num_unfinished_seq_groups(self) -> int:
         return len(self.waiting) + len(self.running) + len(self.swapped)
 

+ 23 - 0
tests/core/block/test_prefix_caching_block.py

@@ -682,6 +682,29 @@ class TestPrefixCachingBlockAllocator:
 
         assert new_block[0].block_id == last_block_id
 
+    # Test case for cache mertics
+    @staticmethod
+    def test_metric():
+        block_size = 16
+        allocator = PrefixCachingBlockAllocator(num_blocks=4,
+                                                block_size=block_size)
+        # Test when no query (0/0)
+        assert allocator.get_prefix_cache_hit_rate() == 0.0
+        token_ids = list(range(block_size))
+        allocator.allocate_immutable_block(prev_block=None,
+                                           token_ids=token_ids)
+        # Test 0/1 hit rate
+        assert allocator.get_prefix_cache_hit_rate() == 0.0
+        allocator.allocate_immutable_block(prev_block=None,
+                                           token_ids=token_ids)
+        # Test 1/2 hit rate
+        assert allocator.get_prefix_cache_hit_rate() == 0.5
+        # Test more than one block
+        for _ in range(2, 1005):
+            allocator.allocate_immutable_block(prev_block=None,
+                                               token_ids=token_ids)
+        assert allocator.get_prefix_cache_hit_rate() > 0.99
+
     @staticmethod
     def create_immutable_chain(
         block_size: int,

+ 7 - 0
tests/prefix_caching/test_prefix_caching.py

@@ -34,6 +34,9 @@ def test_block_allocator(
     assert (first_block == second_block)
     assert (second_block.ref_count == 2)
 
+    # Check metric: 1 hit of 2 queries
+    assert block_allocator.get_prefix_cache_hit_rate() == 0.5
+
     # Free the first_block and confirm that the ref_count is correctly
     # decremented on the second block
     block_allocator.free(first_block)
@@ -48,6 +51,10 @@ def test_block_allocator(
     assert (first_block == second_block)
     assert (first_block.block_hash == block_hash)
 
+    # Allocate one more time to get 3/4 hit rate for easy checking
+    block_allocator.allocate(block_hash, 0)
+    assert block_allocator.get_prefix_cache_hit_rate() == 0.75
+
 
 @pytest.mark.parametrize("num_blocks", [16])
 def test_eviction(num_blocks: int, ):