prefix_caching_block.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975
  1. """Token blocks."""
  2. from os.path import commonprefix
  3. from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple
  4. from aphrodite.common.utils import cdiv
  5. from aphrodite.processing.block.common import (CacheMetricData,
  6. CopyOnWriteTracker,
  7. get_all_blocks_recursively)
  8. from aphrodite.processing.block.interfaces import (Block, BlockAllocator,
  9. BlockId, Device)
  10. from aphrodite.processing.block.naive_block import (BlockPool, NaiveBlock,
  11. NaiveBlockAllocator)
  12. from aphrodite.processing.evictor_v2 import (EvictionPolicy, Evictor,
  13. make_evictor)
  14. PrefixHash = int
  15. # By default, we init our block access time as _DEFAULT_LAST_ACCESSED_TIME
  16. # so that if we find one block is still hold _DEFAULT_LAST_ACCESSED_TIME,
  17. # then we know this block hasn't been accessed yet.
  18. _DEFAULT_LAST_ACCESSED_TIME = -1
  19. class BlockTracker:
  20. """Used to track the status of a block inside the prefix caching allocator
  21. """
  22. __slots__ = ("active", "last_accessed", "computed")
  23. def reset(self):
  24. self.last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME
  25. self.computed: bool = False
  26. def __init__(self):
  27. self.active: bool = False
  28. self.reset()
  29. def enable(self):
  30. assert not self.active
  31. self.active = True
  32. self.reset()
  33. def disable(self):
  34. assert self.active
  35. self.active = False
  36. self.reset()
  37. class PrefixCachingBlockAllocator(BlockAllocator):
  38. """A block allocator that implements prefix caching.
  39. The PrefixCachingBlockAllocator maintains a cache of blocks based on their
  40. content hash. It reuses blocks with the same content hash to avoid redundant
  41. memory allocation. The allocator also supports copy-on-write operations.
  42. Args:
  43. num_blocks (int): The total number of blocks to manage.
  44. block_size (int): The size of each block in tokens.
  45. block_ids(Optional[Iterable[int]], optional): An optional iterable of
  46. block IDs. If not provided, block IDs will be assigned sequentially
  47. from 0 to num_blocks - 1.
  48. """
  49. def __init__(
  50. self,
  51. num_blocks: int,
  52. block_size: int,
  53. block_ids: Optional[Iterable[int]] = None,
  54. eviction_policy: EvictionPolicy = EvictionPolicy.LRU,
  55. ):
  56. if block_ids is None:
  57. block_ids = range(num_blocks)
  58. self._block_size = block_size
  59. # A mapping of prefix hash to block index. All blocks which have a
  60. # prefix hash will be in this dict, even if they have refcount 0.
  61. self._cached_blocks: Dict[PrefixHash, BlockId] = {}
  62. # Used to track status of each physical block id
  63. self._block_tracker: Dict[BlockId, BlockTracker] = {}
  64. for block_id in block_ids:
  65. self._block_tracker[block_id] = BlockTracker()
  66. # Pre-allocate "num_blocks * extra_factor" block objects.
  67. # The "* extra_factor" is a buffer to allow more block objects
  68. # than physical blocks
  69. extra_factor = 4
  70. self._block_pool = BlockPool(self._block_size, self._create_block,
  71. self, num_blocks * extra_factor)
  72. # An allocator for blocks that do not have prefix hashes.
  73. self._hashless_allocator = NaiveBlockAllocator(
  74. create_block=self._create_block, # type: ignore
  75. num_blocks=num_blocks,
  76. block_size=block_size,
  77. block_ids=block_ids,
  78. block_pool=self._block_pool, # Share block pool here
  79. )
  80. # Evitor used to maintain how we want to handle those computed blocks
  81. # if we find memory pressure is high.
  82. self.evictor: Evictor = make_evictor(eviction_policy)
  83. # We share the refcounter between allocators. This allows us to promote
  84. # blocks originally allocated in the hashless allocator to immutable
  85. # blocks.
  86. self._refcounter = self._hashless_allocator.refcounter
  87. self._cow_tracker = CopyOnWriteTracker(
  88. refcounter=self._refcounter.as_readonly())
  89. self.metric_data = CacheMetricData()
  90. # Implements Block.Factory.
  91. def _create_block(
  92. self,
  93. prev_block: Optional[Block],
  94. token_ids: List[int],
  95. block_size: int,
  96. allocator: BlockAllocator,
  97. block_id: Optional[int] = None,
  98. computed: bool = False,
  99. ) -> Block:
  100. # Bind block to self.
  101. allocator = self
  102. return PrefixCachingBlock(
  103. prev_block=prev_block,
  104. token_ids=token_ids,
  105. block_size=block_size,
  106. block_id=block_id,
  107. allocator=allocator,
  108. computed=computed,
  109. )
  110. def allocate_immutable_block(self,
  111. prev_block: Optional[Block],
  112. token_ids: List[int],
  113. device: Optional[Device] = None) -> Block:
  114. """Allocates an immutable block with the given token IDs, reusing cached
  115. blocks if possible.
  116. Args:
  117. prev_block (Optional[Block]): The previous block in the sequence.
  118. token_ids (List[int]): The token IDs to be stored in the block.
  119. Returns:
  120. Block: The allocated immutable block.
  121. """
  122. assert device is None
  123. assert_prefix_caching_block_or_none(prev_block)
  124. # First, try to create a block that points to cached data
  125. block = self._block_pool.init_block(prev_block=prev_block,
  126. token_ids=token_ids,
  127. block_size=self._block_size,
  128. physical_block_id=None)
  129. assert block.content_hash is not None
  130. cached_block_id = self._cached_blocks.get(block.content_hash, None)
  131. if cached_block_id is not None:
  132. self.metric_data.query(hit=True)
  133. block.block_id = cached_block_id
  134. self._incr_refcount_cached_block(block)
  135. return block
  136. self.metric_data.query(hit=False)
  137. self._block_pool.free_block(block)
  138. # No cached block => Allocate a new block
  139. block = self.allocate_mutable_block(prev_block)
  140. block.append_token_ids(token_ids)
  141. return block
  142. def allocate_immutable_blocks(
  143. self,
  144. prev_block: Optional[Block],
  145. block_token_ids: List[List[int]],
  146. device: Optional[Device] = None) -> List[Block]:
  147. blocks = []
  148. for token_ids in block_token_ids:
  149. prev_block = self.allocate_immutable_block(prev_block=prev_block,
  150. token_ids=token_ids,
  151. device=device)
  152. blocks.append(prev_block)
  153. return blocks
  154. def allocate_mutable_block(self,
  155. prev_block: Optional[Block],
  156. device: Optional[Device] = None) -> Block:
  157. """Allocates a mutable block. If there are no free blocks, this will
  158. evict unused cached blocks.
  159. Args:
  160. prev_block (Block): The previous block in the sequence.
  161. None is not allowed unlike it is super class.
  162. Returns:
  163. Block: The allocated mutable block.
  164. """
  165. assert device is None
  166. assert_prefix_caching_block_or_none(prev_block)
  167. block_id = self._allocate_block_id()
  168. block = self._block_pool.init_block(prev_block=prev_block,
  169. token_ids=[],
  170. block_size=self._block_size,
  171. physical_block_id=block_id)
  172. assert not block.computed
  173. assert block.content_hash is None
  174. return block
  175. def _incr_refcount_cached_block(self, block: Block) -> None:
  176. # Set this block to be "computed" since it is pointing to a
  177. # cached block id (which was already computed)
  178. block.computed = True
  179. block_id = block.block_id
  180. assert block_id is not None
  181. refcount = self._refcounter.incr(block_id)
  182. if refcount == 1:
  183. # In case a cached block was evicted, restore its tracking
  184. if block_id in self.evictor:
  185. self.evictor.remove(block_id)
  186. self._track_block_id(block_id, computed=True)
  187. def _decr_refcount_cached_block(self, block: Block) -> None:
  188. # Ensure this is immutable/cached block
  189. assert block.content_hash is not None
  190. block_id = block.block_id
  191. assert block_id is not None
  192. refcount = self._refcounter.decr(block_id)
  193. if refcount > 0:
  194. block.block_id = None
  195. return
  196. else:
  197. assert refcount == 0
  198. # No longer used
  199. assert block.content_hash in self._cached_blocks
  200. # Add the cached block to the evictor
  201. # (This keeps the cached block around so it can be reused)
  202. self.evictor.add(block_id, block.content_hash, block.num_tokens_total,
  203. self._block_tracker[block_id].last_accessed)
  204. # Stop tracking the block
  205. self._untrack_block_id(block_id)
  206. block.block_id = None
  207. def _decr_refcount_hashless_block(self, block: Block) -> None:
  208. block_id = block.block_id
  209. assert block_id is not None
  210. # We may have a fork case where block is shared,
  211. # in which case, we cannot remove it from tracking
  212. refcount = self._refcounter.get(block_id)
  213. if refcount == 1:
  214. self._untrack_block_id(block_id)
  215. # Decrement refcount of the block_id, but do not free the block object
  216. # itself (will be handled by the caller)
  217. self._hashless_allocator.free(block, keep_block_object=True)
  218. def _allocate_block_id(self) -> BlockId:
  219. """First tries to allocate a block id from the hashless allocator,
  220. and if there are no blocks, then tries to evict an unused cached block.
  221. """
  222. hashless_block_id = self._maybe_allocate_hashless_block_id()
  223. if hashless_block_id is not None:
  224. return hashless_block_id
  225. evicted_block_id = self._maybe_allocate_evicted_block_id()
  226. if evicted_block_id is not None:
  227. return evicted_block_id
  228. # No block available in hashless allocator, nor in unused cache blocks.
  229. raise BlockAllocator.NoFreeBlocksError()
  230. def _maybe_allocate_hashless_block_id(self) -> Optional[BlockId]:
  231. try:
  232. # Allocate mutable block and extract its block_id
  233. block = self._hashless_allocator.allocate_mutable_block(
  234. prev_block=None)
  235. block_id = block.block_id
  236. self._block_pool.free_block(block)
  237. self._track_block_id(block_id, computed=False)
  238. return block_id
  239. except BlockAllocator.NoFreeBlocksError:
  240. return None
  241. def _maybe_allocate_evicted_block_id(self) -> Optional[BlockId]:
  242. if self.evictor.num_blocks == 0:
  243. return None
  244. # Here we get an evicted block, which is only added
  245. # into evictor if its ref counter is 0
  246. # and since its content would be changed, we need
  247. # to remove it from _cached_blocks's tracking list
  248. block_id, content_hash_to_evict = self.evictor.evict()
  249. # Sanity checks
  250. assert content_hash_to_evict in self._cached_blocks
  251. _block_id = self._cached_blocks[content_hash_to_evict]
  252. assert self._refcounter.get(_block_id) == 0
  253. assert _block_id == block_id
  254. self._cached_blocks.pop(content_hash_to_evict)
  255. self._refcounter.incr(block_id)
  256. self._track_block_id(block_id, computed=False)
  257. return block_id
  258. def _free_block_id(self, block: Block) -> None:
  259. """Decrements the refcount of the block. The block may be in two
  260. possible states: (1) immutable/cached or (2) mutable/hashless.
  261. In the first case, the refcount is decremented directly and the block
  262. may be possibly added to the evictor. In other case, hashless
  263. allocator free(..) with keep_block_object=True is called to only free
  264. the block id (since the block object may be reused by the caller)
  265. """
  266. block_id = block.block_id
  267. assert block_id is not None, "Freeing unallocated block is undefined"
  268. if block.content_hash is not None:
  269. # Immutable: This type of block is always cached, and we want to
  270. # keep it in the evictor for future reuse
  271. self._decr_refcount_cached_block(block)
  272. else:
  273. # Mutable: This type of block is not cached, so we release it
  274. # directly to the hashless allocator
  275. self._decr_refcount_hashless_block(block)
  276. assert block.block_id is None
  277. def free(self, block: Block, keep_block_object: bool = False) -> None:
  278. """Release the block (look at free_block_id(..) docs)
  279. """
  280. # Release the physical block index
  281. self._free_block_id(block)
  282. # Release the block object to the pool
  283. if not keep_block_object:
  284. self._block_pool.free_block(block)
  285. def fork(self, last_block: Block) -> List[Block]:
  286. """Creates a new sequence of blocks that shares the same underlying
  287. memory as the original sequence.
  288. Args:
  289. last_block (Block): The last block in the original sequence.
  290. Returns:
  291. List[Block]: The new sequence of blocks that shares the same memory
  292. as the original sequence.
  293. """
  294. source_blocks = get_all_blocks_recursively(last_block)
  295. forked_blocks: List[Block] = []
  296. prev_block = None
  297. for block in source_blocks:
  298. block_id = block.block_id
  299. assert block_id is not None
  300. refcount = self._refcounter.incr(block_id)
  301. assert refcount != 1, "can't fork free'd block_id = {}".format(
  302. block_id)
  303. forked_block = self._block_pool.init_block(
  304. prev_block=prev_block,
  305. token_ids=block.token_ids,
  306. block_size=self._block_size,
  307. physical_block_id=block_id)
  308. forked_blocks.append(forked_block)
  309. prev_block = forked_blocks[-1]
  310. return forked_blocks
  311. def get_num_free_blocks(self, device: Optional[Device] = None) -> int:
  312. assert device is None
  313. # The number of free blocks is the number of hashless free blocks
  314. # plus the number of blocks evictor could free from its list.
  315. return self._hashless_allocator.get_num_free_blocks(
  316. ) + self.evictor.num_blocks
  317. def get_num_total_blocks(self) -> int:
  318. return self._hashless_allocator.get_num_total_blocks()
  319. def get_physical_block_id(self, absolute_id: int) -> int:
  320. """Returns the zero-offset block id on certain block allocator
  321. given the absolute block id.
  322. Args:
  323. absolute_id (int): The absolute block id for the block
  324. in whole allocator.
  325. Returns:
  326. int: The rzero-offset block id on certain device.
  327. """
  328. return sorted(self.all_block_ids).index(absolute_id)
  329. @property
  330. def all_block_ids(self) -> FrozenSet[int]:
  331. return self._hashless_allocator.all_block_ids
  332. def get_prefix_cache_hit_rate(self) -> float:
  333. return self.metric_data.get_hit_rate()
  334. def is_block_cached(self, block: Block) -> bool:
  335. assert block.content_hash is not None
  336. if block.content_hash in self._cached_blocks:
  337. return True
  338. return False
  339. def promote_to_immutable_block(self, block: Block) -> BlockId:
  340. """Once a mutable block is full, it can be promoted to an immutable
  341. block. This means that its content can be referenced by future blocks
  342. having the same prefix.
  343. Note that if we already have a cached block with the same content, we
  344. will replace the newly-promoted block's mapping with the existing cached
  345. block id.
  346. Args:
  347. block: The mutable block to be promoted.
  348. Returns:
  349. BlockId: Either the original block index, or the block index of
  350. the previously cached block matching the same content.
  351. """
  352. # Ensure block can be promoted
  353. assert block.content_hash is not None
  354. assert block.block_id is not None
  355. assert self._refcounter.get(block.block_id) > 0
  356. if block.content_hash not in self._cached_blocks:
  357. # No cached content hash => Set this block as cached
  358. # (Note that this block is not computed yet =>
  359. # Will be computed after free())
  360. self._cached_blocks[block.content_hash] = block.block_id
  361. return block.block_id
  362. # Reuse the cached content hash
  363. self._decr_refcount_hashless_block(block)
  364. block.block_id = self._cached_blocks[block.content_hash]
  365. # Increment refcount of the cached block and (possibly) restore
  366. # it from the evictor.
  367. # Note that in this case, the block is marked as computed
  368. self._incr_refcount_cached_block(block)
  369. return block.block_id
  370. def cow_block_if_not_appendable(self, block: Block) -> BlockId:
  371. """Performs a copy-on-write operation on the given block if it is not
  372. appendable.
  373. Args:
  374. block (Block): The block to check for copy-on-write.
  375. Returns:
  376. BlockId: The block index of the new block if a copy-on-write
  377. operation was performed, or the original block index if
  378. no copy-on-write was necessary.
  379. """
  380. src_block_id = block.block_id
  381. assert src_block_id is not None
  382. if self._cow_tracker.is_appendable(block):
  383. return src_block_id
  384. self._free_block_id(block)
  385. trg_block_id = self._allocate_block_id()
  386. self._cow_tracker.record_cow(src_block_id, trg_block_id)
  387. return trg_block_id
  388. def clear_copy_on_writes(self) -> List[Tuple[BlockId, BlockId]]:
  389. """Returns the copy-on-write source->destination mapping and clears it.
  390. Returns:
  391. List[Tuple[BlockId, BlockId]]: A list mapping source
  392. block indices to destination block indices.
  393. """
  394. return self._cow_tracker.clear_cows()
  395. def mark_blocks_as_accessed(self, block_ids: List[int],
  396. now: float) -> None:
  397. """Mark blocks as accessed, used in prefix caching.
  398. If the block is added into evictor, we need to update corresponding
  399. info in evictor's metadata.
  400. """
  401. for block_id in block_ids:
  402. if self._block_tracker[block_id].active:
  403. self._block_tracker[block_id].last_accessed = now
  404. elif block_id in self.evictor:
  405. self.evictor.update(block_id, now)
  406. else:
  407. raise ValueError(
  408. "Mark block as accessed which is not belonged to GPU")
  409. def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
  410. raise NotImplementedError("Marking as computed is incremental")
  411. def _track_block_id(self, block_id: Optional[BlockId],
  412. computed: bool) -> None:
  413. assert block_id is not None
  414. self._block_tracker[block_id].enable()
  415. self._block_tracker[block_id].computed = computed
  416. def _untrack_block_id(self, block_id: Optional[BlockId]) -> None:
  417. assert block_id is not None
  418. self._block_tracker[block_id].disable()
  419. def block_is_computed(self, block_id: int) -> bool:
  420. if self._block_tracker[block_id].active:
  421. return self._block_tracker[block_id].computed
  422. else:
  423. return block_id in self.evictor
  424. def get_computed_block_ids(self,
  425. prev_computed_block_ids: List[int],
  426. block_ids: List[int],
  427. skip_last_block_id: bool = True) -> List[int]:
  428. prev_prefix_size = len(prev_computed_block_ids)
  429. cur_size = len(block_ids)
  430. if skip_last_block_id:
  431. cur_size -= 1
  432. # Sanity checks
  433. assert cur_size >= 0
  434. assert prev_prefix_size <= cur_size
  435. ret = prev_computed_block_ids
  436. for i in range(prev_prefix_size, cur_size):
  437. block_id = block_ids[i]
  438. if self.block_is_computed(block_id):
  439. ret.append(block_id)
  440. return ret
  441. def get_common_computed_block_ids(
  442. self, computed_seq_block_ids: List[List[int]]) -> List[int]:
  443. """Return the block ids that are common for a given sequence group.
  444. Only those blocks that are immutable and already be marked
  445. compyted would be taken consideration.
  446. """
  447. # NOTE We exclude the last block to avoid the case where the entire
  448. # prompt is cached. This would cause erroneous behavior in model
  449. # runner.
  450. # It returns a list of int although type annotation says list of string.
  451. if len(computed_seq_block_ids) == 1:
  452. return computed_seq_block_ids[0]
  453. return commonprefix([
  454. ids for ids in computed_seq_block_ids # type: ignore
  455. if ids
  456. ])
  457. def get_num_blocks_touched(self,
  458. blocks: List[Block],
  459. num_lookahead_slots: int = 0) -> int:
  460. """Determine the number of blocks that will be touched by
  461. swapping in/out the given blocks from certain sequence
  462. group with the provided num_lookahead_slots.
  463. Args:
  464. blocks (List[Block]): The potential blocks to swap.
  465. num_lookahead_slots (int): number of lookahead slots (0 for
  466. swap out).
  467. Returns:
  468. int: the number of blocks that will be touched by
  469. swapping in/out the given blocks and num_lookahead_slots.
  470. """
  471. num_touched_blocks = 0
  472. for block in blocks:
  473. if not block.is_full:
  474. num_touched_blocks += 1
  475. if num_lookahead_slots > block.num_empty_slots:
  476. num_touched_blocks += cdiv(
  477. num_lookahead_slots - block.num_empty_slots,
  478. self._block_size)
  479. else:
  480. # If the block has a match in the cache and the cached block
  481. # is not referenced, then we still count it as a touched block
  482. if not self.is_block_cached(block) or \
  483. (block.content_hash is not None and \
  484. self._cached_blocks[block.content_hash] in self.evictor):
  485. num_touched_blocks += 1
  486. return num_touched_blocks
  487. def swap_out(self, blocks: List[Block]) -> None:
  488. """Execute the swap out actions. Basically just free the
  489. given blocks.
  490. Args:
  491. blocks: List of blocks to be swapped out.
  492. """
  493. for block in blocks:
  494. self._free_block_id(block)
  495. def swap_in(self, blocks: List[Block]) -> None:
  496. """Execute the swap in actions. Change the block id from
  497. old allocator to current allocator for each block to finish
  498. the block table update.
  499. Args:
  500. blocks: List of blocks to be swapped in.
  501. """
  502. for block in blocks:
  503. # Here we allocate either immutable or mutable block and then
  504. # extract its block_id. Note that the block object is released
  505. # and the block_id is assigned to "block" to allow reusing the
  506. # existing "block" object
  507. if block.is_full:
  508. tmp_block = self.allocate_immutable_block(
  509. prev_block=block.prev_block, token_ids=block.token_ids)
  510. else:
  511. tmp_block = self.allocate_mutable_block(
  512. prev_block=block.prev_block)
  513. tmp_block.append_token_ids(block.token_ids)
  514. block_id = tmp_block.block_id
  515. self._block_pool.free_block(tmp_block)
  516. block.block_id = block_id # Assign block_id
  517. class PrefixCachingBlock(Block):
  518. """A block implementation that supports prefix caching.
  519. The PrefixCachingBlock class represents a block of token IDs with prefix
  520. caching capabilities. It wraps a NaiveBlock internally and provides
  521. additional functionality for content hashing and promoting immutable blocks
  522. with the prefix caching allocator.
  523. Args:
  524. prev_block (Optional[PrefixCachingBlock]): The previous block in the
  525. sequence.
  526. token_ids (List[int]): The initial token IDs to be stored in the block.
  527. block_size (int): The maximum number of token IDs that can be stored in
  528. the block.
  529. allocator (BlockAllocator): The prefix
  530. caching block allocator associated with this block.
  531. block_id (Optional[int], optional): The physical block index
  532. of this block. Defaults to None.
  533. """
  534. def __init__(
  535. self,
  536. prev_block: Optional[Block],
  537. token_ids: List[int],
  538. block_size: int,
  539. allocator: BlockAllocator,
  540. block_id: Optional[int] = None,
  541. computed: bool = False,
  542. ):
  543. assert isinstance(allocator, PrefixCachingBlockAllocator), (
  544. "Currently this class is only tested with "
  545. "PrefixCachingBlockAllocator. Got instead allocator = {}".format(
  546. allocator))
  547. assert_prefix_caching_block_or_none(prev_block)
  548. self._prev_block = prev_block
  549. self._cached_content_hash: Optional[int] = None
  550. self._cached_num_tokens_total: int = 0
  551. self._allocator = allocator
  552. self._last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME
  553. self._computed = computed
  554. # On the first time, we create the block object, and next we only
  555. # reinitialize it
  556. if hasattr(self, "_block"):
  557. self._block.__init__( # type: ignore[has-type]
  558. prev_block=prev_block,
  559. token_ids=token_ids,
  560. block_size=block_size,
  561. block_id=block_id,
  562. allocator=self._allocator)
  563. else:
  564. self._block = NaiveBlock(prev_block=prev_block,
  565. token_ids=token_ids,
  566. block_size=block_size,
  567. block_id=block_id,
  568. allocator=self._allocator)
  569. self._update_num_tokens_total()
  570. def _update_num_tokens_total(self):
  571. """Incrementally computes the number of tokens that there is
  572. till the current block (included)
  573. """
  574. res = 0
  575. # Add all previous blocks
  576. if self._prev_block is not None:
  577. res += self._prev_block.num_tokens_total
  578. # Add current block
  579. res += len(self.token_ids)
  580. self._cached_num_tokens_total = res
  581. @property
  582. def computed(self) -> bool:
  583. return self._computed
  584. @computed.setter
  585. def computed(self, value) -> None:
  586. self._computed = value
  587. @property
  588. def last_accessed(self) -> float:
  589. return self._last_accessed
  590. @last_accessed.setter
  591. def last_accessed(self, last_accessed_ts: float):
  592. self._last_accessed = last_accessed_ts
  593. def append_token_ids(self, token_ids: List[int]) -> None:
  594. """Appends the given token IDs to the block and registers the block as
  595. immutable if the block becomes full.
  596. Args:
  597. token_ids (List[int]): The token IDs to be appended to the block.
  598. """
  599. # Ensure this is mutable block (not promoted)
  600. assert self.content_hash is None
  601. assert not self.computed
  602. if len(token_ids) == 0:
  603. return
  604. # Ensure there are input tokens
  605. assert token_ids, "Got token_ids = {}".format(token_ids)
  606. # Naive block handles CoW.
  607. self._block.append_token_ids(token_ids)
  608. self._update_num_tokens_total()
  609. # If the content hash is present, then the block can be made immutable.
  610. # Register ourselves with the allocator, potentially replacing the
  611. # physical block index.
  612. if self.content_hash is not None:
  613. self.block_id = self._allocator.promote_to_immutable_block(self)
  614. @property
  615. def block_id(self) -> Optional[int]:
  616. return self._block.block_id
  617. @block_id.setter
  618. def block_id(self, value) -> None:
  619. self._block.block_id = value
  620. @property
  621. def is_full(self) -> bool:
  622. return self._block.is_full
  623. @property
  624. def num_empty_slots(self) -> int:
  625. return self._block.num_empty_slots
  626. @property
  627. def num_tokens_total(self) -> int:
  628. return self._cached_num_tokens_total
  629. @property
  630. def block_size(self) -> int:
  631. return self._block.block_size
  632. @property
  633. def token_ids(self) -> List[int]:
  634. return self._block.token_ids
  635. @property
  636. def prev_block(self) -> Optional[Block]:
  637. return self._prev_block
  638. @property
  639. def content_hash(self) -> Optional[int]:
  640. """Return the content-based hash of the current block, or None if it is
  641. not yet defined.
  642. For the content-based hash to be defined, the current block must be
  643. full.
  644. """
  645. # If the hash is already computed, return it.
  646. if self._cached_content_hash is not None:
  647. return self._cached_content_hash
  648. # We cannot compute a hash for the current block because it is not full.
  649. if not self.is_full:
  650. return None
  651. is_first_block = self._prev_block is None
  652. prev_block_hash = (
  653. None if is_first_block else
  654. self._prev_block.content_hash # type: ignore
  655. )
  656. # Previous block exists but does not yet have a hash.
  657. # Return no hash in this case.
  658. if prev_block_hash is None and not is_first_block:
  659. return None
  660. self._cached_content_hash = PrefixCachingBlock.hash_block_tokens(
  661. is_first_block,
  662. prev_block_hash,
  663. cur_block_token_ids=self.token_ids)
  664. return self._cached_content_hash
  665. @staticmethod
  666. def hash_block_tokens(is_first_block: bool, prev_block_hash: Optional[int],
  667. cur_block_token_ids: List[int]) -> int:
  668. """Computes a hash value corresponding to the contents of a block and
  669. the contents of the preceding block(s). The hash value is used for
  670. prefix caching.
  671. NOTE: Content-based hashing does not yet support LoRA.
  672. Parameters:
  673. - is_first_block (bool): A flag indicating if the block is the first in
  674. the sequence.
  675. - prev_block_hash (Optional[int]): The hash of the previous block. None
  676. if this is the first block.
  677. - cur_block_token_ids (List[int]): A list of token ids in the current
  678. block. The current block is assumed to be full.
  679. Returns:
  680. - int: The computed hash value for the block.
  681. """
  682. assert (prev_block_hash is None) == is_first_block
  683. return hash((is_first_block, prev_block_hash, *cur_block_token_ids))
  684. class ComputedBlocksTracker:
  685. """Handles caching of per-sequence computed block ids.
  686. When a sequence appears for the first time, it traverses all of the
  687. blocks and detects the prefix of blocks that is computed. On the
  688. subsequent times, it only traverses the new blocks that were added
  689. and updates the already recorded prefix of blocks with the newly
  690. computed blocks.
  691. To avoid redundant traversals, the algorithm also detects when there
  692. is a "gap" in the computed prefix. For example, if we have blocks =
  693. [1,2,3,4,5], and we have detected [1,2,3] as the computed prefix, then
  694. we won't try to add more computed blocks to [1,2,3] in this sequence
  695. iteration, and will add more computed blocks only after the sequence is
  696. freed and reused again.
  697. Note that currently, for a given sequence, we also skip the last
  698. block id for caching purposes, to avoid caching of a full sequence
  699. """
  700. def __init__(self, allocator):
  701. self._allocator = allocator
  702. self._cached_computed_seq_blocks: Dict[int, Tuple[List[int],
  703. bool]] = {}
  704. def add_seq(self, seq_id: int) -> None:
  705. """Start tracking seq_id
  706. """
  707. assert seq_id not in self._cached_computed_seq_blocks
  708. self._cached_computed_seq_blocks[seq_id] = ([], False)
  709. def remove_seq(self, seq_id: int) -> None:
  710. """Stop tracking seq_id
  711. """
  712. assert seq_id in self._cached_computed_seq_blocks
  713. del self._cached_computed_seq_blocks[seq_id]
  714. def get_cached_computed_blocks_and_update(
  715. self, seq_id: int, block_ids: List[int]) -> List[int]:
  716. """ Look at the class documentation for details
  717. """
  718. # Ensure seq_id is already tracked
  719. assert seq_id in self._cached_computed_seq_blocks
  720. # Get cached data (may be empty on the first time)
  721. prev_computed_block_ids, has_gap = self._cached_computed_seq_blocks[
  722. seq_id]
  723. if has_gap:
  724. # When gap is detected, we do not add more computed blocks at this
  725. # sequence iteration
  726. return prev_computed_block_ids
  727. # We do not consider the last block id for caching purposes.
  728. num_cur_blocks = len(block_ids) - 1
  729. assert num_cur_blocks >= 0
  730. if len(prev_computed_block_ids) >= num_cur_blocks:
  731. # Cache HIT
  732. assert len(prev_computed_block_ids) == num_cur_blocks
  733. return prev_computed_block_ids
  734. # If here, then we may possibly add more computed blocks. As a result,
  735. # traverse the additional blocks after prev_computed_block_ids to
  736. # detect more computed blocks and add them.
  737. # Incremental init for seq_id => Look only at the new blocks
  738. computed_block_ids = self._allocator.get_computed_block_ids( # noqa: E501
  739. prev_computed_block_ids,
  740. block_ids,
  741. skip_last_block_id=
  742. True, # We skip last block id to avoid caching of full seq
  743. )
  744. # Detect if there is a "gap"
  745. has_gap = len(computed_block_ids) < num_cur_blocks
  746. # Record
  747. self._cached_computed_seq_blocks[seq_id] = (computed_block_ids,
  748. has_gap)
  749. return computed_block_ids
  750. class LastAccessBlocksTracker:
  751. """Manages the last access time of the tracked sequences, in order to allow
  752. an efficient update of allocator's block last access times
  753. """
  754. def __init__(self, allocator):
  755. self._allocator = allocator
  756. self._seq_last_access: Dict[int, Optional[float]] = {}
  757. def add_seq(self, seq_id: int) -> None:
  758. """Start tracking seq_id
  759. """
  760. assert seq_id not in self._seq_last_access
  761. self._seq_last_access[seq_id] = None
  762. def remove_seq(self, seq_id: int) -> None:
  763. """Stop tracking seq_id
  764. """
  765. assert seq_id in self._seq_last_access
  766. del self._seq_last_access[seq_id]
  767. def update_last_access(self, seq_id: int, time: float) -> None:
  768. assert seq_id in self._seq_last_access
  769. self._seq_last_access[seq_id] = time
  770. def update_seq_blocks_last_access(self, seq_id: int,
  771. block_ids: List[int]) -> None:
  772. assert seq_id in self._seq_last_access
  773. ts = self._seq_last_access[seq_id]
  774. if ts is None:
  775. # No last access was recorded, no need to update.
  776. return
  777. self._allocator.mark_blocks_as_accessed(block_ids, ts)
  778. def assert_prefix_caching_block_or_none(block: Optional[Block]):
  779. if block is None:
  780. return
  781. assert isinstance(block,
  782. PrefixCachingBlock), "Got block = {}".format(block)