block_manager_v1.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608
  1. """A block manager that manages token blocks."""
  2. from abc import ABC, abstractmethod
  3. from itertools import count, takewhile
  4. from os.path import commonprefix
  5. from typing import Dict, List, Optional
  6. from typing import Sequence as GenericSequence
  7. from typing import Set
  8. from loguru import logger
  9. from aphrodite.common.block import BlockTable, PhysicalTokenBlock
  10. from aphrodite.common.sequence import Sequence, SequenceGroup, SequenceStatus
  11. from aphrodite.common.utils import Device
  12. from aphrodite.processing.evictor import EvictionPolicy, Evictor, make_evictor
  13. from aphrodite.processing.interfaces import AllocStatus, BlockSpaceManager
  14. class BlockAllocatorBase(ABC):
  15. """Manages free physical token blocks for a device.
  16. The allocator maintains a list of free blocks and allocates a block when
  17. requested. When a block is freed, its reference count is decremented. If
  18. the reference count becomes zero, the block is added back to the free list.
  19. """
  20. @abstractmethod
  21. def __init__(self,
  22. device: Device,
  23. block_size: int,
  24. num_blocks: int,
  25. eviction_policy: EvictionPolicy = EvictionPolicy.LRU):
  26. pass
  27. @abstractmethod
  28. def allocate(self,
  29. block_hash: Optional[int] = None,
  30. num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
  31. pass
  32. @abstractmethod
  33. def free(self, block: PhysicalTokenBlock) -> None:
  34. pass
  35. @abstractmethod
  36. def get_num_free_blocks(self) -> int:
  37. pass
  38. @abstractmethod
  39. def contains_block(self, block_hash: int) -> bool:
  40. pass
  41. @abstractmethod
  42. def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
  43. pass
  44. class CachedBlockAllocator(BlockAllocatorBase):
  45. """Manages free physical token blocks for a device.
  46. The allocator maintains a list of free blocks and allocates a block when
  47. requested. When a block is freed, its reference count is decremented. If
  48. the reference count becomes zero, the block is added back to the free list.
  49. """
  50. def __init__(self,
  51. device: Device,
  52. block_size: int,
  53. num_blocks: int,
  54. eviction_policy: EvictionPolicy = EvictionPolicy.LRU) -> None:
  55. self.device = device
  56. self.block_size = block_size
  57. self.num_blocks = num_blocks
  58. self.current_num_blocks = 0
  59. self.cached_blocks: Dict[int, PhysicalTokenBlock] = {}
  60. self.evictor: Evictor = make_evictor(eviction_policy)
  61. self.default_hash_ctr = count()
  62. def allocate_block(self, block_hash: int,
  63. num_hashed_tokens: int) -> PhysicalTokenBlock:
  64. if self.current_num_blocks == self.num_blocks:
  65. block = self.evictor.evict()
  66. block.block_hash = block_hash
  67. block.num_hashed_tokens = num_hashed_tokens
  68. return block
  69. block = PhysicalTokenBlock(device=self.device,
  70. block_number=self.current_num_blocks,
  71. block_size=self.block_size,
  72. block_hash=block_hash,
  73. num_hashed_tokens=num_hashed_tokens)
  74. self.current_num_blocks += 1
  75. return block
  76. def allocate(self,
  77. block_hash: Optional[int] = None,
  78. num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
  79. if block_hash is None:
  80. block_hash = next(self.default_hash_ctr)
  81. if block_hash in self.evictor:
  82. assert block_hash not in self.cached_blocks
  83. block = self.evictor.remove(block_hash)
  84. assert block.ref_count == 0
  85. self.cached_blocks[block_hash] = block
  86. block.ref_count += 1
  87. assert block.block_hash == block_hash
  88. return block
  89. if block_hash not in self.cached_blocks:
  90. self.cached_blocks[block_hash] = self.allocate_block(
  91. block_hash, num_hashed_tokens)
  92. block = self.cached_blocks[block_hash]
  93. assert block.block_hash == block_hash
  94. block.ref_count += 1
  95. return block
  96. def free(self, block: PhysicalTokenBlock) -> None:
  97. if block.ref_count == 0:
  98. raise ValueError(f"Double free! {block} is already freed.")
  99. block.ref_count -= 1
  100. if block.ref_count == 0:
  101. assert block.block_hash not in self.evictor
  102. self.evictor.add(block)
  103. # Remove the block from the cached_blocks
  104. del self.cached_blocks[block.block_hash]
  105. def get_num_free_blocks(self) -> int:
  106. return (self.num_blocks - self.current_num_blocks +
  107. self.evictor.num_blocks)
  108. def contains_block(self, block_hash: int) -> bool:
  109. return block_hash in self.cached_blocks or block_hash in self.evictor
  110. def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
  111. # Update the hash of block and the cached_blocks dictionary.
  112. assert not self.contains_block(block_hash)
  113. old_hash = block.block_hash
  114. block.block_hash = block_hash
  115. del self.cached_blocks[old_hash]
  116. self.cached_blocks[block_hash] = block
  117. class UncachedBlockAllocator(BlockAllocatorBase):
  118. """Manages free physical token blocks for a device.
  119. The allocator maintains a list of free blocks and allocates a block when
  120. requested. When a block is freed, its reference count is decremented. If
  121. the reference count becomes zero, the block is added back to the free list.
  122. """
  123. def __init__(
  124. self,
  125. device: Device,
  126. block_size: int,
  127. num_blocks: int,
  128. ) -> None:
  129. self.device = device
  130. self.block_size = block_size
  131. self.num_blocks = num_blocks
  132. # Initialize the free blocks.
  133. self.free_blocks: BlockTable = []
  134. for i in range(num_blocks):
  135. block = PhysicalTokenBlock(device=device,
  136. block_number=i,
  137. block_size=block_size,
  138. block_hash=-1,
  139. num_hashed_tokens=0)
  140. self.free_blocks.append(block)
  141. def allocate(self,
  142. block_hash: Optional[int] = None,
  143. num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
  144. if not self.free_blocks:
  145. raise ValueError("Out of memory! No free blocks are available.")
  146. block = self.free_blocks.pop()
  147. block.ref_count = 1
  148. return block
  149. def free(self, block: PhysicalTokenBlock) -> None:
  150. if block.ref_count == 0:
  151. raise ValueError(f"Double free! {block} is already freed.")
  152. block.ref_count -= 1
  153. if block.ref_count == 0:
  154. self.free_blocks.append(block)
  155. def get_num_free_blocks(self) -> int:
  156. return len(self.free_blocks)
  157. def contains_block(self, block_hash: int) -> bool:
  158. raise NotImplementedError(
  159. "Invalid codepath for uncached block allocator.")
  160. def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
  161. raise NotImplementedError(
  162. "Invalid codepath for uncached block allocator.")
  163. class BlockSpaceManagerV1(BlockSpaceManager):
  164. """Manages the mapping between logical and physical token blocks."""
  165. def __init__(
  166. self,
  167. block_size: int,
  168. num_gpu_blocks: int,
  169. num_cpu_blocks: int,
  170. watermark: float = 0.01,
  171. sliding_window: Optional[int] = None,
  172. enable_caching: bool = False,
  173. ) -> None:
  174. self.block_size = block_size
  175. self.num_total_gpu_blocks = num_gpu_blocks
  176. self.num_total_cpu_blocks = num_cpu_blocks
  177. if enable_caching and sliding_window is not None:
  178. raise NotImplementedError(
  179. "Sliding window is not allowed with prefix caching enabled!")
  180. self.block_sliding_window = None
  181. if sliding_window is not None:
  182. assert sliding_window % block_size == 0, (sliding_window,
  183. block_size)
  184. self.block_sliding_window = sliding_window // block_size
  185. self.watermark = watermark
  186. assert watermark >= 0.0
  187. self.enable_caching = enable_caching
  188. self.watermark_blocks = int(watermark * num_gpu_blocks)
  189. if self.enable_caching:
  190. logger.info("Automatic prefix caching is enabled.")
  191. self.gpu_allocator: BlockAllocatorBase = CachedBlockAllocator(
  192. Device.GPU, block_size, num_gpu_blocks)
  193. self.cpu_allocator: BlockAllocatorBase = CachedBlockAllocator(
  194. Device.CPU, block_size, num_cpu_blocks)
  195. else:
  196. self.gpu_allocator = UncachedBlockAllocator(
  197. Device.GPU, block_size, num_gpu_blocks)
  198. self.cpu_allocator = UncachedBlockAllocator(
  199. Device.CPU, block_size, num_cpu_blocks)
  200. # Mapping: seq_id -> BlockTable.
  201. self.block_tables: Dict[int, BlockTable] = {}
  202. def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
  203. # FIXME: Here we assume that all sequences in the group share
  204. # the same prompt. This may not be true for preempted sequences.
  205. seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
  206. num_required_blocks = len(seq.logical_token_blocks)
  207. if self.block_sliding_window is not None:
  208. num_required_blocks = min(num_required_blocks,
  209. self.block_sliding_window)
  210. num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
  211. # Use watermark to avoid frequent cache eviction.
  212. if (self.num_total_gpu_blocks - num_required_blocks <
  213. self.watermark_blocks):
  214. return AllocStatus.NEVER
  215. if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks:
  216. return AllocStatus.OK
  217. else:
  218. return AllocStatus.LATER
  219. def allocate(self, seq_group: SequenceGroup) -> None:
  220. # NOTE: Here we assume that all sequences in the group have the same
  221. # prompt.
  222. seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
  223. # Allocate new physical token blocks that will store the prompt tokens.
  224. num_prompt_blocks = len(seq.logical_token_blocks)
  225. block_table: BlockTable = []
  226. for logical_idx in range(num_prompt_blocks):
  227. if (self.block_sliding_window is not None
  228. and logical_idx >= self.block_sliding_window):
  229. block = block_table[logical_idx % self.block_sliding_window]
  230. # Set the reference counts of the token blocks.
  231. block.ref_count = seq_group.num_seqs()
  232. elif self.enable_caching:
  233. block = self.gpu_allocator.allocate(
  234. seq.hash_of_block(logical_idx),
  235. seq.num_hashed_tokens_of_block(logical_idx))
  236. else:
  237. block = self.gpu_allocator.allocate()
  238. # Set the reference counts of the token blocks.
  239. block.ref_count = seq_group.num_seqs()
  240. block_table.append(block)
  241. # Assign the block table for each sequence.
  242. for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
  243. self.block_tables[seq.seq_id] = block_table.copy()
  244. def can_append_slots(self,
  245. seq_group: SequenceGroup,
  246. num_lookahead_slots: int = 0) -> bool:
  247. assert (num_lookahead_slots == 0
  248. ), "lookahead allocation not supported in BlockSpaceManagerV1"
  249. # Simple heuristic: If there is at least one free block
  250. # for each sequence, we can append.
  251. num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
  252. num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING)
  253. return num_seqs <= num_free_gpu_blocks
  254. def _promote_last_block(
  255. self,
  256. seq: Sequence,
  257. last_block: PhysicalTokenBlock,
  258. ) -> PhysicalTokenBlock:
  259. assert self.enable_caching
  260. # Compute a new hash for the block so that it can be shared by other
  261. # Sequences
  262. new_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
  263. # if new_hash is already in the cached table, then free last_block
  264. # and return the cached version
  265. if self.gpu_allocator.contains_block(new_hash):
  266. self.gpu_allocator.free(last_block)
  267. return self.gpu_allocator.allocate(new_hash)
  268. else:
  269. self.gpu_allocator.update_hash(new_hash, last_block)
  270. return last_block
  271. def _is_last_block_full(
  272. self,
  273. seq: Sequence,
  274. ) -> bool:
  275. token_ids_len = seq.data.get_len()
  276. return token_ids_len > 0 and token_ids_len % seq.block_size == 0
  277. def _maybe_promote_last_block(
  278. self,
  279. seq: Sequence,
  280. last_block: PhysicalTokenBlock,
  281. ) -> PhysicalTokenBlock:
  282. if self._is_last_block_full(seq):
  283. return self._promote_last_block(seq, last_block)
  284. else:
  285. return last_block
  286. def _allocate_last_physical_block(
  287. self,
  288. seq: Sequence,
  289. ) -> PhysicalTokenBlock:
  290. # Called before a new block is appended.
  291. # This is in charge of allocating a new physical block (to be appended).
  292. # None if the last block is not full. Otherwise, we set it to the
  293. # content hash.
  294. if not self.enable_caching:
  295. return self.gpu_allocator.allocate()
  296. block_hash: Optional[int] = None
  297. if (self._is_last_block_full(seq)):
  298. block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
  299. num_hashed_tokens = seq.num_hashed_tokens_of_block(
  300. len(seq.logical_token_blocks) - 1)
  301. # num_hashed_tokens is used to compute future hashes
  302. # (e.g. in the hashing function, it is used to ask the sequence for
  303. # prefix tokens)
  304. new_block = self.gpu_allocator.allocate(block_hash, num_hashed_tokens)
  305. # If the block has is None, then the block is not full.
  306. # If the block is not full, then we expect it to have a refcount of 1.
  307. if block_hash is None:
  308. assert new_block.ref_count == 1
  309. return new_block
  310. def append_slots(
  311. self,
  312. seq: Sequence,
  313. num_lookahead_slots: int = 0,
  314. ) -> Dict[int, List[int]]:
  315. """Allocate a physical slot for a new token."""
  316. logical_blocks = seq.logical_token_blocks
  317. block_table = self.block_tables[seq.seq_id]
  318. # If we need to allocate a new physical block
  319. if len(block_table) < len(logical_blocks):
  320. # Currently this code only supports adding one physical block
  321. assert len(block_table) == len(logical_blocks) - 1
  322. if (self.block_sliding_window
  323. and len(block_table) >= self.block_sliding_window):
  324. # reuse a block
  325. block_table.append(block_table[len(block_table) %
  326. self.block_sliding_window])
  327. else:
  328. # The sequence has a new logical block.
  329. # Allocate a new physical block.
  330. new_block = self._allocate_last_physical_block(seq)
  331. block_table.append(new_block)
  332. return {}
  333. # We want to append the token to the last physical block.
  334. last_block = block_table[-1]
  335. assert last_block.device == Device.GPU
  336. if last_block.ref_count == 1:
  337. # Not shared with other sequences. Appendable.
  338. if self.enable_caching:
  339. # If the last block is now complete, we may reuse an old block
  340. # to save memory.
  341. maybe_new_block = self._maybe_promote_last_block(
  342. seq, last_block)
  343. block_table[-1] = maybe_new_block
  344. return {}
  345. else:
  346. # The last block is shared with other sequences.
  347. # Copy on Write: Allocate a new block and copy the tokens.
  348. new_block = self._allocate_last_physical_block(seq)
  349. block_table[-1] = new_block
  350. self.gpu_allocator.free(last_block)
  351. return {last_block.block_number: [new_block.block_number]}
  352. def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
  353. # NOTE: fork does not allocate a new physical block.
  354. # Thus, it is always safe from OOM.
  355. src_block_table = self.block_tables[parent_seq.seq_id]
  356. self.block_tables[child_seq.seq_id] = src_block_table.copy()
  357. # When using a sliding window, blocks will be eventually reused.
  358. # In this case the block tables will contain repeated blocks.
  359. # When forking, we must make sure that each block's `ref_count`
  360. # is only incremented by one, so we deduplicate them by wrapping
  361. # them in a set.
  362. for block in set(src_block_table):
  363. block.ref_count += 1
  364. def _get_physical_blocks(
  365. self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]:
  366. # NOTE: Here, we assume that the physical blocks are only shared by
  367. # the sequences in the same group.
  368. blocks: Set[PhysicalTokenBlock] = set()
  369. for seq in seq_group.get_seqs():
  370. if seq.is_finished():
  371. continue
  372. blocks.update(self.block_tables[seq.seq_id])
  373. return list(blocks)
  374. def can_swap_in(self,
  375. seq_group: SequenceGroup,
  376. num_lookahead_slots: int = 0) -> bool:
  377. assert (num_lookahead_slots == 0
  378. ), "BlockSpaceManagerV1 does not support lookahead allocation"
  379. blocks = self._get_physical_blocks(seq_group)
  380. num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
  381. num_free_blocks = self.gpu_allocator.get_num_free_blocks()
  382. # NOTE: Conservatively, we assume that every sequence will allocate
  383. # at least one free block right after the swap-in.
  384. # NOTE: This should match the logic in can_append_slot().
  385. num_required_blocks = len(blocks) + num_swapped_seqs
  386. return num_free_blocks - num_required_blocks >= self.watermark_blocks
  387. def swap_in(self,
  388. seq_group: SequenceGroup,
  389. num_lookahead_slots: int = 0) -> Dict[int, int]:
  390. assert (num_lookahead_slots == 0
  391. ), "BlockSpaceManagerV1 does not support lookahead allocation"
  392. # CPU block -> GPU block.
  393. mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
  394. for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
  395. new_block_table: BlockTable = []
  396. block_table = self.block_tables[seq.seq_id]
  397. for cpu_block in block_table:
  398. if cpu_block in mapping:
  399. gpu_block = mapping[cpu_block]
  400. gpu_block.ref_count += 1
  401. else:
  402. gpu_block = self.gpu_allocator.allocate(
  403. cpu_block.block_hash, cpu_block.num_hashed_tokens)
  404. mapping[cpu_block] = gpu_block
  405. new_block_table.append(gpu_block)
  406. # Free the CPU block swapped in to GPU.
  407. self.cpu_allocator.free(cpu_block)
  408. self.block_tables[seq.seq_id] = new_block_table
  409. block_number_mapping = {
  410. cpu_block.block_number: gpu_block.block_number
  411. for cpu_block, gpu_block in mapping.items()
  412. }
  413. return block_number_mapping
  414. def can_swap_out(self, seq_group: SequenceGroup) -> bool:
  415. blocks = self._get_physical_blocks(seq_group)
  416. return len(blocks) <= self.cpu_allocator.get_num_free_blocks()
  417. def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
  418. # GPU block -> CPU block.
  419. mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
  420. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  421. new_block_table: BlockTable = []
  422. block_table = self.block_tables[seq.seq_id]
  423. for gpu_block in block_table:
  424. if gpu_block in mapping:
  425. cpu_block = mapping[gpu_block]
  426. cpu_block.ref_count += 1
  427. else:
  428. cpu_block = self.cpu_allocator.allocate(
  429. gpu_block.block_hash, gpu_block.num_hashed_tokens)
  430. mapping[gpu_block] = cpu_block
  431. new_block_table.append(cpu_block)
  432. # Free the GPU block swapped out to CPU.
  433. self.gpu_allocator.free(gpu_block)
  434. self.block_tables[seq.seq_id] = new_block_table
  435. block_number_mapping = {
  436. gpu_block.block_number: cpu_block.block_number
  437. for gpu_block, cpu_block in mapping.items()
  438. }
  439. return block_number_mapping
  440. def _free_block_table(self, block_table: BlockTable) -> None:
  441. # when using a sliding window, each seq will only use up
  442. # to `self.block_sliding_window` blocks. When freeing
  443. # the block table, we must make sure to not free blocks more
  444. # than once. If no sliding window is used, there is no block
  445. # reuse in the block table, so we must free all blocks.
  446. blocks_to_free = (block_table[-self.block_sliding_window:]
  447. if self.block_sliding_window is not None else
  448. block_table)
  449. for block in set(blocks_to_free):
  450. if block.device == Device.GPU:
  451. self.gpu_allocator.free(block)
  452. else:
  453. self.cpu_allocator.free(block)
  454. def free(self, seq: Sequence) -> None:
  455. if seq.seq_id not in self.block_tables:
  456. # Already freed or haven't been scheduled yet.
  457. return
  458. block_table = self.block_tables[seq.seq_id]
  459. self._free_block_table(block_table)
  460. del self.block_tables[seq.seq_id]
  461. def reset(self) -> None:
  462. for block_table in self.block_tables.values():
  463. self._free_block_table(block_table)
  464. self.block_tables.clear()
  465. def get_block_table(self, seq: Sequence) -> List[int]:
  466. block_table = self.block_tables[seq.seq_id]
  467. return [block.block_number for block in block_table]
  468. def get_num_free_gpu_blocks(self) -> int:
  469. return self.gpu_allocator.get_num_free_blocks()
  470. def get_num_free_cpu_blocks(self) -> int:
  471. return self.cpu_allocator.get_num_free_blocks()
  472. def access_all_blocks_in_seq(
  473. self,
  474. seq: Sequence,
  475. access_time: float,
  476. ) -> None:
  477. if self.enable_caching:
  478. # Update the last accessed time of all the blocks accessed
  479. # in this step.
  480. block_table = self.block_tables[seq.seq_id]
  481. for block in block_table:
  482. block.last_accessed = access_time
  483. def compute_full_blocks_in_seq(self, seq: Sequence):
  484. if seq.seq_id not in self.block_tables:
  485. return
  486. max_full_block = seq.get_len() // self.block_size - 1
  487. block_table = self.block_tables[seq.seq_id]
  488. if max_full_block == -1:
  489. return
  490. for i in reversed(range(max_full_block)):
  491. if block_table[i].computed:
  492. break
  493. block_table[i].computed = True
  494. def get_all_computed_blocks(self, seq: Sequence) -> List[int]:
  495. if seq.seq_id not in self.block_tables:
  496. return []
  497. block_table = self.block_tables[seq.seq_id]
  498. # NOTE We exclude the last block to avoid the case where the entire
  499. # prompt is cached. This would cause erroneous behavior in model
  500. # runner.
  501. return [
  502. b.block_number
  503. for b in takewhile(lambda b: b.computed, block_table[:-1])
  504. ]
  505. def get_common_computed_block_ids(
  506. self, seqs: List[Sequence]) -> GenericSequence[int]:
  507. """Return the block ids that are common for a given sequence group.
  508. Used in prefill (can skip prefill of some blocks).
  509. """
  510. # Can return non-empty result only with prefix caching enabled.
  511. if not self.enable_caching:
  512. return []
  513. ids_list = [self.get_all_computed_blocks(seq) for seq in seqs]
  514. return commonprefix([ids for ids in ids_list if ids != []])
  515. def mark_blocks_as_computed(self, seq_group: SequenceGroup):
  516. if self.enable_caching:
  517. for seq in seq_group.seqs_dict.values():
  518. self.compute_full_blocks_in_seq(seq)