block_manager_v2.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499
  1. """A block manager that manages token blocks."""
  2. from itertools import chain
  3. from typing import Dict, List, Optional
  4. from typing import Sequence as GenericSequence
  5. from typing import Tuple
  6. from aphrodite.common.sequence import Sequence, SequenceGroup, SequenceStatus
  7. from aphrodite.common.utils import Device
  8. from aphrodite.processing.block.block_table import BlockTable
  9. from aphrodite.processing.block.cpu_gpu_block_allocator import (
  10. CpuGpuBlockAllocator)
  11. from aphrodite.processing.block.interfaces import Block
  12. from aphrodite.processing.block.prefix_caching_block import (
  13. ComputedBlocksTracker, LastAccessBlocksTracker)
  14. from aphrodite.processing.block.utils import (
  15. check_no_caching_or_swa_for_blockmgr_encdec)
  16. from aphrodite.processing.interfaces import AllocStatus, BlockSpaceManager
  17. SeqId = int
  18. EncoderSeqId = str
  19. class BlockSpaceManagerV2(BlockSpaceManager):
  20. """BlockSpaceManager which manages the allocation of KV cache.
  21. It owns responsibility for allocation, swapping, allocating memory for
  22. autoregressively-generated tokens, and other advanced features such as
  23. prefix caching, forking/copy-on-write, and sliding-window memory allocation.
  24. The current implementation is partial; in particular prefix caching and
  25. sliding-window are not feature complete.
  26. Lookahead slots
  27. The block manager has the notion of a "lookahead slot". These are slots
  28. in the KV cache that are allocated for a sequence. Unlike the other
  29. allocated slots, the content of these slots is undefined -- the worker
  30. may use the memory allocations in any way.
  31. In practice, a worker could use these lookahead slots to run multiple
  32. forward passes for a single scheduler invocation. Each successive
  33. forward pass would write KV activations to the corresponding lookahead
  34. slot. This allows low inter-token latency use-cases, where the overhead
  35. of continuous batching scheduling is amortized over >1 generated tokens.
  36. Speculative decoding uses lookahead slots to store KV activations of
  37. proposal tokens.
  38. Args:
  39. block_size (int): The size of each memory block.
  40. num_gpu_blocks (int): The number of memory blocks allocated on GPU.
  41. num_cpu_blocks (int): The number of memory blocks allocated on CPU.
  42. watermark (float, optional): The threshold used for memory swapping.
  43. Defaults to 0.01.
  44. sliding_window (Optional[int], optional): The size of the sliding
  45. window. Defaults to None.
  46. enable_caching (bool, optional): Flag indicating whether caching is
  47. enabled. Defaults to False.
  48. """
  49. def __init__(
  50. self,
  51. block_size: int,
  52. num_gpu_blocks: int,
  53. num_cpu_blocks: int,
  54. watermark: float = 0.01,
  55. sliding_window: Optional[int] = None,
  56. enable_caching: bool = False,
  57. ) -> None:
  58. self.block_size = block_size
  59. self.num_total_gpu_blocks = num_gpu_blocks
  60. self.num_total_cpu_blocks = num_cpu_blocks
  61. self.sliding_window = sliding_window
  62. # max_block_sliding_window is the max number of blocks that need to be
  63. # allocated
  64. self.max_block_sliding_window = None
  65. if sliding_window is not None:
  66. # +1 here because // rounds down
  67. num_blocks = sliding_window // block_size + 1
  68. # +1 here because the last block may not be full,
  69. # and so the sequence stretches one more block at the beginning
  70. # For example, if sliding_window is 3 and block_size is 4,
  71. # we may need 2 blocks when the second block only holds 1 token.
  72. self.max_block_sliding_window = num_blocks + 1
  73. self.watermark = watermark
  74. assert watermark >= 0.0
  75. self.enable_caching = enable_caching
  76. self.watermark_blocks = int(watermark * num_gpu_blocks)
  77. self.block_allocator = CpuGpuBlockAllocator.create(
  78. allocator_type="prefix_caching" if enable_caching else "naive",
  79. num_gpu_blocks=num_gpu_blocks,
  80. num_cpu_blocks=num_cpu_blocks,
  81. block_size=block_size,
  82. )
  83. self.block_tables: Dict[SeqId, BlockTable] = {}
  84. self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {}
  85. self._computed_blocks_tracker = ComputedBlocksTracker(
  86. self.block_allocator)
  87. self._last_access_blocks_tracker = LastAccessBlocksTracker(
  88. self.block_allocator)
  89. def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
  90. # FIXME: Here we assume that all sequences in the group share
  91. # the same prompt. This may not be true for preempted sequences.
  92. check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group)
  93. seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
  94. num_required_blocks = BlockTable.get_num_required_blocks(
  95. seq.get_token_ids(),
  96. block_size=self.block_size,
  97. )
  98. if seq_group.is_encoder_decoder():
  99. num_required_blocks += BlockTable.get_num_required_blocks(
  100. seq_group.get_encoder_seq().get_token_ids(),
  101. block_size=self.block_size,
  102. )
  103. if self.max_block_sliding_window is not None:
  104. num_required_blocks = min(num_required_blocks,
  105. self.max_block_sliding_window)
  106. num_free_gpu_blocks = self.block_allocator.get_num_free_blocks(
  107. device=Device.GPU)
  108. # Use watermark to avoid frequent cache eviction.
  109. if (self.num_total_gpu_blocks - num_required_blocks <
  110. self.watermark_blocks):
  111. return AllocStatus.NEVER
  112. if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks:
  113. return AllocStatus.OK
  114. else:
  115. return AllocStatus.LATER
  116. def _allocate_sequence(self, seq: Sequence) -> BlockTable:
  117. block_table = BlockTable(
  118. block_size=self.block_size,
  119. block_allocator=self.block_allocator,
  120. max_block_sliding_window=self.max_block_sliding_window,
  121. )
  122. block_table.allocate(seq.get_token_ids())
  123. return block_table
  124. def allocate(self, seq_group: SequenceGroup) -> None:
  125. # Allocate self-attention block tables for decoder sequences
  126. waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
  127. assert not (set(seq.seq_id for seq in waiting_seqs)
  128. & self.block_tables.keys()), "block table already exists"
  129. # NOTE: Here we assume that all sequences in the group have the same
  130. # prompt.
  131. seq = waiting_seqs[0]
  132. block_table: BlockTable = self._allocate_sequence(seq)
  133. self.block_tables[seq.seq_id] = block_table
  134. # Track seq
  135. self._computed_blocks_tracker.add_seq(seq.seq_id)
  136. self._last_access_blocks_tracker.add_seq(seq.seq_id)
  137. # Assign the block table for each sequence.
  138. for seq in waiting_seqs[1:]:
  139. self.block_tables[seq.seq_id] = block_table.fork()
  140. # Track seq
  141. self._computed_blocks_tracker.add_seq(seq.seq_id)
  142. self._last_access_blocks_tracker.add_seq(seq.seq_id)
  143. # Allocate cross-attention block table for encoder sequence
  144. #
  145. # NOTE: Here we assume that all sequences in the group have the same
  146. # encoder prompt.
  147. request_id = seq_group.request_id
  148. assert (request_id
  149. not in self.cross_block_tables), \
  150. "block table already exists"
  151. check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group)
  152. if seq_group.is_encoder_decoder():
  153. block_table = self._allocate_sequence(seq_group.get_encoder_seq())
  154. self.cross_block_tables[request_id] = block_table
  155. def can_append_slots(self, seq_group: SequenceGroup,
  156. num_lookahead_slots: int) -> bool:
  157. """Determine if there is enough space in the GPU KV cache to continue
  158. generation of the specified sequence group.
  159. We use a worst-case heuristic: assume each touched block will require a
  160. new allocation (either via CoW or new block). We can append slots if the
  161. number of touched blocks is less than the number of free blocks.
  162. "Lookahead slots" are slots that are allocated in addition to the slots
  163. for known tokens. The contents of the lookahead slots are not defined.
  164. This is used by speculative decoding when speculating future tokens.
  165. """
  166. num_touched_blocks = 0
  167. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  168. block_table = self.block_tables[seq.seq_id]
  169. num_touched_blocks += (
  170. block_table.get_num_blocks_touched_by_append_slots(
  171. token_ids=block_table.get_unseen_token_ids(
  172. seq.get_token_ids()),
  173. num_lookahead_slots=num_lookahead_slots,
  174. ))
  175. num_free_gpu_blocks = self.block_allocator.get_num_free_blocks(
  176. Device.GPU)
  177. return num_touched_blocks <= num_free_gpu_blocks
  178. def append_slots(
  179. self,
  180. seq: Sequence,
  181. num_lookahead_slots: int,
  182. ) -> List[Tuple[int, int]]:
  183. block_table = self.block_tables[seq.seq_id]
  184. block_table.append_token_ids(
  185. token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()),
  186. num_lookahead_slots=num_lookahead_slots,
  187. num_computed_slots=seq.data.get_num_computed_tokens(),
  188. )
  189. # Return any new copy-on-writes.
  190. new_cows = self.block_allocator.clear_copy_on_writes()
  191. return new_cows
  192. def free(self, seq: Sequence) -> None:
  193. seq_id = seq.seq_id
  194. if seq_id not in self.block_tables:
  195. # Already freed or haven't been scheduled yet.
  196. return
  197. # Update seq block ids with the latest access time
  198. self._last_access_blocks_tracker.update_seq_blocks_last_access(
  199. seq_id, self.block_tables[seq.seq_id].physical_block_ids)
  200. # Untrack seq
  201. self._last_access_blocks_tracker.remove_seq(seq_id)
  202. self._computed_blocks_tracker.remove_seq(seq_id)
  203. # Free table/blocks
  204. self.block_tables[seq_id].free()
  205. del self.block_tables[seq_id]
  206. def free_cross(self, seq_group: SequenceGroup) -> None:
  207. request_id = seq_group.request_id
  208. if request_id not in self.cross_block_tables:
  209. # Already freed or hasn't been scheduled yet.
  210. return
  211. self.cross_block_tables[request_id].free()
  212. del self.cross_block_tables[request_id]
  213. def get_block_table(self, seq: Sequence) -> List[int]:
  214. block_ids = self.block_tables[seq.seq_id].physical_block_ids
  215. return block_ids # type: ignore
  216. def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]:
  217. request_id = seq_group.request_id
  218. assert request_id in self.cross_block_tables
  219. block_ids = self.cross_block_tables[request_id].physical_block_ids
  220. assert all(b is not None for b in block_ids)
  221. return block_ids # type: ignore
  222. def access_all_blocks_in_seq(self, seq: Sequence, now: float):
  223. if self.enable_caching:
  224. # Record the latest access time for the sequence. The actual update
  225. # of the block ids is deferred to the sequence free(..) call, since
  226. # only during freeing of block ids, the blocks are actually added to
  227. # the evictor (which is when the most updated time is required)
  228. # (This avoids expensive calls to mark_blocks_as_accessed(..))
  229. self._last_access_blocks_tracker.update_last_access(
  230. seq.seq_id, now)
  231. def mark_blocks_as_computed(self, seq_group: SequenceGroup,
  232. token_chunk_size: int):
  233. # If prefix caching is enabled, mark immutable blocks as computed
  234. # right after they have been scheduled (for prefill). This assumes
  235. # the scheduler is synchronous so blocks are actually computed when
  236. # scheduling the next batch.
  237. self.block_allocator.mark_blocks_as_computed([])
  238. def get_common_computed_block_ids(
  239. self, seqs: List[Sequence]) -> GenericSequence[int]:
  240. """Determine which blocks for which we skip prefill.
  241. With prefix caching we can skip prefill for previously-generated blocks.
  242. Currently, the attention implementation only supports skipping cached
  243. blocks if they are a contiguous prefix of cached blocks.
  244. This method determines which blocks can be safely skipped for all
  245. sequences in the sequence group.
  246. """
  247. computed_seq_block_ids = []
  248. for seq in seqs:
  249. computed_seq_block_ids.append(
  250. self._computed_blocks_tracker.
  251. get_cached_computed_blocks_and_update(
  252. seq.seq_id,
  253. self.block_tables[seq.seq_id].physical_block_ids))
  254. # NOTE: This assumes seq_block_ids doesn't contain any None.
  255. return self.block_allocator.get_common_computed_block_ids(
  256. computed_seq_block_ids) # type: ignore
  257. def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
  258. if parent_seq.seq_id not in self.block_tables:
  259. # Parent sequence has either been freed or never existed.
  260. return
  261. src_block_table = self.block_tables[parent_seq.seq_id]
  262. self.block_tables[child_seq.seq_id] = src_block_table.fork()
  263. # Track child seq
  264. self._computed_blocks_tracker.add_seq(child_seq.seq_id)
  265. self._last_access_blocks_tracker.add_seq(child_seq.seq_id)
  266. def can_swap_in(self, seq_group: SequenceGroup,
  267. num_lookahead_slots: int) -> AllocStatus:
  268. """Returns the AllocStatus for the given sequence_group
  269. with num_lookahead_slots.
  270. Args:
  271. sequence_group (SequenceGroup): The sequence group to swap in.
  272. num_lookahead_slots (int): Number of lookahead slots used in
  273. speculative decoding, default to 0.
  274. Returns:
  275. AllocStatus: The AllocStatus for the given sequence group.
  276. """
  277. return self._can_swap(seq_group, Device.GPU, SequenceStatus.SWAPPED,
  278. num_lookahead_slots)
  279. def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
  280. """Returns the block id mapping (from CPU to GPU) generated by
  281. swapping in the given seq_group with num_lookahead_slots.
  282. Args:
  283. seq_group (SequenceGroup): The sequence group to swap in.
  284. Returns:
  285. List[Tuple[int, int]]: The mapping of swapping block from CPU
  286. to GPU.
  287. """
  288. physical_block_id_mapping = []
  289. for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
  290. blocks = self.block_tables[seq.seq_id].blocks
  291. if len(blocks) == 0:
  292. continue
  293. seq_swap_mapping = self.block_allocator.swap(blocks=blocks,
  294. src_device=Device.CPU,
  295. dst_device=Device.GPU)
  296. # Refresh the block ids of the table (post-swap)
  297. self.block_tables[seq.seq_id].update(blocks)
  298. seq_physical_block_id_mapping = {
  299. self.block_allocator.get_physical_block_id(
  300. Device.CPU, cpu_block_id):
  301. self.block_allocator.get_physical_block_id(
  302. Device.GPU, gpu_block_id)
  303. for cpu_block_id, gpu_block_id in seq_swap_mapping.items()
  304. }
  305. physical_block_id_mapping.extend(
  306. list(seq_physical_block_id_mapping.items()))
  307. return physical_block_id_mapping
  308. def can_swap_out(self, seq_group: SequenceGroup) -> bool:
  309. """Returns whether we can swap out the given sequence_group
  310. with num_lookahead_slots.
  311. Args:
  312. seq_group (SequenceGroup): The sequence group to swap in.
  313. num_lookahead_slots (int): Number of lookahead slots used in
  314. speculative decoding, default to 0.
  315. Returns:
  316. bool: Whether it's possible to swap out current sequence group.
  317. """
  318. alloc_status = self._can_swap(seq_group, Device.CPU,
  319. SequenceStatus.RUNNING)
  320. if alloc_status == AllocStatus.OK:
  321. return True
  322. return False
  323. def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
  324. """Returns the block id mapping (from GPU to CPU) generated by
  325. swapping out the given sequence_group with num_lookahead_slots.
  326. Args:
  327. sequence_group (SequenceGroup): The sequence group to swap in.
  328. Returns:
  329. List[Tuple[int, int]]: The mapping of swapping block from
  330. GPU to CPU.
  331. """
  332. physical_block_id_mapping = []
  333. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  334. blocks = self.block_tables[seq.seq_id].blocks
  335. if len(blocks) == 0:
  336. continue
  337. seq_swap_mapping = self.block_allocator.swap(blocks=blocks,
  338. src_device=Device.GPU,
  339. dst_device=Device.CPU)
  340. # Refresh the block ids of the table (post-swap)
  341. self.block_tables[seq.seq_id].update(blocks)
  342. seq_physical_block_id_mapping = {
  343. self.block_allocator.get_physical_block_id(
  344. Device.GPU, gpu_block_id):
  345. self.block_allocator.get_physical_block_id(
  346. Device.CPU, cpu_block_id)
  347. for gpu_block_id, cpu_block_id in seq_swap_mapping.items()
  348. }
  349. physical_block_id_mapping.extend(
  350. list(seq_physical_block_id_mapping.items()))
  351. return physical_block_id_mapping
  352. def get_num_free_gpu_blocks(self) -> int:
  353. return self.block_allocator.get_num_free_blocks(Device.GPU)
  354. def get_num_free_cpu_blocks(self) -> int:
  355. return self.block_allocator.get_num_free_blocks(Device.CPU)
  356. def get_prefix_cache_hit_rate(self, device: Device) -> float:
  357. return self.block_allocator.get_prefix_cache_hit_rate(device)
  358. def _can_swap(self,
  359. seq_group: SequenceGroup,
  360. device: Device,
  361. status: SequenceStatus,
  362. num_lookahead_slots: int = 0) -> AllocStatus:
  363. """Returns the AllocStatus for swapping in/out the given sequence_group
  364. on to the 'device'.
  365. Args:
  366. sequence_group (SequenceGroup): The sequence group to swap in.
  367. device (Device): device to swap the 'seq_group' on.
  368. status (SequenceStatus): The status of sequence which is needed
  369. for action. RUNNING for swap out and SWAPPED for swap in
  370. num_lookahead_slots (int): Number of lookahead slots used in
  371. speculative decoding, default to 0.
  372. Returns:
  373. AllocStatus: The AllocStatus for swapping in/out the given
  374. sequence_group on to the 'device'.
  375. """
  376. blocks = self._get_blocks_for_swap(seq_group, status)
  377. num_blocks_touched = self.block_allocator.get_num_blocks_touched(
  378. blocks, device, num_lookahead_slots)
  379. watermark_blocks = 0
  380. if device == Device.GPU:
  381. watermark_blocks = self.watermark_blocks
  382. if self.block_allocator.get_num_total_blocks(
  383. device) < num_blocks_touched:
  384. return AllocStatus.NEVER
  385. elif self.block_allocator.get_num_free_blocks(
  386. device) - num_blocks_touched >= watermark_blocks:
  387. return AllocStatus.OK
  388. else:
  389. return AllocStatus.LATER
  390. def _get_blocks_for_swap(self, seq_group: SequenceGroup,
  391. status: SequenceStatus) -> List[Block]:
  392. """Returns the list of blocks those are touched by the seq_group
  393. Args:
  394. sequence_group (SequenceGroup): The sequence group to swap in.
  395. status (SequenceStatus): The status of sequence which is needed
  396. for action. RUNNING for swap out and SWAPPED for swap in
  397. Returns:
  398. The list of blocks those are touched by the seq_group.
  399. """
  400. blocks: Dict[int, List[Block]] = {}
  401. for seq in seq_group.get_seqs(status=status):
  402. block_table = self.block_tables[seq.seq_id]
  403. if block_table.blocks is not None:
  404. blocks[seq.seq_id] = block_table.blocks
  405. combined_blocks = list(chain(*blocks.values()))
  406. return combined_blocks