1
0

prefix_caching_block.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456
  1. """Token blocks."""
  2. from itertools import takewhile
  3. from os.path import commonprefix
  4. from typing import Dict, Iterable, List, Optional
  5. from aphrodite.processing.block.common import (
  6. CopyOnWriteTracker,
  7. get_all_blocks_recursively,
  8. )
  9. from aphrodite.processing.block.interfaces import Block, BlockAllocator
  10. from aphrodite.processing.block.naive_block import (
  11. NaiveBlock,
  12. NaiveBlockAllocator,
  13. )
  14. PrefixHash = int
  15. BlockId = int
  16. class PrefixCachingBlockAllocator(BlockAllocator):
  17. """A block allocator that implements prefix caching.
  18. The PrefixCachingBlockAllocator maintains a cache of blocks based on their
  19. content hash. It reuses blocks with the same content hash to avoid redundant
  20. memory allocation. The allocator also supports copy-on-write operations.
  21. Args:
  22. num_blocks (int): The total number of blocks to manage.
  23. block_size (int): The size of each block in tokens.
  24. block_ids(Optional[Iterable[int]], optional): An optional iterable of
  25. block IDs. If not provided, block IDs will be assigned sequentially
  26. from 0 to num_blocks - 1.
  27. """
  28. # TODO last access time / evictor integration
  29. def __init__(
  30. self,
  31. num_blocks: int,
  32. block_size: int,
  33. block_ids: Optional[Iterable[int]] = None,
  34. ):
  35. # A mapping of prefix hash to block index. All blocks which have a
  36. # prefix hash will be in this dict, even if they have refcount 0.
  37. self._cached_blocks: Dict[PrefixHash, BlockId] = {}
  38. # A mapping of prefix hash to block index. All blocks which have a
  39. # prefix hash AND refcount 0 will be in this dict. Thus, it is a subset
  40. # of self._cached_blocks.
  41. self._unused_cached_blocks: Dict[PrefixHash, BlockId] = {}
  42. # An allocator for blocks that do not have prefix hashes.
  43. self._hashless_allocator = NaiveBlockAllocator(
  44. create_block=self._create_block,
  45. num_blocks=num_blocks,
  46. block_size=block_size,
  47. block_ids=block_ids,
  48. )
  49. self._block_size = block_size
  50. # We share the refcounter between allocators. This allows us to promote
  51. # blocks originally allocated in the hashless allocator to immutable
  52. # blocks.
  53. self._refcounter = self._hashless_allocator.refcounter
  54. self._cow_tracker = CopyOnWriteTracker(
  55. refcounter=self._refcounter.as_readonly(),
  56. allocator=self,
  57. )
  58. # Implements Block.Factory.
  59. def _create_block(
  60. self,
  61. prev_block: Optional[Block],
  62. token_ids: List[int],
  63. block_size: int,
  64. allocator: BlockAllocator,
  65. block_id: Optional[int] = None,
  66. ) -> Block:
  67. # Bind block to self.
  68. allocator = self
  69. return PrefixCachingBlock(
  70. prev_block=prev_block,
  71. token_ids=token_ids,
  72. block_size=block_size,
  73. block_id=block_id,
  74. prefix_caching_allocator=allocator,
  75. )
  76. def allocate_immutable(self, prev_block: Optional[Block],
  77. token_ids: List[int]) -> Block:
  78. """Allocates an immutable block with the given token IDs, reusing cached
  79. blocks if possible.
  80. Args:
  81. prev_block (Optional[Block]): The previous block in the sequence.
  82. token_ids (List[int]): The token IDs to be stored in the block.
  83. Returns:
  84. Block: The allocated immutable block.
  85. """
  86. assert_prefix_caching_block_or_none(prev_block)
  87. block = self._create_block(
  88. prev_block=prev_block,
  89. token_ids=token_ids,
  90. block_size=self._block_size,
  91. allocator=self,
  92. )
  93. assert block.content_hash is not None
  94. cached_block_id = self._cached_blocks.get(block.content_hash, None)
  95. if cached_block_id is not None:
  96. block.block_id = cached_block_id
  97. self._incr_refcount_cached_block(block.content_hash,
  98. block.block_id)
  99. return block
  100. block = self.allocate_mutable(prev_block)
  101. block.append_token_ids(token_ids)
  102. assert block.content_hash is not None
  103. # TODO computed bit
  104. return block
  105. def allocate_mutable(self, prev_block: Block) -> Block:
  106. """Allocates a mutable block. If there are no free blocks, this will
  107. evict unused cached blocks.
  108. Args:
  109. prev_block (Block): The previous block in the sequence.
  110. Returns:
  111. Block: The allocated mutable block.
  112. """
  113. assert_prefix_caching_block_or_none(prev_block)
  114. try:
  115. return self._hashless_allocator.allocate_mutable(
  116. prev_block=prev_block)
  117. except BlockAllocator.NoFreeBlocksError:
  118. # We must check the unused cached blocks before raising OOM.
  119. pass
  120. if self._unused_cached_blocks:
  121. # TODO policy for selecting block to remove
  122. content_hash_to_evict = next(iter(self._unused_cached_blocks))
  123. # Clear content hash mapping; the block will be overwritten.
  124. del self._cached_blocks[content_hash_to_evict]
  125. block_id = self._unused_cached_blocks.pop(content_hash_to_evict)
  126. refcount = self._refcounter.incr(block_id)
  127. assert refcount == 1
  128. block = self._create_block(
  129. prev_block=prev_block,
  130. token_ids=[],
  131. block_size=self._block_size,
  132. allocator=self,
  133. block_id=block_id,
  134. )
  135. assert block.content_hash is None
  136. return block
  137. # No block available in hashless allocator, nor in unused cache blocks.
  138. raise BlockAllocator.NoFreeBlocksError()
  139. def _incr_refcount_cached_block(self, content_hash: int,
  140. block_id: BlockId) -> None:
  141. refcount = self._refcounter.incr(block_id)
  142. if refcount == 1:
  143. assert content_hash in self._unused_cached_blocks
  144. del self._unused_cached_blocks[content_hash]
  145. def free(self, block: Block) -> None:
  146. """Decrement the refcount of the block. If the decremented refcount is
  147. zero, store the block in the freelist.
  148. If the block has a content hash (meaning it is immutable), then we will
  149. keep the block around in case future allocations require it.
  150. """
  151. assert (block.block_id
  152. is not None), "freeing unallocated block is undefined"
  153. self._free_block_id_for_block(block.block_id, block)
  154. block.block_id = None
  155. def _free_block_id_for_block(self, block_id: BlockId,
  156. block: Block) -> None:
  157. assert isinstance(block, PrefixCachingBlock)
  158. if block.content_hash is None:
  159. return self._hashless_allocator.free(block)
  160. refcount = self._refcounter.decr(block_id)
  161. # If no longer used, add the block to the unused cached blocks.
  162. if refcount == 0:
  163. assert block.content_hash not in self._unused_cached_blocks
  164. assert block.content_hash in self._cached_blocks
  165. self._unused_cached_blocks[block.content_hash] = block_id
  166. def fork(self, last_block: Block) -> List[Block]:
  167. """Creates a new sequence of blocks that shares the same underlying
  168. memory as the original sequence.
  169. Args:
  170. last_block (Block): The last block in the original sequence.
  171. Returns:
  172. List[Block]: The new sequence of blocks that shares the same memory
  173. as the original sequence.
  174. """
  175. source_blocks = get_all_blocks_recursively(last_block)
  176. forked_blocks = []
  177. prev_block = None
  178. for block in source_blocks:
  179. refcount = self._refcounter.incr(block.block_id)
  180. assert refcount != 1, "can't fork free'd block"
  181. forked_blocks.append(
  182. self._create_block(
  183. prev_block=prev_block,
  184. token_ids=block.token_ids,
  185. block_id=block.block_id,
  186. block_size=self._block_size,
  187. allocator=self,
  188. ))
  189. prev_block = forked_blocks[-1]
  190. return forked_blocks
  191. def get_num_free_blocks(self) -> int:
  192. # The number of free blocks is the number of hashless free blocks
  193. # plus the number of hashful blocks that are unused.
  194. return self._hashless_allocator.get_num_free_blocks() + len(
  195. self._unused_cached_blocks)
  196. @property
  197. def all_block_ids(self) -> frozenset[int]:
  198. return self._hashless_allocator.all_block_ids
  199. def promote_to_immutable_block(self,
  200. block: "PrefixCachingBlock") -> BlockId:
  201. """Once a mutable block is full, it can be promoted to an immutable
  202. block. This means that its content can be referenced by future blocks
  203. having the same prefix.
  204. Note that if we already have a cached block with the same content, we
  205. will replace the newly-promoted block's mapping with the existing cached
  206. block.
  207. Args:
  208. block (PrefixCachingBlock): The mutable block to be promoted.
  209. Returns:
  210. BlockId: Either the original block index, or the block index of
  211. the previously cached block matching the same content.
  212. """
  213. assert block.content_hash is not None
  214. assert block.block_id is not None
  215. assert self._refcounter.get(block.block_id) > 0
  216. # If the content hash does not have a corresponding cached block,
  217. # set this block as the cached block.
  218. if block.content_hash not in self._cached_blocks:
  219. self._cached_blocks[block.content_hash] = block.block_id
  220. else:
  221. self._free_block_id_for_block(block.block_id, block)
  222. self._incr_refcount_cached_block(
  223. block.content_hash, self._cached_blocks[block.content_hash])
  224. return self._cached_blocks[block.content_hash]
  225. def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
  226. """Performs a copy-on-write operation on the given block if it is not
  227. appendable.
  228. Args:
  229. block (Block): The block to check for copy-on-write.
  230. Returns:
  231. Optional[BlockId]: The block index of the new block if a copy-on
  232. -write operation was performed, or the original block index if
  233. no copy-on-write was necessary.
  234. """
  235. return self._cow_tracker.cow_block_if_not_appendable(block)
  236. def clear_copy_on_writes(self) -> Dict[BlockId, List[BlockId]]:
  237. """Returns the copy-on-write source->destination mapping and clears it.
  238. Returns:
  239. Dict[BlockId, List[BlockId]]: A dictionary mapping source
  240. block indices to lists of destination block indices.
  241. """
  242. return self._cow_tracker.clear_cows()
  243. def mark_blocks_as_computed(self) -> None:
  244. """Mark blocks as computed, used in prefix caching."""
  245. # TODO Track computed blocks.
  246. pass
  247. def get_common_computed_block_ids(
  248. self, seq_block_ids: List[List[int]]) -> List[int]:
  249. """Return the block ids that are common for a given sequence group.
  250. Used in prefill (can skip prefill of some blocks).
  251. """
  252. # TODO: Track computed blocks.
  253. computed = lambda block_id: False
  254. # NOTE We exclude the last block to avoid the case where the entire
  255. # prompt is cached. This would cause erroneous behavior in model
  256. # runner.
  257. ids_list = [
  258. takewhile(lambda block_id: computed(block_id), seq[:-1])
  259. for seq in seq_block_ids
  260. ]
  261. return commonprefix([ids for ids in ids_list if ids != []])
  262. class PrefixCachingBlock(Block):
  263. """A block implementation that supports prefix caching.
  264. The PrefixCachingBlock class represents a block of token IDs with prefix
  265. caching capabilities. It wraps a NaiveBlock internally and provides
  266. additional functionality for content hashing and promoting immutable blocks
  267. with the prefix caching allocator.
  268. Args:
  269. prev_block (Optional[PrefixCachingBlock]): The previous block in the
  270. sequence.
  271. token_ids (List[int]): The initial token IDs to be stored in the block.
  272. block_size (int): The maximum number of token IDs that can be stored in
  273. the block.
  274. prefix_caching_allocator (PrefixCachingBlockAllocator): The prefix
  275. caching block allocator associated with this block.
  276. block_id (Optional[int], optional): The physical block index
  277. of this block. Defaults to None.
  278. """
  279. def __init__(
  280. self,
  281. prev_block: Optional["PrefixCachingBlock"],
  282. token_ids: List[int],
  283. block_size: int,
  284. prefix_caching_allocator: PrefixCachingBlockAllocator,
  285. block_id: Optional[int] = None,
  286. ):
  287. assert_prefix_caching_block_or_none(prev_block)
  288. self._prev_block = prev_block
  289. self._cached_content_hash: Optional[int] = None
  290. self._prefix_caching_allocator = prefix_caching_allocator
  291. self._block = NaiveBlock(
  292. prev_block=prev_block,
  293. token_ids=token_ids,
  294. block_size=block_size,
  295. block_id=block_id,
  296. allocator=prefix_caching_allocator,
  297. _cow_target=self,
  298. )
  299. def append_token_ids(self, token_ids: List[int]) -> None:
  300. """Appends the given token IDs to the block and registers the block as
  301. immutable if the block becomes full.
  302. Internally, the naive block handles CoW.
  303. Args:
  304. token_ids (List[int]): The token IDs to be appended to the block.
  305. """
  306. assert token_ids
  307. # naive block handles CoW.
  308. self._block.append_token_ids(token_ids)
  309. # If the content hash is present, then the block can be made immutable.
  310. # Register ourselves with the allocator, potentially replacing the
  311. # physical block index.
  312. if self.content_hash is not None:
  313. self.block_id = (self._prefix_caching_allocator.
  314. promote_to_immutable_block(self))
  315. @property
  316. def block_id(self) -> Optional[int]:
  317. return self._block.block_id
  318. @block_id.setter
  319. def block_id(self, value) -> None:
  320. self._block.block_id = value
  321. @property
  322. def is_full(self) -> bool:
  323. return self._block.is_full
  324. @property
  325. def num_empty_slots(self) -> int:
  326. return self._block.num_empty_slots
  327. @property
  328. def block_size(self) -> int:
  329. return self._block.block_size
  330. @property
  331. def token_ids(self) -> List[int]:
  332. return self._block.token_ids
  333. @property
  334. def prev_block(self) -> Optional[Block]:
  335. return self._prev_block
  336. @property
  337. def content_hash(self) -> Optional[int]:
  338. """Return the content-based hash of the current block, or None if it is
  339. not yet defined.
  340. For the content-based hash to be defined, the current block must be
  341. full.
  342. """
  343. # If the hash is already computed, return it.
  344. if self._cached_content_hash is not None:
  345. return self._cached_content_hash
  346. # We cannot compute a hash for the current block because it is not full.
  347. if not self.is_full:
  348. return None
  349. is_first_block = self._prev_block is None
  350. prev_block_hash = (None if is_first_block else
  351. self._prev_block.content_hash)
  352. # Previous block exists but does not yet have a hash.
  353. # Return no hash in this case.
  354. if prev_block_hash is None and not is_first_block:
  355. return None
  356. self._cached_content_hash = PrefixCachingBlock.hash_block_tokens(
  357. is_first_block,
  358. prev_block_hash,
  359. cur_block_token_ids=self.token_ids)
  360. return self._cached_content_hash
  361. @staticmethod
  362. def hash_block_tokens(
  363. is_first_block: bool,
  364. prev_block_hash: Optional[int],
  365. cur_block_token_ids: List[int],
  366. ) -> int:
  367. """Computes a hash value corresponding to the contents of a block and
  368. the contents of the preceding block(s). The hash value is used for
  369. prefix caching.
  370. NOTE: Content-based hashing does not yet support LoRA.
  371. Parameters:
  372. - is_first_block (bool): A flag indicating if the block is the first in
  373. the sequence.
  374. - prev_block_hash (Optional[int]): The hash of the previous block. None
  375. if this is the first block.
  376. - cur_block_token_ids (List[int]): A list of token ids in the current
  377. block. The current block is assumed to be full.
  378. Returns:
  379. - int: The computed hash value for the block.
  380. """
  381. assert (prev_block_hash is None) == is_first_block
  382. return hash((is_first_block, prev_block_hash, *cur_block_token_ids))
  383. def assert_prefix_caching_block_or_none(block: Optional[Block]):
  384. if block is None:
  385. return
  386. assert isinstance(block, PrefixCachingBlock)