Browse Source

feat: add metrics for prefix cache hit rate (#829)

AlpinDale 3 months ago
parent
commit
3d83e64f8e

+ 12 - 1
aphrodite/engine/aphrodite_engine.py

@@ -23,7 +23,7 @@ from aphrodite.common.sequence import (EmbeddingSequenceGroupOutput,
                                        ExecuteModelRequest, PoolerOutput,
                                        ExecuteModelRequest, PoolerOutput,
                                        SamplerOutput, Sequence, SequenceGroup,
                                        SamplerOutput, Sequence, SequenceGroup,
                                        SequenceGroupMetadata, SequenceStatus)
                                        SequenceGroupMetadata, SequenceStatus)
-from aphrodite.common.utils import Counter
+from aphrodite.common.utils import Counter, Device
 from aphrodite.engine.args_tools import EngineArgs
 from aphrodite.engine.args_tools import EngineArgs
 from aphrodite.engine.metrics_types import StatLoggerBase, Stats
 from aphrodite.engine.metrics_types import StatLoggerBase, Stats
 from aphrodite.engine.output_processor.interfaces import (
 from aphrodite.engine.output_processor.interfaces import (
@@ -1290,6 +1290,13 @@ class AphroditeEngine:
             else:
             else:
                 cpu_cache_usage_sys = 0.0
                 cpu_cache_usage_sys = 0.0
 
 
+        # Prefix Cache Hit Rate. Note that we always use
+        # the cache hit rate of the first virtual engine.
+        cpu_prefix_cache_hit_rate = self.scheduler[
+            0].get_prefix_cache_hit_rate(Device.CPU)
+        gpu_prefix_cache_hit_rate = self.scheduler[
+            0].get_prefix_cache_hit_rate(Device.GPU)
+
         # Iteration stats
         # Iteration stats
         num_prompt_tokens_iter = 0
         num_prompt_tokens_iter = 0
         num_generation_tokens_iter = 0
         num_generation_tokens_iter = 0
@@ -1400,6 +1407,10 @@ class AphroditeEngine:
             gpu_cache_usage_sys=gpu_cache_usage_sys,
             gpu_cache_usage_sys=gpu_cache_usage_sys,
             cpu_cache_usage_sys=cpu_cache_usage_sys,
             cpu_cache_usage_sys=cpu_cache_usage_sys,
 
 
+            #   Prefix Cache Hit Rate
+            cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate,
+            gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate,
+
             # Iteration stats
             # Iteration stats
             num_prompt_tokens_iter=num_prompt_tokens_iter,
             num_prompt_tokens_iter=num_prompt_tokens_iter,
             num_generation_tokens_iter=num_generation_tokens_iter,
             num_generation_tokens_iter=num_generation_tokens_iter,

+ 23 - 0
aphrodite/engine/metrics.py

@@ -70,6 +70,18 @@ class Metrics:
             documentation="CPU KV-cache usage. 1 means 100 percent usage.",
             documentation="CPU KV-cache usage. 1 means 100 percent usage.",
             labelnames=labelnames,
             labelnames=labelnames,
             multiprocess_mode="sum")
             multiprocess_mode="sum")
+        
+        #   Prefix caching block hit rate
+        self.gauge_cpu_prefix_cache_hit_rate = self._gauge_cls(
+            name="aphrodite:cpu_prefix_cache_hit_rate",
+            documentation="CPU prefix cache block hit rate.",
+            labelnames=labelnames,
+            multiprocess_mode="sum")
+        self.gauge_gpu_prefix_cache_hit_rate = self._gauge_cls(
+            name="aphrodite:gpu_prefix_cache_hit_rate",
+            documentation="GPU prefix cache block hit rate.",
+            labelnames=labelnames,
+            multiprocess_mode="sum")
 
 
         # Iteration stats
         # Iteration stats
         self.counter_num_preemption = self._counter_cls(
         self.counter_num_preemption = self._counter_cls(
@@ -347,6 +359,13 @@ class LoggingStatLogger(StatLoggerBase):
                 f"CPU KV cache usage: {stats.cpu_cache_usage_sys * 100:.1f}%."
                 f"CPU KV cache usage: {stats.cpu_cache_usage_sys * 100:.1f}%."
             )
             )
 
 
+            if (stats.cpu_prefix_cache_hit_rate >= 0
+                    or stats.gpu_prefix_cache_hit_rate >= 0):
+                logger.info(
+                    "Prefix cache hit rate: "
+                    f"GPU: {stats.gpu_prefix_cache_hit_rate * 100:.2f}%, "
+                    f"CPU: {stats.cpu_prefix_cache_hit_rate * 100:.2f}%")
+
             if self.spec_decode_metrics is not None:
             if self.spec_decode_metrics is not None:
                 logger.info(
                 logger.info(
                     self._format_spec_decode_metrics_str(
                     self._format_spec_decode_metrics_str(
@@ -418,6 +437,10 @@ class PrometheusStatLogger(StatLoggerBase):
                         stats.gpu_cache_usage_sys)
                         stats.gpu_cache_usage_sys)
         self._log_gauge(self.metrics.gauge_cpu_cache_usage,
         self._log_gauge(self.metrics.gauge_cpu_cache_usage,
                         stats.cpu_cache_usage_sys)
                         stats.cpu_cache_usage_sys)
+        self._log_gauge(self.metrics.gauge_cpu_prefix_cache_hit_rate,
+                        stats.cpu_prefix_cache_hit_rate)
+        self._log_gauge(self.metrics.gauge_gpu_prefix_cache_hit_rate,
+                        stats.gpu_prefix_cache_hit_rate)
 
 
         # Iteration level data
         # Iteration level data
         self._log_counter(self.metrics.counter_num_preemption,
         self._log_counter(self.metrics.counter_num_preemption,

+ 3 - 0
aphrodite/engine/metrics_types.py

@@ -28,6 +28,9 @@ class Stats:
     #   KV Cache Usage in %
     #   KV Cache Usage in %
     gpu_cache_usage_sys: float
     gpu_cache_usage_sys: float
     cpu_cache_usage_sys: float
     cpu_cache_usage_sys: float
+    #   Prefix caching block hit rate
+    cpu_prefix_cache_hit_rate: float
+    gpu_prefix_cache_hit_rate: float
     # Iteration stats (should have _iter suffix)
     # Iteration stats (should have _iter suffix)
     num_prompt_tokens_iter: int
     num_prompt_tokens_iter: int
     num_generation_tokens_iter: int
     num_generation_tokens_iter: int

+ 49 - 0
aphrodite/processing/block/common.py

@@ -1,4 +1,5 @@
 from collections import deque
 from collections import deque
+from dataclasses import dataclass
 from typing import Deque, Dict, Iterable, List, Optional, Protocol, Tuple
 from typing import Deque, Dict, Iterable, List, Optional, Protocol, Tuple
 
 
 from aphrodite.processing.block.interfaces import Block, BlockAllocator
 from aphrodite.processing.block.interfaces import Block, BlockAllocator
@@ -282,6 +283,54 @@ class BlockList:
         return self._block_ids
         return self._block_ids
 
 
 
 
+@dataclass
+class CacheMetricData:
+    """A utility dataclass to maintain cache metric.
+    To avoid overflow, we maintain the hit rate in block granularity, so that
+    we can maintain a single hit rate for n_completed_block x block_size,
+    and calculate the real time hit rate by the following:
+    BS = The number of queries per block.
+    nB = The number of completed blocks.
+    HR = hit rate of (nB x BS) queries.
+    Q = current number of queries (< BS).
+    H = current number of hits (< BS).
+    hit rate = ((HR x nB) + (H / Q) x (Q / BS)) / (nB + Q / BS)
+    """
+    num_completed_blocks: int = 0
+    completed_block_cache_hit_rate: float = 0.0
+    num_incompleted_block_queries: int = 0
+    num_incompleted_block_hit: int = 0
+    block_size: int = 1000
+    def query(self, hit: bool):
+        self.num_incompleted_block_queries += 1
+        self.num_incompleted_block_hit += 1 if hit else 0
+        # When a block is completed, update the cache hit rate
+        # and reset the incomplete numbers.
+        if self.num_incompleted_block_queries == self.block_size:
+            hit_rate = (self.num_incompleted_block_hit /
+                        self.num_incompleted_block_queries)
+            self.completed_block_cache_hit_rate = (
+                self.completed_block_cache_hit_rate * self.num_completed_blocks
+                + hit_rate) / (self.num_completed_blocks + 1)
+            self.num_incompleted_block_queries = 0
+            self.num_incompleted_block_hit = 0
+            self.num_completed_blocks += 1
+    def get_hit_rate(self):
+        incomplete_ratio = self.num_incompleted_block_queries / self.block_size
+        total_blocks = self.num_completed_blocks + incomplete_ratio
+        if total_blocks == 0:
+            return 0.0
+        completed_block_hit, incompleted_block_hit = 0.0, 0.0
+        if self.num_completed_blocks > 0:
+            completed_block_hit = (self.completed_block_cache_hit_rate *
+                                   self.num_completed_blocks)
+        if self.num_incompleted_block_queries > 0:
+            incompleted_hit_rate = (self.num_incompleted_block_hit /
+                                    self.num_incompleted_block_queries)
+            incompleted_block_hit = (incompleted_hit_rate * incomplete_ratio)
+        return (completed_block_hit + incompleted_block_hit) / total_blocks
+
+
 def get_all_blocks_recursively(last_block: Block) -> List[Block]:
 def get_all_blocks_recursively(last_block: Block) -> List[Block]:
     """Retrieves all the blocks in a sequence starting from the last block.
     """Retrieves all the blocks in a sequence starting from the last block.
 
 

+ 5 - 0
aphrodite/processing/block/cpu_gpu_block_allocator.py

@@ -326,6 +326,11 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
     def all_block_ids(self) -> FrozenSet[int]:
     def all_block_ids(self) -> FrozenSet[int]:
         return frozenset(self._block_ids_to_allocator.keys())
         return frozenset(self._block_ids_to_allocator.keys())
 
 
+    def get_prefix_cache_hit_rate(self, device: Device) -> float:
+        """Prefix cache hit rate. -1 means not supported or disabled."""
+        assert device in self._allocators
+        return self._allocators[device].get_prefix_cache_hit_rate()
+
     def get_and_reset_swaps(self) -> List[Tuple[int, int]]:
     def get_and_reset_swaps(self) -> List[Tuple[int, int]]:
         """Returns and clears the mapping of source to destination block IDs.
         """Returns and clears the mapping of source to destination block IDs.
         Will be called after every swapping operations for now, and after every
         Will be called after every swapping operations for now, and after every

+ 10 - 0
aphrodite/processing/block/interfaces.py

@@ -186,6 +186,11 @@ class BlockAllocator(ABC):
                                num_lookahead_slots: int = 0) -> int:
                                num_lookahead_slots: int = 0) -> int:
         pass
         pass
 
 
+    @abstractmethod
+    def get_prefix_cache_hit_rate(self) -> float:
+        """Prefix cache hit rate. -1 means not supported or disabled."""
+        pass
+
     class NoFreeBlocksError(ValueError):
     class NoFreeBlocksError(ValueError):
         pass
         pass
 
 
@@ -278,3 +283,8 @@ class DeviceAwareBlockAllocator(ABC):
         There is at most one null block per allocator.
         There is at most one null block per allocator.
         """
         """
         pass
         pass
+
+    @abstractmethod
+    def get_prefix_cache_hit_rate(self, device: Device) -> float:
+        """Prefix cache hit rate. -1 means not supported or disabled."""
+        pass

+ 3 - 0
aphrodite/processing/block/naive_block.py

@@ -343,6 +343,9 @@ class NaiveBlockAllocator(BlockAllocator):
 
 
             block.block_id = block_id  # Assign block_id
             block.block_id = block_id  # Assign block_id
 
 
+    def get_prefix_cache_hit_rate(self) -> float:
+        return -1
+
 
 
 class NaiveBlock(Block):
 class NaiveBlock(Block):
     """An implementation of the Block class that does not support prefix
     """An implementation of the Block class that does not support prefix

+ 9 - 1
aphrodite/processing/block/prefix_caching_block.py

@@ -4,7 +4,8 @@ from os.path import commonprefix
 from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple
 from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple
 
 
 from aphrodite.common.utils import cdiv
 from aphrodite.common.utils import cdiv
-from aphrodite.processing.block.common import (CopyOnWriteTracker,
+from aphrodite.processing.block.common import (CacheMetricData,
+                                               CopyOnWriteTracker,
                                                get_all_blocks_recursively)
                                                get_all_blocks_recursively)
 from aphrodite.processing.block.interfaces import (Block, BlockAllocator,
 from aphrodite.processing.block.interfaces import (Block, BlockAllocator,
                                                    BlockId, Device)
                                                    BlockId, Device)
@@ -109,6 +110,8 @@ class PrefixCachingBlockAllocator(BlockAllocator):
         self._cow_tracker = CopyOnWriteTracker(
         self._cow_tracker = CopyOnWriteTracker(
             refcounter=self._refcounter.as_readonly())
             refcounter=self._refcounter.as_readonly())
 
 
+        self.metric_data = CacheMetricData()
+
     # Implements Block.Factory.
     # Implements Block.Factory.
     def _create_block(
     def _create_block(
         self,
         self,
@@ -157,9 +160,11 @@ class PrefixCachingBlockAllocator(BlockAllocator):
 
 
         cached_block_id = self._cached_blocks.get(block.content_hash, None)
         cached_block_id = self._cached_blocks.get(block.content_hash, None)
         if cached_block_id is not None:
         if cached_block_id is not None:
+            self.metric_data.query(hit=True)
             block.block_id = cached_block_id
             block.block_id = cached_block_id
             self._incr_refcount_cached_block(block)
             self._incr_refcount_cached_block(block)
             return block
             return block
+        self.metric_data.query(hit=False)
         self._block_pool.free_block(block)
         self._block_pool.free_block(block)
 
 
         # No cached block => Allocate a new block
         # No cached block => Allocate a new block
@@ -406,6 +411,9 @@ class PrefixCachingBlockAllocator(BlockAllocator):
     def all_block_ids(self) -> FrozenSet[int]:
     def all_block_ids(self) -> FrozenSet[int]:
         return self._hashless_allocator.all_block_ids
         return self._hashless_allocator.all_block_ids
 
 
+    def get_prefix_cache_hit_rate(self) -> float:
+        return self.metric_data.get_hit_rate()
+
     def is_block_cached(self, block: Block) -> bool:
     def is_block_cached(self, block: Block) -> bool:
         assert block.content_hash is not None
         assert block.content_hash is not None
         if block.content_hash in self._cached_blocks:
         if block.content_hash in self._cached_blocks:

+ 26 - 4
aphrodite/processing/block_manager_v1.py

@@ -12,6 +12,7 @@ from loguru import logger
 from aphrodite.common.block import BlockTable, PhysicalTokenBlock
 from aphrodite.common.block import BlockTable, PhysicalTokenBlock
 from aphrodite.common.sequence import Sequence, SequenceGroup, SequenceStatus
 from aphrodite.common.sequence import Sequence, SequenceGroup, SequenceStatus
 from aphrodite.common.utils import Device
 from aphrodite.common.utils import Device
+from aphrodite.processing.block.common import CacheMetricData
 from aphrodite.processing.block.utils import (
 from aphrodite.processing.block.utils import (
     check_no_caching_or_swa_for_blockmgr_encdec)
     check_no_caching_or_swa_for_blockmgr_encdec)
 from aphrodite.processing.evictor_v1 import (EvictionPolicy, Evictor,
 from aphrodite.processing.evictor_v1 import (EvictionPolicy, Evictor,
@@ -62,6 +63,12 @@ class BlockAllocatorBase(ABC):
         pass
         pass
 
 
 
 
+    @abstractmethod
+    def get_prefix_cache_hit_rate(self) -> float:
+        """Prefix cache hit rate. -1 means not supported or disabled."""
+        pass
+
+
 class CachedBlockAllocator(BlockAllocatorBase):
 class CachedBlockAllocator(BlockAllocatorBase):
     """Manages free physical token blocks for a device.
     """Manages free physical token blocks for a device.
 
 
@@ -86,6 +93,8 @@ class CachedBlockAllocator(BlockAllocatorBase):
 
 
         self.default_hash_ctr = count()
         self.default_hash_ctr = count()
 
 
+        self.cache_metric_data = CacheMetricData()
+
     def allocate_block(self, block_hash: int,
     def allocate_block(self, block_hash: int,
                        num_hashed_tokens: int) -> PhysicalTokenBlock:
                        num_hashed_tokens: int) -> PhysicalTokenBlock:
         if self.current_num_blocks == self.num_blocks:
         if self.current_num_blocks == self.num_blocks:
@@ -111,10 +120,10 @@ class CachedBlockAllocator(BlockAllocatorBase):
             block = self.evictor.remove(block_hash)
             block = self.evictor.remove(block_hash)
             assert block.ref_count == 0
             assert block.ref_count == 0
             self.cached_blocks[block_hash] = block
             self.cached_blocks[block_hash] = block
-            block.ref_count += 1
-            assert block.block_hash == block_hash
-            return block
-        if block_hash not in self.cached_blocks:
+        if block_hash in self.cached_blocks:
+            self.cache_metric_data.query(hit=True)
+        else:
+            self.cache_metric_data.query(hit=False)
             self.cached_blocks[block_hash] = self.allocate_block(
             self.cached_blocks[block_hash] = self.allocate_block(
                 block_hash, num_hashed_tokens)
                 block_hash, num_hashed_tokens)
         block = self.cached_blocks[block_hash]
         block = self.cached_blocks[block_hash]
@@ -151,6 +160,9 @@ class CachedBlockAllocator(BlockAllocatorBase):
         del self.cached_blocks[old_hash]
         del self.cached_blocks[old_hash]
         self.cached_blocks[block_hash] = block
         self.cached_blocks[block_hash] = block
 
 
+    def get_prefix_cache_hit_rate(self) -> float:
+        return self.cache_metric_data.get_hit_rate()
+
 
 
 class UncachedBlockAllocator(BlockAllocatorBase):
 class UncachedBlockAllocator(BlockAllocatorBase):
     """Manages free physical token blocks for a device.
     """Manages free physical token blocks for a device.
@@ -210,6 +222,9 @@ class UncachedBlockAllocator(BlockAllocatorBase):
         raise NotImplementedError(
         raise NotImplementedError(
             "Invalid codepath for uncached block allocator.")
             "Invalid codepath for uncached block allocator.")
 
 
+    def get_prefix_cache_hit_rate(self) -> float:
+        return -1
+
 
 
 class BlockSpaceManagerV1(BlockSpaceManager):
 class BlockSpaceManagerV1(BlockSpaceManager):
     """Manages the mapping between logical and physical token blocks."""
     """Manages the mapping between logical and physical token blocks."""
@@ -706,3 +721,10 @@ class BlockSpaceManagerV1(BlockSpaceManager):
         if self.enable_caching:
         if self.enable_caching:
             for seq in seq_group.get_seqs():
             for seq in seq_group.get_seqs():
                 self.compute_full_blocks_in_seq(seq)
                 self.compute_full_blocks_in_seq(seq)
+
+    def get_prefix_cache_hit_rate(self, device: Device) -> float:
+        if device == Device.GPU:
+            return self.gpu_allocator.get_prefix_cache_hit_rate()
+        if device == Device.CPU:
+            return self.cpu_allocator.get_prefix_cache_hit_rate()
+        raise ValueError(f"Invalid device: {device}")

+ 3 - 0
aphrodite/processing/block_manager_v2.py

@@ -439,6 +439,9 @@ class BlockSpaceManagerV2(BlockSpaceManager):
     def get_num_free_cpu_blocks(self) -> int:
     def get_num_free_cpu_blocks(self) -> int:
         return self.block_allocator.get_num_free_blocks(Device.CPU)
         return self.block_allocator.get_num_free_blocks(Device.CPU)
 
 
+    def get_prefix_cache_hit_rate(self, device: Device) -> float:
+        return self.block_allocator.get_prefix_cache_hit_rate(device)
+
     def _can_swap(self,
     def _can_swap(self,
                   seq_group: SequenceGroup,
                   seq_group: SequenceGroup,
                   device: Device,
                   device: Device,

+ 9 - 7
aphrodite/processing/evictor_v2.py

@@ -85,18 +85,21 @@ class LRUEvictor(Evictor):
         if len(self.free_table) == 0:
         if len(self.free_table) == 0:
             raise ValueError("No usable cache memory left")
             raise ValueError("No usable cache memory left")
 
 
-        evicted_block = next(iter(self.free_table.values()))
-        evicted_block_id = next(iter(self.free_table.keys()))
+        evicted_block, evicted_block_id = None, None
         # The blocks with the lowest timestamps should be placed consecutively
         # The blocks with the lowest timestamps should be placed consecutively
         # at the start of OrderedDict. Loop through all these blocks to
         # at the start of OrderedDict. Loop through all these blocks to
         # find the one with maximum number of hashed tokens.
         # find the one with maximum number of hashed tokens.
         for _id, block in self.free_table.items():
         for _id, block in self.free_table.items():
+            if evicted_block is None:
+                evicted_block, evicted_block_id = block, _id
+                continue
             if evicted_block.last_accessed < block.last_accessed:
             if evicted_block.last_accessed < block.last_accessed:
                 break
                 break
-            if (evicted_block.last_accessed == block.last_accessed and
-                    evicted_block.num_hashed_tokens < block.num_hashed_tokens):
-                evicted_block = block
-                evicted_block_id = _id
+            if evicted_block.num_hashed_tokens < block.num_hashed_tokens:
+                evicted_block, evicted_block_id = block, _id
+
+        assert evicted_block is not None
+        assert evicted_block_id is not None
 
 
         self.free_table.pop(evicted_block_id)
         self.free_table.pop(evicted_block_id)
 
 
@@ -110,7 +113,6 @@ class LRUEvictor(Evictor):
 
 
     def update(self, block_id: int, last_accessed: float):
     def update(self, block_id: int, last_accessed: float):
         self.free_table[block_id].last_accessed = last_accessed
         self.free_table[block_id].last_accessed = last_accessed
-        self.free_table.move_to_end(block_id)
 
 
     def remove(self, block_id: int):
     def remove(self, block_id: int):
         if block_id not in self.free_table:
         if block_id not in self.free_table:

+ 6 - 0
aphrodite/processing/interfaces.py

@@ -5,6 +5,7 @@ from typing import Sequence as GenericSequence
 from typing import Tuple
 from typing import Tuple
 
 
 from aphrodite.common.sequence import Sequence, SequenceGroup
 from aphrodite.common.sequence import Sequence, SequenceGroup
+from aphrodite.common.utils import Device
 
 
 
 
 class AllocStatus(enum.Enum):
 class AllocStatus(enum.Enum):
@@ -118,3 +119,8 @@ class BlockSpaceManager(ABC):
     @abstractmethod
     @abstractmethod
     def mark_blocks_as_computed(self, seq_group: SequenceGroup):
     def mark_blocks_as_computed(self, seq_group: SequenceGroup):
         pass
         pass
+
+    @abstractmethod
+    def get_prefix_cache_hit_rate(self, device: Device) -> float:
+        """Prefix cache hit rate. -1 means not supported or disabled."""
+        pass

+ 4 - 0
aphrodite/processing/placeholder_block_space_manager.py

@@ -1,6 +1,7 @@
 from typing import List, Tuple
 from typing import List, Tuple
 
 
 from aphrodite.common.sequence import Sequence, SequenceGroup
 from aphrodite.common.sequence import Sequence, SequenceGroup
+from aphrodite.common.utils import Device
 from aphrodite.processing.interfaces import AllocStatus, BlockSpaceManager
 from aphrodite.processing.interfaces import AllocStatus, BlockSpaceManager
 
 
 
 
@@ -81,3 +82,6 @@ class PlaceholderBlockSpaceManager(BlockSpaceManager):
 
 
     def mark_blocks_as_computed(self, seq_group: SequenceGroup):
     def mark_blocks_as_computed(self, seq_group: SequenceGroup):
         pass
         pass
+
+    def get_prefix_cache_hit_rate(self, device: Device) -> float:
+        return -1

+ 4 - 1
aphrodite/processing/scheduler.py

@@ -13,7 +13,7 @@ from aphrodite.common.sequence import (Sequence, SequenceData, SequenceGroup,
                                        SequenceGroupMetadata,
                                        SequenceGroupMetadata,
                                        SequenceGroupMetadataDelta,
                                        SequenceGroupMetadataDelta,
                                        SequenceStatus)
                                        SequenceStatus)
-from aphrodite.common.utils import PyObjectCache
+from aphrodite.common.utils import Device, PyObjectCache
 from aphrodite.lora.request import LoRARequest
 from aphrodite.lora.request import LoRARequest
 from aphrodite.processing.interfaces import AllocStatus, BlockSpaceManager
 from aphrodite.processing.interfaces import AllocStatus, BlockSpaceManager
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
@@ -457,6 +457,9 @@ class Scheduler:
         return len(self.waiting) != 0 or len(self.running) != 0 or len(
         return len(self.waiting) != 0 or len(self.running) != 0 or len(
             self.swapped) != 0
             self.swapped) != 0
 
 
+    def get_prefix_cache_hit_rate(self, device: Device) -> float:
+        return self.block_manager.get_prefix_cache_hit_rate(device)
+
     def get_num_unfinished_seq_groups(self) -> int:
     def get_num_unfinished_seq_groups(self) -> int:
         return len(self.waiting) + len(self.running) + len(self.swapped)
         return len(self.waiting) + len(self.running) + len(self.swapped)
 
 

+ 23 - 0
tests/core/block/test_prefix_caching_block.py

@@ -682,6 +682,29 @@ class TestPrefixCachingBlockAllocator:
 
 
         assert new_block[0].block_id == last_block_id
         assert new_block[0].block_id == last_block_id
 
 
+    # Test case for cache mertics
+    @staticmethod
+    def test_metric():
+        block_size = 16
+        allocator = PrefixCachingBlockAllocator(num_blocks=4,
+                                                block_size=block_size)
+        # Test when no query (0/0)
+        assert allocator.get_prefix_cache_hit_rate() == 0.0
+        token_ids = list(range(block_size))
+        allocator.allocate_immutable_block(prev_block=None,
+                                           token_ids=token_ids)
+        # Test 0/1 hit rate
+        assert allocator.get_prefix_cache_hit_rate() == 0.0
+        allocator.allocate_immutable_block(prev_block=None,
+                                           token_ids=token_ids)
+        # Test 1/2 hit rate
+        assert allocator.get_prefix_cache_hit_rate() == 0.5
+        # Test more than one block
+        for _ in range(2, 1005):
+            allocator.allocate_immutable_block(prev_block=None,
+                                               token_ids=token_ids)
+        assert allocator.get_prefix_cache_hit_rate() > 0.99
+
     @staticmethod
     @staticmethod
     def create_immutable_chain(
     def create_immutable_chain(
         block_size: int,
         block_size: int,

+ 7 - 0
tests/prefix_caching/test_prefix_caching.py

@@ -34,6 +34,9 @@ def test_block_allocator(
     assert (first_block == second_block)
     assert (first_block == second_block)
     assert (second_block.ref_count == 2)
     assert (second_block.ref_count == 2)
 
 
+    # Check metric: 1 hit of 2 queries
+    assert block_allocator.get_prefix_cache_hit_rate() == 0.5
+
     # Free the first_block and confirm that the ref_count is correctly
     # Free the first_block and confirm that the ref_count is correctly
     # decremented on the second block
     # decremented on the second block
     block_allocator.free(first_block)
     block_allocator.free(first_block)
@@ -48,6 +51,10 @@ def test_block_allocator(
     assert (first_block == second_block)
     assert (first_block == second_block)
     assert (first_block.block_hash == block_hash)
     assert (first_block.block_hash == block_hash)
 
 
+    # Allocate one more time to get 3/4 hit rate for easy checking
+    block_allocator.allocate(block_hash, 0)
+    assert block_allocator.get_prefix_cache_hit_rate() == 0.75
+
 
 
 @pytest.mark.parametrize("num_blocks", [16])
 @pytest.mark.parametrize("num_blocks", [16])
 def test_eviction(num_blocks: int, ):
 def test_eviction(num_blocks: int, ):