|
@@ -8,10 +8,17 @@ from aphrodite.processing.block.common import (CopyOnWriteTracker,
|
|
|
from aphrodite.processing.block.interfaces import Block, BlockAllocator
|
|
|
from aphrodite.processing.block.naive_block import (NaiveBlock,
|
|
|
NaiveBlockAllocator)
|
|
|
+from aphrodite.processing.evictor_v2 import (EvictionPolicy, Evictor,
|
|
|
+ make_evictor)
|
|
|
|
|
|
PrefixHash = int
|
|
|
BlockId = int
|
|
|
|
|
|
+# By default, we init our block access time as _DEFAULT_LAST_ACCESSED_TIME
|
|
|
+# so that if we find one block is still hold _DEFAULT_LAST_ACCESSED_TIME,
|
|
|
+# then we know this block hasn't been accessed yet.
|
|
|
+_DEFAULT_LAST_ACCESSED_TIME = -1
|
|
|
+
|
|
|
|
|
|
class PrefixCachingBlockAllocator(BlockAllocator):
|
|
|
"""A block allocator that implements prefix caching.
|
|
@@ -28,22 +35,19 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|
|
from 0 to num_blocks - 1.
|
|
|
"""
|
|
|
|
|
|
- # TODO last access time / evictor integration
|
|
|
-
|
|
|
def __init__(
|
|
|
self,
|
|
|
num_blocks: int,
|
|
|
block_size: int,
|
|
|
block_ids: Optional[Iterable[int]] = None,
|
|
|
+ eviction_policy: Optional[EvictionPolicy] = EvictionPolicy.LRU,
|
|
|
):
|
|
|
# A mapping of prefix hash to block index. All blocks which have a
|
|
|
# prefix hash will be in this dict, even if they have refcount 0.
|
|
|
self._cached_blocks: Dict[PrefixHash, BlockId] = {}
|
|
|
|
|
|
- # A mapping of prefix hash to block index. All blocks which have a
|
|
|
- # prefix hash AND refcount 0 will be in this dict. Thus, it is a subset
|
|
|
- # of self._cached_blocks.
|
|
|
- self._unused_cached_blocks: Dict[PrefixHash, BlockId] = {}
|
|
|
+ # A mapping of blockId to Block to track those cached blocks
|
|
|
+ self._blocks: Dict[BlockId, Block] = {}
|
|
|
|
|
|
# An allocator for blocks that do not have prefix hashes.
|
|
|
self._hashless_allocator = NaiveBlockAllocator(
|
|
@@ -55,6 +59,10 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|
|
|
|
|
self._block_size = block_size
|
|
|
|
|
|
+ # Evitor used to maintain how we want to handle those computed blocks
|
|
|
+ # if we find memory pressure is high.
|
|
|
+ self.evictor: Evictor = make_evictor(eviction_policy)
|
|
|
+
|
|
|
# We share the refcounter between allocators. This allows us to promote
|
|
|
# blocks originally allocated in the hashless allocator to immutable
|
|
|
# blocks.
|
|
@@ -73,6 +81,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|
|
block_size: int,
|
|
|
allocator: BlockAllocator,
|
|
|
block_id: Optional[int] = None,
|
|
|
+ computed: Optional[bool] = False,
|
|
|
) -> Block:
|
|
|
# Bind block to self.
|
|
|
allocator = self
|
|
@@ -83,6 +92,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|
|
block_size=block_size,
|
|
|
block_id=block_id,
|
|
|
prefix_caching_allocator=allocator,
|
|
|
+ computed=computed,
|
|
|
)
|
|
|
|
|
|
def allocate_immutable(self, prev_block: Optional[Block],
|
|
@@ -110,14 +120,12 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|
|
cached_block_id = self._cached_blocks.get(block.content_hash, None)
|
|
|
if cached_block_id is not None:
|
|
|
block.block_id = cached_block_id
|
|
|
- self._incr_refcount_cached_block(block.content_hash,
|
|
|
- block.block_id)
|
|
|
+ self._incr_refcount_cached_block(block, block.block_id)
|
|
|
return block
|
|
|
|
|
|
block = self.allocate_mutable(prev_block)
|
|
|
block.append_token_ids(token_ids)
|
|
|
assert block.content_hash is not None
|
|
|
- # TODO computed bit
|
|
|
|
|
|
return block
|
|
|
|
|
@@ -134,41 +142,67 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|
|
assert_prefix_caching_block_or_none(prev_block)
|
|
|
|
|
|
try:
|
|
|
- return self._hashless_allocator.allocate_mutable(
|
|
|
+ block = self._hashless_allocator.allocate_mutable(
|
|
|
prev_block=prev_block)
|
|
|
+
|
|
|
+ assert block.block_id not in self._blocks
|
|
|
+ self._blocks[block.block_id] = block
|
|
|
+ return block
|
|
|
except BlockAllocator.NoFreeBlocksError:
|
|
|
# We must check the unused cached blocks before raising OOM.
|
|
|
pass
|
|
|
|
|
|
- if self._unused_cached_blocks:
|
|
|
- # TODO policy for selecting block to remove
|
|
|
- content_hash_to_evict = next(iter(self._unused_cached_blocks))
|
|
|
+ # If the evictor has blocks available for eviction, evict a block
|
|
|
+ # and return it.
|
|
|
+ if self.evictor.num_blocks > 0:
|
|
|
+ block_id, content_hash_to_evict = self.evictor.evict()
|
|
|
+
|
|
|
+ # Here we may have scenario that several blocks have
|
|
|
+ # the same content hash, but due to the latter coming block
|
|
|
+ # is coming from mutable to immutable path, their physical
|
|
|
+ # block is added into evictor.
|
|
|
+ # However in this case, we shall not pop the _cached_blocks,
|
|
|
+ # as the same content is still used by others, which means
|
|
|
+ # we need to check ref before decide to pop the list.
|
|
|
+
|
|
|
+ _block_id = self._cached_blocks[content_hash_to_evict]
|
|
|
+ refcount = self._refcounter.get(_block_id)
|
|
|
+ if refcount == 1:
|
|
|
+ self._cached_blocks.pop(content_hash_to_evict)
|
|
|
+ assert _block_id == block_id
|
|
|
|
|
|
- # Clear content hash mapping; the block will be overwritten.
|
|
|
- del self._cached_blocks[content_hash_to_evict]
|
|
|
+ self._refcounter.incr(block_id)
|
|
|
|
|
|
- block_id = self._unused_cached_blocks.pop(content_hash_to_evict)
|
|
|
- refcount = self._refcounter.incr(block_id)
|
|
|
- assert refcount == 1
|
|
|
+ # the block comes from evictor already contain computed result
|
|
|
block = self._create_block(
|
|
|
prev_block=prev_block,
|
|
|
token_ids=[],
|
|
|
block_size=self._block_size,
|
|
|
allocator=self,
|
|
|
block_id=block_id,
|
|
|
+ computed=True,
|
|
|
)
|
|
|
assert block.content_hash is None
|
|
|
+
|
|
|
+ assert block.block_id not in self._blocks
|
|
|
+ self._blocks[block.block_id] = block
|
|
|
return block
|
|
|
|
|
|
# No block available in hashless allocator, nor in unused cache blocks.
|
|
|
raise BlockAllocator.NoFreeBlocksError()
|
|
|
|
|
|
- def _incr_refcount_cached_block(self, content_hash: int,
|
|
|
+ def _incr_refcount_cached_block(self, block: Block,
|
|
|
block_id: BlockId) -> None:
|
|
|
+ # since block is already computed, mark it
|
|
|
+ block.computed = True
|
|
|
+
|
|
|
refcount = self._refcounter.incr(block_id)
|
|
|
if refcount == 1:
|
|
|
- assert content_hash in self._unused_cached_blocks
|
|
|
- del self._unused_cached_blocks[content_hash]
|
|
|
+ # if block get referred, then it shall not be in evictor
|
|
|
+ # and put it into _blocks for tracking
|
|
|
+ if block_id in self.evictor:
|
|
|
+ self.evictor.remove(block_id)
|
|
|
+ self._blocks[block_id] = block
|
|
|
|
|
|
def free(self, block: Block) -> None:
|
|
|
"""Decrement the refcount of the block. If the decremented refcount is
|
|
@@ -188,15 +222,21 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|
|
assert isinstance(block, PrefixCachingBlock)
|
|
|
|
|
|
if block.content_hash is None:
|
|
|
+ refcount = self._refcounter.get(block_id)
|
|
|
+ # We have fork case where block would get more than one ref,
|
|
|
+ # so we cannot free it from tracking if ref cnt large than 1
|
|
|
+ if refcount <= 1:
|
|
|
+ del self._blocks[block.block_id]
|
|
|
return self._hashless_allocator.free(block)
|
|
|
|
|
|
refcount = self._refcounter.decr(block_id)
|
|
|
|
|
|
- # If no longer used, add the block to the unused cached blocks.
|
|
|
+ # If no longer used, add the block to the evictor.
|
|
|
if refcount == 0:
|
|
|
- assert block.content_hash not in self._unused_cached_blocks
|
|
|
assert block.content_hash in self._cached_blocks
|
|
|
- self._unused_cached_blocks[block.content_hash] = block_id
|
|
|
+ del self._blocks[block.block_id]
|
|
|
+ self.evictor.add(block.block_id, block.content_hash,
|
|
|
+ block.num_tokens_total, block.last_accessed)
|
|
|
|
|
|
def fork(self, last_block: Block) -> List[Block]:
|
|
|
"""Creates a new sequence of blocks that shares the same underlying
|
|
@@ -231,9 +271,9 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|
|
|
|
|
def get_num_free_blocks(self) -> int:
|
|
|
# The number of free blocks is the number of hashless free blocks
|
|
|
- # plus the number of hashful blocks that are unused.
|
|
|
- return self._hashless_allocator.get_num_free_blocks() + len(
|
|
|
- self._unused_cached_blocks)
|
|
|
+ # plus the number of blocks evictor could free from its list.
|
|
|
+ return self._hashless_allocator.get_num_free_blocks(
|
|
|
+ ) + self.evictor.num_blocks
|
|
|
|
|
|
@property
|
|
|
def all_block_ids(self) -> frozenset[int]:
|
|
@@ -267,7 +307,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|
|
else:
|
|
|
self._free_block_id_for_block(block.block_id, block)
|
|
|
self._incr_refcount_cached_block(
|
|
|
- block.content_hash, self._cached_blocks[block.content_hash])
|
|
|
+ block, self._cached_blocks[block.content_hash])
|
|
|
|
|
|
return self._cached_blocks[block.content_hash]
|
|
|
|
|
@@ -294,29 +334,58 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|
|
"""
|
|
|
return self._cow_tracker.clear_cows()
|
|
|
|
|
|
- def mark_blocks_as_computed(self) -> None:
|
|
|
+ def mark_blocks_as_accessed(self, block_ids: List[int],
|
|
|
+ now: float) -> None:
|
|
|
+ """Mark blocks as accessed, used in prefix caching.
|
|
|
+ If the block is added into evictor, we need to update corresponding
|
|
|
+ info in evictor's metadata.
|
|
|
+ """
|
|
|
+
|
|
|
+ for block_id in block_ids:
|
|
|
+ if block_id in self._blocks:
|
|
|
+ self._blocks[block_id].last_accessed = now
|
|
|
+ elif block_id in self.evictor:
|
|
|
+ self.evictor.update(block_id, now)
|
|
|
+ else:
|
|
|
+ raise ValueError(
|
|
|
+ "Mark block as accessed which is not belonged to GPU")
|
|
|
+
|
|
|
+ def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
|
|
|
"""Mark blocks as computed, used in prefix caching."""
|
|
|
- # TODO Track computed blocks.
|
|
|
- pass
|
|
|
+
|
|
|
+ for block_id in block_ids:
|
|
|
+ if block_id in self._blocks:
|
|
|
+ # only those full block is valid for prefix caching
|
|
|
+ if self._blocks[block_id].is_full:
|
|
|
+ self._blocks[block_id].computed = True
|
|
|
+ elif block_id not in self.evictor:
|
|
|
+ raise ValueError(f"Mark {block_id=} as computed which "
|
|
|
+ "is not belonged to GPU")
|
|
|
+
|
|
|
+ def block_is_computed(self, block_id: int) -> bool:
|
|
|
+ if block_id in self._blocks:
|
|
|
+ return self._blocks[block_id].computed
|
|
|
+ else:
|
|
|
+ return block_id in self.evictor
|
|
|
|
|
|
def get_common_computed_block_ids(
|
|
|
self, seq_block_ids: List[List[int]]) -> List[int]:
|
|
|
"""Return the block ids that are common for a given sequence group.
|
|
|
-
|
|
|
- Used in prefill (can skip prefill of some blocks).
|
|
|
+ Only those blocks that are immutable and already be marked
|
|
|
+ compyted would be taken consideration.
|
|
|
"""
|
|
|
|
|
|
- # TODO: Track computed blocks.
|
|
|
- computed = lambda block_id: False
|
|
|
-
|
|
|
# NOTE We exclude the last block to avoid the case where the entire
|
|
|
# prompt is cached. This would cause erroneous behavior in model
|
|
|
# runner.
|
|
|
+
|
|
|
ids_list = [
|
|
|
- takewhile(lambda block_id: computed(block_id), seq[:-1])
|
|
|
- for seq in seq_block_ids
|
|
|
+ list(
|
|
|
+ takewhile(lambda block_id: self.block_is_computed(block_id),
|
|
|
+ seq[:-1])) for seq in seq_block_ids
|
|
|
]
|
|
|
- return commonprefix([ids for ids in ids_list if ids != []])
|
|
|
+ res = commonprefix([ids for ids in ids_list if ids != []])
|
|
|
+ return res
|
|
|
|
|
|
|
|
|
class PrefixCachingBlock(Block):
|
|
@@ -339,19 +408,21 @@ class PrefixCachingBlock(Block):
|
|
|
of this block. Defaults to None.
|
|
|
"""
|
|
|
|
|
|
- def __init__(
|
|
|
- self,
|
|
|
- prev_block: Optional["PrefixCachingBlock"],
|
|
|
- token_ids: List[int],
|
|
|
- block_size: int,
|
|
|
- prefix_caching_allocator: PrefixCachingBlockAllocator,
|
|
|
- block_id: Optional[int] = None,
|
|
|
- ):
|
|
|
+ def __init__(self,
|
|
|
+ prev_block: Optional["PrefixCachingBlock"],
|
|
|
+ token_ids: List[int],
|
|
|
+ block_size: int,
|
|
|
+ prefix_caching_allocator: PrefixCachingBlockAllocator,
|
|
|
+ block_id: Optional[int] = None,
|
|
|
+ computed: Optional[bool] = False):
|
|
|
assert_prefix_caching_block_or_none(prev_block)
|
|
|
|
|
|
self._prev_block = prev_block
|
|
|
self._cached_content_hash: Optional[int] = None
|
|
|
+ self._cached_num_tokens_total: Optional[int] = None
|
|
|
self._prefix_caching_allocator = prefix_caching_allocator
|
|
|
+ self.last_accessed = _DEFAULT_LAST_ACCESSED_TIME
|
|
|
+ self.computed = computed
|
|
|
|
|
|
self._block = NaiveBlock(
|
|
|
prev_block=prev_block,
|
|
@@ -399,6 +470,26 @@ class PrefixCachingBlock(Block):
|
|
|
def num_empty_slots(self) -> int:
|
|
|
return self._block.num_empty_slots
|
|
|
|
|
|
+ @property
|
|
|
+ def num_tokens_total(self) -> int:
|
|
|
+ """return the total tokens so far.
|
|
|
+ Here we iterate the block chain till to the first block, while
|
|
|
+ cache the result in local to prevent repeated computations.
|
|
|
+ """
|
|
|
+ if self._cached_num_tokens_total is not None:
|
|
|
+ return self._cached_num_tokens_total
|
|
|
+
|
|
|
+ _block = self
|
|
|
+ self._cached_num_tokens_total = 0
|
|
|
+
|
|
|
+ # TODO: current implementation here is O(N^2), we expect future
|
|
|
+ # ones to be O(1)
|
|
|
+ while _block is not None:
|
|
|
+ self._cached_num_tokens_total += len(_block.token_ids)
|
|
|
+ _block = _block.prev_block
|
|
|
+
|
|
|
+ return self._cached_num_tokens_total
|
|
|
+
|
|
|
@property
|
|
|
def block_size(self) -> int:
|
|
|
return self._block.block_size
|