1
0

block_manager_v1.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629
  1. """A block manager that manages token blocks."""
  2. import math
  3. from abc import ABC, abstractmethod
  4. from itertools import count, takewhile
  5. from os.path import commonprefix
  6. from typing import Dict, List, Optional
  7. from typing import Sequence as GenericSequence
  8. from typing import Set, Tuple
  9. from loguru import logger
  10. from aphrodite.common.block import BlockTable, PhysicalTokenBlock
  11. from aphrodite.common.sequence import Sequence, SequenceGroup, SequenceStatus
  12. from aphrodite.common.utils import Device
  13. from aphrodite.processing.evictor_v1 import (EvictionPolicy, Evictor,
  14. make_evictor)
  15. from aphrodite.processing.interfaces import AllocStatus, BlockSpaceManager
  16. class BlockAllocatorBase(ABC):
  17. """Manages free physical token blocks for a device.
  18. The allocator maintains a list of free blocks and allocates a block when
  19. requested. When a block is freed, its reference count is decremented. If
  20. the reference count becomes zero, the block is added back to the free list.
  21. """
  22. @abstractmethod
  23. def __init__(self,
  24. device: Device,
  25. block_size: int,
  26. num_blocks: int,
  27. eviction_policy: EvictionPolicy = EvictionPolicy.LRU):
  28. pass
  29. @abstractmethod
  30. def allocate(self,
  31. block_hash: Optional[int] = None,
  32. num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
  33. pass
  34. @abstractmethod
  35. def free(self, block: PhysicalTokenBlock) -> None:
  36. pass
  37. @abstractmethod
  38. def get_num_free_blocks(self) -> int:
  39. pass
  40. @abstractmethod
  41. def get_num_total_blocks(self) -> int:
  42. pass
  43. @abstractmethod
  44. def contains_block(self, block_hash: int) -> bool:
  45. pass
  46. @abstractmethod
  47. def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
  48. pass
  49. class CachedBlockAllocator(BlockAllocatorBase):
  50. """Manages free physical token blocks for a device.
  51. The allocator maintains a list of free blocks and allocates a block when
  52. requested. When a block is freed, its reference count is decremented. If
  53. the reference count becomes zero, the block is added back to the free list.
  54. """
  55. def __init__(self,
  56. device: Device,
  57. block_size: int,
  58. num_blocks: int,
  59. eviction_policy: EvictionPolicy = EvictionPolicy.LRU) -> None:
  60. self.device = device
  61. self.block_size = block_size
  62. self.num_blocks = num_blocks
  63. self.current_num_blocks = 0
  64. self.cached_blocks: Dict[int, PhysicalTokenBlock] = {}
  65. self.evictor: Evictor = make_evictor(eviction_policy)
  66. self.default_hash_ctr = count()
  67. def allocate_block(self, block_hash: int,
  68. num_hashed_tokens: int) -> PhysicalTokenBlock:
  69. if self.current_num_blocks == self.num_blocks:
  70. block = self.evictor.evict()
  71. block.block_hash = block_hash
  72. block.num_hashed_tokens = num_hashed_tokens
  73. return block
  74. block = PhysicalTokenBlock(device=self.device,
  75. block_number=self.current_num_blocks,
  76. block_size=self.block_size,
  77. block_hash=block_hash,
  78. num_hashed_tokens=num_hashed_tokens)
  79. self.current_num_blocks += 1
  80. return block
  81. def allocate(self,
  82. block_hash: Optional[int] = None,
  83. num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
  84. if block_hash is None:
  85. block_hash = next(self.default_hash_ctr)
  86. if block_hash in self.evictor:
  87. assert block_hash not in self.cached_blocks
  88. block = self.evictor.remove(block_hash)
  89. assert block.ref_count == 0
  90. self.cached_blocks[block_hash] = block
  91. block.ref_count += 1
  92. assert block.block_hash == block_hash
  93. return block
  94. if block_hash not in self.cached_blocks:
  95. self.cached_blocks[block_hash] = self.allocate_block(
  96. block_hash, num_hashed_tokens)
  97. block = self.cached_blocks[block_hash]
  98. assert block.block_hash == block_hash
  99. block.ref_count += 1
  100. return block
  101. def free(self, block: PhysicalTokenBlock) -> None:
  102. if block.ref_count == 0:
  103. raise ValueError(f"Double free! {block} is already freed.")
  104. block.ref_count -= 1
  105. if block.ref_count == 0:
  106. assert block.block_hash not in self.evictor
  107. self.evictor.add(block)
  108. # Remove the block from the cached_blocks
  109. del self.cached_blocks[block.block_hash]
  110. def get_num_free_blocks(self) -> int:
  111. return (self.num_blocks - self.current_num_blocks +
  112. self.evictor.num_blocks)
  113. def get_num_total_blocks(self) -> int:
  114. return self.num_blocks
  115. def contains_block(self, block_hash: int) -> bool:
  116. return block_hash in self.cached_blocks or block_hash in self.evictor
  117. def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
  118. # Update the hash of block and the cached_blocks dictionary.
  119. assert not self.contains_block(block_hash)
  120. old_hash = block.block_hash
  121. block.block_hash = block_hash
  122. del self.cached_blocks[old_hash]
  123. self.cached_blocks[block_hash] = block
  124. class UncachedBlockAllocator(BlockAllocatorBase):
  125. """Manages free physical token blocks for a device.
  126. The allocator maintains a list of free blocks and allocates a block when
  127. requested. When a block is freed, its reference count is decremented. If
  128. the reference count becomes zero, the block is added back to the free list.
  129. """
  130. def __init__(
  131. self,
  132. device: Device,
  133. block_size: int,
  134. num_blocks: int,
  135. ) -> None:
  136. self.device = device
  137. self.block_size = block_size
  138. self.num_blocks = num_blocks
  139. # Initialize the free blocks.
  140. self.free_blocks: BlockTable = []
  141. for i in range(num_blocks):
  142. block = PhysicalTokenBlock(device=device,
  143. block_number=i,
  144. block_size=block_size,
  145. block_hash=-1,
  146. num_hashed_tokens=0)
  147. self.free_blocks.append(block)
  148. def allocate(self,
  149. block_hash: Optional[int] = None,
  150. num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
  151. if not self.free_blocks:
  152. raise ValueError("Out of memory! No free blocks are available.")
  153. block = self.free_blocks.pop()
  154. block.ref_count = 1
  155. return block
  156. def free(self, block: PhysicalTokenBlock) -> None:
  157. if block.ref_count == 0:
  158. raise ValueError(f"Double free! {block} is already freed.")
  159. block.ref_count -= 1
  160. if block.ref_count == 0:
  161. self.free_blocks.append(block)
  162. def get_num_free_blocks(self) -> int:
  163. return len(self.free_blocks)
  164. def get_num_total_blocks(self) -> int:
  165. return self.num_blocks
  166. def contains_block(self, block_hash: int) -> bool:
  167. raise NotImplementedError(
  168. "Invalid codepath for uncached block allocator.")
  169. def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
  170. raise NotImplementedError(
  171. "Invalid codepath for uncached block allocator.")
  172. class BlockSpaceManagerV1(BlockSpaceManager):
  173. """Manages the mapping between logical and physical token blocks."""
  174. def __init__(
  175. self,
  176. block_size: int,
  177. num_gpu_blocks: int,
  178. num_cpu_blocks: int,
  179. watermark: float = 0.01,
  180. sliding_window: Optional[int] = None,
  181. enable_caching: bool = False,
  182. ) -> None:
  183. self.block_size = block_size
  184. self.num_total_gpu_blocks = num_gpu_blocks
  185. self.num_total_cpu_blocks = num_cpu_blocks
  186. if enable_caching and sliding_window is not None:
  187. raise NotImplementedError(
  188. "Sliding window is not allowed with prefix caching enabled!")
  189. self.block_sliding_window = None
  190. if sliding_window is not None:
  191. # Round up to nearest block size to regularize sliding window
  192. # allocation sizes.
  193. self.block_sliding_window = math.ceil(sliding_window / block_size)
  194. self.watermark = watermark
  195. assert watermark >= 0.0
  196. self.enable_caching = enable_caching
  197. self.watermark_blocks = int(watermark * num_gpu_blocks)
  198. if self.enable_caching:
  199. logger.info("Automatic prefix caching is enabled.")
  200. self.gpu_allocator: BlockAllocatorBase = CachedBlockAllocator(
  201. Device.GPU, block_size, num_gpu_blocks)
  202. self.cpu_allocator: BlockAllocatorBase = CachedBlockAllocator(
  203. Device.CPU, block_size, num_cpu_blocks)
  204. else:
  205. self.gpu_allocator = UncachedBlockAllocator(
  206. Device.GPU, block_size, num_gpu_blocks)
  207. self.cpu_allocator = UncachedBlockAllocator(
  208. Device.CPU, block_size, num_cpu_blocks)
  209. # Mapping: seq_id -> BlockTable.
  210. self.block_tables: Dict[int, BlockTable] = {}
  211. def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
  212. # FIXME(woosuk): Here we assume that all sequences in the group share
  213. # the same prompt. This may not be true for preempted sequences.
  214. seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
  215. num_required_blocks = len(seq.logical_token_blocks)
  216. if self.block_sliding_window is not None:
  217. num_required_blocks = min(num_required_blocks,
  218. self.block_sliding_window)
  219. num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
  220. # Use watermark to avoid frequent cache eviction.
  221. if (self.num_total_gpu_blocks - num_required_blocks <
  222. self.watermark_blocks):
  223. return AllocStatus.NEVER
  224. if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks:
  225. return AllocStatus.OK
  226. else:
  227. return AllocStatus.LATER
  228. def allocate(self, seq_group: SequenceGroup) -> None:
  229. # NOTE: Here we assume that all sequences in the group have the same
  230. # prompt.
  231. seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
  232. # Allocate new physical token blocks that will store the prompt tokens.
  233. num_prompt_blocks = len(seq.logical_token_blocks)
  234. block_table: BlockTable = []
  235. for logical_idx in range(num_prompt_blocks):
  236. if (self.block_sliding_window is not None
  237. and logical_idx >= self.block_sliding_window):
  238. block = block_table[logical_idx % self.block_sliding_window]
  239. # Set the reference counts of the token blocks.
  240. block.ref_count = seq_group.num_seqs()
  241. elif self.enable_caching:
  242. block = self.gpu_allocator.allocate(
  243. seq.hash_of_block(logical_idx),
  244. seq.num_hashed_tokens_of_block(logical_idx))
  245. else:
  246. block = self.gpu_allocator.allocate()
  247. # Set the reference counts of the token blocks.
  248. block.ref_count = seq_group.num_seqs()
  249. block_table.append(block)
  250. # Assign the block table for each sequence.
  251. for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
  252. self.block_tables[seq.seq_id] = block_table.copy()
  253. def can_append_slots(self,
  254. seq_group: SequenceGroup,
  255. num_lookahead_slots: int = 0) -> bool:
  256. assert (num_lookahead_slots == 0
  257. ), "lookahead allocation not supported in BlockSpaceManagerV1"
  258. # Simple heuristic: If there is at least one free block
  259. # for each sequence, we can append.
  260. num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
  261. num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING)
  262. return num_seqs <= num_free_gpu_blocks
  263. def _promote_last_block(
  264. self,
  265. seq: Sequence,
  266. last_block: PhysicalTokenBlock,
  267. ) -> PhysicalTokenBlock:
  268. assert self.enable_caching
  269. # Compute a new hash for the block so that it can be shared by other
  270. # Sequences
  271. new_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
  272. # if new_hash is already in the cached table, then free last_block
  273. # and return the cached version
  274. if self.gpu_allocator.contains_block(new_hash):
  275. self.gpu_allocator.free(last_block)
  276. return self.gpu_allocator.allocate(new_hash)
  277. else:
  278. self.gpu_allocator.update_hash(new_hash, last_block)
  279. return last_block
  280. def _is_last_block_full(
  281. self,
  282. seq: Sequence,
  283. ) -> bool:
  284. token_ids_len = seq.data.get_len()
  285. return token_ids_len > 0 and token_ids_len % seq.block_size == 0
  286. def _maybe_promote_last_block(
  287. self,
  288. seq: Sequence,
  289. last_block: PhysicalTokenBlock,
  290. ) -> PhysicalTokenBlock:
  291. if self._is_last_block_full(seq):
  292. return self._promote_last_block(seq, last_block)
  293. else:
  294. return last_block
  295. def _allocate_last_physical_block(
  296. self,
  297. seq: Sequence,
  298. ) -> PhysicalTokenBlock:
  299. # Called before a new block is appended.
  300. # This is in charge of allocating a new physical block (to be appended).
  301. # None if the last block is not full. Otherwise, we set it to the
  302. # content hash.
  303. if not self.enable_caching:
  304. return self.gpu_allocator.allocate()
  305. block_hash: Optional[int] = None
  306. if (self._is_last_block_full(seq)):
  307. block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
  308. num_hashed_tokens = seq.num_hashed_tokens_of_block(
  309. len(seq.logical_token_blocks) - 1)
  310. # num_hashed_tokens is used to compute future hashes
  311. # (e.g. in the hashing function, it is used to ask the sequence for
  312. # prefix tokens)
  313. new_block = self.gpu_allocator.allocate(block_hash, num_hashed_tokens)
  314. # If the block has is None, then the block is not full.
  315. # If the block is not full, then we expect it to have a refcount of 1.
  316. if block_hash is None:
  317. assert new_block.ref_count == 1
  318. return new_block
  319. def append_slots(
  320. self,
  321. seq: Sequence,
  322. num_lookahead_slots: int = 0,
  323. ) -> List[Tuple[int, int]]:
  324. """Allocate a physical slot for a new token."""
  325. logical_blocks = seq.logical_token_blocks
  326. block_table = self.block_tables[seq.seq_id]
  327. # If we need to allocate a new physical block
  328. if len(block_table) < len(logical_blocks):
  329. # Currently this code only supports adding one physical block
  330. assert len(block_table) == len(logical_blocks) - 1
  331. if (self.block_sliding_window
  332. and len(block_table) >= self.block_sliding_window):
  333. # reuse a block
  334. block_table.append(block_table[len(block_table) %
  335. self.block_sliding_window])
  336. else:
  337. # The sequence hash a new logical block.
  338. # Allocate a new physical block.
  339. new_block = self._allocate_last_physical_block(seq)
  340. block_table.append(new_block)
  341. return []
  342. # We want to append the token to the last physical block.
  343. last_block = block_table[-1]
  344. assert last_block.device == Device.GPU
  345. if last_block.ref_count == 1:
  346. # Not shared with other sequences. Appendable.
  347. if self.enable_caching:
  348. # If the last block is now complete, we may reuse an old block
  349. # to save memory.
  350. maybe_new_block = self._maybe_promote_last_block(
  351. seq, last_block)
  352. block_table[-1] = maybe_new_block
  353. return []
  354. else:
  355. # The last block is shared with other sequences.
  356. # Copy on Write: Allocate a new block and copy the tokens.
  357. new_block = self._allocate_last_physical_block(seq)
  358. block_table[-1] = new_block
  359. self.gpu_allocator.free(last_block)
  360. return [(last_block.block_number, new_block.block_number)]
  361. def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
  362. # NOTE: fork does not allocate a new physical block.
  363. # Thus, it is always safe from OOM.
  364. src_block_table = self.block_tables[parent_seq.seq_id]
  365. self.block_tables[child_seq.seq_id] = src_block_table.copy()
  366. # When using a sliding window, blocks will be eventually reused.
  367. # In this case the block tables will contain repeated blocks.
  368. # When forking, we must make sure that each block's `ref_count`
  369. # is only incremented by one, so we deduplicate them by wrapping
  370. # them in a set.
  371. for block in set(src_block_table):
  372. block.ref_count += 1
  373. def _get_physical_blocks(
  374. self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]:
  375. # NOTE: Here, we assume that the physical blocks are only shared by
  376. # the sequences in the same group.
  377. blocks: Set[PhysicalTokenBlock] = set()
  378. for seq in seq_group.get_seqs():
  379. if seq.is_finished():
  380. continue
  381. blocks.update(self.block_tables[seq.seq_id])
  382. return list(blocks)
  383. def can_swap_in(self,
  384. seq_group: SequenceGroup,
  385. num_lookahead_slots: int = 0) -> AllocStatus:
  386. assert (num_lookahead_slots == 0
  387. ), "BlockSpaceManagerV1 does not support lookahead allocation"
  388. blocks = self._get_physical_blocks(seq_group)
  389. num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
  390. num_free_blocks = self.gpu_allocator.get_num_free_blocks()
  391. # NOTE: Conservatively, we assume that every sequence will allocate
  392. # at least one free block right after the swap-in.
  393. # NOTE: This should match the logic in can_append_slot().
  394. num_required_blocks = len(blocks) + num_swapped_seqs
  395. if self.gpu_allocator.get_num_total_blocks() < num_required_blocks:
  396. return AllocStatus.NEVER
  397. elif num_free_blocks - num_required_blocks >= self.watermark_blocks:
  398. return AllocStatus.OK
  399. else:
  400. return AllocStatus.LATER
  401. def swap_in(self,
  402. seq_group: SequenceGroup,
  403. num_lookahead_slots: int = 0) -> List[Tuple[int, int]]:
  404. assert (num_lookahead_slots == 0
  405. ), "BlockSpaceManagerV1 does not support lookahead allocation"
  406. # CPU block -> GPU block.
  407. # dict is efficient in lookup `if cpu_block in mapping`
  408. mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
  409. for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
  410. new_block_table: BlockTable = []
  411. block_table = self.block_tables[seq.seq_id]
  412. for cpu_block in block_table:
  413. if cpu_block in mapping:
  414. gpu_block = mapping[cpu_block]
  415. gpu_block.ref_count += 1
  416. else:
  417. gpu_block = self.gpu_allocator.allocate(
  418. cpu_block.block_hash, cpu_block.num_hashed_tokens)
  419. mapping[cpu_block] = gpu_block
  420. new_block_table.append(gpu_block)
  421. # Free the CPU block swapped in to GPU.
  422. self.cpu_allocator.free(cpu_block)
  423. self.block_tables[seq.seq_id] = new_block_table
  424. block_number_mapping = {
  425. cpu_block.block_number: gpu_block.block_number
  426. for cpu_block, gpu_block in mapping.items()
  427. }
  428. # convert to list of tuples once here
  429. return list(block_number_mapping.items())
  430. def can_swap_out(self, seq_group: SequenceGroup) -> bool:
  431. blocks = self._get_physical_blocks(seq_group)
  432. return len(blocks) <= self.cpu_allocator.get_num_free_blocks()
  433. def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
  434. # GPU block -> CPU block.
  435. # dict is efficient in lookup `if gpu_block in mapping`
  436. mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
  437. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  438. new_block_table: BlockTable = []
  439. block_table = self.block_tables[seq.seq_id]
  440. for gpu_block in block_table:
  441. if gpu_block in mapping:
  442. cpu_block = mapping[gpu_block]
  443. cpu_block.ref_count += 1
  444. else:
  445. cpu_block = self.cpu_allocator.allocate(
  446. gpu_block.block_hash, gpu_block.num_hashed_tokens)
  447. mapping[gpu_block] = cpu_block
  448. new_block_table.append(cpu_block)
  449. # Free the GPU block swapped out to CPU.
  450. self.gpu_allocator.free(gpu_block)
  451. self.block_tables[seq.seq_id] = new_block_table
  452. block_number_mapping = {
  453. gpu_block.block_number: cpu_block.block_number
  454. for gpu_block, cpu_block in mapping.items()
  455. }
  456. # convert to list of tuples once here
  457. return list(block_number_mapping.items())
  458. def _free_block_table(self, block_table: BlockTable) -> None:
  459. # when using a sliding window, each seq will only use up
  460. # to `self.block_sliding_window` blocks. When freeing
  461. # the block table, we must make sure to not free blocks more
  462. # than once. If no sliding window is used, there is no block
  463. # reuse in the block table, so we must free all blocks.
  464. blocks_to_free = (block_table[-self.block_sliding_window:]
  465. if self.block_sliding_window is not None else
  466. block_table)
  467. for block in set(blocks_to_free):
  468. if block.device == Device.GPU:
  469. self.gpu_allocator.free(block)
  470. else:
  471. self.cpu_allocator.free(block)
  472. def free(self, seq: Sequence) -> None:
  473. if seq.seq_id not in self.block_tables:
  474. # Already freed or haven't been scheduled yet.
  475. return
  476. block_table = self.block_tables[seq.seq_id]
  477. self._free_block_table(block_table)
  478. del self.block_tables[seq.seq_id]
  479. def reset(self) -> None:
  480. for block_table in self.block_tables.values():
  481. self._free_block_table(block_table)
  482. self.block_tables.clear()
  483. def get_block_table(self, seq: Sequence) -> List[int]:
  484. block_table = self.block_tables[seq.seq_id]
  485. return [block.block_number for block in block_table]
  486. def get_num_free_gpu_blocks(self) -> int:
  487. return self.gpu_allocator.get_num_free_blocks()
  488. def get_num_free_cpu_blocks(self) -> int:
  489. return self.cpu_allocator.get_num_free_blocks()
  490. def access_all_blocks_in_seq(
  491. self,
  492. seq: Sequence,
  493. access_time: float,
  494. ) -> None:
  495. if self.enable_caching:
  496. # Update the last accessed time of all the blocks accessed
  497. # in this step.
  498. block_table = self.block_tables[seq.seq_id]
  499. for block in block_table:
  500. block.last_accessed = access_time
  501. def compute_full_blocks_in_seq(self, seq: Sequence):
  502. if seq.seq_id not in self.block_tables:
  503. return
  504. max_full_block = seq.get_len() // self.block_size - 1
  505. block_table = self.block_tables[seq.seq_id]
  506. if max_full_block == -1:
  507. return
  508. for i in reversed(range(max_full_block)):
  509. if block_table[i].computed:
  510. break
  511. block_table[i].computed = True
  512. def get_all_computed_blocks(self, seq: Sequence) -> List[int]:
  513. if seq.seq_id not in self.block_tables:
  514. return []
  515. block_table = self.block_tables[seq.seq_id]
  516. # NOTE We exclude the last block to avoid the case where the entire
  517. # prompt is cached. This would cause erroneous behavior in model
  518. # runner.
  519. return [
  520. b.block_number
  521. for b in takewhile(lambda b: b.computed, block_table[:-1])
  522. ]
  523. def get_common_computed_block_ids(
  524. self, seqs: List[Sequence]) -> GenericSequence[int]:
  525. """Return the block ids that are common for a given sequence group.
  526. Used in prefill (can skip prefill of some blocks).
  527. """
  528. # Can return non-empty result only with prefix caching enabled.
  529. if not self.enable_caching:
  530. return []
  531. ids_list = [self.get_all_computed_blocks(seq) for seq in seqs]
  532. return commonprefix([ids for ids in ids_list if ids != []])
  533. def mark_blocks_as_computed(self, seq_group: SequenceGroup):
  534. if self.enable_caching:
  535. for seq in seq_group.seqs_dict.values():
  536. self.compute_full_blocks_in_seq(seq)