|
@@ -1,18 +1,18 @@
|
|
|
"""Token blocks."""
|
|
|
from itertools import takewhile
|
|
|
from os.path import commonprefix
|
|
|
-from typing import Dict, Iterable, List, Optional
|
|
|
+from typing import Dict, FrozenSet, Iterable, List, Optional
|
|
|
|
|
|
from aphrodite.processing.block.common import (CopyOnWriteTracker,
|
|
|
get_all_blocks_recursively)
|
|
|
-from aphrodite.processing.block.interfaces import Block, BlockAllocator
|
|
|
+from aphrodite.processing.block.interfaces import (Block, BlockAllocator,
|
|
|
+ BlockId, Device)
|
|
|
from aphrodite.processing.block.naive_block import (NaiveBlock,
|
|
|
NaiveBlockAllocator)
|
|
|
from aphrodite.processing.evictor_v2 import (EvictionPolicy, Evictor,
|
|
|
make_evictor)
|
|
|
|
|
|
PrefixHash = int
|
|
|
-BlockId = int
|
|
|
|
|
|
# By default, we init our block access time as _DEFAULT_LAST_ACCESSED_TIME
|
|
|
# so that if we find one block is still hold _DEFAULT_LAST_ACCESSED_TIME,
|
|
@@ -40,7 +40,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|
|
num_blocks: int,
|
|
|
block_size: int,
|
|
|
block_ids: Optional[Iterable[int]] = None,
|
|
|
- eviction_policy: Optional[EvictionPolicy] = EvictionPolicy.LRU,
|
|
|
+ eviction_policy: EvictionPolicy = EvictionPolicy.LRU,
|
|
|
):
|
|
|
# A mapping of prefix hash to block index. All blocks which have a
|
|
|
# prefix hash will be in this dict, even if they have refcount 0.
|
|
@@ -51,7 +51,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|
|
|
|
|
# An allocator for blocks that do not have prefix hashes.
|
|
|
self._hashless_allocator = NaiveBlockAllocator(
|
|
|
- create_block=self._create_block,
|
|
|
+ create_block=self._create_block, # type: ignore
|
|
|
num_blocks=num_blocks,
|
|
|
block_size=block_size,
|
|
|
block_ids=block_ids,
|
|
@@ -81,7 +81,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|
|
block_size: int,
|
|
|
allocator: BlockAllocator,
|
|
|
block_id: Optional[int] = None,
|
|
|
- computed: Optional[bool] = False,
|
|
|
+ computed: bool = False,
|
|
|
) -> Block:
|
|
|
# Bind block to self.
|
|
|
allocator = self
|
|
@@ -95,8 +95,10 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|
|
computed=computed,
|
|
|
)
|
|
|
|
|
|
- def allocate_immutable(self, prev_block: Optional[Block],
|
|
|
- token_ids: List[int]) -> Block:
|
|
|
+ def allocate_immutable(self,
|
|
|
+ prev_block: Optional[Block],
|
|
|
+ token_ids: List[int],
|
|
|
+ device: Optional[Device] = None) -> Block:
|
|
|
"""Allocates an immutable block with the given token IDs, reusing cached
|
|
|
blocks if possible.
|
|
|
|
|
@@ -107,6 +109,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|
|
Returns:
|
|
|
Block: The allocated immutable block.
|
|
|
"""
|
|
|
+ assert device is None
|
|
|
assert_prefix_caching_block_or_none(prev_block)
|
|
|
|
|
|
block = self._create_block(
|
|
@@ -129,16 +132,20 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|
|
|
|
|
return block
|
|
|
|
|
|
- def allocate_mutable(self, prev_block: Block) -> Block:
|
|
|
+ def allocate_mutable(self,
|
|
|
+ prev_block: Optional[Block],
|
|
|
+ device: Optional[Device] = None) -> Block:
|
|
|
"""Allocates a mutable block. If there are no free blocks, this will
|
|
|
evict unused cached blocks.
|
|
|
|
|
|
Args:
|
|
|
prev_block (Block): The previous block in the sequence.
|
|
|
+ None is not allowed unlike it is super class.
|
|
|
|
|
|
Returns:
|
|
|
Block: The allocated mutable block.
|
|
|
"""
|
|
|
+ assert device is None
|
|
|
assert_prefix_caching_block_or_none(prev_block)
|
|
|
|
|
|
try:
|
|
@@ -146,6 +153,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|
|
prev_block=prev_block)
|
|
|
|
|
|
assert block.block_id not in self._blocks
|
|
|
+ assert block.block_id is not None
|
|
|
self._blocks[block.block_id] = block
|
|
|
return block
|
|
|
except BlockAllocator.NoFreeBlocksError:
|
|
@@ -185,6 +193,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|
|
assert block.content_hash is None
|
|
|
|
|
|
assert block.block_id not in self._blocks
|
|
|
+ assert block.block_id is not None
|
|
|
self._blocks[block.block_id] = block
|
|
|
return block
|
|
|
|
|
@@ -215,6 +224,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|
|
is not None), "freeing unallocated block is undefined"
|
|
|
|
|
|
self._free_block_id_for_block(block.block_id, block)
|
|
|
+
|
|
|
block.block_id = None
|
|
|
|
|
|
def _free_block_id_for_block(self, block_id: BlockId,
|
|
@@ -226,6 +236,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|
|
# We have fork case where block would get more than one ref,
|
|
|
# so we cannot free it from tracking if ref cnt large than 1
|
|
|
if refcount <= 1:
|
|
|
+ assert block.block_id is not None
|
|
|
del self._blocks[block.block_id]
|
|
|
return self._hashless_allocator.free(block)
|
|
|
|
|
@@ -234,6 +245,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|
|
# If no longer used, add the block to the evictor.
|
|
|
if refcount == 0:
|
|
|
assert block.content_hash in self._cached_blocks
|
|
|
+ assert block.block_id is not None
|
|
|
del self._blocks[block.block_id]
|
|
|
self.evictor.add(block.block_id, block.content_hash,
|
|
|
block.num_tokens_total, block.last_accessed)
|
|
@@ -269,18 +281,21 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|
|
|
|
|
return forked_blocks
|
|
|
|
|
|
- def get_num_free_blocks(self) -> int:
|
|
|
+ def get_num_free_blocks(self, device: Optional[Device] = None) -> int:
|
|
|
+ assert device is None
|
|
|
# The number of free blocks is the number of hashless free blocks
|
|
|
# plus the number of blocks evictor could free from its list.
|
|
|
return self._hashless_allocator.get_num_free_blocks(
|
|
|
) + self.evictor.num_blocks
|
|
|
|
|
|
+ def get_num_total_blocks(self) -> int:
|
|
|
+ return self._hashless_allocator.get_num_total_blocks()
|
|
|
+
|
|
|
@property
|
|
|
- def all_block_ids(self) -> frozenset[int]:
|
|
|
+ def all_block_ids(self) -> FrozenSet[int]:
|
|
|
return self._hashless_allocator.all_block_ids
|
|
|
|
|
|
- def promote_to_immutable_block(self,
|
|
|
- block: "PrefixCachingBlock") -> BlockId:
|
|
|
+ def promote_to_immutable_block(self, block: Block) -> BlockId:
|
|
|
"""Once a mutable block is full, it can be promoted to an immutable
|
|
|
block. This means that its content can be referenced by future blocks
|
|
|
having the same prefix.
|
|
@@ -290,7 +305,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|
|
block.
|
|
|
|
|
|
Args:
|
|
|
- block (PrefixCachingBlock): The mutable block to be promoted.
|
|
|
+ block: The mutable block to be promoted.
|
|
|
|
|
|
Returns:
|
|
|
BlockId: Either the original block index, or the block index of
|
|
@@ -337,6 +352,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|
|
def mark_blocks_as_accessed(self, block_ids: List[int],
|
|
|
now: float) -> None:
|
|
|
"""Mark blocks as accessed, used in prefix caching.
|
|
|
+
|
|
|
If the block is added into evictor, we need to update corresponding
|
|
|
info in evictor's metadata.
|
|
|
"""
|
|
@@ -371,6 +387,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|
|
def get_common_computed_block_ids(
|
|
|
self, seq_block_ids: List[List[int]]) -> List[int]:
|
|
|
"""Return the block ids that are common for a given sequence group.
|
|
|
+
|
|
|
Only those blocks that are immutable and already be marked
|
|
|
compyted would be taken consideration.
|
|
|
"""
|
|
@@ -384,8 +401,11 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|
|
takewhile(lambda block_id: self.block_is_computed(block_id),
|
|
|
seq[:-1])) for seq in seq_block_ids
|
|
|
]
|
|
|
- res = commonprefix([ids for ids in ids_list if ids != []])
|
|
|
- return res
|
|
|
+ # It returns a list of int although type annotation says list of string.
|
|
|
+ return commonprefix([
|
|
|
+ ids for ids in ids_list # type: ignore
|
|
|
+ if ids != []
|
|
|
+ ])
|
|
|
|
|
|
|
|
|
class PrefixCachingBlock(Block):
|
|
@@ -402,27 +422,33 @@ class PrefixCachingBlock(Block):
|
|
|
token_ids (List[int]): The initial token IDs to be stored in the block.
|
|
|
block_size (int): The maximum number of token IDs that can be stored in
|
|
|
the block.
|
|
|
- prefix_caching_allocator (PrefixCachingBlockAllocator): The prefix
|
|
|
+ prefix_caching_allocator (BlockAllocator): The prefix
|
|
|
caching block allocator associated with this block.
|
|
|
block_id (Optional[int], optional): The physical block index
|
|
|
of this block. Defaults to None.
|
|
|
"""
|
|
|
|
|
|
- def __init__(self,
|
|
|
- prev_block: Optional["PrefixCachingBlock"],
|
|
|
- token_ids: List[int],
|
|
|
- block_size: int,
|
|
|
- prefix_caching_allocator: PrefixCachingBlockAllocator,
|
|
|
- block_id: Optional[int] = None,
|
|
|
- computed: Optional[bool] = False):
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ prev_block: Optional[Block],
|
|
|
+ token_ids: List[int],
|
|
|
+ block_size: int,
|
|
|
+ prefix_caching_allocator: BlockAllocator,
|
|
|
+ block_id: Optional[int] = None,
|
|
|
+ computed: bool = False,
|
|
|
+ ):
|
|
|
+ assert isinstance(prefix_caching_allocator,
|
|
|
+ PrefixCachingBlockAllocator), (
|
|
|
+ "Currently this class is only tested with "
|
|
|
+ "PrefixCachingBlockAllocator.")
|
|
|
assert_prefix_caching_block_or_none(prev_block)
|
|
|
|
|
|
self._prev_block = prev_block
|
|
|
self._cached_content_hash: Optional[int] = None
|
|
|
self._cached_num_tokens_total: Optional[int] = None
|
|
|
self._prefix_caching_allocator = prefix_caching_allocator
|
|
|
- self.last_accessed = _DEFAULT_LAST_ACCESSED_TIME
|
|
|
- self.computed = computed
|
|
|
+ self._last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME
|
|
|
+ self._computed = computed
|
|
|
|
|
|
self._block = NaiveBlock(
|
|
|
prev_block=prev_block,
|
|
@@ -433,6 +459,22 @@ class PrefixCachingBlock(Block):
|
|
|
_cow_target=self,
|
|
|
)
|
|
|
|
|
|
+ @property
|
|
|
+ def computed(self) -> bool:
|
|
|
+ return self._computed
|
|
|
+
|
|
|
+ @computed.setter
|
|
|
+ def computed(self, value) -> None:
|
|
|
+ self._computed = value
|
|
|
+
|
|
|
+ @property
|
|
|
+ def last_accessed(self) -> float:
|
|
|
+ return self._last_accessed
|
|
|
+
|
|
|
+ @last_accessed.setter
|
|
|
+ def last_accessed(self, last_accessed_ts: float):
|
|
|
+ self._last_accessed = last_accessed_ts
|
|
|
+
|
|
|
def append_token_ids(self, token_ids: List[int]) -> None:
|
|
|
"""Appends the given token IDs to the block and registers the block as
|
|
|
immutable if the block becomes full.
|
|
@@ -473,17 +515,18 @@ class PrefixCachingBlock(Block):
|
|
|
@property
|
|
|
def num_tokens_total(self) -> int:
|
|
|
"""return the total tokens so far.
|
|
|
+
|
|
|
Here we iterate the block chain till to the first block, while
|
|
|
cache the result in local to prevent repeated computations.
|
|
|
"""
|
|
|
if self._cached_num_tokens_total is not None:
|
|
|
return self._cached_num_tokens_total
|
|
|
|
|
|
- _block = self
|
|
|
+ _block: Optional[Block] = self
|
|
|
self._cached_num_tokens_total = 0
|
|
|
|
|
|
- # TODO: current implementation here is O(N^2), we expect future
|
|
|
- # ones to be O(1)
|
|
|
+ # TODO: current implement here take O(N^2), we expect future
|
|
|
+ # we have O(1) here
|
|
|
while _block is not None:
|
|
|
self._cached_num_tokens_total += len(_block.token_ids)
|
|
|
_block = _block.prev_block
|
|
@@ -520,8 +563,10 @@ class PrefixCachingBlock(Block):
|
|
|
return None
|
|
|
|
|
|
is_first_block = self._prev_block is None
|
|
|
- prev_block_hash = (None if is_first_block else
|
|
|
- self._prev_block.content_hash)
|
|
|
+ prev_block_hash = (
|
|
|
+ None if is_first_block else
|
|
|
+ self._prev_block.content_hash # type: ignore
|
|
|
+ )
|
|
|
|
|
|
# Previous block exists but does not yet have a hash.
|
|
|
# Return no hash in this case.
|