Răsfoiți Sursa

enable prefix caching with v2 block manager for spec decoding

AlpinDale 8 luni în urmă
părinte
comite
6f6bf568e5

+ 10 - 2
aphrodite/processing/block/cpu_gpu_block_allocator.py

@@ -192,10 +192,18 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
         device = Device.GPU
         return self._allocators[device].clear_copy_on_writes()
 
-    def mark_blocks_as_computed(self) -> None:
+    def mark_blocks_as_accessed(self, block_ids: List[int],
+                                now: float) -> None:
+        """Mark blocks as accessed, only use for prefix caching."""
         # Prefix caching only supported on GPU.
         device = Device.GPU
-        return self._allocators[device].mark_blocks_as_computed()
+        return self._allocators[device].mark_blocks_as_accessed(block_ids, now)
+
+    def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
+        """Mark blocks as accessed, only use for prefix caching."""
+        # Prefix caching only supported on GPU.
+        device = Device.GPU
+        return self._allocators[device].mark_blocks_as_computed(block_ids)
 
     def get_common_computed_block_ids(
             self, seq_block_ids: List[List[int]]) -> List[int]:

+ 4 - 0
aphrodite/processing/block/interfaces.py

@@ -75,6 +75,10 @@ class BlockAllocator(ABC):
     def clear_copy_on_writes(self) -> Dict[int, List[int]]:
         pass
 
+    @abstractmethod
+    def mark_blocks_as_accessed(self) -> None:
+        pass
+
     @abstractmethod
     def mark_blocks_as_computed(self) -> None:
         pass

+ 9 - 1
aphrodite/processing/block/naive_block.py

@@ -174,7 +174,15 @@ class NaiveBlockAllocator(BlockAllocator):
         """
         return self._cow_tracker.clear_cows()
 
-    def mark_blocks_as_computed(self) -> None:
+    def mark_blocks_as_accessed(self, block_ids: List[int],
+                                now: float) -> None:
+        """Mark blocks as accessed, used in prefix caching.
+        Since the naive allocator does not implement prefix caching, we do
+        nothing.
+        """
+        pass
+
+    def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
         """Mark blocks as computed, used in prefix caching.
 
         Since the naive allocator does not implement prefix caching, we do

+ 138 - 47
aphrodite/processing/block/prefix_caching_block.py

@@ -8,10 +8,17 @@ from aphrodite.processing.block.common import (CopyOnWriteTracker,
 from aphrodite.processing.block.interfaces import Block, BlockAllocator
 from aphrodite.processing.block.naive_block import (NaiveBlock,
                                                     NaiveBlockAllocator)
+from aphrodite.processing.evictor_v2 import (EvictionPolicy, Evictor,
+                                             make_evictor)
 
 PrefixHash = int
 BlockId = int
 
+# By default, we init our block access time as _DEFAULT_LAST_ACCESSED_TIME
+# so that if we find one block is still hold _DEFAULT_LAST_ACCESSED_TIME,
+# then we know this block hasn't been accessed yet.
+_DEFAULT_LAST_ACCESSED_TIME = -1
+
 
 class PrefixCachingBlockAllocator(BlockAllocator):
     """A block allocator that implements prefix caching.
@@ -28,22 +35,19 @@ class PrefixCachingBlockAllocator(BlockAllocator):
             from 0 to num_blocks - 1.
     """
 
-    # TODO last access time / evictor integration
-
     def __init__(
         self,
         num_blocks: int,
         block_size: int,
         block_ids: Optional[Iterable[int]] = None,
+        eviction_policy: Optional[EvictionPolicy] = EvictionPolicy.LRU,
     ):
         # A mapping of prefix hash to block index. All blocks which have a
         # prefix hash will be in this dict, even if they have refcount 0.
         self._cached_blocks: Dict[PrefixHash, BlockId] = {}
 
-        # A mapping of prefix hash to block index. All blocks which have a
-        # prefix hash AND refcount 0 will be in this dict. Thus, it is a subset
-        # of self._cached_blocks.
-        self._unused_cached_blocks: Dict[PrefixHash, BlockId] = {}
+        # A mapping of blockId to Block to track those cached blocks
+        self._blocks: Dict[BlockId, Block] = {}
 
         # An allocator for blocks that do not have prefix hashes.
         self._hashless_allocator = NaiveBlockAllocator(
@@ -55,6 +59,10 @@ class PrefixCachingBlockAllocator(BlockAllocator):
 
         self._block_size = block_size
 
+        # Evitor used to maintain how we want to handle those computed blocks
+        # if we find memory pressure is high.
+        self.evictor: Evictor = make_evictor(eviction_policy)
+
         # We share the refcounter between allocators. This allows us to promote
         # blocks originally allocated in the hashless allocator to immutable
         # blocks.
@@ -73,6 +81,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
         block_size: int,
         allocator: BlockAllocator,
         block_id: Optional[int] = None,
+        computed: Optional[bool] = False,
     ) -> Block:
         # Bind block to self.
         allocator = self
@@ -83,6 +92,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
             block_size=block_size,
             block_id=block_id,
             prefix_caching_allocator=allocator,
+            computed=computed,
         )
 
     def allocate_immutable(self, prev_block: Optional[Block],
@@ -110,14 +120,12 @@ class PrefixCachingBlockAllocator(BlockAllocator):
         cached_block_id = self._cached_blocks.get(block.content_hash, None)
         if cached_block_id is not None:
             block.block_id = cached_block_id
-            self._incr_refcount_cached_block(block.content_hash,
-                                             block.block_id)
+            self._incr_refcount_cached_block(block, block.block_id)
             return block
 
         block = self.allocate_mutable(prev_block)
         block.append_token_ids(token_ids)
         assert block.content_hash is not None
-        # TODO computed bit
 
         return block
 
@@ -134,41 +142,67 @@ class PrefixCachingBlockAllocator(BlockAllocator):
         assert_prefix_caching_block_or_none(prev_block)
 
         try:
-            return self._hashless_allocator.allocate_mutable(
+            block = self._hashless_allocator.allocate_mutable(
                 prev_block=prev_block)
+
+            assert block.block_id not in self._blocks
+            self._blocks[block.block_id] = block
+            return block
         except BlockAllocator.NoFreeBlocksError:
             # We must check the unused cached blocks before raising OOM.
             pass
 
-        if self._unused_cached_blocks:
-            # TODO policy for selecting block to remove
-            content_hash_to_evict = next(iter(self._unused_cached_blocks))
+        # If the evictor has blocks available for eviction, evict a block
+        # and return it.
+        if self.evictor.num_blocks > 0:
+            block_id, content_hash_to_evict = self.evictor.evict()
+
+            # Here we may have scenario that several blocks have
+            # the same content hash, but due to the latter coming block
+            # is coming from mutable to immutable path, their physical
+            # block is added into evictor.
+            # However in this case, we shall not pop the _cached_blocks,
+            # as the same content is still used by others, which means
+            # we need to check ref before decide to pop the list.
+
+            _block_id = self._cached_blocks[content_hash_to_evict]
+            refcount = self._refcounter.get(_block_id)
+            if refcount == 1:
+                self._cached_blocks.pop(content_hash_to_evict)
+                assert _block_id == block_id
 
-            # Clear content hash mapping; the block will be overwritten.
-            del self._cached_blocks[content_hash_to_evict]
+            self._refcounter.incr(block_id)
 
-            block_id = self._unused_cached_blocks.pop(content_hash_to_evict)
-            refcount = self._refcounter.incr(block_id)
-            assert refcount == 1
+            # the block comes from evictor already contain computed result
             block = self._create_block(
                 prev_block=prev_block,
                 token_ids=[],
                 block_size=self._block_size,
                 allocator=self,
                 block_id=block_id,
+                computed=True,
             )
             assert block.content_hash is None
+
+            assert block.block_id not in self._blocks
+            self._blocks[block.block_id] = block
             return block
 
         # No block available in hashless allocator, nor in unused cache blocks.
         raise BlockAllocator.NoFreeBlocksError()
 
-    def _incr_refcount_cached_block(self, content_hash: int,
+    def _incr_refcount_cached_block(self, block: Block,
                                     block_id: BlockId) -> None:
+        # since block is already computed, mark it
+        block.computed = True
+
         refcount = self._refcounter.incr(block_id)
         if refcount == 1:
-            assert content_hash in self._unused_cached_blocks
-            del self._unused_cached_blocks[content_hash]
+            # if block get referred, then it shall not be in evictor
+            # and put it into _blocks for tracking
+            if block_id in self.evictor:
+                self.evictor.remove(block_id)
+            self._blocks[block_id] = block
 
     def free(self, block: Block) -> None:
         """Decrement the refcount of the block. If the decremented refcount is
@@ -188,15 +222,21 @@ class PrefixCachingBlockAllocator(BlockAllocator):
         assert isinstance(block, PrefixCachingBlock)
 
         if block.content_hash is None:
+            refcount = self._refcounter.get(block_id)
+            # We have fork case where block would get more than one ref,
+            # so we cannot free it from tracking if ref cnt large than 1
+            if refcount <= 1:
+                del self._blocks[block.block_id]
             return self._hashless_allocator.free(block)
 
         refcount = self._refcounter.decr(block_id)
 
-        # If no longer used, add the block to the unused cached blocks.
+        # If no longer used, add the block to the evictor.
         if refcount == 0:
-            assert block.content_hash not in self._unused_cached_blocks
             assert block.content_hash in self._cached_blocks
-            self._unused_cached_blocks[block.content_hash] = block_id
+            del self._blocks[block.block_id]
+            self.evictor.add(block.block_id, block.content_hash,
+                             block.num_tokens_total, block.last_accessed)
 
     def fork(self, last_block: Block) -> List[Block]:
         """Creates a new sequence of blocks that shares the same underlying
@@ -231,9 +271,9 @@ class PrefixCachingBlockAllocator(BlockAllocator):
 
     def get_num_free_blocks(self) -> int:
         # The number of free blocks is the number of hashless free blocks
-        # plus the number of hashful blocks that are unused.
-        return self._hashless_allocator.get_num_free_blocks() + len(
-            self._unused_cached_blocks)
+        # plus the number of blocks evictor could free from its list.
+        return self._hashless_allocator.get_num_free_blocks(
+        ) + self.evictor.num_blocks
 
     @property
     def all_block_ids(self) -> frozenset[int]:
@@ -267,7 +307,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
         else:
             self._free_block_id_for_block(block.block_id, block)
             self._incr_refcount_cached_block(
-                block.content_hash, self._cached_blocks[block.content_hash])
+                block, self._cached_blocks[block.content_hash])
 
         return self._cached_blocks[block.content_hash]
 
@@ -294,29 +334,58 @@ class PrefixCachingBlockAllocator(BlockAllocator):
         """
         return self._cow_tracker.clear_cows()
 
-    def mark_blocks_as_computed(self) -> None:
+    def mark_blocks_as_accessed(self, block_ids: List[int],
+                                now: float) -> None:
+        """Mark blocks as accessed, used in prefix caching.
+        If the block is added into evictor, we need to update corresponding
+        info in evictor's metadata.
+        """
+
+        for block_id in block_ids:
+            if block_id in self._blocks:
+                self._blocks[block_id].last_accessed = now
+            elif block_id in self.evictor:
+                self.evictor.update(block_id, now)
+            else:
+                raise ValueError(
+                    "Mark block as accessed which is not belonged to GPU")
+
+    def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
         """Mark blocks as computed, used in prefix caching."""
-        # TODO Track computed blocks.
-        pass
+
+        for block_id in block_ids:
+            if block_id in self._blocks:
+                # only those full block is valid for prefix caching
+                if self._blocks[block_id].is_full:
+                    self._blocks[block_id].computed = True
+            elif block_id not in self.evictor:
+                raise ValueError(f"Mark {block_id=} as computed which "
+                                 "is not belonged to GPU")
+
+    def block_is_computed(self, block_id: int) -> bool:
+        if block_id in self._blocks:
+            return self._blocks[block_id].computed
+        else:
+            return block_id in self.evictor
 
     def get_common_computed_block_ids(
             self, seq_block_ids: List[List[int]]) -> List[int]:
         """Return the block ids that are common for a given sequence group.
-
-        Used in prefill (can skip prefill of some blocks).
+        Only those blocks that are immutable and already be marked
+        compyted would be taken consideration.
         """
 
-        # TODO: Track computed blocks.
-        computed = lambda block_id: False
-
         # NOTE We exclude the last block to avoid the case where the entire
         # prompt is cached. This would cause erroneous behavior in model
         # runner.
+
         ids_list = [
-            takewhile(lambda block_id: computed(block_id), seq[:-1])
-            for seq in seq_block_ids
+            list(
+                takewhile(lambda block_id: self.block_is_computed(block_id),
+                          seq[:-1])) for seq in seq_block_ids
         ]
-        return commonprefix([ids for ids in ids_list if ids != []])
+        res = commonprefix([ids for ids in ids_list if ids != []])
+        return res
 
 
 class PrefixCachingBlock(Block):
@@ -339,19 +408,21 @@ class PrefixCachingBlock(Block):
             of this block. Defaults to None.
     """
 
-    def __init__(
-        self,
-        prev_block: Optional["PrefixCachingBlock"],
-        token_ids: List[int],
-        block_size: int,
-        prefix_caching_allocator: PrefixCachingBlockAllocator,
-        block_id: Optional[int] = None,
-    ):
+    def __init__(self,
+                 prev_block: Optional["PrefixCachingBlock"],
+                 token_ids: List[int],
+                 block_size: int,
+                 prefix_caching_allocator: PrefixCachingBlockAllocator,
+                 block_id: Optional[int] = None,
+                 computed: Optional[bool] = False):
         assert_prefix_caching_block_or_none(prev_block)
 
         self._prev_block = prev_block
         self._cached_content_hash: Optional[int] = None
+        self._cached_num_tokens_total: Optional[int] = None
         self._prefix_caching_allocator = prefix_caching_allocator
+        self.last_accessed = _DEFAULT_LAST_ACCESSED_TIME
+        self.computed = computed
 
         self._block = NaiveBlock(
             prev_block=prev_block,
@@ -399,6 +470,26 @@ class PrefixCachingBlock(Block):
     def num_empty_slots(self) -> int:
         return self._block.num_empty_slots
 
+    @property
+    def num_tokens_total(self) -> int:
+        """return the total tokens so far.
+        Here we iterate the block chain till to the first block, while
+        cache the result in local to prevent repeated computations.
+        """
+        if self._cached_num_tokens_total is not None:
+            return self._cached_num_tokens_total
+
+        _block = self
+        self._cached_num_tokens_total = 0
+
+        # TODO: current implementation here is O(N^2), we expect future
+        # ones to be O(1)
+        while _block is not None:
+            self._cached_num_tokens_total += len(_block.token_ids)
+            _block = _block.prev_block
+
+        return self._cached_num_tokens_total
+
     @property
     def block_size(self) -> int:
         return self._block.block_size

+ 2 - 1
aphrodite/processing/block_manager_v1.py

@@ -11,7 +11,8 @@ from loguru import logger
 from aphrodite.common.block import BlockTable, PhysicalTokenBlock
 from aphrodite.common.sequence import Sequence, SequenceGroup, SequenceStatus
 from aphrodite.common.utils import Device
-from aphrodite.processing.evictor import EvictionPolicy, Evictor, make_evictor
+from aphrodite.processing.evictor_v1 import (EvictionPolicy, Evictor,
+                                             make_evictor)
 from aphrodite.processing.interfaces import AllocStatus, BlockSpaceManager
 
 

+ 19 - 11
aphrodite/processing/block_manager_v2.py

@@ -69,14 +69,12 @@ class BlockSpaceManagerV2(BlockSpaceManager):
         self.watermark = watermark
         assert watermark >= 0.0
 
-        assert not enable_caching, "Prefix caching not yet supported"
         self.enable_caching = enable_caching
 
         self.watermark_blocks = int(watermark * num_gpu_blocks)
 
         self.block_allocator = CpuGpuBlockAllocator.create(
-            # Currently, only naive blocks are supported (no prefix caching).
-            allocator_type="naive",
+            allocator_type="prefix_caching" if enable_caching else "naive",
             num_gpu_blocks=num_gpu_blocks,
             num_cpu_blocks=num_cpu_blocks,
             block_size=block_size,
@@ -191,16 +189,26 @@ class BlockSpaceManagerV2(BlockSpaceManager):
         assert all(b is not None for b in block_ids)
         return block_ids
 
-    def access_all_blocks_in_seq(self, seq, now):
-        # TODO add prefix caching support.
-        pass
+    def access_all_blocks_in_seq(self, seq: Sequence, now: float):
+        # Update the last accessed time of all the blocks accessed
+        # in this step.
+        # And the accessed time is only useful for prefix caching now,
+        # as it support internal evictor policy for which cached
+        # block could be refilled, to keep cached content could be reused
+        # at max extend.
+        if self.enable_caching:
+            block_table = self.block_tables[seq.seq_id]
+            block_ids = []
+            for block_id in block_table.physical_block_ids:
+                block_ids.append(block_id)
+            self.block_allocator.mark_blocks_as_accessed(block_ids, now)
 
     def mark_blocks_as_computed(self, seq_group: SequenceGroup):
-        # We ignore the sequence group as its not necessary. After the batch is
-        # formed by the scheduler, we do not need to mark blocks from individual
-        # sequence groups as computed -- all blocks in the batch can be marked
-        # as computed.
-        self.block_allocator.mark_blocks_as_computed()
+        # The only need for mark block as computed is for prefix caching,
+        # while currently we could determine whether one block is computed
+        # or not by check whether it has content hash.
+        # So this function is useless for block_v2.
+        pass
 
     def get_common_computed_block_ids(
             self, seqs: List[Sequence]) -> GenericSequence[int]:

+ 0 - 0
aphrodite/processing/evictor.py → aphrodite/processing/evictor_v1.py


+ 121 - 0
aphrodite/processing/evictor_v2.py

@@ -0,0 +1,121 @@
+import enum
+from abc import ABC, abstractmethod, abstractproperty
+from typing import OrderedDict, Tuple
+
+
+class EvictionPolicy(enum.Enum):
+    """Enum for eviction policy used by make_evictor to instantiate the correct
+       Evictor subclass.
+    """
+    LRU = enum.auto()
+
+
+class Evictor(ABC):
+    """The Evictor subclasses should be used by the BlockAllocator class to
+    handle eviction of freed PhysicalTokenBlocks.
+    """
+
+    @abstractmethod
+    def __init__(self):
+        pass
+
+    @abstractmethod
+    def __contains__(self, block_id: int) -> bool:
+        pass
+
+    @abstractmethod
+    def evict(self) -> Tuple[int, int]:
+        """Runs the eviction algorithm and returns the evicted block's
+        content hash along with physical block id along with physical block id
+        """
+        pass
+
+    @abstractmethod
+    def add(self, block_id: int, content_hash: int, num_hashed_tokens: int,
+            last_accessed: int):
+        """Adds block to the evictor, making it a candidate for eviction"""
+        pass
+
+    @abstractmethod
+    def update(self, block_id: int, last_accessed: int):
+        """Update corresponding block's access time in metadata"""
+        pass
+
+    @abstractproperty
+    def num_blocks(self) -> int:
+        pass
+
+
+class BlockMetaData():
+    """Data structure for storing key data describe cached block, so that
+    evitor could use to make its decision which one to choose for eviction
+    Here we use physical block id as the dict key, as there maybe several
+    blocks with the same content hash, but their physical id is unique.
+    """
+
+    def __init__(self, content_hash: int, num_hashed_tokens: int,
+                 last_accessed: int):
+        self.content_hash = content_hash
+        self.num_hashed_tokens = num_hashed_tokens
+        self.last_accessed = last_accessed
+
+
+class LRUEvictor(Evictor):
+    """Evicts in a least-recently-used order using the last_accessed timestamp
+    that's recorded in the PhysicalTokenBlock. If there are multiple blocks with
+    the same last_accessed time, then the one with the largest num_hashed_tokens
+    will be evicted. If two blocks each have the lowest last_accessed time and
+    highest num_hashed_tokens value, then one will be chose arbitrarily
+    """
+
+    def __init__(self):
+        self.free_table: OrderedDict[int, BlockMetaData] = OrderedDict()
+
+    def __contains__(self, block_id: int) -> bool:
+        return block_id in self.free_table
+
+    def evict(self) -> Tuple[int, int]:
+        if len(self.free_table) == 0:
+            raise ValueError("No usable cache memory left")
+
+        evicted_block = next(iter(self.free_table.values()))
+        evicted_block_id = next(iter(self.free_table.keys()))
+        # The blocks with the lowest timestamps should be placed consecutively
+        # at the start of OrderedDict. Loop through all these blocks to
+        # find the one with maximum number of hashed tokens.
+        for _id, block in self.free_table.items():
+            if evicted_block.last_accessed > block.last_accessed or (
+                    evicted_block.last_accessed == block.last_accessed and
+                    evicted_block.num_hashed_tokens < block.num_hashed_tokens):
+                evicted_block = block
+                evicted_block_id = _id
+
+        self.free_table.pop(evicted_block_id)
+
+        return evicted_block_id, evicted_block.content_hash
+
+    def add(self, block_id: int, content_hash: int, num_hashed_tokens: int,
+            last_accessed: int):
+        self.free_table[block_id] = BlockMetaData(content_hash,
+                                                  num_hashed_tokens,
+                                                  last_accessed)
+
+    def update(self, block_id: int, last_accessed: int):
+        self.free_table[block_id].last_accessed = last_accessed
+
+    def remove(self, block_id: int):
+        if block_id not in self.free_table:
+            raise ValueError(
+                "Attempting to remove block that's not in the evictor")
+        self.free_table.pop(block_id)
+
+    @property
+    def num_blocks(self) -> int:
+        return len(self.free_table)
+
+
+def make_evictor(eviction_policy: EvictionPolicy) -> Evictor:
+    if eviction_policy == EvictionPolicy.LRU:
+        return LRUEvictor()
+    else:
+        raise ValueError(f"Unknown cache eviction policy: {eviction_policy}")