prefix_caching_block.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987
  1. """Token blocks."""
  2. from os.path import commonprefix
  3. from typing import Dict, FrozenSet, Iterable, List, Optional, Set, Tuple
  4. from aphrodite.common.utils import cdiv
  5. from aphrodite.processing.block.common import (CacheMetricData,
  6. CopyOnWriteTracker,
  7. get_all_blocks_recursively)
  8. from aphrodite.processing.block.interfaces import (Block, BlockAllocator,
  9. BlockId, Device)
  10. from aphrodite.processing.block.naive_block import (BlockPool, NaiveBlock,
  11. NaiveBlockAllocator)
  12. from aphrodite.processing.evictor_v2 import (EvictionPolicy, Evictor,
  13. make_evictor)
  14. PrefixHash = int
  15. # By default, we init our block access time as _DEFAULT_LAST_ACCESSED_TIME
  16. # so that if we find one block is still hold _DEFAULT_LAST_ACCESSED_TIME,
  17. # then we know this block hasn't been accessed yet.
  18. _DEFAULT_LAST_ACCESSED_TIME = -1
  19. class BlockTracker:
  20. """Used to track the status of a block inside the prefix caching allocator
  21. """
  22. __slots__ = ("active", "last_accessed", "computed")
  23. def reset(self):
  24. self.last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME
  25. self.computed: bool = False
  26. def __init__(self):
  27. self.active: bool = False
  28. self.reset()
  29. def enable(self):
  30. assert not self.active
  31. self.active = True
  32. self.reset()
  33. def disable(self):
  34. assert self.active
  35. self.active = False
  36. self.reset()
  37. class PrefixCachingBlockAllocator(BlockAllocator):
  38. """A block allocator that implements prefix caching.
  39. The PrefixCachingBlockAllocator maintains a cache of blocks based on their
  40. content hash. It reuses blocks with the same content hash to avoid redundant
  41. memory allocation. The allocator also supports copy-on-write operations.
  42. Args:
  43. num_blocks (int): The total number of blocks to manage.
  44. block_size (int): The size of each block in tokens.
  45. block_ids(Optional[Iterable[int]], optional): An optional iterable of
  46. block IDs. If not provided, block IDs will be assigned sequentially
  47. from 0 to num_blocks - 1.
  48. """
  49. def __init__(
  50. self,
  51. num_blocks: int,
  52. block_size: int,
  53. block_ids: Optional[Iterable[int]] = None,
  54. eviction_policy: EvictionPolicy = EvictionPolicy.LRU,
  55. ):
  56. if block_ids is None:
  57. block_ids = range(num_blocks)
  58. self._block_size = block_size
  59. # A mapping of prefix hash to block index. All blocks which have a
  60. # prefix hash will be in this dict, even if they have refcount 0.
  61. self._cached_blocks: Dict[PrefixHash, BlockId] = {}
  62. # A list of immutable block IDs that have been touched by scheduler
  63. # and should be marked as computed after an entire batch of sequences
  64. # are scheduled.
  65. self._touched_blocks: Set[BlockId] = set()
  66. # Used to track status of each physical block id
  67. self._block_tracker: Dict[BlockId, BlockTracker] = {}
  68. for block_id in block_ids:
  69. self._block_tracker[block_id] = BlockTracker()
  70. # Pre-allocate "num_blocks * extra_factor" block objects.
  71. # The "* extra_factor" is a buffer to allow more block objects
  72. # than physical blocks
  73. extra_factor = 4
  74. self._block_pool = BlockPool(self._block_size, self._create_block,
  75. self, num_blocks * extra_factor)
  76. # An allocator for blocks that do not have prefix hashes.
  77. self._hashless_allocator = NaiveBlockAllocator(
  78. create_block=self._create_block, # type: ignore
  79. num_blocks=num_blocks,
  80. block_size=block_size,
  81. block_ids=block_ids,
  82. block_pool=self._block_pool, # Share block pool here
  83. )
  84. # Evitor used to maintain how we want to handle those computed blocks
  85. # if we find memory pressure is high.
  86. self.evictor: Evictor = make_evictor(eviction_policy)
  87. # We share the refcounter between allocators. This allows us to promote
  88. # blocks originally allocated in the hashless allocator to immutable
  89. # blocks.
  90. self._refcounter = self._hashless_allocator.refcounter
  91. self._cow_tracker = CopyOnWriteTracker(
  92. refcounter=self._refcounter.as_readonly())
  93. self.metric_data = CacheMetricData()
  94. # Implements Block.Factory.
  95. def _create_block(
  96. self,
  97. prev_block: Optional[Block],
  98. token_ids: List[int],
  99. block_size: int,
  100. allocator: BlockAllocator,
  101. block_id: Optional[int] = None,
  102. computed: bool = False,
  103. ) -> Block:
  104. # Bind block to self.
  105. allocator = self
  106. return PrefixCachingBlock(
  107. prev_block=prev_block,
  108. token_ids=token_ids,
  109. block_size=block_size,
  110. block_id=block_id,
  111. allocator=allocator,
  112. computed=computed,
  113. )
  114. def allocate_immutable_block(self,
  115. prev_block: Optional[Block],
  116. token_ids: List[int],
  117. device: Optional[Device] = None) -> Block:
  118. """Allocates an immutable block with the given token IDs, reusing cached
  119. blocks if possible.
  120. Args:
  121. prev_block (Optional[Block]): The previous block in the sequence.
  122. token_ids (List[int]): The token IDs to be stored in the block.
  123. Returns:
  124. Block: The allocated immutable block.
  125. """
  126. assert device is None
  127. assert_prefix_caching_block_or_none(prev_block)
  128. # First, try to create a block that points to cached data
  129. block = self._block_pool.init_block(prev_block=prev_block,
  130. token_ids=token_ids,
  131. block_size=self._block_size,
  132. physical_block_id=None)
  133. assert block.content_hash is not None
  134. cached_block_id = self._cached_blocks.get(block.content_hash, None)
  135. if cached_block_id is not None:
  136. self.metric_data.query(hit=True)
  137. block.block_id = cached_block_id
  138. self._incr_refcount_cached_block(block)
  139. return block
  140. self.metric_data.query(hit=False)
  141. self._block_pool.free_block(block)
  142. # No cached block => Allocate a new block
  143. block = self.allocate_mutable_block(prev_block)
  144. block.append_token_ids(token_ids)
  145. return block
  146. def allocate_immutable_blocks(
  147. self,
  148. prev_block: Optional[Block],
  149. block_token_ids: List[List[int]],
  150. device: Optional[Device] = None) -> List[Block]:
  151. blocks = []
  152. for token_ids in block_token_ids:
  153. prev_block = self.allocate_immutable_block(prev_block=prev_block,
  154. token_ids=token_ids,
  155. device=device)
  156. blocks.append(prev_block)
  157. return blocks
  158. def allocate_mutable_block(self,
  159. prev_block: Optional[Block],
  160. device: Optional[Device] = None) -> Block:
  161. """Allocates a mutable block. If there are no free blocks, this will
  162. evict unused cached blocks.
  163. Args:
  164. prev_block (Block): The previous block in the sequence.
  165. None is not allowed unlike it is super class.
  166. Returns:
  167. Block: The allocated mutable block.
  168. """
  169. assert device is None
  170. assert_prefix_caching_block_or_none(prev_block)
  171. block_id = self._allocate_block_id()
  172. block = self._block_pool.init_block(prev_block=prev_block,
  173. token_ids=[],
  174. block_size=self._block_size,
  175. physical_block_id=block_id)
  176. assert not block.computed
  177. assert block.content_hash is None
  178. return block
  179. def _incr_refcount_cached_block(self, block: Block) -> None:
  180. # Set this block to be "computed" since it is pointing to a
  181. # cached block id (which was already computed)
  182. block.computed = True
  183. block_id = block.block_id
  184. assert block_id is not None
  185. refcount = self._refcounter.incr(block_id)
  186. if refcount == 1:
  187. # In case a cached block was evicted, restore its tracking
  188. if block_id in self.evictor:
  189. self.evictor.remove(block_id)
  190. self._track_block_id(block_id, computed=True)
  191. def _decr_refcount_cached_block(self, block: Block) -> None:
  192. # Ensure this is immutable/cached block
  193. assert block.content_hash is not None
  194. block_id = block.block_id
  195. assert block_id is not None
  196. refcount = self._refcounter.decr(block_id)
  197. if refcount > 0:
  198. block.block_id = None
  199. return
  200. else:
  201. assert refcount == 0
  202. # No longer used
  203. assert block.content_hash in self._cached_blocks
  204. # Add the cached block to the evictor
  205. # (This keeps the cached block around so it can be reused)
  206. self.evictor.add(block_id, block.content_hash, block.num_tokens_total,
  207. self._block_tracker[block_id].last_accessed)
  208. # Stop tracking the block
  209. self._untrack_block_id(block_id)
  210. block.block_id = None
  211. def _decr_refcount_hashless_block(self, block: Block) -> None:
  212. block_id = block.block_id
  213. assert block_id is not None
  214. # We may have a fork case where block is shared,
  215. # in which case, we cannot remove it from tracking
  216. refcount = self._refcounter.get(block_id)
  217. if refcount == 1:
  218. self._untrack_block_id(block_id)
  219. # Decrement refcount of the block_id, but do not free the block object
  220. # itself (will be handled by the caller)
  221. self._hashless_allocator.free(block, keep_block_object=True)
  222. def _allocate_block_id(self) -> BlockId:
  223. """First tries to allocate a block id from the hashless allocator,
  224. and if there are no blocks, then tries to evict an unused cached block.
  225. """
  226. hashless_block_id = self._maybe_allocate_hashless_block_id()
  227. if hashless_block_id is not None:
  228. return hashless_block_id
  229. evicted_block_id = self._maybe_allocate_evicted_block_id()
  230. if evicted_block_id is not None:
  231. return evicted_block_id
  232. # No block available in hashless allocator, nor in unused cache blocks.
  233. raise BlockAllocator.NoFreeBlocksError()
  234. def _maybe_allocate_hashless_block_id(self) -> Optional[BlockId]:
  235. try:
  236. # Allocate mutable block and extract its block_id
  237. block = self._hashless_allocator.allocate_mutable_block(
  238. prev_block=None)
  239. block_id = block.block_id
  240. self._block_pool.free_block(block)
  241. self._track_block_id(block_id, computed=False)
  242. return block_id
  243. except BlockAllocator.NoFreeBlocksError:
  244. return None
  245. def _maybe_allocate_evicted_block_id(self) -> Optional[BlockId]:
  246. if self.evictor.num_blocks == 0:
  247. return None
  248. # Here we get an evicted block, which is only added
  249. # into evictor if its ref counter is 0
  250. # and since its content would be changed, we need
  251. # to remove it from _cached_blocks's tracking list
  252. block_id, content_hash_to_evict = self.evictor.evict()
  253. # Sanity checks
  254. assert content_hash_to_evict in self._cached_blocks
  255. _block_id = self._cached_blocks[content_hash_to_evict]
  256. assert self._refcounter.get(_block_id) == 0
  257. assert _block_id == block_id
  258. self._cached_blocks.pop(content_hash_to_evict)
  259. self._refcounter.incr(block_id)
  260. self._track_block_id(block_id, computed=False)
  261. return block_id
  262. def _free_block_id(self, block: Block) -> None:
  263. """Decrements the refcount of the block. The block may be in two
  264. possible states: (1) immutable/cached or (2) mutable/hashless.
  265. In the first case, the refcount is decremented directly and the block
  266. may be possibly added to the evictor. In other case, hashless
  267. allocator free(..) with keep_block_object=True is called to only free
  268. the block id (since the block object may be reused by the caller)
  269. """
  270. block_id = block.block_id
  271. assert block_id is not None, "Freeing unallocated block is undefined"
  272. if block.content_hash is not None:
  273. # Immutable: This type of block is always cached, and we want to
  274. # keep it in the evictor for future reuse
  275. self._decr_refcount_cached_block(block)
  276. else:
  277. # Mutable: This type of block is not cached, so we release it
  278. # directly to the hashless allocator
  279. self._decr_refcount_hashless_block(block)
  280. assert block.block_id is None
  281. def free(self, block: Block, keep_block_object: bool = False) -> None:
  282. """Release the block (look at free_block_id(..) docs)
  283. """
  284. # Release the physical block index
  285. self._free_block_id(block)
  286. # Release the block object to the pool
  287. if not keep_block_object:
  288. self._block_pool.free_block(block)
  289. def fork(self, last_block: Block) -> List[Block]:
  290. """Creates a new sequence of blocks that shares the same underlying
  291. memory as the original sequence.
  292. Args:
  293. last_block (Block): The last block in the original sequence.
  294. Returns:
  295. List[Block]: The new sequence of blocks that shares the same memory
  296. as the original sequence.
  297. """
  298. source_blocks = get_all_blocks_recursively(last_block)
  299. forked_blocks: List[Block] = []
  300. prev_block = None
  301. for block in source_blocks:
  302. block_id = block.block_id
  303. assert block_id is not None
  304. refcount = self._refcounter.incr(block_id)
  305. assert refcount != 1, "can't fork free'd block_id = {}".format(
  306. block_id)
  307. forked_block = self._block_pool.init_block(
  308. prev_block=prev_block,
  309. token_ids=block.token_ids,
  310. block_size=self._block_size,
  311. physical_block_id=block_id)
  312. forked_blocks.append(forked_block)
  313. prev_block = forked_blocks[-1]
  314. return forked_blocks
  315. def get_num_free_blocks(self, device: Optional[Device] = None) -> int:
  316. assert device is None
  317. # The number of free blocks is the number of hashless free blocks
  318. # plus the number of blocks evictor could free from its list.
  319. return self._hashless_allocator.get_num_free_blocks(
  320. ) + self.evictor.num_blocks
  321. def get_num_total_blocks(self) -> int:
  322. return self._hashless_allocator.get_num_total_blocks()
  323. def get_physical_block_id(self, absolute_id: int) -> int:
  324. """Returns the zero-offset block id on certain block allocator
  325. given the absolute block id.
  326. Args:
  327. absolute_id (int): The absolute block id for the block
  328. in whole allocator.
  329. Returns:
  330. int: The rzero-offset block id on certain device.
  331. """
  332. return sorted(self.all_block_ids).index(absolute_id)
  333. @property
  334. def all_block_ids(self) -> FrozenSet[int]:
  335. return self._hashless_allocator.all_block_ids
  336. def get_prefix_cache_hit_rate(self) -> float:
  337. return self.metric_data.get_hit_rate()
  338. def is_block_cached(self, block: Block) -> bool:
  339. assert block.content_hash is not None
  340. if block.content_hash in self._cached_blocks:
  341. return True
  342. return False
  343. def promote_to_immutable_block(self, block: Block) -> BlockId:
  344. """Once a mutable block is full, it can be promoted to an immutable
  345. block. This means that its content can be referenced by future blocks
  346. having the same prefix.
  347. Note that if we already have a cached block with the same content, we
  348. will replace the newly-promoted block's mapping with the existing cached
  349. block id.
  350. Args:
  351. block: The mutable block to be promoted.
  352. Returns:
  353. BlockId: Either the original block index, or the block index of
  354. the previously cached block matching the same content.
  355. """
  356. # Ensure block can be promoted
  357. assert block.content_hash is not None
  358. assert block.block_id is not None
  359. assert self._refcounter.get(block.block_id) > 0
  360. if block.content_hash not in self._cached_blocks:
  361. # No cached content hash => Set this block as cached.
  362. # Note that this block cannot be marked as computed yet
  363. # because other sequences in the same batch cannot reuse
  364. # this block.
  365. self._cached_blocks[block.content_hash] = block.block_id
  366. # Mark this block as touched so that it can be marked as
  367. # computed after the entire batch of sequences are scheduled.
  368. self._touched_blocks.add(block.block_id)
  369. return block.block_id
  370. # Reuse the cached content hash
  371. self._decr_refcount_hashless_block(block)
  372. block.block_id = self._cached_blocks[block.content_hash]
  373. # Increment refcount of the cached block and (possibly) restore
  374. # it from the evictor.
  375. # Note that in this case, the block is marked as computed
  376. self._incr_refcount_cached_block(block)
  377. return block.block_id
  378. def cow_block_if_not_appendable(self, block: Block) -> BlockId:
  379. """Performs a copy-on-write operation on the given block if it is not
  380. appendable.
  381. Args:
  382. block (Block): The block to check for copy-on-write.
  383. Returns:
  384. BlockId: The block index of the new block if a copy-on-write
  385. operation was performed, or the original block index if
  386. no copy-on-write was necessary.
  387. """
  388. src_block_id = block.block_id
  389. assert src_block_id is not None
  390. if self._cow_tracker.is_appendable(block):
  391. return src_block_id
  392. self._free_block_id(block)
  393. trg_block_id = self._allocate_block_id()
  394. self._cow_tracker.record_cow(src_block_id, trg_block_id)
  395. return trg_block_id
  396. def clear_copy_on_writes(self) -> List[Tuple[BlockId, BlockId]]:
  397. """Returns the copy-on-write source->destination mapping and clears it.
  398. Returns:
  399. List[Tuple[BlockId, BlockId]]: A list mapping source
  400. block indices to destination block indices.
  401. """
  402. return self._cow_tracker.clear_cows()
  403. def mark_blocks_as_accessed(self, block_ids: List[int],
  404. now: float) -> None:
  405. """Mark blocks as accessed, used in prefix caching.
  406. If the block is added into evictor, we need to update corresponding
  407. info in evictor's metadata.
  408. """
  409. for block_id in block_ids:
  410. if self._block_tracker[block_id].active:
  411. self._block_tracker[block_id].last_accessed = now
  412. elif block_id in self.evictor:
  413. self.evictor.update(block_id, now)
  414. else:
  415. raise ValueError(
  416. "Mark block as accessed which is not belonged to GPU")
  417. def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
  418. # Mark all touched blocks as computed.
  419. for block_id in self._touched_blocks:
  420. self._block_tracker[block_id].computed = True
  421. self._touched_blocks.clear()
  422. def _track_block_id(self, block_id: Optional[BlockId],
  423. computed: bool) -> None:
  424. assert block_id is not None
  425. self._block_tracker[block_id].enable()
  426. self._block_tracker[block_id].computed = computed
  427. def _untrack_block_id(self, block_id: Optional[BlockId]) -> None:
  428. assert block_id is not None
  429. self._block_tracker[block_id].disable()
  430. def block_is_computed(self, block_id: int) -> bool:
  431. if self._block_tracker[block_id].active:
  432. return self._block_tracker[block_id].computed
  433. else:
  434. return block_id in self.evictor
  435. def get_computed_block_ids(self,
  436. prev_computed_block_ids: List[int],
  437. block_ids: List[int],
  438. skip_last_block_id: bool = True) -> List[int]:
  439. prev_prefix_size = len(prev_computed_block_ids)
  440. cur_size = len(block_ids)
  441. if skip_last_block_id:
  442. cur_size -= 1
  443. # Sanity checks
  444. assert cur_size >= 0
  445. assert prev_prefix_size <= cur_size
  446. ret = prev_computed_block_ids
  447. for i in range(prev_prefix_size, cur_size):
  448. block_id = block_ids[i]
  449. if self.block_is_computed(block_id):
  450. ret.append(block_id)
  451. return ret
  452. def get_common_computed_block_ids(
  453. self, computed_seq_block_ids: List[List[int]]) -> List[int]:
  454. """Return the block ids that are common for a given sequence group.
  455. Only those blocks that are immutable and already be marked
  456. compyted would be taken consideration.
  457. """
  458. # NOTE We exclude the last block to avoid the case where the entire
  459. # prompt is cached. This would cause erroneous behavior in model
  460. # runner.
  461. # It returns a list of int although type annotation says list of string.
  462. if len(computed_seq_block_ids) == 1:
  463. return computed_seq_block_ids[0]
  464. return commonprefix([
  465. ids for ids in computed_seq_block_ids # type: ignore
  466. if ids
  467. ])
  468. def get_num_blocks_touched(self,
  469. blocks: List[Block],
  470. num_lookahead_slots: int = 0) -> int:
  471. """Determine the number of blocks that will be touched by
  472. swapping in/out the given blocks from certain sequence
  473. group with the provided num_lookahead_slots.
  474. Args:
  475. blocks (List[Block]): The potential blocks to swap.
  476. num_lookahead_slots (int): number of lookahead slots (0 for
  477. swap out).
  478. Returns:
  479. int: the number of blocks that will be touched by
  480. swapping in/out the given blocks and num_lookahead_slots.
  481. """
  482. num_touched_blocks = 0
  483. for block in blocks:
  484. if not block.is_full:
  485. num_touched_blocks += 1
  486. if num_lookahead_slots > block.num_empty_slots:
  487. num_touched_blocks += cdiv(
  488. num_lookahead_slots - block.num_empty_slots,
  489. self._block_size)
  490. else:
  491. # If the block has a match in the cache and the cached block
  492. # is not referenced, then we still count it as a touched block
  493. if not self.is_block_cached(block) or \
  494. (block.content_hash is not None and \
  495. self._cached_blocks[block.content_hash] in self.evictor):
  496. num_touched_blocks += 1
  497. return num_touched_blocks
  498. def swap_out(self, blocks: List[Block]) -> None:
  499. """Execute the swap out actions. Basically just free the
  500. given blocks.
  501. Args:
  502. blocks: List of blocks to be swapped out.
  503. """
  504. for block in blocks:
  505. self._free_block_id(block)
  506. def swap_in(self, blocks: List[Block]) -> None:
  507. """Execute the swap in actions. Change the block id from
  508. old allocator to current allocator for each block to finish
  509. the block table update.
  510. Args:
  511. blocks: List of blocks to be swapped in.
  512. """
  513. for block in blocks:
  514. # Here we allocate either immutable or mutable block and then
  515. # extract its block_id. Note that the block object is released
  516. # and the block_id is assigned to "block" to allow reusing the
  517. # existing "block" object
  518. if block.is_full:
  519. tmp_block = self.allocate_immutable_block(
  520. prev_block=block.prev_block, token_ids=block.token_ids)
  521. else:
  522. tmp_block = self.allocate_mutable_block(
  523. prev_block=block.prev_block)
  524. tmp_block.append_token_ids(block.token_ids)
  525. block_id = tmp_block.block_id
  526. self._block_pool.free_block(tmp_block)
  527. block.block_id = block_id # Assign block_id
  528. class PrefixCachingBlock(Block):
  529. """A block implementation that supports prefix caching.
  530. The PrefixCachingBlock class represents a block of token IDs with prefix
  531. caching capabilities. It wraps a NaiveBlock internally and provides
  532. additional functionality for content hashing and promoting immutable blocks
  533. with the prefix caching allocator.
  534. Args:
  535. prev_block (Optional[PrefixCachingBlock]): The previous block in the
  536. sequence.
  537. token_ids (List[int]): The initial token IDs to be stored in the block.
  538. block_size (int): The maximum number of token IDs that can be stored in
  539. the block.
  540. allocator (BlockAllocator): The prefix
  541. caching block allocator associated with this block.
  542. block_id (Optional[int], optional): The physical block index
  543. of this block. Defaults to None.
  544. """
  545. def __init__(
  546. self,
  547. prev_block: Optional[Block],
  548. token_ids: List[int],
  549. block_size: int,
  550. allocator: BlockAllocator,
  551. block_id: Optional[int] = None,
  552. computed: bool = False,
  553. ):
  554. assert isinstance(allocator, PrefixCachingBlockAllocator), (
  555. "Currently this class is only tested with "
  556. "PrefixCachingBlockAllocator. Got instead allocator = {}".format(
  557. allocator))
  558. assert_prefix_caching_block_or_none(prev_block)
  559. self._prev_block = prev_block
  560. self._cached_content_hash: Optional[int] = None
  561. self._cached_num_tokens_total: int = 0
  562. self._allocator = allocator
  563. self._last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME
  564. self._computed = computed
  565. # On the first time, we create the block object, and next we only
  566. # reinitialize it
  567. if hasattr(self, "_block"):
  568. self._block.__init__( # type: ignore[has-type]
  569. prev_block=prev_block,
  570. token_ids=token_ids,
  571. block_size=block_size,
  572. block_id=block_id,
  573. allocator=self._allocator)
  574. else:
  575. self._block = NaiveBlock(prev_block=prev_block,
  576. token_ids=token_ids,
  577. block_size=block_size,
  578. block_id=block_id,
  579. allocator=self._allocator)
  580. self._update_num_tokens_total()
  581. def _update_num_tokens_total(self):
  582. """Incrementally computes the number of tokens that there is
  583. till the current block (included)
  584. """
  585. res = 0
  586. # Add all previous blocks
  587. if self._prev_block is not None:
  588. res += self._prev_block.num_tokens_total
  589. # Add current block
  590. res += len(self.token_ids)
  591. self._cached_num_tokens_total = res
  592. @property
  593. def computed(self) -> bool:
  594. return self._computed
  595. @computed.setter
  596. def computed(self, value) -> None:
  597. self._computed = value
  598. @property
  599. def last_accessed(self) -> float:
  600. return self._last_accessed
  601. @last_accessed.setter
  602. def last_accessed(self, last_accessed_ts: float):
  603. self._last_accessed = last_accessed_ts
  604. def append_token_ids(self, token_ids: List[int]) -> None:
  605. """Appends the given token IDs to the block and registers the block as
  606. immutable if the block becomes full.
  607. Args:
  608. token_ids (List[int]): The token IDs to be appended to the block.
  609. """
  610. # Ensure this is mutable block (not promoted)
  611. assert self.content_hash is None
  612. assert not self.computed
  613. if len(token_ids) == 0:
  614. return
  615. # Ensure there are input tokens
  616. assert token_ids, "Got token_ids = {}".format(token_ids)
  617. # Naive block handles CoW.
  618. self._block.append_token_ids(token_ids)
  619. self._update_num_tokens_total()
  620. # If the content hash is present, then the block can be made immutable.
  621. # Register ourselves with the allocator, potentially replacing the
  622. # physical block index.
  623. if self.content_hash is not None:
  624. self.block_id = self._allocator.promote_to_immutable_block(self)
  625. @property
  626. def block_id(self) -> Optional[int]:
  627. return self._block.block_id
  628. @block_id.setter
  629. def block_id(self, value) -> None:
  630. self._block.block_id = value
  631. @property
  632. def is_full(self) -> bool:
  633. return self._block.is_full
  634. @property
  635. def num_empty_slots(self) -> int:
  636. return self._block.num_empty_slots
  637. @property
  638. def num_tokens_total(self) -> int:
  639. return self._cached_num_tokens_total
  640. @property
  641. def block_size(self) -> int:
  642. return self._block.block_size
  643. @property
  644. def token_ids(self) -> List[int]:
  645. return self._block.token_ids
  646. @property
  647. def prev_block(self) -> Optional[Block]:
  648. return self._prev_block
  649. @property
  650. def content_hash(self) -> Optional[int]:
  651. """Return the content-based hash of the current block, or None if it is
  652. not yet defined.
  653. For the content-based hash to be defined, the current block must be
  654. full.
  655. """
  656. # If the hash is already computed, return it.
  657. if self._cached_content_hash is not None:
  658. return self._cached_content_hash
  659. # We cannot compute a hash for the current block because it is not full.
  660. if not self.is_full:
  661. return None
  662. is_first_block = self._prev_block is None
  663. prev_block_hash = (
  664. None if is_first_block else
  665. self._prev_block.content_hash # type: ignore
  666. )
  667. # Previous block exists but does not yet have a hash.
  668. # Return no hash in this case.
  669. if prev_block_hash is None and not is_first_block:
  670. return None
  671. self._cached_content_hash = PrefixCachingBlock.hash_block_tokens(
  672. is_first_block,
  673. prev_block_hash,
  674. cur_block_token_ids=self.token_ids)
  675. return self._cached_content_hash
  676. @staticmethod
  677. def hash_block_tokens(is_first_block: bool, prev_block_hash: Optional[int],
  678. cur_block_token_ids: List[int]) -> int:
  679. """Computes a hash value corresponding to the contents of a block and
  680. the contents of the preceding block(s). The hash value is used for
  681. prefix caching.
  682. NOTE: Content-based hashing does not yet support LoRA.
  683. Parameters:
  684. - is_first_block (bool): A flag indicating if the block is the first in
  685. the sequence.
  686. - prev_block_hash (Optional[int]): The hash of the previous block. None
  687. if this is the first block.
  688. - cur_block_token_ids (List[int]): A list of token ids in the current
  689. block. The current block is assumed to be full.
  690. Returns:
  691. - int: The computed hash value for the block.
  692. """
  693. assert (prev_block_hash is None) == is_first_block
  694. return hash((is_first_block, prev_block_hash, *cur_block_token_ids))
  695. class ComputedBlocksTracker:
  696. """Handles caching of per-sequence computed block ids.
  697. When a sequence appears for the first time, it traverses all of the
  698. blocks and detects the prefix of blocks that is computed. On the
  699. subsequent times, it only traverses the new blocks that were added
  700. and updates the already recorded prefix of blocks with the newly
  701. computed blocks.
  702. To avoid redundant traversals, the algorithm also detects when there
  703. is a "gap" in the computed prefix. For example, if we have blocks =
  704. [1,2,3,4,5], and we have detected [1,2,3] as the computed prefix, then
  705. we won't try to add more computed blocks to [1,2,3] in this sequence
  706. iteration, and will add more computed blocks only after the sequence is
  707. freed and reused again.
  708. Note that currently, for a given sequence, we also skip the last
  709. block id for caching purposes, to avoid caching of a full sequence
  710. """
  711. def __init__(self, allocator):
  712. self._allocator = allocator
  713. self._cached_computed_seq_blocks: Dict[int, Tuple[List[int],
  714. bool]] = {}
  715. def add_seq(self, seq_id: int) -> None:
  716. """Start tracking seq_id
  717. """
  718. assert seq_id not in self._cached_computed_seq_blocks
  719. self._cached_computed_seq_blocks[seq_id] = ([], False)
  720. def remove_seq(self, seq_id: int) -> None:
  721. """Stop tracking seq_id
  722. """
  723. assert seq_id in self._cached_computed_seq_blocks
  724. del self._cached_computed_seq_blocks[seq_id]
  725. def get_cached_computed_blocks_and_update(
  726. self, seq_id: int, block_ids: List[int]) -> List[int]:
  727. """ Look at the class documentation for details
  728. """
  729. # Ensure seq_id is already tracked
  730. assert seq_id in self._cached_computed_seq_blocks
  731. # Get cached data (may be empty on the first time)
  732. prev_computed_block_ids, has_gap = self._cached_computed_seq_blocks[
  733. seq_id]
  734. if has_gap:
  735. # When gap is detected, we do not add more computed blocks at this
  736. # sequence iteration
  737. return prev_computed_block_ids
  738. # We do not consider the last block id for caching purposes.
  739. num_cur_blocks = len(block_ids) - 1
  740. assert num_cur_blocks >= 0
  741. if len(prev_computed_block_ids) >= num_cur_blocks:
  742. # Cache HIT
  743. assert len(prev_computed_block_ids) == num_cur_blocks
  744. return prev_computed_block_ids
  745. # If here, then we may possibly add more computed blocks. As a result,
  746. # traverse the additional blocks after prev_computed_block_ids to
  747. # detect more computed blocks and add them.
  748. # Incremental init for seq_id => Look only at the new blocks
  749. computed_block_ids = self._allocator.get_computed_block_ids( # noqa: E501
  750. prev_computed_block_ids,
  751. block_ids,
  752. skip_last_block_id=
  753. True, # We skip last block id to avoid caching of full seq
  754. )
  755. # Detect if there is a "gap"
  756. has_gap = len(computed_block_ids) < num_cur_blocks
  757. # Record
  758. self._cached_computed_seq_blocks[seq_id] = (computed_block_ids,
  759. has_gap)
  760. return computed_block_ids
  761. class LastAccessBlocksTracker:
  762. """Manages the last access time of the tracked sequences, in order to allow
  763. an efficient update of allocator's block last access times
  764. """
  765. def __init__(self, allocator):
  766. self._allocator = allocator
  767. self._seq_last_access: Dict[int, Optional[float]] = {}
  768. def add_seq(self, seq_id: int) -> None:
  769. """Start tracking seq_id
  770. """
  771. assert seq_id not in self._seq_last_access
  772. self._seq_last_access[seq_id] = None
  773. def remove_seq(self, seq_id: int) -> None:
  774. """Stop tracking seq_id
  775. """
  776. assert seq_id in self._seq_last_access
  777. del self._seq_last_access[seq_id]
  778. def update_last_access(self, seq_id: int, time: float) -> None:
  779. assert seq_id in self._seq_last_access
  780. self._seq_last_access[seq_id] = time
  781. def update_seq_blocks_last_access(self, seq_id: int,
  782. block_ids: List[int]) -> None:
  783. assert seq_id in self._seq_last_access
  784. ts = self._seq_last_access[seq_id]
  785. if ts is None:
  786. # No last access was recorded, no need to update.
  787. return
  788. self._allocator.mark_blocks_as_accessed(block_ids, ts)
  789. def assert_prefix_caching_block_or_none(block: Optional[Block]):
  790. if block is None:
  791. return
  792. assert isinstance(block,
  793. PrefixCachingBlock), "Got block = {}".format(block)