common.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. from collections import deque
  2. from dataclasses import dataclass
  3. from typing import Deque, Dict, Iterable, List, Optional, Protocol, Tuple
  4. from aphrodite.processing.block.interfaces import Block, BlockAllocator
  5. BlockId = int
  6. RefCount = int
  7. class RefCounterProtocol(Protocol):
  8. def incr(self, block_id: BlockId) -> RefCount:
  9. raise NotImplementedError
  10. def decr(self, block_id: BlockId) -> RefCount:
  11. raise NotImplementedError
  12. def get(self, block_id: BlockId) -> RefCount:
  13. raise NotImplementedError
  14. class RefCounter(RefCounterProtocol):
  15. """A class for managing reference counts for a set of block indices.
  16. The RefCounter class maintains a dictionary that maps block indices to their
  17. corresponding reference counts. It provides methods to increment, decrement,
  18. and retrieve the reference count for a given block index.
  19. Args:
  20. all_block_indices (Iterable[BlockId]): An iterable of block indices
  21. to initialize the reference counter with.
  22. """
  23. def __init__(self, all_block_indices: Iterable[BlockId]):
  24. deduped = set(all_block_indices)
  25. self._refcounts: Dict[BlockId,
  26. RefCount] = {index: 0
  27. for index in deduped}
  28. def incr(self, block_id: BlockId) -> RefCount:
  29. assert block_id in self._refcounts
  30. pre_incr_refcount = self._refcounts[block_id]
  31. assert pre_incr_refcount >= 0
  32. post_incr_refcount = pre_incr_refcount + 1
  33. self._refcounts[block_id] = post_incr_refcount
  34. return post_incr_refcount
  35. def decr(self, block_id: BlockId) -> RefCount:
  36. assert block_id in self._refcounts
  37. refcount = self._refcounts[block_id]
  38. assert refcount > 0
  39. refcount -= 1
  40. self._refcounts[block_id] = refcount
  41. return refcount
  42. def get(self, block_id: BlockId) -> RefCount:
  43. assert block_id in self._refcounts
  44. return self._refcounts[block_id]
  45. def as_readonly(self) -> "ReadOnlyRefCounter":
  46. return ReadOnlyRefCounter(self)
  47. class ReadOnlyRefCounter(RefCounterProtocol):
  48. """A read-only view of the RefCounter class.
  49. The ReadOnlyRefCounter class provides a read-only interface to access the
  50. reference counts maintained by a RefCounter instance. It does not allow
  51. modifications to the reference counts.
  52. Args:
  53. refcounter (RefCounter): The RefCounter instance to create a read-only
  54. view for.
  55. """
  56. def __init__(self, refcounter: RefCounter):
  57. self._refcounter = refcounter
  58. def incr(self, block_id: BlockId) -> RefCount:
  59. raise ValueError("Incr not allowed")
  60. def decr(self, block_id: BlockId) -> RefCount:
  61. raise ValueError("Decr not allowed")
  62. def get(self, block_id: BlockId) -> RefCount:
  63. return self._refcounter.get(block_id)
  64. class CopyOnWriteTracker:
  65. """A class for tracking and managing copy-on-write operations for blocks.
  66. The CopyOnWriteTracker class maintains a mapping of source block indices to
  67. their corresponding copy-on-write destination block indices. It works in
  68. conjunction with a RefCounter.
  69. Args:
  70. refcounter (RefCounter): The reference counter used to track block
  71. reference counts.
  72. """
  73. def __init__(self, refcounter: RefCounterProtocol):
  74. self._copy_on_writes: List[Tuple[BlockId, BlockId]] = []
  75. self._refcounter = refcounter
  76. def is_appendable(self, block: Block) -> bool:
  77. """Checks if the block is shared or not. If shared, then it cannot
  78. be appended and needs to be duplicated via copy-on-write
  79. """
  80. block_id = block.block_id
  81. if block_id is None:
  82. return True
  83. refcount = self._refcounter.get(block_id)
  84. return refcount <= 1
  85. def record_cow(self, src_block_id: Optional[BlockId],
  86. trg_block_id: Optional[BlockId]) -> None:
  87. """Records a copy-on-write operation from source to target block id
  88. Args:
  89. src_block_id (BlockId): The source block id from which to copy
  90. the data
  91. trg_block_id (BlockId): The target block id to which the data
  92. is copied
  93. """
  94. assert src_block_id is not None
  95. assert trg_block_id is not None
  96. self._copy_on_writes.append((src_block_id, trg_block_id))
  97. def clear_cows(self) -> List[Tuple[BlockId, BlockId]]:
  98. """Clears the copy-on-write tracking information and returns the current
  99. state.
  100. This method returns a list mapping source block indices to
  101. destination block indices for the current copy-on-write operations.
  102. It then clears the internal tracking information.
  103. Returns:
  104. List[Tuple[BlockId, BlockId]]: A list mapping source
  105. block indices to destination block indices for the
  106. current copy-on-write operations.
  107. """
  108. cows = self._copy_on_writes
  109. self._copy_on_writes = []
  110. return cows
  111. class BlockPool:
  112. """Used to pre-allocate block objects, in order to avoid excessive python
  113. object allocations/deallocations.
  114. The pool starts from "pool_size" objects and will increase to more objects
  115. if necessary
  116. Note that multiple block objects may point to the same physical block id,
  117. which is why this pool is needed, so that it will be easier to support
  118. prefix caching and more complicated sharing of physical blocks.
  119. """
  120. def __init__(self, block_size: int, create_block: Block.Factory,
  121. allocator: BlockAllocator, pool_size: int):
  122. self._block_size = block_size
  123. self._create_block = create_block
  124. self._allocator = allocator
  125. self._pool_size = pool_size
  126. assert self._pool_size >= 0
  127. self._free_ids: Deque[int] = deque(range(self._pool_size))
  128. self._pool = []
  129. for i in range(self._pool_size):
  130. self._pool.append(
  131. self._create_block(prev_block=None,
  132. token_ids=[],
  133. block_size=self._block_size,
  134. allocator=self._allocator,
  135. block_id=None))
  136. def increase_pool(self):
  137. """Doubles the internal pool size
  138. """
  139. cur_pool_size = self._pool_size
  140. new_pool_size = cur_pool_size * 2
  141. self._pool_size = new_pool_size
  142. self._free_ids += deque(range(cur_pool_size, new_pool_size))
  143. for i in range(cur_pool_size, new_pool_size):
  144. self._pool.append(
  145. self._create_block(prev_block=None,
  146. token_ids=[],
  147. block_size=self._block_size,
  148. allocator=self._allocator,
  149. block_id=None))
  150. def init_block(self, prev_block: Optional[Block], token_ids: List[int],
  151. block_size: int, physical_block_id: Optional[int]) -> Block:
  152. if len(self._free_ids) == 0:
  153. self.increase_pool()
  154. assert len(self._free_ids) > 0
  155. pool_id = self._free_ids.popleft()
  156. block = self._pool[pool_id]
  157. block.__init__( # type: ignore[misc]
  158. prev_block=prev_block,
  159. token_ids=token_ids,
  160. block_size=block_size,
  161. allocator=block._allocator, # type: ignore[attr-defined]
  162. block_id=physical_block_id)
  163. block.pool_id = pool_id # type: ignore[attr-defined]
  164. return block
  165. def free_block(self, block: Block) -> None:
  166. self._free_ids.appendleft(block.pool_id) # type: ignore[attr-defined]
  167. class BlockList:
  168. """This class is an optimization to allow fast-access to physical
  169. block ids. It maintains a block id list that is updated with the
  170. block list and this avoids the need to reconstruct the block id
  171. list on every iteration of the block manager
  172. """
  173. def __init__(self, blocks: List[Block]):
  174. self._blocks: List[Block] = []
  175. self._block_ids: List[int] = []
  176. self.update(blocks)
  177. def _add_block_id(self, block_id: Optional[BlockId]) -> None:
  178. assert block_id is not None
  179. self._block_ids.append(block_id)
  180. def _update_block_id(self, block_index: int,
  181. new_block_id: Optional[BlockId]) -> None:
  182. assert new_block_id is not None
  183. self._block_ids[block_index] = new_block_id
  184. def update(self, blocks: List[Block]):
  185. self._blocks = blocks
  186. # Cache block ids for fast query
  187. self._block_ids = []
  188. for block in self._blocks:
  189. self._add_block_id(block.block_id)
  190. def append_token_ids(self, block_index: int, token_ids: List[int]) -> None:
  191. block = self._blocks[block_index]
  192. prev_block_id = block.block_id
  193. block.append_token_ids(token_ids)
  194. # CoW or promotion may update the internal block_id
  195. if prev_block_id != block.block_id:
  196. self._update_block_id(block_index, block.block_id)
  197. def append(self, new_block: Block):
  198. self._blocks.append(new_block)
  199. self._add_block_id(new_block.block_id)
  200. def __len__(self) -> int:
  201. return len(self._blocks)
  202. def __getitem__(self, block_index: int) -> Block:
  203. return self._blocks[block_index]
  204. def __setitem__(self, block_index: int, new_block: Block) -> None:
  205. self._blocks[block_index] = new_block
  206. self._update_block_id(block_index, new_block.block_id)
  207. def reset(self):
  208. self._blocks = []
  209. self._block_ids = []
  210. def list(self) -> List[Block]:
  211. return self._blocks
  212. def ids(self) -> List[int]:
  213. return self._block_ids
  214. @dataclass
  215. class CacheMetricData:
  216. """A utility dataclass to maintain cache metric.
  217. To avoid overflow, we maintain the hit rate in block granularity, so that
  218. we can maintain a single hit rate for n_completed_block x block_size,
  219. and calculate the real time hit rate by the following:
  220. BS = The number of queries per block.
  221. nB = The number of completed blocks.
  222. HR = hit rate of (nB x BS) queries.
  223. Q = current number of queries (< BS).
  224. H = current number of hits (< BS).
  225. hit rate = ((HR x nB) + (H / Q) x (Q / BS)) / (nB + Q / BS)
  226. """
  227. num_completed_blocks: int = 0
  228. completed_block_cache_hit_rate: float = 0.0
  229. num_incompleted_block_queries: int = 0
  230. num_incompleted_block_hit: int = 0
  231. block_size: int = 1000
  232. def query(self, hit: bool):
  233. self.num_incompleted_block_queries += 1
  234. self.num_incompleted_block_hit += 1 if hit else 0
  235. # When a block is completed, update the cache hit rate
  236. # and reset the incomplete numbers.
  237. if self.num_incompleted_block_queries == self.block_size:
  238. hit_rate = (self.num_incompleted_block_hit /
  239. self.num_incompleted_block_queries)
  240. self.completed_block_cache_hit_rate = (
  241. self.completed_block_cache_hit_rate * self.num_completed_blocks
  242. + hit_rate) / (self.num_completed_blocks + 1)
  243. self.num_incompleted_block_queries = 0
  244. self.num_incompleted_block_hit = 0
  245. self.num_completed_blocks += 1
  246. def get_hit_rate(self):
  247. incomplete_ratio = self.num_incompleted_block_queries / self.block_size
  248. total_blocks = self.num_completed_blocks + incomplete_ratio
  249. if total_blocks == 0:
  250. return 0.0
  251. completed_block_hit, incompleted_block_hit = 0.0, 0.0
  252. if self.num_completed_blocks > 0:
  253. completed_block_hit = (self.completed_block_cache_hit_rate *
  254. self.num_completed_blocks)
  255. if self.num_incompleted_block_queries > 0:
  256. incompleted_hit_rate = (self.num_incompleted_block_hit /
  257. self.num_incompleted_block_queries)
  258. incompleted_block_hit = (incompleted_hit_rate * incomplete_ratio)
  259. return (completed_block_hit + incompleted_block_hit) / total_blocks
  260. def get_all_blocks_recursively(last_block: Block) -> List[Block]:
  261. """Retrieves all the blocks in a sequence starting from the last block.
  262. This function recursively traverses the sequence of blocks in reverse order,
  263. starting from the given last block, and returns a list of all the blocks in
  264. the sequence.
  265. Args:
  266. last_block (Block): The last block in the sequence.
  267. Returns:
  268. List[Block]: A list of all the blocks in the sequence, in the order they
  269. appear.
  270. """
  271. def recurse(block: Block, lst: List[Block]) -> None:
  272. if block.prev_block is not None:
  273. recurse(block.prev_block, lst)
  274. lst.append(block)
  275. all_blocks: List[Block] = []
  276. recurse(last_block, all_blocks)
  277. return all_blocks