prefix_caching_block.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967
  1. """Token blocks."""
  2. from os.path import commonprefix
  3. from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple
  4. from aphrodite.common.utils import cdiv
  5. from aphrodite.processing.block.common import (CopyOnWriteTracker,
  6. get_all_blocks_recursively)
  7. from aphrodite.processing.block.interfaces import (Block, BlockAllocator,
  8. BlockId, Device)
  9. from aphrodite.processing.block.naive_block import (BlockPool, NaiveBlock,
  10. NaiveBlockAllocator)
  11. from aphrodite.processing.evictor_v2 import (EvictionPolicy, Evictor,
  12. make_evictor)
  13. PrefixHash = int
  14. # By default, we init our block access time as _DEFAULT_LAST_ACCESSED_TIME
  15. # so that if we find one block is still hold _DEFAULT_LAST_ACCESSED_TIME,
  16. # then we know this block hasn't been accessed yet.
  17. _DEFAULT_LAST_ACCESSED_TIME = -1
  18. class BlockTracker:
  19. """Used to track the status of a block inside the prefix caching allocator
  20. """
  21. __slots__ = ("active", "last_accessed", "computed")
  22. def reset(self):
  23. self.last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME
  24. self.computed: bool = False
  25. def __init__(self):
  26. self.active: bool = False
  27. self.reset()
  28. def enable(self):
  29. assert not self.active
  30. self.active = True
  31. self.reset()
  32. def disable(self):
  33. assert self.active
  34. self.active = False
  35. self.reset()
  36. class PrefixCachingBlockAllocator(BlockAllocator):
  37. """A block allocator that implements prefix caching.
  38. The PrefixCachingBlockAllocator maintains a cache of blocks based on their
  39. content hash. It reuses blocks with the same content hash to avoid redundant
  40. memory allocation. The allocator also supports copy-on-write operations.
  41. Args:
  42. num_blocks (int): The total number of blocks to manage.
  43. block_size (int): The size of each block in tokens.
  44. block_ids(Optional[Iterable[int]], optional): An optional iterable of
  45. block IDs. If not provided, block IDs will be assigned sequentially
  46. from 0 to num_blocks - 1.
  47. """
  48. def __init__(
  49. self,
  50. num_blocks: int,
  51. block_size: int,
  52. block_ids: Optional[Iterable[int]] = None,
  53. eviction_policy: EvictionPolicy = EvictionPolicy.LRU,
  54. ):
  55. if block_ids is None:
  56. block_ids = range(num_blocks)
  57. self._block_size = block_size
  58. # A mapping of prefix hash to block index. All blocks which have a
  59. # prefix hash will be in this dict, even if they have refcount 0.
  60. self._cached_blocks: Dict[PrefixHash, BlockId] = {}
  61. # Used to track status of each physical block id
  62. self._block_tracker: Dict[BlockId, BlockTracker] = {}
  63. for block_id in block_ids:
  64. self._block_tracker[block_id] = BlockTracker()
  65. # Pre-allocate "num_blocks * extra_factor" block objects.
  66. # The "* extra_factor" is a buffer to allow more block objects
  67. # than physical blocks
  68. extra_factor = 4
  69. self._block_pool = BlockPool(self._block_size, self._create_block,
  70. self, num_blocks * extra_factor)
  71. # An allocator for blocks that do not have prefix hashes.
  72. self._hashless_allocator = NaiveBlockAllocator(
  73. create_block=self._create_block, # type: ignore
  74. num_blocks=num_blocks,
  75. block_size=block_size,
  76. block_ids=block_ids,
  77. block_pool=self._block_pool, # Share block pool here
  78. )
  79. # Evitor used to maintain how we want to handle those computed blocks
  80. # if we find memory pressure is high.
  81. self.evictor: Evictor = make_evictor(eviction_policy)
  82. # We share the refcounter between allocators. This allows us to promote
  83. # blocks originally allocated in the hashless allocator to immutable
  84. # blocks.
  85. self._refcounter = self._hashless_allocator.refcounter
  86. self._cow_tracker = CopyOnWriteTracker(
  87. refcounter=self._refcounter.as_readonly())
  88. # Implements Block.Factory.
  89. def _create_block(
  90. self,
  91. prev_block: Optional[Block],
  92. token_ids: List[int],
  93. block_size: int,
  94. allocator: BlockAllocator,
  95. block_id: Optional[int] = None,
  96. computed: bool = False,
  97. ) -> Block:
  98. # Bind block to self.
  99. allocator = self
  100. return PrefixCachingBlock(
  101. prev_block=prev_block,
  102. token_ids=token_ids,
  103. block_size=block_size,
  104. block_id=block_id,
  105. allocator=allocator,
  106. computed=computed,
  107. )
  108. def allocate_immutable_block(self,
  109. prev_block: Optional[Block],
  110. token_ids: List[int],
  111. device: Optional[Device] = None) -> Block:
  112. """Allocates an immutable block with the given token IDs, reusing cached
  113. blocks if possible.
  114. Args:
  115. prev_block (Optional[Block]): The previous block in the sequence.
  116. token_ids (List[int]): The token IDs to be stored in the block.
  117. Returns:
  118. Block: The allocated immutable block.
  119. """
  120. assert device is None
  121. assert_prefix_caching_block_or_none(prev_block)
  122. # First, try to create a block that points to cached data
  123. block = self._block_pool.init_block(prev_block=prev_block,
  124. token_ids=token_ids,
  125. block_size=self._block_size,
  126. physical_block_id=None)
  127. assert block.content_hash is not None
  128. cached_block_id = self._cached_blocks.get(block.content_hash, None)
  129. if cached_block_id is not None:
  130. block.block_id = cached_block_id
  131. self._incr_refcount_cached_block(block)
  132. return block
  133. self._block_pool.free_block(block)
  134. # No cached block => Allocate a new block
  135. block = self.allocate_mutable_block(prev_block)
  136. block.append_token_ids(token_ids)
  137. return block
  138. def allocate_immutable_blocks(
  139. self,
  140. prev_block: Optional[Block],
  141. block_token_ids: List[List[int]],
  142. device: Optional[Device] = None) -> List[Block]:
  143. blocks = []
  144. for token_ids in block_token_ids:
  145. prev_block = self.allocate_immutable_block(prev_block=prev_block,
  146. token_ids=token_ids,
  147. device=device)
  148. blocks.append(prev_block)
  149. return blocks
  150. def allocate_mutable_block(self,
  151. prev_block: Optional[Block],
  152. device: Optional[Device] = None) -> Block:
  153. """Allocates a mutable block. If there are no free blocks, this will
  154. evict unused cached blocks.
  155. Args:
  156. prev_block (Block): The previous block in the sequence.
  157. None is not allowed unlike it is super class.
  158. Returns:
  159. Block: The allocated mutable block.
  160. """
  161. assert device is None
  162. assert_prefix_caching_block_or_none(prev_block)
  163. block_id = self._allocate_block_id()
  164. block = self._block_pool.init_block(prev_block=prev_block,
  165. token_ids=[],
  166. block_size=self._block_size,
  167. physical_block_id=block_id)
  168. assert not block.computed
  169. assert block.content_hash is None
  170. return block
  171. def _incr_refcount_cached_block(self, block: Block) -> None:
  172. # Set this block to be "computed" since it is pointing to a
  173. # cached block id (which was already computed)
  174. block.computed = True
  175. block_id = block.block_id
  176. assert block_id is not None
  177. refcount = self._refcounter.incr(block_id)
  178. if refcount == 1:
  179. # In case a cached block was evicted, restore its tracking
  180. if block_id in self.evictor:
  181. self.evictor.remove(block_id)
  182. self._track_block_id(block_id, computed=True)
  183. def _decr_refcount_cached_block(self, block: Block) -> None:
  184. # Ensure this is immutable/cached block
  185. assert block.content_hash is not None
  186. block_id = block.block_id
  187. assert block_id is not None
  188. refcount = self._refcounter.decr(block_id)
  189. if refcount > 0:
  190. block.block_id = None
  191. return
  192. else:
  193. assert refcount == 0
  194. # No longer used
  195. assert block.content_hash in self._cached_blocks
  196. # Add the cached block to the evictor
  197. # (This keeps the cached block around so it can be reused)
  198. self.evictor.add(block_id, block.content_hash, block.num_tokens_total,
  199. self._block_tracker[block_id].last_accessed)
  200. # Stop tracking the block
  201. self._untrack_block_id(block_id)
  202. block.block_id = None
  203. def _decr_refcount_hashless_block(self, block: Block) -> None:
  204. block_id = block.block_id
  205. assert block_id is not None
  206. # We may have a fork case where block is shared,
  207. # in which case, we cannot remove it from tracking
  208. refcount = self._refcounter.get(block_id)
  209. if refcount == 1:
  210. self._untrack_block_id(block_id)
  211. # Decrement refcount of the block_id, but do not free the block object
  212. # itself (will be handled by the caller)
  213. self._hashless_allocator.free(block, keep_block_object=True)
  214. def _allocate_block_id(self) -> BlockId:
  215. """First tries to allocate a block id from the hashless allocator,
  216. and if there are no blocks, then tries to evict an unused cached block.
  217. """
  218. hashless_block_id = self._maybe_allocate_hashless_block_id()
  219. if hashless_block_id is not None:
  220. return hashless_block_id
  221. evicted_block_id = self._maybe_allocate_evicted_block_id()
  222. if evicted_block_id is not None:
  223. return evicted_block_id
  224. # No block available in hashless allocator, nor in unused cache blocks.
  225. raise BlockAllocator.NoFreeBlocksError()
  226. def _maybe_allocate_hashless_block_id(self) -> Optional[BlockId]:
  227. try:
  228. # Allocate mutable block and extract its block_id
  229. block = self._hashless_allocator.allocate_mutable_block(
  230. prev_block=None)
  231. block_id = block.block_id
  232. self._block_pool.free_block(block)
  233. self._track_block_id(block_id, computed=False)
  234. return block_id
  235. except BlockAllocator.NoFreeBlocksError:
  236. return None
  237. def _maybe_allocate_evicted_block_id(self) -> Optional[BlockId]:
  238. if self.evictor.num_blocks == 0:
  239. return None
  240. # Here we get an evicted block, which is only added
  241. # into evictor if its ref counter is 0
  242. # and since its content would be changed, we need
  243. # to remove it from _cached_blocks's tracking list
  244. block_id, content_hash_to_evict = self.evictor.evict()
  245. # Sanity checks
  246. assert content_hash_to_evict in self._cached_blocks
  247. _block_id = self._cached_blocks[content_hash_to_evict]
  248. assert self._refcounter.get(_block_id) == 0
  249. assert _block_id == block_id
  250. self._cached_blocks.pop(content_hash_to_evict)
  251. self._refcounter.incr(block_id)
  252. self._track_block_id(block_id, computed=False)
  253. return block_id
  254. def _free_block_id(self, block: Block) -> None:
  255. """Decrements the refcount of the block. The block may be in two
  256. possible states: (1) immutable/cached or (2) mutable/hashless.
  257. In the first case, the refcount is decremented directly and the block
  258. may be possibly added to the evictor. In other case, hashless
  259. allocator free(..) with keep_block_object=True is called to only free
  260. the block id (since the block object may be reused by the caller)
  261. """
  262. block_id = block.block_id
  263. assert block_id is not None, "Freeing unallocated block is undefined"
  264. if block.content_hash is not None:
  265. # Immutable: This type of block is always cached, and we want to
  266. # keep it in the evictor for future reuse
  267. self._decr_refcount_cached_block(block)
  268. else:
  269. # Mutable: This type of block is not cached, so we release it
  270. # directly to the hashless allocator
  271. self._decr_refcount_hashless_block(block)
  272. assert block.block_id is None
  273. def free(self, block: Block, keep_block_object: bool = False) -> None:
  274. """Release the block (look at free_block_id(..) docs)
  275. """
  276. # Release the physical block index
  277. self._free_block_id(block)
  278. # Release the block object to the pool
  279. if not keep_block_object:
  280. self._block_pool.free_block(block)
  281. def fork(self, last_block: Block) -> List[Block]:
  282. """Creates a new sequence of blocks that shares the same underlying
  283. memory as the original sequence.
  284. Args:
  285. last_block (Block): The last block in the original sequence.
  286. Returns:
  287. List[Block]: The new sequence of blocks that shares the same memory
  288. as the original sequence.
  289. """
  290. source_blocks = get_all_blocks_recursively(last_block)
  291. forked_blocks: List[Block] = []
  292. prev_block = None
  293. for block in source_blocks:
  294. block_id = block.block_id
  295. assert block_id is not None
  296. refcount = self._refcounter.incr(block_id)
  297. assert refcount != 1, "can't fork free'd block_id = {}".format(
  298. block_id)
  299. forked_block = self._block_pool.init_block(
  300. prev_block=prev_block,
  301. token_ids=block.token_ids,
  302. block_size=self._block_size,
  303. physical_block_id=block_id)
  304. forked_blocks.append(forked_block)
  305. prev_block = forked_blocks[-1]
  306. return forked_blocks
  307. def get_num_free_blocks(self, device: Optional[Device] = None) -> int:
  308. assert device is None
  309. # The number of free blocks is the number of hashless free blocks
  310. # plus the number of blocks evictor could free from its list.
  311. return self._hashless_allocator.get_num_free_blocks(
  312. ) + self.evictor.num_blocks
  313. def get_num_total_blocks(self) -> int:
  314. return self._hashless_allocator.get_num_total_blocks()
  315. def get_physical_block_id(self, absolute_id: int) -> int:
  316. """Returns the zero-offset block id on certain block allocator
  317. given the absolute block id.
  318. Args:
  319. absolute_id (int): The absolute block id for the block
  320. in whole allocator.
  321. Returns:
  322. int: The rzero-offset block id on certain device.
  323. """
  324. return sorted(self.all_block_ids).index(absolute_id)
  325. @property
  326. def all_block_ids(self) -> FrozenSet[int]:
  327. return self._hashless_allocator.all_block_ids
  328. def is_block_cached(self, block: Block) -> bool:
  329. assert block.content_hash is not None
  330. if block.content_hash in self._cached_blocks:
  331. return True
  332. return False
  333. def promote_to_immutable_block(self, block: Block) -> BlockId:
  334. """Once a mutable block is full, it can be promoted to an immutable
  335. block. This means that its content can be referenced by future blocks
  336. having the same prefix.
  337. Note that if we already have a cached block with the same content, we
  338. will replace the newly-promoted block's mapping with the existing cached
  339. block id.
  340. Args:
  341. block: The mutable block to be promoted.
  342. Returns:
  343. BlockId: Either the original block index, or the block index of
  344. the previously cached block matching the same content.
  345. """
  346. # Ensure block can be promoted
  347. assert block.content_hash is not None
  348. assert block.block_id is not None
  349. assert self._refcounter.get(block.block_id) > 0
  350. if block.content_hash not in self._cached_blocks:
  351. # No cached content hash => Set this block as cached
  352. # (Note that this block is not computed yet =>
  353. # Will be computed after free())
  354. self._cached_blocks[block.content_hash] = block.block_id
  355. return block.block_id
  356. # Reuse the cached content hash
  357. self._decr_refcount_hashless_block(block)
  358. block.block_id = self._cached_blocks[block.content_hash]
  359. # Increment refcount of the cached block and (possibly) restore
  360. # it from the evictor.
  361. # Note that in this case, the block is marked as computed
  362. self._incr_refcount_cached_block(block)
  363. return block.block_id
  364. def cow_block_if_not_appendable(self, block: Block) -> BlockId:
  365. """Performs a copy-on-write operation on the given block if it is not
  366. appendable.
  367. Args:
  368. block (Block): The block to check for copy-on-write.
  369. Returns:
  370. BlockId: The block index of the new block if a copy-on-write
  371. operation was performed, or the original block index if
  372. no copy-on-write was necessary.
  373. """
  374. src_block_id = block.block_id
  375. assert src_block_id is not None
  376. if self._cow_tracker.is_appendable(block):
  377. return src_block_id
  378. self._free_block_id(block)
  379. trg_block_id = self._allocate_block_id()
  380. self._cow_tracker.record_cow(src_block_id, trg_block_id)
  381. return trg_block_id
  382. def clear_copy_on_writes(self) -> List[Tuple[BlockId, BlockId]]:
  383. """Returns the copy-on-write source->destination mapping and clears it.
  384. Returns:
  385. List[Tuple[BlockId, BlockId]]: A list mapping source
  386. block indices to destination block indices.
  387. """
  388. return self._cow_tracker.clear_cows()
  389. def mark_blocks_as_accessed(self, block_ids: List[int],
  390. now: float) -> None:
  391. """Mark blocks as accessed, used in prefix caching.
  392. If the block is added into evictor, we need to update corresponding
  393. info in evictor's metadata.
  394. """
  395. for block_id in block_ids:
  396. if self._block_tracker[block_id].active:
  397. self._block_tracker[block_id].last_accessed = now
  398. elif block_id in self.evictor:
  399. self.evictor.update(block_id, now)
  400. else:
  401. raise ValueError(
  402. "Mark block as accessed which is not belonged to GPU")
  403. def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
  404. raise NotImplementedError("Marking as computed is incremental")
  405. def _track_block_id(self, block_id: Optional[BlockId],
  406. computed: bool) -> None:
  407. assert block_id is not None
  408. self._block_tracker[block_id].enable()
  409. self._block_tracker[block_id].computed = computed
  410. def _untrack_block_id(self, block_id: Optional[BlockId]) -> None:
  411. assert block_id is not None
  412. self._block_tracker[block_id].disable()
  413. def block_is_computed(self, block_id: int) -> bool:
  414. if self._block_tracker[block_id].active:
  415. return self._block_tracker[block_id].computed
  416. else:
  417. return block_id in self.evictor
  418. def get_computed_block_ids(self,
  419. prev_computed_block_ids: List[int],
  420. block_ids: List[int],
  421. skip_last_block_id: bool = True) -> List[int]:
  422. prev_prefix_size = len(prev_computed_block_ids)
  423. cur_size = len(block_ids)
  424. if skip_last_block_id:
  425. cur_size -= 1
  426. # Sanity checks
  427. assert cur_size >= 0
  428. assert prev_prefix_size <= cur_size
  429. ret = prev_computed_block_ids
  430. for i in range(prev_prefix_size, cur_size):
  431. block_id = block_ids[i]
  432. if self.block_is_computed(block_id):
  433. ret.append(block_id)
  434. return ret
  435. def get_common_computed_block_ids(
  436. self, computed_seq_block_ids: List[List[int]]) -> List[int]:
  437. """Return the block ids that are common for a given sequence group.
  438. Only those blocks that are immutable and already be marked
  439. compyted would be taken consideration.
  440. """
  441. # NOTE We exclude the last block to avoid the case where the entire
  442. # prompt is cached. This would cause erroneous behavior in model
  443. # runner.
  444. # It returns a list of int although type annotation says list of string.
  445. if len(computed_seq_block_ids) == 1:
  446. return computed_seq_block_ids[0]
  447. return commonprefix([
  448. ids for ids in computed_seq_block_ids # type: ignore
  449. if ids
  450. ])
  451. def get_num_blocks_touched(self,
  452. blocks: List[Block],
  453. num_lookahead_slots: int = 0) -> int:
  454. """Determine the number of blocks that will be touched by
  455. swapping in/out the given blocks from certain sequence
  456. group with the provided num_lookahead_slots.
  457. Args:
  458. blocks (List[Block]): The potential blocks to swap.
  459. num_lookahead_slots (int): number of lookahead slots (0 for
  460. swap out).
  461. Returns:
  462. int: the number of blocks that will be touched by
  463. swapping in/out the given blocks and num_lookahead_slots.
  464. """
  465. num_touched_blocks = 0
  466. for block in blocks:
  467. if not block.is_full:
  468. num_touched_blocks += 1
  469. if num_lookahead_slots > block.num_empty_slots:
  470. num_touched_blocks += cdiv(
  471. num_lookahead_slots - block.num_empty_slots,
  472. self._block_size)
  473. else:
  474. # If the block has a match in the cache and the cached block
  475. # is not referenced, then we still count it as a touched block
  476. if not self.is_block_cached(block) or \
  477. (block.content_hash is not None and \
  478. self._cached_blocks[block.content_hash] in self.evictor):
  479. num_touched_blocks += 1
  480. return num_touched_blocks
  481. def swap_out(self, blocks: List[Block]) -> None:
  482. """Execute the swap out actions. Basically just free the
  483. given blocks.
  484. Args:
  485. blocks: List of blocks to be swapped out.
  486. """
  487. for block in blocks:
  488. self._free_block_id(block)
  489. def swap_in(self, blocks: List[Block]) -> None:
  490. """Execute the swap in actions. Change the block id from
  491. old allocator to current allocator for each block to finish
  492. the block table update.
  493. Args:
  494. blocks: List of blocks to be swapped in.
  495. """
  496. for block in blocks:
  497. # Here we allocate either immutable or mutable block and then
  498. # extract its block_id. Note that the block object is released
  499. # and the block_id is assigned to "block" to allow reusing the
  500. # existing "block" object
  501. if block.is_full:
  502. tmp_block = self.allocate_immutable_block(
  503. prev_block=block.prev_block, token_ids=block.token_ids)
  504. else:
  505. tmp_block = self.allocate_mutable_block(
  506. prev_block=block.prev_block)
  507. tmp_block.append_token_ids(block.token_ids)
  508. block_id = tmp_block.block_id
  509. self._block_pool.free_block(tmp_block)
  510. block.block_id = block_id # Assign block_id
  511. class PrefixCachingBlock(Block):
  512. """A block implementation that supports prefix caching.
  513. The PrefixCachingBlock class represents a block of token IDs with prefix
  514. caching capabilities. It wraps a NaiveBlock internally and provides
  515. additional functionality for content hashing and promoting immutable blocks
  516. with the prefix caching allocator.
  517. Args:
  518. prev_block (Optional[PrefixCachingBlock]): The previous block in the
  519. sequence.
  520. token_ids (List[int]): The initial token IDs to be stored in the block.
  521. block_size (int): The maximum number of token IDs that can be stored in
  522. the block.
  523. allocator (BlockAllocator): The prefix
  524. caching block allocator associated with this block.
  525. block_id (Optional[int], optional): The physical block index
  526. of this block. Defaults to None.
  527. """
  528. def __init__(
  529. self,
  530. prev_block: Optional[Block],
  531. token_ids: List[int],
  532. block_size: int,
  533. allocator: BlockAllocator,
  534. block_id: Optional[int] = None,
  535. computed: bool = False,
  536. ):
  537. assert isinstance(allocator, PrefixCachingBlockAllocator), (
  538. "Currently this class is only tested with "
  539. "PrefixCachingBlockAllocator. Got instead allocator = {}".format(
  540. allocator))
  541. assert_prefix_caching_block_or_none(prev_block)
  542. self._prev_block = prev_block
  543. self._cached_content_hash: Optional[int] = None
  544. self._cached_num_tokens_total: int = 0
  545. self._allocator = allocator
  546. self._last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME
  547. self._computed = computed
  548. # On the first time, we create the block object, and next we only
  549. # reinitialize it
  550. if hasattr(self, "_block"):
  551. self._block.__init__( # type: ignore[has-type]
  552. prev_block=prev_block,
  553. token_ids=token_ids,
  554. block_size=block_size,
  555. block_id=block_id,
  556. allocator=self._allocator)
  557. else:
  558. self._block = NaiveBlock(prev_block=prev_block,
  559. token_ids=token_ids,
  560. block_size=block_size,
  561. block_id=block_id,
  562. allocator=self._allocator)
  563. self._update_num_tokens_total()
  564. def _update_num_tokens_total(self):
  565. """Incrementally computes the number of tokens that there is
  566. till the current block (included)
  567. """
  568. res = 0
  569. # Add all previous blocks
  570. if self._prev_block is not None:
  571. res += self._prev_block.num_tokens_total
  572. # Add current block
  573. res += len(self.token_ids)
  574. self._cached_num_tokens_total = res
  575. @property
  576. def computed(self) -> bool:
  577. return self._computed
  578. @computed.setter
  579. def computed(self, value) -> None:
  580. self._computed = value
  581. @property
  582. def last_accessed(self) -> float:
  583. return self._last_accessed
  584. @last_accessed.setter
  585. def last_accessed(self, last_accessed_ts: float):
  586. self._last_accessed = last_accessed_ts
  587. def append_token_ids(self, token_ids: List[int]) -> None:
  588. """Appends the given token IDs to the block and registers the block as
  589. immutable if the block becomes full.
  590. Args:
  591. token_ids (List[int]): The token IDs to be appended to the block.
  592. """
  593. # Ensure this is mutable block (not promoted)
  594. assert self.content_hash is None
  595. assert not self.computed
  596. if len(token_ids) == 0:
  597. return
  598. # Ensure there are input tokens
  599. assert token_ids, "Got token_ids = {}".format(token_ids)
  600. # Naive block handles CoW.
  601. self._block.append_token_ids(token_ids)
  602. self._update_num_tokens_total()
  603. # If the content hash is present, then the block can be made immutable.
  604. # Register ourselves with the allocator, potentially replacing the
  605. # physical block index.
  606. if self.content_hash is not None:
  607. self.block_id = self._allocator.promote_to_immutable_block(self)
  608. @property
  609. def block_id(self) -> Optional[int]:
  610. return self._block.block_id
  611. @block_id.setter
  612. def block_id(self, value) -> None:
  613. self._block.block_id = value
  614. @property
  615. def is_full(self) -> bool:
  616. return self._block.is_full
  617. @property
  618. def num_empty_slots(self) -> int:
  619. return self._block.num_empty_slots
  620. @property
  621. def num_tokens_total(self) -> int:
  622. return self._cached_num_tokens_total
  623. @property
  624. def block_size(self) -> int:
  625. return self._block.block_size
  626. @property
  627. def token_ids(self) -> List[int]:
  628. return self._block.token_ids
  629. @property
  630. def prev_block(self) -> Optional[Block]:
  631. return self._prev_block
  632. @property
  633. def content_hash(self) -> Optional[int]:
  634. """Return the content-based hash of the current block, or None if it is
  635. not yet defined.
  636. For the content-based hash to be defined, the current block must be
  637. full.
  638. """
  639. # If the hash is already computed, return it.
  640. if self._cached_content_hash is not None:
  641. return self._cached_content_hash
  642. # We cannot compute a hash for the current block because it is not full.
  643. if not self.is_full:
  644. return None
  645. is_first_block = self._prev_block is None
  646. prev_block_hash = (
  647. None if is_first_block else
  648. self._prev_block.content_hash # type: ignore
  649. )
  650. # Previous block exists but does not yet have a hash.
  651. # Return no hash in this case.
  652. if prev_block_hash is None and not is_first_block:
  653. return None
  654. self._cached_content_hash = PrefixCachingBlock.hash_block_tokens(
  655. is_first_block,
  656. prev_block_hash,
  657. cur_block_token_ids=self.token_ids)
  658. return self._cached_content_hash
  659. @staticmethod
  660. def hash_block_tokens(is_first_block: bool, prev_block_hash: Optional[int],
  661. cur_block_token_ids: List[int]) -> int:
  662. """Computes a hash value corresponding to the contents of a block and
  663. the contents of the preceding block(s). The hash value is used for
  664. prefix caching.
  665. NOTE: Content-based hashing does not yet support LoRA.
  666. Parameters:
  667. - is_first_block (bool): A flag indicating if the block is the first in
  668. the sequence.
  669. - prev_block_hash (Optional[int]): The hash of the previous block. None
  670. if this is the first block.
  671. - cur_block_token_ids (List[int]): A list of token ids in the current
  672. block. The current block is assumed to be full.
  673. Returns:
  674. - int: The computed hash value for the block.
  675. """
  676. assert (prev_block_hash is None) == is_first_block
  677. return hash((is_first_block, prev_block_hash, *cur_block_token_ids))
  678. class ComputedBlocksTracker:
  679. """Handles caching of per-sequence computed block ids.
  680. When a sequence appears for the first time, it traverses all of the
  681. blocks and detects the prefix of blocks that is computed. On the
  682. subsequent times, it only traverses the new blocks that were added
  683. and updates the already recorded prefix of blocks with the newly
  684. computed blocks.
  685. To avoid redundant traversals, the algorithm also detects when there
  686. is a "gap" in the computed prefix. For example, if we have blocks =
  687. [1,2,3,4,5], and we have detected [1,2,3] as the computed prefix, then
  688. we won't try to add more computed blocks to [1,2,3] in this sequence
  689. iteration, and will add more computed blocks only after the sequence is
  690. freed and reused again.
  691. Note that currently, for a given sequence, we also skip the last
  692. block id for caching purposes, to avoid caching of a full sequence
  693. """
  694. def __init__(self, allocator):
  695. self._allocator = allocator
  696. self._cached_computed_seq_blocks: Dict[int, Tuple[List[int],
  697. bool]] = {}
  698. def add_seq(self, seq_id: int) -> None:
  699. """Start tracking seq_id
  700. """
  701. assert seq_id not in self._cached_computed_seq_blocks
  702. self._cached_computed_seq_blocks[seq_id] = ([], False)
  703. def remove_seq(self, seq_id: int) -> None:
  704. """Stop tracking seq_id
  705. """
  706. assert seq_id in self._cached_computed_seq_blocks
  707. del self._cached_computed_seq_blocks[seq_id]
  708. def get_cached_computed_blocks_and_update(
  709. self, seq_id: int, block_ids: List[int]) -> List[int]:
  710. """ Look at the class documentation for details
  711. """
  712. # Ensure seq_id is already tracked
  713. assert seq_id in self._cached_computed_seq_blocks
  714. # Get cached data (may be empty on the first time)
  715. prev_computed_block_ids, has_gap = self._cached_computed_seq_blocks[
  716. seq_id]
  717. if has_gap:
  718. # When gap is detected, we do not add more computed blocks at this
  719. # sequence iteration
  720. return prev_computed_block_ids
  721. # We do not consider the last block id for caching purposes.
  722. num_cur_blocks = len(block_ids) - 1
  723. assert num_cur_blocks >= 0
  724. if len(prev_computed_block_ids) >= num_cur_blocks:
  725. # Cache HIT
  726. assert len(prev_computed_block_ids) == num_cur_blocks
  727. return prev_computed_block_ids
  728. # If here, then we may possibly add more computed blocks. As a result,
  729. # traverse the additional blocks after prev_computed_block_ids to
  730. # detect more computed blocks and add them.
  731. # Incremental init for seq_id => Look only at the new blocks
  732. computed_block_ids = self._allocator.get_computed_block_ids( # noqa: E501
  733. prev_computed_block_ids,
  734. block_ids,
  735. skip_last_block_id=
  736. True, # We skip last block id to avoid caching of full seq
  737. )
  738. # Detect if there is a "gap"
  739. has_gap = len(computed_block_ids) < num_cur_blocks
  740. # Record
  741. self._cached_computed_seq_blocks[seq_id] = (computed_block_ids,
  742. has_gap)
  743. return computed_block_ids
  744. class LastAccessBlocksTracker:
  745. """Manages the last access time of the tracked sequences, in order to allow
  746. an efficient update of allocator's block last access times
  747. """
  748. def __init__(self, allocator):
  749. self._allocator = allocator
  750. self._seq_last_access: Dict[int, Optional[float]] = {}
  751. def add_seq(self, seq_id: int) -> None:
  752. """Start tracking seq_id
  753. """
  754. assert seq_id not in self._seq_last_access
  755. self._seq_last_access[seq_id] = None
  756. def remove_seq(self, seq_id: int) -> None:
  757. """Stop tracking seq_id
  758. """
  759. assert seq_id in self._seq_last_access
  760. del self._seq_last_access[seq_id]
  761. def update_last_access(self, seq_id: int, time: float) -> None:
  762. assert seq_id in self._seq_last_access
  763. self._seq_last_access[seq_id] = time
  764. def update_seq_blocks_last_access(self, seq_id: int,
  765. block_ids: List[int]) -> None:
  766. assert seq_id in self._seq_last_access
  767. ts = self._seq_last_access[seq_id]
  768. if ts is None:
  769. # No last access was recorded, no need to update.
  770. return
  771. self._allocator.mark_blocks_as_accessed(block_ids, ts)
  772. def assert_prefix_caching_block_or_none(block: Optional[Block]):
  773. if block is None:
  774. return
  775. assert isinstance(block,
  776. PrefixCachingBlock), "Got block = {}".format(block)