block_manager_v1.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736
  1. """A block manager that manages token blocks."""
  2. import math
  3. from abc import ABC, abstractmethod
  4. from itertools import count, takewhile
  5. from os.path import commonprefix
  6. from typing import Dict, List, Optional
  7. from typing import Sequence as GenericSequence
  8. from typing import Set, Tuple
  9. from loguru import logger
  10. from aphrodite.common.block import BlockTable, PhysicalTokenBlock
  11. from aphrodite.common.sequence import Sequence, SequenceGroup, SequenceStatus
  12. from aphrodite.common.utils import Device
  13. from aphrodite.processing.block.common import CacheMetricData
  14. from aphrodite.processing.block.utils import (
  15. check_no_caching_or_swa_for_blockmgr_encdec)
  16. from aphrodite.processing.evictor_v1 import (EvictionPolicy, Evictor,
  17. make_evictor)
  18. from aphrodite.processing.interfaces import AllocStatus, BlockSpaceManager
  19. class BlockAllocatorBase(ABC):
  20. """Manages free physical token blocks for a device.
  21. The allocator maintains a list of free blocks and allocates a block when
  22. requested. When a block is freed, its reference count is decremented. If
  23. the reference count becomes zero, the block is added back to the free list.
  24. """
  25. @abstractmethod
  26. def __init__(self,
  27. device: Device,
  28. block_size: int,
  29. num_blocks: int,
  30. eviction_policy: EvictionPolicy = EvictionPolicy.LRU):
  31. pass
  32. @abstractmethod
  33. def allocate(self,
  34. block_hash: Optional[int] = None,
  35. num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
  36. pass
  37. @abstractmethod
  38. def free(self, block: PhysicalTokenBlock) -> None:
  39. pass
  40. @abstractmethod
  41. def get_num_free_blocks(self) -> int:
  42. pass
  43. @abstractmethod
  44. def get_num_total_blocks(self) -> int:
  45. pass
  46. @abstractmethod
  47. def contains_block(self, block_hash: int) -> bool:
  48. pass
  49. @abstractmethod
  50. def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
  51. pass
  52. @abstractmethod
  53. def get_prefix_cache_hit_rate(self) -> float:
  54. """Prefix cache hit rate. -1 means not supported or disabled."""
  55. pass
  56. class CachedBlockAllocator(BlockAllocatorBase):
  57. """Manages free physical token blocks for a device.
  58. The allocator maintains a list of free blocks and allocates a block when
  59. requested. When a block is freed, its reference count is decremented. If
  60. the reference count becomes zero, the block is added back to the free list.
  61. """
  62. def __init__(self,
  63. device: Device,
  64. block_size: int,
  65. num_blocks: int,
  66. eviction_policy: EvictionPolicy = EvictionPolicy.LRU) -> None:
  67. self.device = device
  68. self.block_size = block_size
  69. self.num_blocks = num_blocks
  70. self.current_num_blocks = 0
  71. self.cached_blocks: Dict[int, PhysicalTokenBlock] = {}
  72. self.evictor: Evictor = make_evictor(eviction_policy)
  73. self.default_hash_ctr = count()
  74. self.cache_metric_data = CacheMetricData()
  75. def allocate_block(self, block_hash: int,
  76. num_hashed_tokens: int) -> PhysicalTokenBlock:
  77. if self.current_num_blocks == self.num_blocks:
  78. block = self.evictor.evict()
  79. block.block_hash = block_hash
  80. block.num_hashed_tokens = num_hashed_tokens
  81. return block
  82. block = PhysicalTokenBlock(device=self.device,
  83. block_number=self.current_num_blocks,
  84. block_size=self.block_size,
  85. block_hash=block_hash,
  86. num_hashed_tokens=num_hashed_tokens)
  87. self.current_num_blocks += 1
  88. return block
  89. def allocate(self,
  90. block_hash: Optional[int] = None,
  91. num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
  92. if block_hash is None:
  93. block_hash = next(self.default_hash_ctr)
  94. if block_hash in self.evictor:
  95. assert block_hash not in self.cached_blocks
  96. block = self.evictor.remove(block_hash)
  97. assert block.ref_count == 0
  98. self.cached_blocks[block_hash] = block
  99. if block_hash in self.cached_blocks:
  100. self.cache_metric_data.query(hit=True)
  101. else:
  102. self.cache_metric_data.query(hit=False)
  103. self.cached_blocks[block_hash] = self.allocate_block(
  104. block_hash, num_hashed_tokens)
  105. block = self.cached_blocks[block_hash]
  106. assert block.block_hash == block_hash
  107. block.ref_count += 1
  108. return block
  109. def free(self, block: PhysicalTokenBlock) -> None:
  110. if block.ref_count == 0:
  111. raise ValueError(f"Double free! {block} is already freed.")
  112. block.ref_count -= 1
  113. if block.ref_count == 0:
  114. assert block.block_hash not in self.evictor
  115. self.evictor.add(block)
  116. # Remove the block from the cached_blocks
  117. del self.cached_blocks[block.block_hash]
  118. def get_num_free_blocks(self) -> int:
  119. return (self.num_blocks - self.current_num_blocks +
  120. self.evictor.num_blocks)
  121. def get_num_total_blocks(self) -> int:
  122. return self.num_blocks
  123. def contains_block(self, block_hash: int) -> bool:
  124. return block_hash in self.cached_blocks or block_hash in self.evictor
  125. def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
  126. # Update the hash of block and the cached_blocks dictionary.
  127. assert not self.contains_block(block_hash)
  128. old_hash = block.block_hash
  129. block.block_hash = block_hash
  130. del self.cached_blocks[old_hash]
  131. self.cached_blocks[block_hash] = block
  132. def get_prefix_cache_hit_rate(self) -> float:
  133. return self.cache_metric_data.get_hit_rate()
  134. class UncachedBlockAllocator(BlockAllocatorBase):
  135. """Manages free physical token blocks for a device.
  136. The allocator maintains a list of free blocks and allocates a block when
  137. requested. When a block is freed, its reference count is decremented. If
  138. the reference count becomes zero, the block is added back to the free list.
  139. """
  140. def __init__(
  141. self,
  142. device: Device,
  143. block_size: int,
  144. num_blocks: int,
  145. ) -> None:
  146. self.device = device
  147. self.block_size = block_size
  148. self.num_blocks = num_blocks
  149. # Initialize the free blocks.
  150. self.free_blocks: List[PhysicalTokenBlock] = []
  151. for i in range(num_blocks):
  152. block = PhysicalTokenBlock(device=device,
  153. block_number=i,
  154. block_size=block_size,
  155. block_hash=-1,
  156. num_hashed_tokens=0)
  157. self.free_blocks.append(block)
  158. def allocate(self,
  159. block_hash: Optional[int] = None,
  160. num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
  161. if not self.free_blocks:
  162. raise ValueError("Out of memory! No free blocks are available.")
  163. block = self.free_blocks.pop()
  164. block.ref_count = 1
  165. return block
  166. def free(self, block: PhysicalTokenBlock) -> None:
  167. if block.ref_count == 0:
  168. raise ValueError(f"Double free! {block} is already freed.")
  169. block.ref_count -= 1
  170. if block.ref_count == 0:
  171. self.free_blocks.append(block)
  172. def get_num_free_blocks(self) -> int:
  173. return len(self.free_blocks)
  174. def get_num_total_blocks(self) -> int:
  175. return self.num_blocks
  176. def contains_block(self, block_hash: int) -> bool:
  177. raise NotImplementedError(
  178. "Invalid codepath for uncached block allocator.")
  179. def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
  180. raise NotImplementedError(
  181. "Invalid codepath for uncached block allocator.")
  182. def get_prefix_cache_hit_rate(self) -> float:
  183. return -1
  184. class BlockSpaceManagerV1(BlockSpaceManager):
  185. """Manages the mapping between logical and physical token blocks."""
  186. def __init__(
  187. self,
  188. block_size: int,
  189. num_gpu_blocks: int,
  190. num_cpu_blocks: int,
  191. watermark: float = 0.01,
  192. sliding_window: Optional[int] = None,
  193. enable_caching: bool = False,
  194. ) -> None:
  195. self.block_size = block_size
  196. self.num_total_gpu_blocks = num_gpu_blocks
  197. self.num_total_cpu_blocks = num_cpu_blocks
  198. if enable_caching and sliding_window is not None:
  199. raise NotImplementedError(
  200. "Sliding window is not allowed with prefix caching enabled!")
  201. self.block_sliding_window = None
  202. if sliding_window is not None:
  203. # Round up to nearest block size to regularize sliding window
  204. # allocation sizes.
  205. self.block_sliding_window = math.ceil(sliding_window / block_size)
  206. self.watermark = watermark
  207. assert watermark >= 0.0
  208. self.enable_caching = enable_caching
  209. self.watermark_blocks = int(watermark * num_gpu_blocks)
  210. if self.enable_caching:
  211. logger.info("Automatic prefix caching is enabled.")
  212. self.gpu_allocator: BlockAllocatorBase = CachedBlockAllocator(
  213. Device.GPU, block_size, num_gpu_blocks)
  214. self.cpu_allocator: BlockAllocatorBase = CachedBlockAllocator(
  215. Device.CPU, block_size, num_cpu_blocks)
  216. else:
  217. self.gpu_allocator = UncachedBlockAllocator(
  218. Device.GPU, block_size, num_gpu_blocks)
  219. self.cpu_allocator = UncachedBlockAllocator(
  220. Device.CPU, block_size, num_cpu_blocks)
  221. # Mapping: seq_id -> BlockTable.
  222. self.block_tables: Dict[int, BlockTable] = {}
  223. # Mapping: req_id -> BlockTable
  224. # Note that each SequenceGroup has a unique
  225. # request ID
  226. self.cross_block_tables: Dict[str, BlockTable] = {}
  227. def _get_seq_num_required_blocks(self, seq: Sequence) -> int:
  228. return 0 if seq is None else seq.n_blocks
  229. def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
  230. # FIXME: Here we assume that all sequences in the group share
  231. # the same prompt. This may not be true for preempted sequences.
  232. check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group)
  233. self_num_required_blocks = self._get_seq_num_required_blocks(
  234. seq_group.get_seqs(status=SequenceStatus.WAITING)[0])
  235. cross_num_required_blocks = self._get_seq_num_required_blocks(
  236. seq_group.get_encoder_seq())
  237. num_required_blocks = self_num_required_blocks + \
  238. cross_num_required_blocks
  239. if self.block_sliding_window is not None:
  240. num_required_blocks = min(num_required_blocks,
  241. self.block_sliding_window)
  242. num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
  243. # Use watermark to avoid frequent cache eviction.
  244. if (self.num_total_gpu_blocks - num_required_blocks <
  245. self.watermark_blocks):
  246. return AllocStatus.NEVER
  247. if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks:
  248. return AllocStatus.OK
  249. else:
  250. return AllocStatus.LATER
  251. def _allocate_sequence(self, \
  252. seq: Sequence, \
  253. ref_count: int, \
  254. is_encoder_decoder: bool = True) -> BlockTable:
  255. # Allocate new physical token blocks that will store the prompt tokens.
  256. num_prompt_blocks = seq.n_blocks
  257. block_table: BlockTable = BlockTable()
  258. for logical_idx in range(num_prompt_blocks):
  259. if (self.block_sliding_window is not None
  260. and logical_idx >= self.block_sliding_window):
  261. block = block_table[logical_idx % self.block_sliding_window]
  262. # Set the reference counts of the token blocks.
  263. block.ref_count = ref_count
  264. elif not is_encoder_decoder and self.enable_caching:
  265. block = self.gpu_allocator.allocate(
  266. seq.hash_of_block(logical_idx),
  267. seq.num_hashed_tokens_of_block(logical_idx))
  268. else:
  269. block = self.gpu_allocator.allocate()
  270. # Set the reference counts of the token blocks.
  271. block.ref_count = ref_count
  272. block_table.append(block)
  273. return block_table
  274. def allocate(self, seq_group: SequenceGroup) -> None:
  275. is_encoder_decoder = seq_group.is_encoder_decoder()
  276. check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group)
  277. # Allocate decoder sequences
  278. #
  279. # NOTE: Here we assume that all sequences in the group have the same
  280. # decoder prompt.
  281. wait_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
  282. seq = wait_seqs[0]
  283. block_table: BlockTable = \
  284. self._allocate_sequence(seq,
  285. seq_group.num_seqs(),
  286. is_encoder_decoder)
  287. # Assign the self-attention block tables for each sequence.
  288. if len(wait_seqs) == 1:
  289. self.block_tables[seq.seq_id] = block_table
  290. else:
  291. for seq in wait_seqs:
  292. self.block_tables[seq.seq_id] = block_table.copy()
  293. # Allocate encoder sequence
  294. if is_encoder_decoder:
  295. # A SequenceGroup has only a single encoder sequence (at most),
  296. # thus allocate with a ref count of 1
  297. block_table = self._allocate_sequence(seq_group.get_encoder_seq(),
  298. 1, is_encoder_decoder)
  299. # Assign the cross-attention block table for the SequenceGroup.
  300. self.cross_block_tables[seq_group.request_id] = block_table
  301. def can_append_slots(self,
  302. seq_group: SequenceGroup,
  303. num_lookahead_slots: int = 0) -> bool:
  304. assert (num_lookahead_slots == 0
  305. ), "lookahead allocation not supported in BlockSpaceManagerV1"
  306. # Simple heuristic: If there is at least one free block
  307. # for each sequence, we can append.
  308. num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
  309. num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING)
  310. return num_seqs <= num_free_gpu_blocks
  311. def _promote_last_block(
  312. self,
  313. seq: Sequence,
  314. last_block: PhysicalTokenBlock,
  315. ) -> PhysicalTokenBlock:
  316. assert self.enable_caching
  317. # Compute a new hash for the block so that it can be shared by other
  318. # Sequences
  319. new_hash = seq.hash_of_block(seq.n_blocks - 1)
  320. # if new_hash is already in the cached table, then free last_block
  321. # and return the cached version
  322. if self.gpu_allocator.contains_block(new_hash):
  323. self.gpu_allocator.free(last_block)
  324. return self.gpu_allocator.allocate(new_hash)
  325. else:
  326. self.gpu_allocator.update_hash(new_hash, last_block)
  327. return last_block
  328. def _is_last_block_full(
  329. self,
  330. seq: Sequence,
  331. ) -> bool:
  332. token_ids_len = seq.data.get_len()
  333. return token_ids_len > 0 and token_ids_len % seq.block_size == 0
  334. def _maybe_promote_last_block(
  335. self,
  336. seq: Sequence,
  337. last_block: PhysicalTokenBlock,
  338. ) -> PhysicalTokenBlock:
  339. if self._is_last_block_full(seq):
  340. return self._promote_last_block(seq, last_block)
  341. else:
  342. return last_block
  343. def _allocate_last_physical_block(
  344. self,
  345. seq: Sequence,
  346. ) -> PhysicalTokenBlock:
  347. # Called before a new block is appended.
  348. # This is in charge of allocating a new physical block (to be appended).
  349. # None if the last block is not full. Otherwise, we set it to the
  350. # content hash.
  351. if not self.enable_caching:
  352. return self.gpu_allocator.allocate()
  353. block_hash: Optional[int] = None
  354. n_blocks = seq.n_blocks
  355. if (self._is_last_block_full(seq)):
  356. block_hash = seq.hash_of_block(n_blocks - 1)
  357. num_hashed_tokens = seq.num_hashed_tokens_of_block(n_blocks - 1)
  358. # num_hashed_tokens is used to compute future hashes
  359. # (e.g. in the hashing function, it is used to ask the sequence for
  360. # prefix tokens)
  361. new_block = self.gpu_allocator.allocate(block_hash, num_hashed_tokens)
  362. # If the block has is None, then the block is not full.
  363. # If the block is not full, then we expect it to have a refcount of 1.
  364. if block_hash is None:
  365. assert new_block.ref_count == 1
  366. return new_block
  367. def append_slots(
  368. self,
  369. seq: Sequence,
  370. num_lookahead_slots: int = 0,
  371. ) -> List[Tuple[int, int]]:
  372. """Allocate a physical slot for a new token."""
  373. n_blocks = seq.n_blocks
  374. block_table = self.block_tables[seq.seq_id]
  375. # If we need to allocate a new physical block
  376. if len(block_table) < n_blocks:
  377. # Currently this code only supports adding one physical block
  378. assert len(block_table) == n_blocks - 1
  379. if (self.block_sliding_window
  380. and len(block_table) >= self.block_sliding_window):
  381. # reuse a block
  382. block_table.append(block_table[len(block_table) %
  383. self.block_sliding_window])
  384. else:
  385. # The sequence hash a new logical block.
  386. # Allocate a new physical block.
  387. new_block = self._allocate_last_physical_block(seq)
  388. block_table.append(new_block)
  389. return []
  390. # We want to append the token to the last physical block.
  391. last_block = block_table[-1]
  392. assert last_block.device == Device.GPU
  393. if last_block.ref_count == 1:
  394. # Not shared with other sequences. Appendable.
  395. if self.enable_caching:
  396. # If the last block is now complete, we may reuse an old block
  397. # to save memory.
  398. maybe_new_block = self._maybe_promote_last_block(
  399. seq, last_block)
  400. block_table[-1] = maybe_new_block
  401. return []
  402. else:
  403. # The last block is shared with other sequences.
  404. # Copy on Write: Allocate a new block and copy the tokens.
  405. new_block = self._allocate_last_physical_block(seq)
  406. block_table[-1] = new_block
  407. self.gpu_allocator.free(last_block)
  408. return [(last_block.block_number, new_block.block_number)]
  409. def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
  410. # NOTE: fork does not allocate a new physical block.
  411. # Thus, it is always safe from OOM.
  412. if parent_seq.seq_id not in self.block_tables:
  413. # Parent sequence has either been freed or never existed.
  414. return
  415. src_block_table = self.block_tables[parent_seq.seq_id]
  416. self.block_tables[child_seq.seq_id] = src_block_table.copy()
  417. # When using a sliding window, blocks will be eventually reused.
  418. # In this case the block tables will contain repeated blocks.
  419. # When forking, we must make sure that each block's `ref_count`
  420. # is only incremented by one, so we deduplicate them by wrapping
  421. # them in a set.
  422. for block in set(src_block_table):
  423. block.ref_count += 1
  424. def _get_physical_blocks(
  425. self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]:
  426. # NOTE: Here, we assume that the physical blocks are only shared by
  427. # the sequences in the same group.
  428. request_id = seq_group.request_id
  429. blocks: Set[PhysicalTokenBlock] = set()
  430. for seq in seq_group.get_seqs():
  431. if seq.is_finished():
  432. continue
  433. blocks.update(self.block_tables[seq.seq_id])
  434. # Cross-attention blocks
  435. if seq_group.is_encoder_decoder():
  436. blocks.update(self.cross_block_tables[request_id])
  437. return list(blocks)
  438. def can_swap_in(self,
  439. seq_group: SequenceGroup,
  440. num_lookahead_slots: int = 0) -> AllocStatus:
  441. assert (num_lookahead_slots == 0
  442. ), "BlockSpaceManagerV1 does not support lookahead allocation"
  443. blocks = self._get_physical_blocks(seq_group)
  444. num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
  445. if seq_group.is_encoder_decoder():
  446. num_swapped_seqs += 1
  447. num_free_blocks = self.gpu_allocator.get_num_free_blocks()
  448. # NOTE: Conservatively, we assume that every sequence will allocate
  449. # at least one free block right after the swap-in.
  450. # NOTE: This should match the logic in can_append_slot().
  451. num_required_blocks = len(blocks) + num_swapped_seqs
  452. if self.gpu_allocator.get_num_total_blocks() < num_required_blocks:
  453. return AllocStatus.NEVER
  454. elif num_free_blocks - num_required_blocks >= self.watermark_blocks:
  455. return AllocStatus.OK
  456. else:
  457. return AllocStatus.LATER
  458. def _swap_block_table(
  459. self, block_table: BlockTable, src_allocator: BlockAllocatorBase,
  460. dest_allocator: BlockAllocatorBase,
  461. mapping: Dict[PhysicalTokenBlock,
  462. PhysicalTokenBlock]) -> BlockTable:
  463. new_block_table: BlockTable = BlockTable()
  464. for from_block in block_table:
  465. if from_block in mapping:
  466. to_block = mapping[from_block]
  467. to_block.ref_count += 1
  468. else:
  469. to_block = dest_allocator.allocate(
  470. from_block.block_hash, from_block.num_hashed_tokens)
  471. mapping[from_block] = to_block
  472. new_block_table.append(to_block)
  473. # Free the source block swapped in to destination.
  474. src_allocator.free(from_block)
  475. return new_block_table
  476. def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
  477. request_id = seq_group.request_id
  478. # CPU block -> GPU block.
  479. # dict is efficient in lookup `if cpu_block in mapping`
  480. mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
  481. for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
  482. self.block_tables[seq.seq_id] = \
  483. self._swap_block_table(self.block_tables[seq.seq_id],
  484. self.cpu_allocator, self.gpu_allocator,
  485. mapping)
  486. if seq_group.is_encoder_decoder():
  487. self.cross_block_tables[request_id] = \
  488. self._swap_block_table(self.cross_block_tables[request_id],
  489. self.cpu_allocator,
  490. self.gpu_allocator,
  491. mapping)
  492. return [(cpu_block.block_number, gpu_block.block_number)
  493. for cpu_block, gpu_block in mapping.items()]
  494. def can_swap_out(self, seq_group: SequenceGroup) -> bool:
  495. blocks = self._get_physical_blocks(seq_group)
  496. return len(blocks) <= self.cpu_allocator.get_num_free_blocks()
  497. def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
  498. request_id = seq_group.request_id
  499. # GPU block -> CPU block.
  500. # dict is efficient in lookup `if gpu_block in mapping`
  501. mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
  502. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  503. self.block_tables[seq.seq_id] = \
  504. self._swap_block_table(self.block_tables[seq.seq_id],
  505. self.gpu_allocator, self.cpu_allocator,
  506. mapping)
  507. if seq_group.is_encoder_decoder():
  508. self.cross_block_tables[request_id] = \
  509. self._swap_block_table(self.cross_block_tables[request_id],
  510. self.gpu_allocator,
  511. self.cpu_allocator,
  512. mapping)
  513. return [(cpu_block.block_number, gpu_block.block_number)
  514. for cpu_block, gpu_block in mapping.items()]
  515. def _free_block_table(self, block_table: BlockTable) -> None:
  516. # when using a sliding window, each seq will only use up
  517. # to `self.block_sliding_window` blocks. When freeing
  518. # the block table, we must make sure to not free blocks more
  519. # than once. If no sliding window is used, there is no block
  520. # reuse in the block table, so we must free all blocks.
  521. blocks_to_free = (block_table[-self.block_sliding_window:]
  522. if self.block_sliding_window is not None else
  523. block_table)
  524. for block in set(blocks_to_free):
  525. if block.device == Device.GPU:
  526. self.gpu_allocator.free(block)
  527. else:
  528. self.cpu_allocator.free(block)
  529. def free(self, seq: Sequence) -> None:
  530. if seq.seq_id not in self.block_tables:
  531. # Already freed or haven't been scheduled yet.
  532. return
  533. block_table = self.block_tables[seq.seq_id]
  534. self._free_block_table(block_table)
  535. del self.block_tables[seq.seq_id]
  536. def free_cross(self, seq_group: SequenceGroup) -> None:
  537. if seq_group.request_id not in self.cross_block_tables:
  538. # Already freed or hasn't ben scheduled yet.
  539. return
  540. block_table = self.cross_block_tables[seq_group.request_id]
  541. self._free_block_table(block_table)
  542. del self.cross_block_tables[seq_group.request_id]
  543. def reset(self) -> None:
  544. # Free decoder block tables
  545. for block_table in self.block_tables.values():
  546. self._free_block_table(block_table)
  547. self.block_tables.clear()
  548. # Free cross-attention block tables
  549. for block_table in self.cross_block_tables.values():
  550. self._free_block_table(block_table)
  551. self.cross_block_tables.clear()
  552. def get_block_table(self, seq: Sequence) -> List[int]:
  553. return self.block_tables[seq.seq_id].ids()
  554. def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]:
  555. block_table = self.cross_block_tables[seq_group.request_id]
  556. return [block.block_number for block in block_table]
  557. def get_num_free_gpu_blocks(self) -> int:
  558. return self.gpu_allocator.get_num_free_blocks()
  559. def get_num_free_cpu_blocks(self) -> int:
  560. return self.cpu_allocator.get_num_free_blocks()
  561. def access_all_blocks_in_seq(
  562. self,
  563. seq: Sequence,
  564. access_time: float,
  565. ) -> None:
  566. if self.enable_caching:
  567. # Update the last accessed time of all the blocks accessed
  568. # in this step.
  569. block_table = self.block_tables[seq.seq_id]
  570. for block in block_table:
  571. block.last_accessed = access_time
  572. def compute_full_blocks_in_seq(self, seq: Sequence, token_chunk_size: int):
  573. if seq.seq_id not in self.block_tables:
  574. return
  575. # When chunked prefill is enabled, the computed full blocks
  576. # should be calculated based on the number of computed tokens.
  577. max_computed_tokens = (seq.data.get_num_computed_tokens() +
  578. token_chunk_size)
  579. computed_full_blocks = max_computed_tokens // self.block_size
  580. block_table = self.block_tables[seq.seq_id]
  581. if computed_full_blocks == 0:
  582. return
  583. for i in reversed(range(computed_full_blocks)):
  584. if block_table[i].computed:
  585. break
  586. block_table[i].computed = True
  587. def get_all_computed_blocks(self, seq: Sequence) -> List[int]:
  588. if seq.seq_id not in self.block_tables:
  589. return []
  590. block_table = self.block_tables[seq.seq_id]
  591. # NOTE We exclude the last block to avoid the case where the entire
  592. # prompt is cached. This would cause erroneous behavior in model
  593. # runner.
  594. return [
  595. b.block_number
  596. for b in takewhile(lambda b: b.computed, block_table[:-1])
  597. ]
  598. def get_common_computed_block_ids(
  599. self, seqs: List[Sequence]) -> GenericSequence[int]:
  600. """Return the block ids that are common for a given sequence group.
  601. Used in prefill (can skip prefill of some blocks).
  602. """
  603. # Can return non-empty result only with prefix caching enabled.
  604. if not self.enable_caching:
  605. return []
  606. ids_list = [self.get_all_computed_blocks(seq) for seq in seqs]
  607. return commonprefix([ids for ids in ids_list if ids != []])
  608. def mark_blocks_as_computed(self, seq_group: SequenceGroup,
  609. token_chunk_size: int):
  610. if self.enable_caching:
  611. for seq in seq_group.get_seqs():
  612. self.compute_full_blocks_in_seq(seq, token_chunk_size)
  613. def get_prefix_cache_hit_rate(self, device: Device) -> float:
  614. if device == Device.GPU:
  615. return self.gpu_allocator.get_prefix_cache_hit_rate()
  616. if device == Device.CPU:
  617. return self.cpu_allocator.get_prefix_cache_hit_rate()
  618. raise ValueError(f"Invalid device: {device}")