block_manager_v2.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. """A block manager that manages token blocks."""
  2. from typing import Dict, List, Optional
  3. from typing import Sequence as GenericSequence
  4. from typing import Tuple
  5. from aphrodite.common.sequence import Sequence, SequenceGroup, SequenceStatus
  6. from aphrodite.common.utils import Device
  7. from aphrodite.processing.block.block_table import BlockTable
  8. from aphrodite.processing.block.cpu_gpu_block_allocator import \
  9. CpuGpuBlockAllocator
  10. from aphrodite.processing.interfaces import AllocStatus, BlockSpaceManager
  11. SeqId = int
  12. class BlockSpaceManagerV2(BlockSpaceManager):
  13. """BlockSpaceManager which manages the allocation of KV cache.
  14. It owns responsibility for allocation, swapping, allocating memory for
  15. autoregressively-generated tokens, and other advanced features such as
  16. prefix caching, forking/copy-on-write, and sliding-window memory allocation.
  17. The current implementation is partial; in particular prefix caching and
  18. sliding-window are not feature complete.
  19. Lookahead slots
  20. The block manager has the notion of a "lookahead slot". These are slots
  21. in the KV cache that are allocated for a sequence. Unlike the other
  22. allocated slots, the content of these slots is undefined -- the worker
  23. may use the memory allocations in any way.
  24. In practice, a worker could use these lookahead slots to run multiple
  25. forward passes for a single scheduler invocation. Each successive
  26. forward pass would write KV activations to the corresponding lookahead
  27. slot. This allows low inter-token latency use-cases, where the overhead
  28. of continuous batching scheduling is amortized over >1 generated tokens.
  29. Speculative decoding uses lookahead slots to store KV activations of
  30. proposal tokens.
  31. Args:
  32. block_size (int): The size of each memory block.
  33. num_gpu_blocks (int): The number of memory blocks allocated on GPU.
  34. num_cpu_blocks (int): The number of memory blocks allocated on CPU.
  35. watermark (float, optional): The threshold used for memory swapping.
  36. Defaults to 0.01.
  37. sliding_window (Optional[int], optional): The size of the sliding
  38. window. Defaults to None.
  39. enable_caching (bool, optional): Flag indicating whether caching is
  40. enabled. Defaults to False.
  41. """
  42. def __init__(
  43. self,
  44. block_size: int,
  45. num_gpu_blocks: int,
  46. num_cpu_blocks: int,
  47. watermark: float = 0.01,
  48. sliding_window: Optional[int] = None,
  49. enable_caching: bool = False,
  50. ) -> None:
  51. self.block_size = block_size
  52. self.num_total_gpu_blocks = num_gpu_blocks
  53. self.num_total_cpu_blocks = num_cpu_blocks
  54. assert sliding_window is None, "Sliding window not yet supported"
  55. self.block_sliding_window = None
  56. self.watermark = watermark
  57. assert watermark >= 0.0
  58. self.enable_caching = enable_caching
  59. self.watermark_blocks = int(watermark * num_gpu_blocks)
  60. self.block_allocator = CpuGpuBlockAllocator.create(
  61. allocator_type="prefix_caching" if enable_caching else "naive",
  62. num_gpu_blocks=num_gpu_blocks,
  63. num_cpu_blocks=num_cpu_blocks,
  64. block_size=block_size,
  65. )
  66. self.block_tables: Dict[SeqId, BlockTable] = {}
  67. def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
  68. # FIXME: Here we assume that all sequences in the group share
  69. # the same prompt. This may not be true for preempted sequences.
  70. seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
  71. num_required_blocks = BlockTable.get_num_required_blocks(
  72. seq.get_token_ids(),
  73. block_size=self.block_size,
  74. )
  75. assert self.block_sliding_window is None
  76. if self.block_sliding_window is not None:
  77. num_required_blocks = min(num_required_blocks,
  78. self.block_sliding_window)
  79. num_free_gpu_blocks = self.block_allocator.get_num_free_blocks(
  80. device=Device.GPU)
  81. # Use watermark to avoid frequent cache eviction.
  82. if (self.num_total_gpu_blocks - num_required_blocks <
  83. self.watermark_blocks):
  84. return AllocStatus.NEVER
  85. if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks:
  86. return AllocStatus.OK
  87. else:
  88. return AllocStatus.LATER
  89. def allocate(self, seq_group: SequenceGroup) -> None:
  90. waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
  91. assert not (set(seq.seq_id for seq in waiting_seqs)
  92. & self.block_tables.keys()), "block table already exists"
  93. # NOTE: Here we assume that all sequences in the group have the same
  94. # prompt.
  95. seq = waiting_seqs[0]
  96. block_table = BlockTable(
  97. block_size=self.block_size,
  98. block_allocator=self.block_allocator,
  99. )
  100. assert self.block_sliding_window is None
  101. block_table.allocate(seq.get_token_ids())
  102. self.block_tables[seq.seq_id] = block_table
  103. # Assign the block table for each sequence.
  104. for seq in waiting_seqs[1:]:
  105. self.block_tables[seq.seq_id] = block_table.fork()
  106. def can_append_slots(self, seq_group: SequenceGroup,
  107. num_lookahead_slots: int) -> bool:
  108. """Determine if there is enough space in the GPU KV cache to continue
  109. generation of the specified sequence group.
  110. We use a worst-case heuristic: assume each touched block will require a
  111. new allocation (either via CoW or new block). We can append slots if the
  112. number of touched blocks is less than the number of free blocks.
  113. "Lookahead slots" are slots that are allocated in addition to the slots
  114. for known tokens. The contents of the lookahead slots are not defined.
  115. This is used by speculative decoding when speculating future tokens.
  116. """
  117. num_touched_blocks = 0
  118. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  119. block_table = self.block_tables[seq.seq_id]
  120. num_touched_blocks += (
  121. block_table.get_num_blocks_touched_by_append_slots(
  122. token_ids=block_table.get_unseen_token_ids(
  123. seq.get_token_ids()),
  124. num_lookahead_slots=num_lookahead_slots,
  125. ))
  126. num_free_gpu_blocks = self.block_allocator.get_num_free_blocks(
  127. Device.GPU)
  128. return num_touched_blocks <= num_free_gpu_blocks
  129. def append_slots(
  130. self,
  131. seq: Sequence,
  132. num_lookahead_slots: int,
  133. ) -> List[Tuple[int, int]]:
  134. block_table = self.block_tables[seq.seq_id]
  135. block_table.append_token_ids(
  136. token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()),
  137. num_lookahead_slots=num_lookahead_slots,
  138. )
  139. # Return any new copy-on-writes.
  140. new_cows = self.block_allocator.clear_copy_on_writes()
  141. return new_cows
  142. def free(self, seq: Sequence) -> None:
  143. if seq.seq_id not in self.block_tables:
  144. # Already freed or haven't been scheduled yet.
  145. return
  146. self.block_tables[seq.seq_id].free()
  147. del self.block_tables[seq.seq_id]
  148. def get_block_table(self, seq: Sequence) -> List[int]:
  149. assert seq.seq_id in self.block_tables
  150. block_ids = self.block_tables[seq.seq_id].physical_block_ids
  151. assert all(b is not None for b in block_ids)
  152. return block_ids # type: ignore
  153. def access_all_blocks_in_seq(self, seq: Sequence, now: float):
  154. # Update the last accessed time of all the blocks accessed
  155. # in this step.
  156. # And the accessed time is only useful for prefix caching now,
  157. # as it support internal evictor policy for which cached
  158. # block could be refilled, to keep cached content could be reused
  159. # at max extend.
  160. if self.enable_caching:
  161. block_table = self.block_tables[seq.seq_id]
  162. block_ids = []
  163. for block_id in block_table.physical_block_ids:
  164. block_ids.append(block_id)
  165. self.block_allocator.mark_blocks_as_accessed(
  166. block_ids, # type: ignore
  167. now)
  168. def mark_blocks_as_computed(self, seq_group: SequenceGroup):
  169. # The only need for mark block as computed is for prefix caching,
  170. # while currently we could determine whether one block is computed
  171. # or not by check whether it has content hash.
  172. # So this function is useless for block_v2.
  173. pass
  174. def get_common_computed_block_ids(
  175. self, seqs: List[Sequence]) -> GenericSequence[int]:
  176. """Determine which blocks for which we skip prefill.
  177. With prefix caching we can skip prefill for previously-generated blocks.
  178. Currently, the attention implementation only supports skipping cached
  179. blocks if they are a contiguous prefix of cached blocks.
  180. This method determines which blocks can be safely skipped for all
  181. sequences in the sequence group.
  182. """
  183. seq_block_ids = [
  184. self.block_tables[seq.seq_id].physical_block_ids for seq in seqs
  185. ]
  186. # NOTE: This assumes seq_block_ids doesn't contain any None.
  187. return self.block_allocator.get_common_computed_block_ids(
  188. seq_block_ids) # type: ignore
  189. def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
  190. src_block_table = self.block_tables[parent_seq.seq_id]
  191. self.block_tables[child_seq.seq_id] = src_block_table.fork()
  192. def can_swap_in(self, seq_group: SequenceGroup,
  193. num_lookahead_slots: int) -> AllocStatus:
  194. return AllocStatus.LATER
  195. def swap_in(self, seq_group: SequenceGroup,
  196. num_lookahead_slots: int) -> List[Tuple[int, int]]:
  197. raise NotImplementedError
  198. def can_swap_out(self, seq_group: SequenceGroup) -> bool:
  199. return False
  200. def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
  201. raise NotImplementedError
  202. def get_num_free_gpu_blocks(self) -> int:
  203. return self.block_allocator.get_num_free_blocks(Device.GPU)
  204. def get_num_free_cpu_blocks(self) -> int:
  205. return self.block_allocator.get_num_free_blocks(Device.CPU)