prefix_caching_block.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473
  1. """Token blocks."""
  2. from itertools import takewhile
  3. from os.path import commonprefix
  4. from typing import Dict, Iterable, List, Optional
  5. from aphrodite.processing.block.common import (CopyOnWriteTracker,
  6. get_all_blocks_recursively)
  7. from aphrodite.processing.block.interfaces import Block, BlockAllocator
  8. from aphrodite.processing.block.naive_block import (NaiveBlock,
  9. NaiveBlockAllocator)
  10. PrefixHash = int
  11. BlockId = int
  12. class PrefixCachingBlockAllocator(BlockAllocator):
  13. """A block allocator that implements prefix caching.
  14. The PrefixCachingBlockAllocator maintains a cache of blocks based on their
  15. content hash. It reuses blocks with the same content hash to avoid redundant
  16. memory allocation. The allocator also supports copy-on-write operations.
  17. Args:
  18. num_blocks (int): The total number of blocks to manage.
  19. block_size (int): The size of each block in tokens.
  20. block_ids(Optional[Iterable[int]], optional): An optional iterable of
  21. block IDs. If not provided, block IDs will be assigned sequentially
  22. from 0 to num_blocks - 1.
  23. """
  24. # TODO last access time / evictor integration
  25. def __init__(
  26. self,
  27. num_blocks: int,
  28. block_size: int,
  29. block_ids: Optional[Iterable[int]] = None,
  30. ):
  31. # A mapping of prefix hash to block index. All blocks which have a
  32. # prefix hash will be in this dict, even if they have refcount 0.
  33. self._cached_blocks: Dict[PrefixHash, BlockId] = {}
  34. # A mapping of prefix hash to block index. All blocks which have a
  35. # prefix hash AND refcount 0 will be in this dict. Thus, it is a subset
  36. # of self._cached_blocks.
  37. self._unused_cached_blocks: Dict[PrefixHash, BlockId] = {}
  38. # An allocator for blocks that do not have prefix hashes.
  39. self._hashless_allocator = NaiveBlockAllocator(
  40. create_block=self._create_block,
  41. num_blocks=num_blocks,
  42. block_size=block_size,
  43. block_ids=block_ids,
  44. )
  45. self._block_size = block_size
  46. # We share the refcounter between allocators. This allows us to promote
  47. # blocks originally allocated in the hashless allocator to immutable
  48. # blocks.
  49. self._refcounter = self._hashless_allocator.refcounter
  50. self._cow_tracker = CopyOnWriteTracker(
  51. refcounter=self._refcounter.as_readonly(),
  52. allocator=self,
  53. )
  54. # Implements Block.Factory.
  55. def _create_block(
  56. self,
  57. prev_block: Optional[Block],
  58. token_ids: List[int],
  59. block_size: int,
  60. allocator: BlockAllocator,
  61. block_id: Optional[int] = None,
  62. ) -> Block:
  63. # Bind block to self.
  64. allocator = self
  65. return PrefixCachingBlock(
  66. prev_block=prev_block,
  67. token_ids=token_ids,
  68. block_size=block_size,
  69. block_id=block_id,
  70. prefix_caching_allocator=allocator,
  71. )
  72. def allocate_immutable(self, prev_block: Optional[Block],
  73. token_ids: List[int]) -> Block:
  74. """Allocates an immutable block with the given token IDs, reusing cached
  75. blocks if possible.
  76. Args:
  77. prev_block (Optional[Block]): The previous block in the sequence.
  78. token_ids (List[int]): The token IDs to be stored in the block.
  79. Returns:
  80. Block: The allocated immutable block.
  81. """
  82. assert_prefix_caching_block_or_none(prev_block)
  83. block = self._create_block(
  84. prev_block=prev_block,
  85. token_ids=token_ids,
  86. block_size=self._block_size,
  87. allocator=self,
  88. )
  89. assert block.content_hash is not None
  90. cached_block_id = self._cached_blocks.get(block.content_hash, None)
  91. if cached_block_id is not None:
  92. block.block_id = cached_block_id
  93. self._incr_refcount_cached_block(block.content_hash,
  94. block.block_id)
  95. return block
  96. block = self.allocate_mutable(prev_block)
  97. block.append_token_ids(token_ids)
  98. assert block.content_hash is not None
  99. # TODO computed bit
  100. return block
  101. def allocate_mutable(self, prev_block: Block) -> Block:
  102. """Allocates a mutable block. If there are no free blocks, this will
  103. evict unused cached blocks.
  104. Args:
  105. prev_block (Block): The previous block in the sequence.
  106. Returns:
  107. Block: The allocated mutable block.
  108. """
  109. assert_prefix_caching_block_or_none(prev_block)
  110. try:
  111. return self._hashless_allocator.allocate_mutable(
  112. prev_block=prev_block)
  113. except BlockAllocator.NoFreeBlocksError:
  114. # We must check the unused cached blocks before raising OOM.
  115. pass
  116. if self._unused_cached_blocks:
  117. # TODO policy for selecting block to remove
  118. content_hash_to_evict = next(iter(self._unused_cached_blocks))
  119. # Clear content hash mapping; the block will be overwritten.
  120. del self._cached_blocks[content_hash_to_evict]
  121. block_id = self._unused_cached_blocks.pop(content_hash_to_evict)
  122. refcount = self._refcounter.incr(block_id)
  123. assert refcount == 1
  124. block = self._create_block(
  125. prev_block=prev_block,
  126. token_ids=[],
  127. block_size=self._block_size,
  128. allocator=self,
  129. block_id=block_id,
  130. )
  131. assert block.content_hash is None
  132. return block
  133. # No block available in hashless allocator, nor in unused cache blocks.
  134. raise BlockAllocator.NoFreeBlocksError()
  135. def _incr_refcount_cached_block(self, content_hash: int,
  136. block_id: BlockId) -> None:
  137. refcount = self._refcounter.incr(block_id)
  138. if refcount == 1:
  139. assert content_hash in self._unused_cached_blocks
  140. del self._unused_cached_blocks[content_hash]
  141. def free(self, block: Block) -> None:
  142. """Decrement the refcount of the block. If the decremented refcount is
  143. zero, store the block in the freelist.
  144. If the block has a content hash (meaning it is immutable), then we will
  145. keep the block around in case future allocations require it.
  146. """
  147. assert (block.block_id
  148. is not None), "freeing unallocated block is undefined"
  149. self._free_block_id_for_block(block.block_id, block)
  150. block.block_id = None
  151. def _free_block_id_for_block(self, block_id: BlockId,
  152. block: Block) -> None:
  153. assert isinstance(block, PrefixCachingBlock)
  154. if block.content_hash is None:
  155. return self._hashless_allocator.free(block)
  156. refcount = self._refcounter.decr(block_id)
  157. # If no longer used, add the block to the unused cached blocks.
  158. if refcount == 0:
  159. assert block.content_hash not in self._unused_cached_blocks
  160. assert block.content_hash in self._cached_blocks
  161. self._unused_cached_blocks[block.content_hash] = block_id
  162. def fork(self, last_block: Block) -> List[Block]:
  163. """Creates a new sequence of blocks that shares the same underlying
  164. memory as the original sequence.
  165. Args:
  166. last_block (Block): The last block in the original sequence.
  167. Returns:
  168. List[Block]: The new sequence of blocks that shares the same memory
  169. as the original sequence.
  170. """
  171. source_blocks = get_all_blocks_recursively(last_block)
  172. forked_blocks = []
  173. prev_block = None
  174. for block in source_blocks:
  175. refcount = self._refcounter.incr(block.block_id)
  176. assert refcount != 1, "can't fork free'd block"
  177. forked_blocks.append(
  178. self._create_block(
  179. prev_block=prev_block,
  180. token_ids=block.token_ids,
  181. block_id=block.block_id,
  182. block_size=self._block_size,
  183. allocator=self,
  184. ))
  185. prev_block = forked_blocks[-1]
  186. return forked_blocks
  187. def get_num_free_blocks(self) -> int:
  188. # The number of free blocks is the number of hashless free blocks
  189. # plus the number of hashful blocks that are unused.
  190. return self._hashless_allocator.get_num_free_blocks() + len(
  191. self._unused_cached_blocks)
  192. @property
  193. def all_block_ids(self) -> frozenset[int]:
  194. return self._hashless_allocator.all_block_ids
  195. def promote_to_immutable_block(self,
  196. block: "PrefixCachingBlock") -> BlockId:
  197. """Once a mutable block is full, it can be promoted to an immutable
  198. block. This means that its content can be referenced by future blocks
  199. having the same prefix.
  200. Note that if we already have a cached block with the same content, we
  201. will replace the newly-promoted block's mapping with the existing cached
  202. block.
  203. Args:
  204. block (PrefixCachingBlock): The mutable block to be promoted.
  205. Returns:
  206. BlockId: Either the original block index, or the block index of
  207. the previously cached block matching the same content.
  208. """
  209. assert block.content_hash is not None
  210. assert block.block_id is not None
  211. assert self._refcounter.get(block.block_id) > 0
  212. # If the content hash does not have a corresponding cached block,
  213. # set this block as the cached block.
  214. if block.content_hash not in self._cached_blocks:
  215. self._cached_blocks[block.content_hash] = block.block_id
  216. else:
  217. self._free_block_id_for_block(block.block_id, block)
  218. self._incr_refcount_cached_block(
  219. block.content_hash, self._cached_blocks[block.content_hash])
  220. return self._cached_blocks[block.content_hash]
  221. def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
  222. """Performs a copy-on-write operation on the given block if it is not
  223. appendable.
  224. Args:
  225. block (Block): The block to check for copy-on-write.
  226. Returns:
  227. Optional[BlockId]: The block index of the new block if a copy-on
  228. -write operation was performed, or the original block index if
  229. no copy-on-write was necessary.
  230. """
  231. return self._cow_tracker.cow_block_if_not_appendable(block)
  232. def clear_copy_on_writes(self) -> Dict[BlockId, List[BlockId]]:
  233. """Returns the copy-on-write source->destination mapping and clears it.
  234. Returns:
  235. Dict[BlockId, List[BlockId]]: A dictionary mapping source
  236. block indices to lists of destination block indices.
  237. """
  238. return self._cow_tracker.clear_cows()
  239. def mark_blocks_as_computed(self) -> None:
  240. """Mark blocks as computed, used in prefix caching."""
  241. # TODO Track computed blocks.
  242. pass
  243. def get_common_computed_block_ids(
  244. self, seq_block_ids: List[List[int]]) -> List[int]:
  245. """Return the block ids that are common for a given sequence group.
  246. Used in prefill (can skip prefill of some blocks).
  247. """
  248. # TODO: Track computed blocks.
  249. computed = lambda block_id: False
  250. # NOTE We exclude the last block to avoid the case where the entire
  251. # prompt is cached. This would cause erroneous behavior in model
  252. # runner.
  253. ids_list = [
  254. takewhile(lambda block_id: computed(block_id), seq[:-1])
  255. for seq in seq_block_ids
  256. ]
  257. return commonprefix([ids for ids in ids_list if ids != []])
  258. class PrefixCachingBlock(Block):
  259. """A block implementation that supports prefix caching.
  260. The PrefixCachingBlock class represents a block of token IDs with prefix
  261. caching capabilities. It wraps a NaiveBlock internally and provides
  262. additional functionality for content hashing and promoting immutable blocks
  263. with the prefix caching allocator.
  264. Args:
  265. prev_block (Optional[PrefixCachingBlock]): The previous block in the
  266. sequence.
  267. token_ids (List[int]): The initial token IDs to be stored in the block.
  268. block_size (int): The maximum number of token IDs that can be stored in
  269. the block.
  270. prefix_caching_allocator (PrefixCachingBlockAllocator): The prefix
  271. caching block allocator associated with this block.
  272. block_id (Optional[int], optional): The physical block index
  273. of this block. Defaults to None.
  274. """
  275. def __init__(
  276. self,
  277. prev_block: Optional["PrefixCachingBlock"],
  278. token_ids: List[int],
  279. block_size: int,
  280. prefix_caching_allocator: PrefixCachingBlockAllocator,
  281. block_id: Optional[int] = None,
  282. ):
  283. assert_prefix_caching_block_or_none(prev_block)
  284. self._prev_block = prev_block
  285. self._cached_content_hash: Optional[int] = None
  286. self._prefix_caching_allocator = prefix_caching_allocator
  287. self._block = NaiveBlock(
  288. prev_block=prev_block,
  289. token_ids=token_ids,
  290. block_size=block_size,
  291. block_id=block_id,
  292. allocator=prefix_caching_allocator,
  293. _cow_target=self,
  294. )
  295. def append_token_ids(self, token_ids: List[int]) -> None:
  296. """Appends the given token IDs to the block and registers the block as
  297. immutable if the block becomes full.
  298. Internally, the naive block handles CoW.
  299. Args:
  300. token_ids (List[int]): The token IDs to be appended to the block.
  301. """
  302. assert token_ids
  303. # naive block handles CoW.
  304. self._block.append_token_ids(token_ids)
  305. # If the content hash is present, then the block can be made immutable.
  306. # Register ourselves with the allocator, potentially replacing the
  307. # physical block index.
  308. if self.content_hash is not None:
  309. self.block_id = (self._prefix_caching_allocator.
  310. promote_to_immutable_block(self))
  311. @property
  312. def block_id(self) -> Optional[int]:
  313. return self._block.block_id
  314. @block_id.setter
  315. def block_id(self, value) -> None:
  316. self._block.block_id = value
  317. @property
  318. def is_full(self) -> bool:
  319. return self._block.is_full
  320. @property
  321. def num_empty_slots(self) -> int:
  322. return self._block.num_empty_slots
  323. @property
  324. def block_size(self) -> int:
  325. return self._block.block_size
  326. @property
  327. def token_ids(self) -> List[int]:
  328. return self._block.token_ids
  329. @property
  330. def prev_block(self) -> Optional[Block]:
  331. return self._prev_block
  332. @property
  333. def content_hash(self) -> Optional[int]:
  334. """Return the content-based hash of the current block, or None if it is
  335. not yet defined.
  336. For the content-based hash to be defined, the current block must be
  337. full.
  338. """
  339. # If the hash is already computed, return it.
  340. if self._cached_content_hash is not None:
  341. return self._cached_content_hash
  342. # We cannot compute a hash for the current block because it is not full.
  343. if not self.is_full:
  344. return None
  345. is_first_block = self._prev_block is None
  346. prev_block_hash = (None if is_first_block else
  347. self._prev_block.content_hash)
  348. # Previous block exists but does not yet have a hash.
  349. # Return no hash in this case.
  350. if prev_block_hash is None and not is_first_block:
  351. return None
  352. self._cached_content_hash = PrefixCachingBlock.hash_block_tokens(
  353. is_first_block,
  354. prev_block_hash,
  355. cur_block_token_ids=self.token_ids)
  356. return self._cached_content_hash
  357. @staticmethod
  358. def hash_block_tokens(is_first_block: bool, prev_block_hash: Optional[int],
  359. cur_block_token_ids: List[int]) -> int:
  360. """Computes a hash value corresponding to the contents of a block and
  361. the contents of the preceding block(s). The hash value is used for
  362. prefix caching.
  363. NOTE: Content-based hashing does not yet support LoRA.
  364. Parameters:
  365. - is_first_block (bool): A flag indicating if the block is the first in
  366. the sequence.
  367. - prev_block_hash (Optional[int]): The hash of the previous block. None
  368. if this is the first block.
  369. - cur_block_token_ids (List[int]): A list of token ids in the current
  370. block. The current block is assumed to be full.
  371. Returns:
  372. - int: The computed hash value for the block.
  373. """
  374. assert (prev_block_hash is None) == is_first_block
  375. return hash((is_first_block, prev_block_hash, *cur_block_token_ids))
  376. def assert_prefix_caching_block_or_none(block: Optional[Block]):
  377. if block is None:
  378. return
  379. assert isinstance(block, PrefixCachingBlock)