123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187 |
- from collections import defaultdict
- from typing import Dict, Iterable, List, Optional
- from aphrodite.processing.block.interfaces import Block, BlockAllocator
- BlockId = int
- RefCount = int
- class RefCounter:
- """A class for managing reference counts for a set of block indices.
- The RefCounter class maintains a dictionary that maps block indices to their
- corresponding reference counts. It provides methods to increment, decrement,
- and retrieve the reference count for a given block index.
- Args:
- all_block_indices (Iterable[BlockId]): An iterable of block indices
- to initialize the reference counter with.
- """
- def __init__(self, all_block_indices: Iterable[BlockId]):
- deduped = set(all_block_indices)
- self._refcounts: Dict[BlockId,
- RefCount] = {index: 0
- for index in deduped}
- def incr(self, block_id: BlockId) -> RefCount:
- assert block_id in self._refcounts
- pre_incr_refcount = self._refcounts[block_id]
- assert pre_incr_refcount >= 0
- post_incr_refcount = pre_incr_refcount + 1
- self._refcounts[block_id] = post_incr_refcount
- return post_incr_refcount
- def decr(self, block_id: BlockId) -> RefCount:
- assert block_id in self._refcounts
- refcount = self._refcounts[block_id]
- assert refcount > 0
- refcount -= 1
- self._refcounts[block_id] = refcount
- return refcount
- def get(self, block_id: BlockId) -> RefCount:
- assert block_id in self._refcounts
- return self._refcounts[block_id]
- def as_readonly(self) -> "ReadOnlyRefCounter":
- return ReadOnlyRefCounter(self)
- class ReadOnlyRefCounter:
- """A read-only view of the RefCounter class.
- The ReadOnlyRefCounter class provides a read-only interface to access the
- reference counts maintained by a RefCounter instance. It does not allow
- modifications to the reference counts.
- Args:
- refcounter (RefCounter): The RefCounter instance to create a read-only
- view for.
- """
- def __init__(self, refcounter: RefCounter):
- self._refcounter = refcounter
- def incr(self, block_id: BlockId) -> RefCount:
- raise ValueError("Incr not allowed")
- def decr(self, block_id: BlockId) -> RefCount:
- raise ValueError("Decr not allowed")
- def get(self, block_id: BlockId) -> RefCount:
- return self._refcounter.get(block_id)
- class CopyOnWriteTracker:
- """A class for tracking and managing copy-on-write operations for blocks.
- The CopyOnWriteTracker class maintains a mapping of source block indices to
- their corresponding copy-on-write destination block indices. It works in
- conjunction with a RefCounter and a BlockAllocator to handle reference
- counting and block allocation.
- Args:
- refcounter (RefCounter): The reference counter used to track block
- reference counts.
- allocator (BlockAllocator): The block allocator used to allocate and
- free blocks.
- """
- def __init__(
- self,
- refcounter: RefCounter,
- allocator: BlockAllocator,
- ):
- self._copy_on_writes: Dict[BlockId, List[BlockId]] = defaultdict(list)
- self._refcounter = refcounter
- self._allocator = allocator
- def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
- """Performs a copy-on-write operation on the given block if it is not
- appendable.
- This method checks the reference count of the given block. If the
- reference count is greater than 1, indicating that the block is shared,
- a copy-on-write operation is performed. The original block is freed,
- and a new block is allocated with the same content. The new block index
- is returned.
- Args:
- block (Block): The block to check for copy-on-write.
- Returns:
- Optional[BlockId]: The block index of the new block if a copy-on
- -write operation was performed, or the original block index if
- no copy-on-write was necessary.
- """
- block_id = block.block_id
- if block_id is None:
- return block_id
- refcount = self._refcounter.get(block_id)
- assert refcount != 0
- if refcount > 1:
- src_block_id = block_id
- # Decrement refcount of the old block.
- self._allocator.free(block)
- # Allocate a fresh new block.
- block_id = self._allocator.allocate_mutable(
- prev_block=block.prev_block).block_id
- # Track src/dst copy.
- assert src_block_id is not None
- assert block_id is not None
- self._copy_on_writes[src_block_id].append(block_id)
- return block_id
- def clear_cows(self) -> Dict[BlockId, List[BlockId]]:
- """Clears the copy-on-write tracking information and returns the current
- state.
- This method returns a dictionary mapping source block indices to lists
- of destination block indices for the current copy-on-write operations.
- It then clears the internal tracking information.
- Returns:
- Dict[BlockId, List[BlockId]]: A dictionary mapping source
- block indices to lists of destination block indices for the
- current copy-on-write operations.
- """
- cows = dict(self._copy_on_writes)
- self._copy_on_writes.clear()
- return cows
- def get_all_blocks_recursively(last_block: Block) -> List[Block]:
- """Retrieves all the blocks in a sequence starting from the last block.
- This function recursively traverses the sequence of blocks in reverse order,
- starting from the given last block, and returns a list of all the blocks in
- the sequence.
- Args:
- last_block (Block): The last block in the sequence.
- Returns:
- List[Block]: A list of all the blocks in the sequence, in the order they
- appear.
- """
- def recurse(block: Block, lst: List[Block]) -> None:
- if block.prev_block is not None:
- recurse(block.prev_block, lst)
- lst.append(block)
- all_blocks: List[Block] = []
- recurse(last_block, all_blocks)
- return all_blocks
|