123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307 |
- from collections import deque
- from typing import Deque, Dict, Iterable, List, Optional, Protocol, Tuple
- from aphrodite.processing.block.interfaces import Block, BlockAllocator
- BlockId = int
- RefCount = int
- class RefCounterProtocol(Protocol):
- def incr(self, block_id: BlockId) -> RefCount:
- raise NotImplementedError
- def decr(self, block_id: BlockId) -> RefCount:
- raise NotImplementedError
- def get(self, block_id: BlockId) -> RefCount:
- raise NotImplementedError
- class RefCounter(RefCounterProtocol):
- """A class for managing reference counts for a set of block indices.
- The RefCounter class maintains a dictionary that maps block indices to their
- corresponding reference counts. It provides methods to increment, decrement,
- and retrieve the reference count for a given block index.
- Args:
- all_block_indices (Iterable[BlockId]): An iterable of block indices
- to initialize the reference counter with.
- """
- def __init__(self, all_block_indices: Iterable[BlockId]):
- deduped = set(all_block_indices)
- self._refcounts: Dict[BlockId,
- RefCount] = {index: 0
- for index in deduped}
- def incr(self, block_id: BlockId) -> RefCount:
- assert block_id in self._refcounts
- pre_incr_refcount = self._refcounts[block_id]
- assert pre_incr_refcount >= 0
- post_incr_refcount = pre_incr_refcount + 1
- self._refcounts[block_id] = post_incr_refcount
- return post_incr_refcount
- def decr(self, block_id: BlockId) -> RefCount:
- assert block_id in self._refcounts
- refcount = self._refcounts[block_id]
- assert refcount > 0
- refcount -= 1
- self._refcounts[block_id] = refcount
- return refcount
- def get(self, block_id: BlockId) -> RefCount:
- assert block_id in self._refcounts
- return self._refcounts[block_id]
- def as_readonly(self) -> "ReadOnlyRefCounter":
- return ReadOnlyRefCounter(self)
- class ReadOnlyRefCounter(RefCounterProtocol):
- """A read-only view of the RefCounter class.
- The ReadOnlyRefCounter class provides a read-only interface to access the
- reference counts maintained by a RefCounter instance. It does not allow
- modifications to the reference counts.
- Args:
- refcounter (RefCounter): The RefCounter instance to create a read-only
- view for.
- """
- def __init__(self, refcounter: RefCounter):
- self._refcounter = refcounter
- def incr(self, block_id: BlockId) -> RefCount:
- raise ValueError("Incr not allowed")
- def decr(self, block_id: BlockId) -> RefCount:
- raise ValueError("Decr not allowed")
- def get(self, block_id: BlockId) -> RefCount:
- return self._refcounter.get(block_id)
- class CopyOnWriteTracker:
- """A class for tracking and managing copy-on-write operations for blocks.
- The CopyOnWriteTracker class maintains a mapping of source block indices to
- their corresponding copy-on-write destination block indices. It works in
- conjunction with a RefCounter.
- Args:
- refcounter (RefCounter): The reference counter used to track block
- reference counts.
- """
- def __init__(self, refcounter: RefCounterProtocol):
- self._copy_on_writes: List[Tuple[BlockId, BlockId]] = []
- self._refcounter = refcounter
- def is_appendable(self, block: Block) -> bool:
- """Checks if the block is shared or not. If shared, then it cannot
- be appended and needs to be duplicated via copy-on-write
- """
- block_id = block.block_id
- if block_id is None:
- return True
- refcount = self._refcounter.get(block_id)
- return refcount <= 1
- def record_cow(self, src_block_id: Optional[BlockId],
- trg_block_id: Optional[BlockId]) -> None:
- """Records a copy-on-write operation from source to target block id
- Args:
- src_block_id (BlockId): The source block id from which to copy
- the data
- trg_block_id (BlockId): The target block id to which the data
- is copied
- """
- assert src_block_id is not None
- assert trg_block_id is not None
- self._copy_on_writes.append((src_block_id, trg_block_id))
- def clear_cows(self) -> List[Tuple[BlockId, BlockId]]:
- """Clears the copy-on-write tracking information and returns the current
- state.
- This method returns a list mapping source block indices to
- destination block indices for the current copy-on-write operations.
- It then clears the internal tracking information.
- Returns:
- List[Tuple[BlockId, BlockId]]: A list mapping source
- block indices to destination block indices for the
- current copy-on-write operations.
- """
- cows = self._copy_on_writes
- self._copy_on_writes = []
- return cows
- class BlockPool:
- """Used to pre-allocate block objects, in order to avoid excessive python
- object allocations/deallocations.
- The pool starts from "pool_size" objects and will increase to more objects
- if necessary
- Note that multiple block objects may point to the same physical block id,
- which is why this pool is needed, so that it will be easier to support
- prefix caching and more complicated sharing of physical blocks.
- """
- def __init__(self, block_size: int, create_block: Block.Factory,
- allocator: BlockAllocator, pool_size: int):
- self._block_size = block_size
- self._create_block = create_block
- self._allocator = allocator
- self._pool_size = pool_size
- assert self._pool_size >= 0
- self._free_ids: Deque[int] = deque(range(self._pool_size))
- self._pool = []
- for i in range(self._pool_size):
- self._pool.append(
- self._create_block(prev_block=None,
- token_ids=[],
- block_size=self._block_size,
- allocator=self._allocator,
- block_id=None))
- def increase_pool(self):
- """Doubles the internal pool size
- """
- cur_pool_size = self._pool_size
- new_pool_size = cur_pool_size * 2
- self._pool_size = new_pool_size
- self._free_ids += deque(range(cur_pool_size, new_pool_size))
- for i in range(cur_pool_size, new_pool_size):
- self._pool.append(
- self._create_block(prev_block=None,
- token_ids=[],
- block_size=self._block_size,
- allocator=self._allocator,
- block_id=None))
- def init_block(self, prev_block: Optional[Block], token_ids: List[int],
- block_size: int, physical_block_id: Optional[int]) -> Block:
- if len(self._free_ids) == 0:
- self.increase_pool()
- assert len(self._free_ids) > 0
- pool_id = self._free_ids.popleft()
- block = self._pool[pool_id]
- block.__init__( # type: ignore[misc]
- prev_block=prev_block,
- token_ids=token_ids,
- block_size=block_size,
- allocator=block._allocator, # type: ignore[attr-defined]
- block_id=physical_block_id)
- block.pool_id = pool_id # type: ignore[attr-defined]
- return block
- def free_block(self, block: Block) -> None:
- self._free_ids.appendleft(block.pool_id) # type: ignore[attr-defined]
- class BlockList:
- """This class is an optimization to allow fast-access to physical
- block ids. It maintains a block id list that is updated with the
- block list and this avoids the need to reconstruct the block id
- list on every iteration of the block manager
- """
- def __init__(self, blocks: List[Block]):
- self._blocks: List[Block] = []
- self._block_ids: List[int] = []
- self.update(blocks)
- def _add_block_id(self, block_id: Optional[BlockId]) -> None:
- assert block_id is not None
- self._block_ids.append(block_id)
- def _update_block_id(self, block_index: int,
- new_block_id: Optional[BlockId]) -> None:
- assert new_block_id is not None
- self._block_ids[block_index] = new_block_id
- def update(self, blocks: List[Block]):
- self._blocks = blocks
- # Cache block ids for fast query
- self._block_ids = []
- for block in self._blocks:
- self._add_block_id(block.block_id)
- def append_token_ids(self, block_index: int, token_ids: List[int]) -> None:
- block = self._blocks[block_index]
- prev_block_id = block.block_id
- block.append_token_ids(token_ids)
- # CoW or promotion may update the internal block_id
- if prev_block_id != block.block_id:
- self._update_block_id(block_index, block.block_id)
- def append(self, new_block: Block):
- self._blocks.append(new_block)
- self._add_block_id(new_block.block_id)
- def __len__(self) -> int:
- return len(self._blocks)
- def __getitem__(self, block_index: int) -> Block:
- return self._blocks[block_index]
- def __setitem__(self, block_index: int, new_block: Block) -> None:
- self._blocks[block_index] = new_block
- self._update_block_id(block_index, new_block.block_id)
- def reset(self):
- self._blocks = []
- self._block_ids = []
- def list(self) -> List[Block]:
- return self._blocks
- def ids(self) -> List[int]:
- return self._block_ids
- def get_all_blocks_recursively(last_block: Block) -> List[Block]:
- """Retrieves all the blocks in a sequence starting from the last block.
- This function recursively traverses the sequence of blocks in reverse order,
- starting from the given last block, and returns a list of all the blocks in
- the sequence.
- Args:
- last_block (Block): The last block in the sequence.
- Returns:
- List[Block]: A list of all the blocks in the sequence, in the order they
- appear.
- """
- def recurse(block: Block, lst: List[Block]) -> None:
- if block.prev_block is not None:
- recurse(block.prev_block, lst)
- lst.append(block)
- all_blocks: List[Block] = []
- recurse(last_block, all_blocks)
- return all_blocks
|