123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547 |
- """A block manager that manages token blocks."""
- from itertools import chain
- from typing import Dict, List, Optional
- from typing import Sequence as GenericSequence
- from typing import Tuple
- from aphrodite.common.sequence import Sequence, SequenceGroup, SequenceStatus
- from aphrodite.common.utils import Device
- from aphrodite.processing.block.block_table import BlockTable
- from aphrodite.processing.block.cpu_gpu_block_allocator import (
- CpuGpuBlockAllocator)
- from aphrodite.processing.block.interfaces import Block
- from aphrodite.processing.block.prefix_caching_block import (
- ComputedBlocksTracker, LastAccessBlocksTracker)
- from aphrodite.processing.block.utils import (
- check_no_caching_or_swa_for_blockmgr_encdec)
- from aphrodite.processing.interfaces import AllocStatus, BlockSpaceManager
- SeqId = int
- NegativeSeqId = str
- EncoderSeqId = str
- class BlockSpaceManagerV2(BlockSpaceManager):
- """BlockSpaceManager which manages the allocation of KV cache.
- It owns responsibility for allocation, swapping, allocating memory for
- autoregressively-generated tokens, and other advanced features such as
- prefix caching, forking/copy-on-write, and sliding-window memory allocation.
- The current implementation is partial; in particular prefix caching and
- sliding-window are not feature complete.
- Lookahead slots
- The block manager has the notion of a "lookahead slot". These are slots
- in the KV cache that are allocated for a sequence. Unlike the other
- allocated slots, the content of these slots is undefined -- the worker
- may use the memory allocations in any way.
- In practice, a worker could use these lookahead slots to run multiple
- forward passes for a single scheduler invocation. Each successive
- forward pass would write KV activations to the corresponding lookahead
- slot. This allows low inter-token latency use-cases, where the overhead
- of continuous batching scheduling is amortized over >1 generated tokens.
- Speculative decoding uses lookahead slots to store KV activations of
- proposal tokens.
- Args:
- block_size (int): The size of each memory block.
- num_gpu_blocks (int): The number of memory blocks allocated on GPU.
- num_cpu_blocks (int): The number of memory blocks allocated on CPU.
- watermark (float, optional): The threshold used for memory swapping.
- Defaults to 0.01.
- sliding_window (Optional[int], optional): The size of the sliding
- window. Defaults to None.
- enable_caching (bool, optional): Flag indicating whether caching is
- enabled. Defaults to False.
- """
- def __init__(
- self,
- block_size: int,
- num_gpu_blocks: int,
- num_cpu_blocks: int,
- watermark: float = 0.01,
- sliding_window: Optional[int] = None,
- enable_caching: bool = False,
- ) -> None:
- self.block_size = block_size
- self.num_total_gpu_blocks = num_gpu_blocks
- self.num_total_cpu_blocks = num_cpu_blocks
- self.sliding_window = sliding_window
- # max_block_sliding_window is the max number of blocks that need to be
- # allocated
- self.max_block_sliding_window = None
- if sliding_window is not None:
- # +1 here because // rounds down
- num_blocks = sliding_window // block_size + 1
- # +1 here because the last block may not be full,
- # and so the sequence stretches one more block at the beginning
- # For example, if sliding_window is 3 and block_size is 4,
- # we may need 2 blocks when the second block only holds 1 token.
- self.max_block_sliding_window = num_blocks + 1
- self.watermark = watermark
- assert watermark >= 0.0
- self.enable_caching = enable_caching
- self.watermark_blocks = int(watermark * num_gpu_blocks)
- self.block_allocator = CpuGpuBlockAllocator.create(
- allocator_type="prefix_caching" if enable_caching else "naive",
- num_gpu_blocks=num_gpu_blocks,
- num_cpu_blocks=num_cpu_blocks,
- block_size=block_size,
- )
- self.block_tables: Dict[SeqId, BlockTable] = {}
- self.negative_block_tables: Dict[NegativeSeqId, BlockTable] = {}
- self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {}
- self._computed_blocks_tracker = ComputedBlocksTracker(
- self.block_allocator)
- self._last_access_blocks_tracker = LastAccessBlocksTracker(
- self.block_allocator)
- def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
- # FIXME: Here we assume that all sequences in the group share
- # the same prompt. This may not be true for preempted sequences.
- check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group)
- seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
- num_required_blocks = BlockTable.get_num_required_blocks(
- seq.get_token_ids(),
- block_size=self.block_size,
- )
- if seq_group.is_encoder_decoder():
- num_required_blocks += BlockTable.get_num_required_blocks(
- seq_group.get_encoder_seq().get_token_ids(),
- block_size=self.block_size,
- )
- if seq_group.has_negative_prompt():
- num_required_blocks += BlockTable.get_num_required_blocks(
- seq_group.get_negative_seq().get_token_ids(),
- block_size=self.block_size)
- if self.max_block_sliding_window is not None:
- num_required_blocks = min(num_required_blocks,
- self.max_block_sliding_window)
- num_free_gpu_blocks = self.block_allocator.get_num_free_blocks(
- device=Device.GPU)
- # Use watermark to avoid frequent cache eviction.
- if (self.num_total_gpu_blocks - num_required_blocks <
- self.watermark_blocks):
- return AllocStatus.NEVER
- if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks:
- return AllocStatus.OK
- else:
- return AllocStatus.LATER
- def _allocate_sequence(self, seq: Sequence) -> BlockTable:
- block_table = BlockTable(
- block_size=self.block_size,
- block_allocator=self.block_allocator,
- max_block_sliding_window=self.max_block_sliding_window,
- )
- block_table.allocate(seq.get_token_ids())
- return block_table
- def allocate(self, seq_group: SequenceGroup) -> None:
- # Allocate self-attention block tables for decoder sequences
- waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
- assert not (set(seq.seq_id for seq in waiting_seqs)
- & self.block_tables.keys()), "block table already exists"
- # NOTE: Here we assume that all sequences in the group have the same
- # prompt.
- seq = waiting_seqs[0]
- block_table: BlockTable = self._allocate_sequence(seq)
- self.block_tables[seq.seq_id] = block_table
- # Track seq
- self._computed_blocks_tracker.add_seq(seq.seq_id)
- self._last_access_blocks_tracker.add_seq(seq.seq_id)
- # Assign the block table for each sequence.
- for seq in waiting_seqs[1:]:
- self.block_tables[seq.seq_id] = block_table.fork()
- # Track seq
- self._computed_blocks_tracker.add_seq(seq.seq_id)
- self._last_access_blocks_tracker.add_seq(seq.seq_id)
- # Allocate cross-attention block table for encoder sequence
- #
- # NOTE: Here we assume that all sequences in the group have the same
- # encoder prompt.
- request_id = seq_group.request_id
- assert (request_id
- not in self.cross_block_tables), \
- "block table already exists"
- assert (request_id
- not in self.negative_block_tables), \
- "block table already exists"
-
- if seq_group.has_negative_prompt():
- block_table = self._allocate_sequence(
- seq_group.get_negative_seq())
- self.negative_block_tables[request_id] = block_table
-
- check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group)
- if seq_group.is_encoder_decoder():
- block_table = self._allocate_sequence(seq_group.get_encoder_seq())
- self.cross_block_tables[request_id] = block_table
- def can_append_slots(self, seq_group: SequenceGroup,
- num_lookahead_slots: int) -> bool:
- """Determine if there is enough space in the GPU KV cache to continue
- generation of the specified sequence group.
- We use a worst-case heuristic: assume each touched block will require a
- new allocation (either via CoW or new block). We can append slots if the
- number of touched blocks is less than the number of free blocks.
- "Lookahead slots" are slots that are allocated in addition to the slots
- for known tokens. The contents of the lookahead slots are not defined.
- This is used by speculative decoding when speculating future tokens.
- """
- num_touched_blocks = 0
- for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
- block_table = self.block_tables[seq.seq_id]
- num_touched_blocks += (
- block_table.get_num_blocks_touched_by_append_slots(
- token_ids=block_table.get_unseen_token_ids(
- seq.get_token_ids()),
- num_lookahead_slots=num_lookahead_slots,
- ))
- negative_block_table = self.negative_block_tables[
- seq_group.request_id]
- num_touched_blocks += (
- negative_block_table.get_num_blocks_touched_by_append_slots(
- token_ids=negative_block_table.get_unseen_token_ids(
- seq_group.get_negative_seq().get_token_ids()),
- num_lookahead_slots=num_lookahead_slots,
- ))
- num_free_gpu_blocks = self.block_allocator.get_num_free_blocks(
- Device.GPU)
- return num_touched_blocks <= num_free_gpu_blocks
- def append_slots(
- self,
- seq: Sequence,
- num_lookahead_slots: int,
- seq_group: SequenceGroup,
- ) -> List[Tuple[int, int]]:
- block_table = self.block_tables[seq.seq_id]
- block_table.append_token_ids(
- token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()),
- num_lookahead_slots=num_lookahead_slots,
- num_computed_slots=seq.data.get_num_computed_tokens(),
- )
- negative_block_table = self.negative_block_tables[seq_group.request_id]
- negative_seq = seq_group.negative_seq
- negative_block_table.append_token_ids(
- token_ids=negative_block_table.get_unseen_token_ids(
- negative_seq.get_token_ids()),
- num_lookahead_slots=num_lookahead_slots,
- num_computed_slots=negative_seq.data.get_num_computed_tokens(),
- )
- # Return any new copy-on-writes.
- new_cows = self.block_allocator.clear_copy_on_writes()
- return new_cows
- def free(self, seq: Sequence) -> None:
- seq_id = seq.seq_id
- if seq_id not in self.block_tables:
- # Already freed or haven't been scheduled yet.
- return
- # Update seq block ids with the latest access time
- self._last_access_blocks_tracker.update_seq_blocks_last_access(
- seq_id, self.block_tables[seq.seq_id].physical_block_ids)
- # Untrack seq
- self._last_access_blocks_tracker.remove_seq(seq_id)
- self._computed_blocks_tracker.remove_seq(seq_id)
- # Free table/blocks
- self.block_tables[seq_id].free()
- del self.block_tables[seq_id]
- def free_cross(self, seq_group: SequenceGroup) -> None:
- request_id = seq_group.request_id
- if request_id not in self.cross_block_tables:
- # Already freed or hasn't been scheduled yet.
- return
- self.cross_block_tables[request_id].free()
- del self.cross_block_tables[request_id]
- def free_negative(self, seq_group: SequenceGroup) -> None:
- request_id = seq_group.request_id
- if request_id not in self.negative_block_tables:
- return
- self.negative_block_tables[request_id].free()
- del self.negative_block_tables[request_id]
- def get_block_table(self, seq: Sequence) -> List[int]:
- block_ids = self.block_tables[seq.seq_id].physical_block_ids
- return block_ids # type: ignore
- def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]:
- request_id = seq_group.request_id
- assert request_id in self.cross_block_tables
- block_ids = self.cross_block_tables[request_id].physical_block_ids
- assert all(b is not None for b in block_ids)
- return block_ids # type: ignore
- def get_negative_block_table(self, seq_group: SequenceGroup) -> List[int]:
- request_id = seq_group.request_id
- assert request_id in self.negative_block_tables
- block_ids = self.negative_block_tables[request_id].physical_block_ids
- assert all(b is not None for b in block_ids)
- return block_ids
- def access_all_blocks_in_seq(self, seq: Sequence, now: float):
- if self.enable_caching:
- # Record the latest access time for the sequence. The actual update
- # of the block ids is deferred to the sequence free(..) call, since
- # only during freeing of block ids, the blocks are actually added to
- # the evictor (which is when the most updated time is required)
- # (This avoids expensive calls to mark_blocks_as_accessed(..))
- self._last_access_blocks_tracker.update_last_access(
- seq.seq_id, now)
- def mark_blocks_as_computed(self, seq_group: SequenceGroup):
- # The only need for mark block as computed is for prefix caching,
- # while currently we could determine whether one block is computed
- # or not by check whether it has content hash.
- # So this function is useless for block_v2.
- pass
- def get_common_computed_block_ids(
- self, seqs: List[Sequence]) -> GenericSequence[int]:
- """Determine which blocks for which we skip prefill.
- With prefix caching we can skip prefill for previously-generated blocks.
- Currently, the attention implementation only supports skipping cached
- blocks if they are a contiguous prefix of cached blocks.
- This method determines which blocks can be safely skipped for all
- sequences in the sequence group.
- """
- computed_seq_block_ids = []
- for seq in seqs:
- computed_seq_block_ids.append(
- self._computed_blocks_tracker.
- get_cached_computed_blocks_and_update(
- seq.seq_id,
- self.block_tables[seq.seq_id].physical_block_ids))
- # NOTE: This assumes seq_block_ids doesn't contain any None.
- return self.block_allocator.get_common_computed_block_ids(
- computed_seq_block_ids) # type: ignore
- def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
- if parent_seq.seq_id not in self.block_tables:
- # Parent sequence has either been freed or never existed.
- return
- src_block_table = self.block_tables[parent_seq.seq_id]
- self.block_tables[child_seq.seq_id] = src_block_table.fork()
- # Track child seq
- self._computed_blocks_tracker.add_seq(child_seq.seq_id)
- self._last_access_blocks_tracker.add_seq(child_seq.seq_id)
- def can_swap_in(self, seq_group: SequenceGroup,
- num_lookahead_slots: int) -> AllocStatus:
- """Returns the AllocStatus for the given sequence_group
- with num_lookahead_slots.
- Args:
- sequence_group (SequenceGroup): The sequence group to swap in.
- num_lookahead_slots (int): Number of lookahead slots used in
- speculative decoding, default to 0.
- Returns:
- AllocStatus: The AllocStatus for the given sequence group.
- """
- return self._can_swap(seq_group, Device.GPU, SequenceStatus.SWAPPED,
- num_lookahead_slots)
- def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
- """Returns the block id mapping (from CPU to GPU) generated by
- swapping in the given seq_group with num_lookahead_slots.
- Args:
- seq_group (SequenceGroup): The sequence group to swap in.
- Returns:
- List[Tuple[int, int]]: The mapping of swapping block from CPU
- to GPU.
- """
- physical_block_id_mapping = []
- for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
- blocks = self.block_tables[seq.seq_id].blocks
- if len(blocks) == 0:
- continue
- seq_swap_mapping = self.block_allocator.swap(blocks=blocks,
- src_device=Device.CPU,
- dst_device=Device.GPU)
- # Refresh the block ids of the table (post-swap)
- self.block_tables[seq.seq_id].update(blocks)
- seq_physical_block_id_mapping = {
- self.block_allocator.get_physical_block_id(
- Device.CPU, cpu_block_id):
- self.block_allocator.get_physical_block_id(
- Device.GPU, gpu_block_id)
- for cpu_block_id, gpu_block_id in seq_swap_mapping.items()
- }
- physical_block_id_mapping.extend(
- list(seq_physical_block_id_mapping.items()))
- return physical_block_id_mapping
- def can_swap_out(self, seq_group: SequenceGroup) -> bool:
- """Returns whether we can swap out the given sequence_group
- with num_lookahead_slots.
- Args:
- seq_group (SequenceGroup): The sequence group to swap in.
- num_lookahead_slots (int): Number of lookahead slots used in
- speculative decoding, default to 0.
- Returns:
- bool: Whether it's possible to swap out current sequence group.
- """
- alloc_status = self._can_swap(seq_group, Device.CPU,
- SequenceStatus.RUNNING)
- if alloc_status == AllocStatus.OK:
- return True
- return False
- def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
- """Returns the block id mapping (from GPU to CPU) generated by
- swapping out the given sequence_group with num_lookahead_slots.
- Args:
- sequence_group (SequenceGroup): The sequence group to swap in.
- Returns:
- List[Tuple[int, int]]: The mapping of swapping block from
- GPU to CPU.
- """
- physical_block_id_mapping = []
- for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
- blocks = self.block_tables[seq.seq_id].blocks
- if len(blocks) == 0:
- continue
- seq_swap_mapping = self.block_allocator.swap(blocks=blocks,
- src_device=Device.GPU,
- dst_device=Device.CPU)
- # Refresh the block ids of the table (post-swap)
- self.block_tables[seq.seq_id].update(blocks)
- seq_physical_block_id_mapping = {
- self.block_allocator.get_physical_block_id(
- Device.GPU, gpu_block_id):
- self.block_allocator.get_physical_block_id(
- Device.CPU, cpu_block_id)
- for gpu_block_id, cpu_block_id in seq_swap_mapping.items()
- }
- physical_block_id_mapping.extend(
- list(seq_physical_block_id_mapping.items()))
- return physical_block_id_mapping
- def get_num_free_gpu_blocks(self) -> int:
- return self.block_allocator.get_num_free_blocks(Device.GPU)
- def get_num_free_cpu_blocks(self) -> int:
- return self.block_allocator.get_num_free_blocks(Device.CPU)
- def get_prefix_cache_hit_rate(self, device: Device) -> float:
- return self.block_allocator.get_prefix_cache_hit_rate(device)
- def _can_swap(self,
- seq_group: SequenceGroup,
- device: Device,
- status: SequenceStatus,
- num_lookahead_slots: int = 0) -> AllocStatus:
- """Returns the AllocStatus for swapping in/out the given sequence_group
- on to the 'device'.
- Args:
- sequence_group (SequenceGroup): The sequence group to swap in.
- device (Device): device to swap the 'seq_group' on.
- status (SequenceStatus): The status of sequence which is needed
- for action. RUNNING for swap out and SWAPPED for swap in
- num_lookahead_slots (int): Number of lookahead slots used in
- speculative decoding, default to 0.
- Returns:
- AllocStatus: The AllocStatus for swapping in/out the given
- sequence_group on to the 'device'.
- """
- blocks = self._get_blocks_for_swap(seq_group, status)
- num_blocks_touched = self.block_allocator.get_num_blocks_touched(
- blocks, device, num_lookahead_slots)
- watermark_blocks = 0
- if device == Device.GPU:
- watermark_blocks = self.watermark_blocks
- if self.block_allocator.get_num_total_blocks(
- device) < num_blocks_touched:
- return AllocStatus.NEVER
- elif self.block_allocator.get_num_free_blocks(
- device) - num_blocks_touched >= watermark_blocks:
- return AllocStatus.OK
- else:
- return AllocStatus.LATER
- def _get_blocks_for_swap(self, seq_group: SequenceGroup,
- status: SequenceStatus) -> List[Block]:
- """Returns the list of blocks those are touched by the seq_group
-
- Args:
- sequence_group (SequenceGroup): The sequence group to swap in.
- status (SequenceStatus): The status of sequence which is needed
- for action. RUNNING for swap out and SWAPPED for swap in
-
- Returns:
- The list of blocks those are touched by the seq_group.
- """
- blocks: Dict[int, List[Block]] = {}
- for seq in seq_group.get_seqs(status=status):
- block_table = self.block_tables[seq.seq_id]
- if block_table.blocks is not None:
- blocks[seq.seq_id] = block_table.blocks
- combined_blocks = list(chain(*blocks.values()))
- return combined_blocks
|