block_manager_v1.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708
  1. """A block manager that manages token blocks."""
  2. import math
  3. from abc import ABC, abstractmethod
  4. from itertools import count, takewhile
  5. from os.path import commonprefix
  6. from typing import Dict, List, Optional
  7. from typing import Sequence as GenericSequence
  8. from typing import Set, Tuple
  9. from loguru import logger
  10. from aphrodite.common.block import BlockTable, PhysicalTokenBlock
  11. from aphrodite.common.sequence import Sequence, SequenceGroup, SequenceStatus
  12. from aphrodite.common.utils import Device
  13. from aphrodite.processing.block.utils import (
  14. check_no_caching_or_swa_for_blockmgr_encdec)
  15. from aphrodite.processing.evictor_v1 import (EvictionPolicy, Evictor,
  16. make_evictor)
  17. from aphrodite.processing.interfaces import AllocStatus, BlockSpaceManager
  18. class BlockAllocatorBase(ABC):
  19. """Manages free physical token blocks for a device.
  20. The allocator maintains a list of free blocks and allocates a block when
  21. requested. When a block is freed, its reference count is decremented. If
  22. the reference count becomes zero, the block is added back to the free list.
  23. """
  24. @abstractmethod
  25. def __init__(self,
  26. device: Device,
  27. block_size: int,
  28. num_blocks: int,
  29. eviction_policy: EvictionPolicy = EvictionPolicy.LRU):
  30. pass
  31. @abstractmethod
  32. def allocate(self,
  33. block_hash: Optional[int] = None,
  34. num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
  35. pass
  36. @abstractmethod
  37. def free(self, block: PhysicalTokenBlock) -> None:
  38. pass
  39. @abstractmethod
  40. def get_num_free_blocks(self) -> int:
  41. pass
  42. @abstractmethod
  43. def get_num_total_blocks(self) -> int:
  44. pass
  45. @abstractmethod
  46. def contains_block(self, block_hash: int) -> bool:
  47. pass
  48. @abstractmethod
  49. def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
  50. pass
  51. class CachedBlockAllocator(BlockAllocatorBase):
  52. """Manages free physical token blocks for a device.
  53. The allocator maintains a list of free blocks and allocates a block when
  54. requested. When a block is freed, its reference count is decremented. If
  55. the reference count becomes zero, the block is added back to the free list.
  56. """
  57. def __init__(self,
  58. device: Device,
  59. block_size: int,
  60. num_blocks: int,
  61. eviction_policy: EvictionPolicy = EvictionPolicy.LRU) -> None:
  62. self.device = device
  63. self.block_size = block_size
  64. self.num_blocks = num_blocks
  65. self.current_num_blocks = 0
  66. self.cached_blocks: Dict[int, PhysicalTokenBlock] = {}
  67. self.evictor: Evictor = make_evictor(eviction_policy)
  68. self.default_hash_ctr = count()
  69. def allocate_block(self, block_hash: int,
  70. num_hashed_tokens: int) -> PhysicalTokenBlock:
  71. if self.current_num_blocks == self.num_blocks:
  72. block = self.evictor.evict()
  73. block.block_hash = block_hash
  74. block.num_hashed_tokens = num_hashed_tokens
  75. return block
  76. block = PhysicalTokenBlock(device=self.device,
  77. block_number=self.current_num_blocks,
  78. block_size=self.block_size,
  79. block_hash=block_hash,
  80. num_hashed_tokens=num_hashed_tokens)
  81. self.current_num_blocks += 1
  82. return block
  83. def allocate(self,
  84. block_hash: Optional[int] = None,
  85. num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
  86. if block_hash is None:
  87. block_hash = next(self.default_hash_ctr)
  88. if block_hash in self.evictor:
  89. assert block_hash not in self.cached_blocks
  90. block = self.evictor.remove(block_hash)
  91. assert block.ref_count == 0
  92. self.cached_blocks[block_hash] = block
  93. block.ref_count += 1
  94. assert block.block_hash == block_hash
  95. return block
  96. if block_hash not in self.cached_blocks:
  97. self.cached_blocks[block_hash] = self.allocate_block(
  98. block_hash, num_hashed_tokens)
  99. block = self.cached_blocks[block_hash]
  100. assert block.block_hash == block_hash
  101. block.ref_count += 1
  102. return block
  103. def free(self, block: PhysicalTokenBlock) -> None:
  104. if block.ref_count == 0:
  105. raise ValueError(f"Double free! {block} is already freed.")
  106. block.ref_count -= 1
  107. if block.ref_count == 0:
  108. assert block.block_hash not in self.evictor
  109. self.evictor.add(block)
  110. # Remove the block from the cached_blocks
  111. del self.cached_blocks[block.block_hash]
  112. def get_num_free_blocks(self) -> int:
  113. return (self.num_blocks - self.current_num_blocks +
  114. self.evictor.num_blocks)
  115. def get_num_total_blocks(self) -> int:
  116. return self.num_blocks
  117. def contains_block(self, block_hash: int) -> bool:
  118. return block_hash in self.cached_blocks or block_hash in self.evictor
  119. def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
  120. # Update the hash of block and the cached_blocks dictionary.
  121. assert not self.contains_block(block_hash)
  122. old_hash = block.block_hash
  123. block.block_hash = block_hash
  124. del self.cached_blocks[old_hash]
  125. self.cached_blocks[block_hash] = block
  126. class UncachedBlockAllocator(BlockAllocatorBase):
  127. """Manages free physical token blocks for a device.
  128. The allocator maintains a list of free blocks and allocates a block when
  129. requested. When a block is freed, its reference count is decremented. If
  130. the reference count becomes zero, the block is added back to the free list.
  131. """
  132. def __init__(
  133. self,
  134. device: Device,
  135. block_size: int,
  136. num_blocks: int,
  137. ) -> None:
  138. self.device = device
  139. self.block_size = block_size
  140. self.num_blocks = num_blocks
  141. # Initialize the free blocks.
  142. self.free_blocks: List[PhysicalTokenBlock] = []
  143. for i in range(num_blocks):
  144. block = PhysicalTokenBlock(device=device,
  145. block_number=i,
  146. block_size=block_size,
  147. block_hash=-1,
  148. num_hashed_tokens=0)
  149. self.free_blocks.append(block)
  150. def allocate(self,
  151. block_hash: Optional[int] = None,
  152. num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
  153. if not self.free_blocks:
  154. raise ValueError("Out of memory! No free blocks are available.")
  155. block = self.free_blocks.pop()
  156. block.ref_count = 1
  157. return block
  158. def free(self, block: PhysicalTokenBlock) -> None:
  159. if block.ref_count == 0:
  160. raise ValueError(f"Double free! {block} is already freed.")
  161. block.ref_count -= 1
  162. if block.ref_count == 0:
  163. self.free_blocks.append(block)
  164. def get_num_free_blocks(self) -> int:
  165. return len(self.free_blocks)
  166. def get_num_total_blocks(self) -> int:
  167. return self.num_blocks
  168. def contains_block(self, block_hash: int) -> bool:
  169. raise NotImplementedError(
  170. "Invalid codepath for uncached block allocator.")
  171. def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
  172. raise NotImplementedError(
  173. "Invalid codepath for uncached block allocator.")
  174. class BlockSpaceManagerV1(BlockSpaceManager):
  175. """Manages the mapping between logical and physical token blocks."""
  176. def __init__(
  177. self,
  178. block_size: int,
  179. num_gpu_blocks: int,
  180. num_cpu_blocks: int,
  181. watermark: float = 0.01,
  182. sliding_window: Optional[int] = None,
  183. enable_caching: bool = False,
  184. ) -> None:
  185. self.block_size = block_size
  186. self.num_total_gpu_blocks = num_gpu_blocks
  187. self.num_total_cpu_blocks = num_cpu_blocks
  188. if enable_caching and sliding_window is not None:
  189. raise NotImplementedError(
  190. "Sliding window is not allowed with prefix caching enabled!")
  191. self.block_sliding_window = None
  192. if sliding_window is not None:
  193. # Round up to nearest block size to regularize sliding window
  194. # allocation sizes.
  195. self.block_sliding_window = math.ceil(sliding_window / block_size)
  196. self.watermark = watermark
  197. assert watermark >= 0.0
  198. self.enable_caching = enable_caching
  199. self.watermark_blocks = int(watermark * num_gpu_blocks)
  200. if self.enable_caching:
  201. logger.info("Automatic prefix caching is enabled.")
  202. self.gpu_allocator: BlockAllocatorBase = CachedBlockAllocator(
  203. Device.GPU, block_size, num_gpu_blocks)
  204. self.cpu_allocator: BlockAllocatorBase = CachedBlockAllocator(
  205. Device.CPU, block_size, num_cpu_blocks)
  206. else:
  207. self.gpu_allocator = UncachedBlockAllocator(
  208. Device.GPU, block_size, num_gpu_blocks)
  209. self.cpu_allocator = UncachedBlockAllocator(
  210. Device.CPU, block_size, num_cpu_blocks)
  211. # Mapping: seq_id -> BlockTable.
  212. self.block_tables: Dict[int, BlockTable] = {}
  213. # Mapping: req_id -> BlockTable
  214. # Note that each SequenceGroup has a unique
  215. # request ID
  216. self.cross_block_tables: Dict[str, BlockTable] = {}
  217. def _get_seq_num_required_blocks(self, seq: Sequence) -> int:
  218. return 0 if seq is None else seq.n_blocks
  219. def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
  220. # FIXME: Here we assume that all sequences in the group share
  221. # the same prompt. This may not be true for preempted sequences.
  222. check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group)
  223. self_num_required_blocks = self._get_seq_num_required_blocks(
  224. seq_group.get_seqs(status=SequenceStatus.WAITING)[0])
  225. cross_num_required_blocks = self._get_seq_num_required_blocks(
  226. seq_group.get_encoder_seq())
  227. num_required_blocks = self_num_required_blocks + \
  228. cross_num_required_blocks
  229. if self.block_sliding_window is not None:
  230. num_required_blocks = min(num_required_blocks,
  231. self.block_sliding_window)
  232. num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
  233. # Use watermark to avoid frequent cache eviction.
  234. if (self.num_total_gpu_blocks - num_required_blocks <
  235. self.watermark_blocks):
  236. return AllocStatus.NEVER
  237. if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks:
  238. return AllocStatus.OK
  239. else:
  240. return AllocStatus.LATER
  241. def _allocate_sequence(self, \
  242. seq: Sequence, \
  243. ref_count: int, \
  244. is_encoder_decoder: bool = True) -> BlockTable:
  245. # Allocate new physical token blocks that will store the prompt tokens.
  246. num_prompt_blocks = seq.n_blocks
  247. block_table: BlockTable = BlockTable()
  248. for logical_idx in range(num_prompt_blocks):
  249. if (self.block_sliding_window is not None
  250. and logical_idx >= self.block_sliding_window):
  251. block = block_table[logical_idx % self.block_sliding_window]
  252. # Set the reference counts of the token blocks.
  253. block.ref_count = ref_count
  254. elif not is_encoder_decoder and self.enable_caching:
  255. block = self.gpu_allocator.allocate(
  256. seq.hash_of_block(logical_idx),
  257. seq.num_hashed_tokens_of_block(logical_idx))
  258. else:
  259. block = self.gpu_allocator.allocate()
  260. # Set the reference counts of the token blocks.
  261. block.ref_count = ref_count
  262. block_table.append(block)
  263. return block_table
  264. def allocate(self, seq_group: SequenceGroup) -> None:
  265. is_encoder_decoder = seq_group.is_encoder_decoder()
  266. check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group)
  267. # Allocate decoder sequences
  268. #
  269. # NOTE: Here we assume that all sequences in the group have the same
  270. # decoder prompt.
  271. wait_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
  272. seq = wait_seqs[0]
  273. block_table: BlockTable = \
  274. self._allocate_sequence(seq,
  275. seq_group.num_seqs(),
  276. is_encoder_decoder)
  277. # Assign the self-attention block tables for each sequence.
  278. if len(wait_seqs) == 1:
  279. self.block_tables[seq.seq_id] = block_table
  280. else:
  281. for seq in wait_seqs:
  282. self.block_tables[seq.seq_id] = block_table.copy()
  283. # Allocate encoder sequence
  284. if is_encoder_decoder:
  285. # A SequenceGroup has only a single encoder sequence (at most),
  286. # thus allocate with a ref count of 1
  287. block_table = self._allocate_sequence(seq_group.get_encoder_seq(),
  288. 1, is_encoder_decoder)
  289. # Assign the cross-attention block table for the SequenceGroup.
  290. self.cross_block_tables[seq_group.request_id] = block_table
  291. def can_append_slots(self,
  292. seq_group: SequenceGroup,
  293. num_lookahead_slots: int = 0) -> bool:
  294. assert (num_lookahead_slots == 0
  295. ), "lookahead allocation not supported in BlockSpaceManagerV1"
  296. # Simple heuristic: If there is at least one free block
  297. # for each sequence, we can append.
  298. num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
  299. num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING)
  300. return num_seqs <= num_free_gpu_blocks
  301. def _promote_last_block(
  302. self,
  303. seq: Sequence,
  304. last_block: PhysicalTokenBlock,
  305. ) -> PhysicalTokenBlock:
  306. assert self.enable_caching
  307. # Compute a new hash for the block so that it can be shared by other
  308. # Sequences
  309. new_hash = seq.hash_of_block(seq.n_blocks - 1)
  310. # if new_hash is already in the cached table, then free last_block
  311. # and return the cached version
  312. if self.gpu_allocator.contains_block(new_hash):
  313. self.gpu_allocator.free(last_block)
  314. return self.gpu_allocator.allocate(new_hash)
  315. else:
  316. self.gpu_allocator.update_hash(new_hash, last_block)
  317. return last_block
  318. def _is_last_block_full(
  319. self,
  320. seq: Sequence,
  321. ) -> bool:
  322. token_ids_len = seq.data.get_len()
  323. return token_ids_len > 0 and token_ids_len % seq.block_size == 0
  324. def _maybe_promote_last_block(
  325. self,
  326. seq: Sequence,
  327. last_block: PhysicalTokenBlock,
  328. ) -> PhysicalTokenBlock:
  329. if self._is_last_block_full(seq):
  330. return self._promote_last_block(seq, last_block)
  331. else:
  332. return last_block
  333. def _allocate_last_physical_block(
  334. self,
  335. seq: Sequence,
  336. ) -> PhysicalTokenBlock:
  337. # Called before a new block is appended.
  338. # This is in charge of allocating a new physical block (to be appended).
  339. # None if the last block is not full. Otherwise, we set it to the
  340. # content hash.
  341. if not self.enable_caching:
  342. return self.gpu_allocator.allocate()
  343. block_hash: Optional[int] = None
  344. n_blocks = seq.n_blocks
  345. if (self._is_last_block_full(seq)):
  346. block_hash = seq.hash_of_block(n_blocks - 1)
  347. num_hashed_tokens = seq.num_hashed_tokens_of_block(n_blocks - 1)
  348. # num_hashed_tokens is used to compute future hashes
  349. # (e.g. in the hashing function, it is used to ask the sequence for
  350. # prefix tokens)
  351. new_block = self.gpu_allocator.allocate(block_hash, num_hashed_tokens)
  352. # If the block has is None, then the block is not full.
  353. # If the block is not full, then we expect it to have a refcount of 1.
  354. if block_hash is None:
  355. assert new_block.ref_count == 1
  356. return new_block
  357. def append_slots(
  358. self,
  359. seq: Sequence,
  360. num_lookahead_slots: int = 0,
  361. ) -> List[Tuple[int, int]]:
  362. """Allocate a physical slot for a new token."""
  363. n_blocks = seq.n_blocks
  364. block_table = self.block_tables[seq.seq_id]
  365. # If we need to allocate a new physical block
  366. if len(block_table) < n_blocks:
  367. # Currently this code only supports adding one physical block
  368. assert len(block_table) == n_blocks - 1
  369. if (self.block_sliding_window
  370. and len(block_table) >= self.block_sliding_window):
  371. # reuse a block
  372. block_table.append(block_table[len(block_table) %
  373. self.block_sliding_window])
  374. else:
  375. # The sequence hash a new logical block.
  376. # Allocate a new physical block.
  377. new_block = self._allocate_last_physical_block(seq)
  378. block_table.append(new_block)
  379. return []
  380. # We want to append the token to the last physical block.
  381. last_block = block_table[-1]
  382. assert last_block.device == Device.GPU
  383. if last_block.ref_count == 1:
  384. # Not shared with other sequences. Appendable.
  385. if self.enable_caching:
  386. # If the last block is now complete, we may reuse an old block
  387. # to save memory.
  388. maybe_new_block = self._maybe_promote_last_block(
  389. seq, last_block)
  390. block_table[-1] = maybe_new_block
  391. return []
  392. else:
  393. # The last block is shared with other sequences.
  394. # Copy on Write: Allocate a new block and copy the tokens.
  395. new_block = self._allocate_last_physical_block(seq)
  396. block_table[-1] = new_block
  397. self.gpu_allocator.free(last_block)
  398. return [(last_block.block_number, new_block.block_number)]
  399. def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
  400. # NOTE: fork does not allocate a new physical block.
  401. # Thus, it is always safe from OOM.
  402. if parent_seq.seq_id not in self.block_tables:
  403. # Parent sequence has either been freed or never existed.
  404. return
  405. src_block_table = self.block_tables[parent_seq.seq_id]
  406. self.block_tables[child_seq.seq_id] = src_block_table.copy()
  407. # When using a sliding window, blocks will be eventually reused.
  408. # In this case the block tables will contain repeated blocks.
  409. # When forking, we must make sure that each block's `ref_count`
  410. # is only incremented by one, so we deduplicate them by wrapping
  411. # them in a set.
  412. for block in set(src_block_table):
  413. block.ref_count += 1
  414. def _get_physical_blocks(
  415. self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]:
  416. # NOTE: Here, we assume that the physical blocks are only shared by
  417. # the sequences in the same group.
  418. request_id = seq_group.request_id
  419. blocks: Set[PhysicalTokenBlock] = set()
  420. for seq in seq_group.get_seqs():
  421. if seq.is_finished():
  422. continue
  423. blocks.update(self.block_tables[seq.seq_id])
  424. # Cross-attention blocks
  425. if seq_group.is_encoder_decoder():
  426. blocks.update(self.cross_block_tables[request_id])
  427. return list(blocks)
  428. def can_swap_in(self,
  429. seq_group: SequenceGroup,
  430. num_lookahead_slots: int = 0) -> AllocStatus:
  431. assert (num_lookahead_slots == 0
  432. ), "BlockSpaceManagerV1 does not support lookahead allocation"
  433. blocks = self._get_physical_blocks(seq_group)
  434. num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
  435. if seq_group.is_encoder_decoder():
  436. num_swapped_seqs += 1
  437. num_free_blocks = self.gpu_allocator.get_num_free_blocks()
  438. # NOTE: Conservatively, we assume that every sequence will allocate
  439. # at least one free block right after the swap-in.
  440. # NOTE: This should match the logic in can_append_slot().
  441. num_required_blocks = len(blocks) + num_swapped_seqs
  442. if self.gpu_allocator.get_num_total_blocks() < num_required_blocks:
  443. return AllocStatus.NEVER
  444. elif num_free_blocks - num_required_blocks >= self.watermark_blocks:
  445. return AllocStatus.OK
  446. else:
  447. return AllocStatus.LATER
  448. def _swap_block_table(
  449. self, block_table: BlockTable, src_allocator: BlockAllocatorBase,
  450. dest_allocator: BlockAllocatorBase,
  451. mapping: Dict[PhysicalTokenBlock,
  452. PhysicalTokenBlock]) -> BlockTable:
  453. new_block_table: BlockTable = BlockTable()
  454. for from_block in block_table:
  455. if from_block in mapping:
  456. to_block = mapping[from_block]
  457. to_block.ref_count += 1
  458. else:
  459. to_block = dest_allocator.allocate(
  460. from_block.block_hash, from_block.num_hashed_tokens)
  461. mapping[from_block] = to_block
  462. new_block_table.append(to_block)
  463. # Free the source block swapped in to destination.
  464. src_allocator.free(from_block)
  465. return new_block_table
  466. def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
  467. request_id = seq_group.request_id
  468. # CPU block -> GPU block.
  469. # dict is efficient in lookup `if cpu_block in mapping`
  470. mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
  471. for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
  472. self.block_tables[seq.seq_id] = \
  473. self._swap_block_table(self.block_tables[seq.seq_id],
  474. self.cpu_allocator, self.gpu_allocator,
  475. mapping)
  476. if seq_group.is_encoder_decoder():
  477. self.cross_block_tables[request_id] = \
  478. self._swap_block_table(self.cross_block_tables[request_id],
  479. self.cpu_allocator,
  480. self.gpu_allocator,
  481. mapping)
  482. return [(cpu_block.block_number, gpu_block.block_number)
  483. for cpu_block, gpu_block in mapping.items()]
  484. def can_swap_out(self, seq_group: SequenceGroup) -> bool:
  485. blocks = self._get_physical_blocks(seq_group)
  486. return len(blocks) <= self.cpu_allocator.get_num_free_blocks()
  487. def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
  488. request_id = seq_group.request_id
  489. # GPU block -> CPU block.
  490. # dict is efficient in lookup `if gpu_block in mapping`
  491. mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
  492. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  493. self.block_tables[seq.seq_id] = \
  494. self._swap_block_table(self.block_tables[seq.seq_id],
  495. self.gpu_allocator, self.cpu_allocator,
  496. mapping)
  497. if seq_group.is_encoder_decoder():
  498. self.cross_block_tables[request_id] = \
  499. self._swap_block_table(self.cross_block_tables[request_id],
  500. self.gpu_allocator,
  501. self.cpu_allocator,
  502. mapping)
  503. return [(cpu_block.block_number, gpu_block.block_number)
  504. for cpu_block, gpu_block in mapping.items()]
  505. def _free_block_table(self, block_table: BlockTable) -> None:
  506. # when using a sliding window, each seq will only use up
  507. # to `self.block_sliding_window` blocks. When freeing
  508. # the block table, we must make sure to not free blocks more
  509. # than once. If no sliding window is used, there is no block
  510. # reuse in the block table, so we must free all blocks.
  511. blocks_to_free = (block_table[-self.block_sliding_window:]
  512. if self.block_sliding_window is not None else
  513. block_table)
  514. for block in set(blocks_to_free):
  515. if block.device == Device.GPU:
  516. self.gpu_allocator.free(block)
  517. else:
  518. self.cpu_allocator.free(block)
  519. def free(self, seq: Sequence) -> None:
  520. if seq.seq_id not in self.block_tables:
  521. # Already freed or haven't been scheduled yet.
  522. return
  523. block_table = self.block_tables[seq.seq_id]
  524. self._free_block_table(block_table)
  525. del self.block_tables[seq.seq_id]
  526. def free_cross(self, seq_group: SequenceGroup) -> None:
  527. if seq_group.request_id not in self.cross_block_tables:
  528. # Already freed or hasn't ben scheduled yet.
  529. return
  530. block_table = self.cross_block_tables[seq_group.request_id]
  531. self._free_block_table(block_table)
  532. del self.cross_block_tables[seq_group.request_id]
  533. def reset(self) -> None:
  534. # Free decoder block tables
  535. for block_table in self.block_tables.values():
  536. self._free_block_table(block_table)
  537. self.block_tables.clear()
  538. # Free cross-attention block tables
  539. for block_table in self.cross_block_tables.values():
  540. self._free_block_table(block_table)
  541. self.cross_block_tables.clear()
  542. def get_block_table(self, seq: Sequence) -> List[int]:
  543. return self.block_tables[seq.seq_id].ids()
  544. def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]:
  545. block_table = self.cross_block_tables[seq_group.request_id]
  546. return [block.block_number for block in block_table]
  547. def get_num_free_gpu_blocks(self) -> int:
  548. return self.gpu_allocator.get_num_free_blocks()
  549. def get_num_free_cpu_blocks(self) -> int:
  550. return self.cpu_allocator.get_num_free_blocks()
  551. def access_all_blocks_in_seq(
  552. self,
  553. seq: Sequence,
  554. access_time: float,
  555. ) -> None:
  556. if self.enable_caching:
  557. # Update the last accessed time of all the blocks accessed
  558. # in this step.
  559. block_table = self.block_tables[seq.seq_id]
  560. for block in block_table:
  561. block.last_accessed = access_time
  562. def compute_full_blocks_in_seq(self, seq: Sequence):
  563. if seq.seq_id not in self.block_tables:
  564. return
  565. max_full_block = seq.get_len() // self.block_size - 1
  566. block_table = self.block_tables[seq.seq_id]
  567. if max_full_block == -1:
  568. return
  569. for i in reversed(range(max_full_block)):
  570. if block_table[i].computed:
  571. break
  572. block_table[i].computed = True
  573. def get_all_computed_blocks(self, seq: Sequence) -> List[int]:
  574. if seq.seq_id not in self.block_tables:
  575. return []
  576. block_table = self.block_tables[seq.seq_id]
  577. # NOTE We exclude the last block to avoid the case where the entire
  578. # prompt is cached. This would cause erroneous behavior in model
  579. # runner.
  580. return [
  581. b.block_number
  582. for b in takewhile(lambda b: b.computed, block_table[:-1])
  583. ]
  584. def get_common_computed_block_ids(
  585. self, seqs: List[Sequence]) -> GenericSequence[int]:
  586. """Return the block ids that are common for a given sequence group.
  587. Used in prefill (can skip prefill of some blocks).
  588. """
  589. # Can return non-empty result only with prefix caching enabled.
  590. if not self.enable_caching:
  591. return []
  592. ids_list = [self.get_all_computed_blocks(seq) for seq in seqs]
  593. return commonprefix([ids for ids in ids_list if ids != []])
  594. def mark_blocks_as_computed(self, seq_group: SequenceGroup):
  595. if self.enable_caching:
  596. for seq in seq_group.get_seqs():
  597. self.compute_full_blocks_in_seq(seq)