common.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. from collections import defaultdict
  2. from typing import Dict, Iterable, List, Optional
  3. from aphrodite.processing.block.interfaces import Block, BlockAllocator
  4. BlockId = int
  5. RefCount = int
  6. class RefCounter:
  7. """A class for managing reference counts for a set of block indices.
  8. The RefCounter class maintains a dictionary that maps block indices to their
  9. corresponding reference counts. It provides methods to increment, decrement,
  10. and retrieve the reference count for a given block index.
  11. Args:
  12. all_block_indices (Iterable[BlockId]): An iterable of block indices
  13. to initialize the reference counter with.
  14. """
  15. def __init__(self, all_block_indices: Iterable[BlockId]):
  16. deduped = set(all_block_indices)
  17. self._refcounts: Dict[BlockId,
  18. RefCount] = {index: 0
  19. for index in deduped}
  20. def incr(self, block_id: BlockId) -> RefCount:
  21. assert block_id in self._refcounts
  22. pre_incr_refcount = self._refcounts[block_id]
  23. assert pre_incr_refcount >= 0
  24. post_incr_refcount = pre_incr_refcount + 1
  25. self._refcounts[block_id] = post_incr_refcount
  26. return post_incr_refcount
  27. def decr(self, block_id: BlockId) -> RefCount:
  28. assert block_id in self._refcounts
  29. refcount = self._refcounts[block_id]
  30. assert refcount > 0
  31. refcount -= 1
  32. self._refcounts[block_id] = refcount
  33. return refcount
  34. def get(self, block_id: BlockId) -> RefCount:
  35. assert block_id in self._refcounts
  36. return self._refcounts[block_id]
  37. def as_readonly(self) -> "ReadOnlyRefCounter":
  38. return ReadOnlyRefCounter(self)
  39. class ReadOnlyRefCounter:
  40. """A read-only view of the RefCounter class.
  41. The ReadOnlyRefCounter class provides a read-only interface to access the
  42. reference counts maintained by a RefCounter instance. It does not allow
  43. modifications to the reference counts.
  44. Args:
  45. refcounter (RefCounter): The RefCounter instance to create a read-only
  46. view for.
  47. """
  48. def __init__(self, refcounter: RefCounter):
  49. self._refcounter = refcounter
  50. def incr(self, block_id: BlockId) -> RefCount:
  51. raise ValueError("Incr not allowed")
  52. def decr(self, block_id: BlockId) -> RefCount:
  53. raise ValueError("Decr not allowed")
  54. def get(self, block_id: BlockId) -> RefCount:
  55. return self._refcounter.get(block_id)
  56. class CopyOnWriteTracker:
  57. """A class for tracking and managing copy-on-write operations for blocks.
  58. The CopyOnWriteTracker class maintains a mapping of source block indices to
  59. their corresponding copy-on-write destination block indices. It works in
  60. conjunction with a RefCounter and a BlockAllocator to handle reference
  61. counting and block allocation.
  62. Args:
  63. refcounter (RefCounter): The reference counter used to track block
  64. reference counts.
  65. allocator (BlockAllocator): The block allocator used to allocate and
  66. free blocks.
  67. """
  68. def __init__(
  69. self,
  70. refcounter: RefCounter,
  71. allocator: BlockAllocator,
  72. ):
  73. self._copy_on_writes = defaultdict(list)
  74. self._refcounter = refcounter
  75. self._allocator = allocator
  76. def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
  77. """Performs a copy-on-write operation on the given block if it is not
  78. appendable.
  79. This method checks the reference count of the given block. If the
  80. reference count is greater than 1, indicating that the block is shared,
  81. a copy-on-write operation is performed. The original block is freed,
  82. and a new block is allocated with the same content. The new block index
  83. is returned.
  84. Args:
  85. block (Block): The block to check for copy-on-write.
  86. Returns:
  87. Optional[BlockId]: The block index of the new block if a copy-on
  88. -write operation was performed, or the original block index if
  89. no copy-on-write was necessary.
  90. """
  91. block_id = block.block_id
  92. if block_id is None:
  93. return block_id
  94. refcount = self._refcounter.get(block_id)
  95. assert refcount != 0
  96. if refcount > 1:
  97. src_block_id = block_id
  98. # Decrement refcount of the old block.
  99. self._allocator.free(block)
  100. # Allocate a fresh new block.
  101. block_id = self._allocator.allocate_mutable(
  102. prev_block=block.prev_block).block_id
  103. # Track src/dst copy.
  104. self._copy_on_writes[src_block_id].append(block_id)
  105. return block_id
  106. def clear_cows(self) -> Dict[BlockId, List[BlockId]]:
  107. """Clears the copy-on-write tracking information and returns the current
  108. state.
  109. This method returns a dictionary mapping source block indices to lists
  110. of destination block indices for the current copy-on-write operations.
  111. It then clears the internal tracking information.
  112. Returns:
  113. Dict[BlockId, List[BlockId]]: A dictionary mapping source
  114. block indices to lists of destination block indices for the
  115. current copy-on-write operations.
  116. """
  117. cows = dict(self._copy_on_writes)
  118. self._copy_on_writes.clear()
  119. return cows
  120. def get_all_blocks_recursively(last_block: Block) -> List[Block]:
  121. """Retrieves all the blocks in a sequence starting from the last block.
  122. This function recursively traverses the sequence of blocks in reverse order,
  123. starting from the given last block, and returns a list of all the blocks in
  124. the sequence.
  125. Args:
  126. last_block (Block): The last block in the sequence.
  127. Returns:
  128. List[Block]: A list of all the blocks in the sequence, in the order they
  129. appear.
  130. """
  131. def recurse(block: Block, lst: List[Block]) -> None:
  132. if block.prev_block is not None:
  133. recurse(block.prev_block, lst)
  134. lst.append(block)
  135. all_blocks = []
  136. recurse(last_block, all_blocks)
  137. return all_blocks