|
@@ -12,6 +12,7 @@ from loguru import logger
|
|
|
from aphrodite.common.block import BlockTable, PhysicalTokenBlock
|
|
|
from aphrodite.common.sequence import Sequence, SequenceGroup, SequenceStatus
|
|
|
from aphrodite.common.utils import Device
|
|
|
+from aphrodite.processing.block.common import CacheMetricData
|
|
|
from aphrodite.processing.block.utils import (
|
|
|
check_no_caching_or_swa_for_blockmgr_encdec)
|
|
|
from aphrodite.processing.evictor_v1 import (EvictionPolicy, Evictor,
|
|
@@ -62,6 +63,12 @@ class BlockAllocatorBase(ABC):
|
|
|
pass
|
|
|
|
|
|
|
|
|
+ @abstractmethod
|
|
|
+ def get_prefix_cache_hit_rate(self) -> float:
|
|
|
+ """Prefix cache hit rate. -1 means not supported or disabled."""
|
|
|
+ pass
|
|
|
+
|
|
|
+
|
|
|
class CachedBlockAllocator(BlockAllocatorBase):
|
|
|
"""Manages free physical token blocks for a device.
|
|
|
|
|
@@ -86,6 +93,8 @@ class CachedBlockAllocator(BlockAllocatorBase):
|
|
|
|
|
|
self.default_hash_ctr = count()
|
|
|
|
|
|
+ self.cache_metric_data = CacheMetricData()
|
|
|
+
|
|
|
def allocate_block(self, block_hash: int,
|
|
|
num_hashed_tokens: int) -> PhysicalTokenBlock:
|
|
|
if self.current_num_blocks == self.num_blocks:
|
|
@@ -111,10 +120,10 @@ class CachedBlockAllocator(BlockAllocatorBase):
|
|
|
block = self.evictor.remove(block_hash)
|
|
|
assert block.ref_count == 0
|
|
|
self.cached_blocks[block_hash] = block
|
|
|
- block.ref_count += 1
|
|
|
- assert block.block_hash == block_hash
|
|
|
- return block
|
|
|
- if block_hash not in self.cached_blocks:
|
|
|
+ if block_hash in self.cached_blocks:
|
|
|
+ self.cache_metric_data.query(hit=True)
|
|
|
+ else:
|
|
|
+ self.cache_metric_data.query(hit=False)
|
|
|
self.cached_blocks[block_hash] = self.allocate_block(
|
|
|
block_hash, num_hashed_tokens)
|
|
|
block = self.cached_blocks[block_hash]
|
|
@@ -151,6 +160,9 @@ class CachedBlockAllocator(BlockAllocatorBase):
|
|
|
del self.cached_blocks[old_hash]
|
|
|
self.cached_blocks[block_hash] = block
|
|
|
|
|
|
+ def get_prefix_cache_hit_rate(self) -> float:
|
|
|
+ return self.cache_metric_data.get_hit_rate()
|
|
|
+
|
|
|
|
|
|
class UncachedBlockAllocator(BlockAllocatorBase):
|
|
|
"""Manages free physical token blocks for a device.
|
|
@@ -210,6 +222,9 @@ class UncachedBlockAllocator(BlockAllocatorBase):
|
|
|
raise NotImplementedError(
|
|
|
"Invalid codepath for uncached block allocator.")
|
|
|
|
|
|
+ def get_prefix_cache_hit_rate(self) -> float:
|
|
|
+ return -1
|
|
|
+
|
|
|
|
|
|
class BlockSpaceManagerV1(BlockSpaceManager):
|
|
|
"""Manages the mapping between logical and physical token blocks."""
|
|
@@ -706,3 +721,10 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
|
|
if self.enable_caching:
|
|
|
for seq in seq_group.get_seqs():
|
|
|
self.compute_full_blocks_in_seq(seq)
|
|
|
+
|
|
|
+ def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
|
|
+ if device == Device.GPU:
|
|
|
+ return self.gpu_allocator.get_prefix_cache_hit_rate()
|
|
|
+ if device == Device.CPU:
|
|
|
+ return self.cpu_allocator.get_prefix_cache_hit_rate()
|
|
|
+ raise ValueError(f"Invalid device: {device}")
|