block_manager_v2.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. """A block manager that manages token blocks."""
  2. from typing import Dict, List, Optional
  3. from typing import Sequence as GenericSequence
  4. from aphrodite.common.sequence import Sequence, SequenceGroup, SequenceStatus
  5. from aphrodite.common.utils import Device
  6. from aphrodite.processing.block.block_table import BlockTable
  7. from aphrodite.processing.block.cpu_gpu_block_allocator import \
  8. CpuGpuBlockAllocator
  9. from aphrodite.processing.interfaces import AllocStatus, BlockSpaceManager
  10. SeqId = int
  11. class BlockSpaceManagerV2(BlockSpaceManager):
  12. """BlockSpaceManager which manages the allocation of KV cache.
  13. It owns responsibility for allocation, swapping, allocating memory for
  14. autoregressively-generated tokens, and other advanced features such as
  15. prefix caching, forking/copy-on-write, and sliding-window memory allocation.
  16. The current implementation is partial; in particular prefix caching and
  17. sliding-window are not feature complete.
  18. Lookahead slots
  19. The block manager has the notion of a "lookahead slot". These are slots
  20. in the KV cache that are allocated for a sequence. Unlike the other
  21. allocated slots, the content of these slots is undefined -- the worker
  22. may use the memory allocations in any way.
  23. In practice, a worker could use these lookahead slots to run multiple
  24. forward passes for a single scheduler invocation. Each successive
  25. forward pass would write KV activations to the corresponding lookahead
  26. slot. This allows low inter-token latency use-cases, where the overhead
  27. of continuous batching scheduling is amortized over >1 generated tokens.
  28. Speculative decoding uses lookahead slots to store KV activations of
  29. proposal tokens.
  30. Args:
  31. block_size (int): The size of each memory block.
  32. num_gpu_blocks (int): The number of memory blocks allocated on GPU.
  33. num_cpu_blocks (int): The number of memory blocks allocated on CPU.
  34. watermark (float, optional): The threshold used for memory swapping.
  35. Defaults to 0.01.
  36. sliding_window (Optional[int], optional): The size of the sliding
  37. window. Defaults to None.
  38. enable_caching (bool, optional): Flag indicating whether caching is
  39. enabled. Defaults to False.
  40. """
  41. def __init__(
  42. self,
  43. block_size: int,
  44. num_gpu_blocks: int,
  45. num_cpu_blocks: int,
  46. watermark: float = 0.01,
  47. sliding_window: Optional[int] = None,
  48. enable_caching: bool = False,
  49. ) -> None:
  50. self.block_size = block_size
  51. self.num_total_gpu_blocks = num_gpu_blocks
  52. self.num_total_cpu_blocks = num_cpu_blocks
  53. assert sliding_window is None, "Sliding window not yet supported"
  54. self.block_sliding_window = None
  55. self.watermark = watermark
  56. assert watermark >= 0.0
  57. assert not enable_caching, "Prefix caching not yet supported"
  58. self.enable_caching = enable_caching
  59. self.watermark_blocks = int(watermark * num_gpu_blocks)
  60. self.block_allocator = CpuGpuBlockAllocator.create(
  61. # Currently, only naive blocks are supported (no prefix caching).
  62. allocator_type="naive",
  63. num_gpu_blocks=num_gpu_blocks,
  64. num_cpu_blocks=num_cpu_blocks,
  65. block_size=block_size,
  66. )
  67. self.block_tables: Dict[SeqId, BlockTable] = {}
  68. def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
  69. # FIXME: Here we assume that all sequences in the group share
  70. # the same prompt. This may not be true for preempted sequences.
  71. seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
  72. num_required_blocks = BlockTable.get_num_required_blocks(
  73. seq.get_token_ids(),
  74. block_size=self.block_size,
  75. )
  76. assert self.block_sliding_window is None
  77. if self.block_sliding_window is not None:
  78. num_required_blocks = min(num_required_blocks,
  79. self.block_sliding_window)
  80. num_free_gpu_blocks = self.block_allocator.get_num_free_blocks(
  81. device=Device.GPU)
  82. # Use watermark to avoid frequent cache eviction.
  83. if (self.num_total_gpu_blocks - num_required_blocks <
  84. self.watermark_blocks):
  85. return AllocStatus.NEVER
  86. if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks:
  87. return AllocStatus.OK
  88. else:
  89. return AllocStatus.LATER
  90. def allocate(self, seq_group: SequenceGroup) -> None:
  91. waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
  92. assert not (set(seq.seq_id for seq in waiting_seqs)
  93. & self.block_tables.keys()), "block table already exists"
  94. # NOTE: Here we assume that all sequences in the group have the same
  95. # prompt.
  96. seq = waiting_seqs[0]
  97. block_table = BlockTable(
  98. block_size=self.block_size,
  99. block_allocator=self.block_allocator,
  100. )
  101. assert self.block_sliding_window is None
  102. block_table.allocate(seq.get_token_ids())
  103. self.block_tables[seq.seq_id] = block_table
  104. # Assign the block table for each sequence.
  105. for seq in waiting_seqs[1:]:
  106. self.block_tables[seq.seq_id] = block_table.fork()
  107. def can_append_slots(self, seq_group: SequenceGroup,
  108. num_lookahead_slots: int) -> bool:
  109. """Determine if there is enough space in the GPU KV cache to continue
  110. generation of the specified sequence group.
  111. We use a worst-case heuristic: assume each touched block will require a
  112. new allocation (either via CoW or new block). We can append slots if the
  113. number of touched blocks is less than the number of free blocks.
  114. "Lookahead slots" are slots that are allocated in addition to the slots
  115. for known tokens. The contents of the lookahead slots are not defined.
  116. This is used by speculative decoding when speculating future tokens.
  117. """
  118. num_touched_blocks = 0
  119. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  120. block_table = self.block_tables[seq.seq_id]
  121. num_touched_blocks += (
  122. block_table.get_num_blocks_touched_by_append_slots(
  123. token_ids=block_table.get_unseen_token_ids(
  124. seq.get_token_ids()),
  125. num_lookahead_slots=num_lookahead_slots,
  126. ))
  127. num_free_gpu_blocks = self.block_allocator.get_num_free_blocks(
  128. Device.GPU)
  129. return num_touched_blocks <= num_free_gpu_blocks
  130. def append_slots(
  131. self,
  132. seq: Sequence,
  133. num_lookahead_slots: int,
  134. ) -> Dict[int, List[int]]:
  135. block_table = self.block_tables[seq.seq_id]
  136. block_table.append_token_ids(
  137. token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()),
  138. num_lookahead_slots=num_lookahead_slots,
  139. )
  140. # Return any new copy-on-writes.
  141. new_cows = self.block_allocator.clear_copy_on_writes()
  142. return new_cows
  143. def free(self, seq: Sequence) -> None:
  144. if seq.seq_id not in self.block_tables:
  145. # Already freed or haven't been scheduled yet.
  146. return
  147. self.block_tables[seq.seq_id].free()
  148. del self.block_tables[seq.seq_id]
  149. def get_block_table(self, seq: Sequence) -> List[int]:
  150. assert seq.seq_id in self.block_tables
  151. block_ids = self.block_tables[seq.seq_id].physical_block_ids
  152. assert all(b is not None for b in block_ids)
  153. return block_ids
  154. def access_all_blocks_in_seq(self, seq, now):
  155. # TODO add prefix caching support.
  156. pass
  157. def mark_blocks_as_computed(self, seq_group: SequenceGroup):
  158. # We ignore the sequence group as its not necessary. After the batch is
  159. # formed by the scheduler, we do not need to mark blocks from individual
  160. # sequence groups as computed -- all blocks in the batch can be marked
  161. # as computed.
  162. self.block_allocator.mark_blocks_as_computed()
  163. def get_common_computed_block_ids(
  164. self, seqs: List[Sequence]) -> GenericSequence[int]:
  165. """Determine which blocks for which we skip prefill.
  166. With prefix caching we can skip prefill for previously-generated blocks.
  167. Currently, the attention implementation only supports skipping cached
  168. blocks if they are a contiguous prefix of cached blocks.
  169. This method determines which blocks can be safely skipped for all
  170. sequences in the sequence group.
  171. """
  172. seq_block_ids = [
  173. self.block_tables[seq.seq_id].physical_block_ids for seq in seqs
  174. ]
  175. return self.block_allocator.get_common_computed_block_ids(
  176. seq_block_ids)
  177. def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
  178. src_block_table = self.block_tables[parent_seq.seq_id]
  179. self.block_tables[child_seq.seq_id] = src_block_table.fork()
  180. def can_swap_in(self, seq_group: SequenceGroup,
  181. num_lookahead_slots: int) -> bool:
  182. return False
  183. def swap_in(self, seq_group: SequenceGroup,
  184. num_lookahead_slots: int) -> Dict[int, int]:
  185. raise NotImplementedError
  186. def can_swap_out(self, seq_group: SequenceGroup) -> bool:
  187. return False
  188. def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
  189. raise NotImplementedError
  190. def get_num_free_gpu_blocks(self) -> int:
  191. return self.block_allocator.get_num_free_blocks(Device.GPU)
  192. def get_num_free_cpu_blocks(self) -> int:
  193. return self.block_allocator.get_num_free_blocks(Device.CPU)