123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410 |
- from typing import Dict, FrozenSet, List, Optional, Tuple
- from aphrodite.common.utils import Device
- from aphrodite.processing.block.interfaces import (Block, BlockAllocator,
- BlockId,
- DeviceAwareBlockAllocator)
- from aphrodite.processing.block.naive_block import (NaiveBlock,
- NaiveBlockAllocator)
- from aphrodite.processing.block.prefix_caching_block import (
- PrefixCachingBlockAllocator)
- class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
- """A block allocator that can allocate blocks on both CPU and GPU memory.
- This class implements the `DeviceAwareBlockAllocator` interface and provides
- functionality for allocating and managing blocks of memory on both CPU and
- GPU devices.
- The `CpuGpuBlockAllocator` maintains separate memory pools for CPU and GPU
- blocks, and allows for allocation, deallocation, forking, and swapping of
- blocks across these memory pools.
- """
- @staticmethod
- def create(
- allocator_type: str,
- num_gpu_blocks: int,
- num_cpu_blocks: int,
- block_size: int,
- ) -> DeviceAwareBlockAllocator:
- """Creates a CpuGpuBlockAllocator instance with the specified
- configuration.
- This static method creates and returns a CpuGpuBlockAllocator instance
- based on the provided parameters. It initializes the CPU and GPU block
- allocators with the specified number of blocks, block size, and
- allocator type.
- Args:
- allocator_type (str): The type of block allocator to use for CPU
- and GPU blocks. Currently supported values are "naive" and
- "prefix_caching".
- num_gpu_blocks (int): The number of blocks to allocate for GPU
- memory.
- num_cpu_blocks (int): The number of blocks to allocate for CPU
- memory.
- block_size (int): The size of each block in number of tokens.
- Returns:
- DeviceAwareBlockAllocator: A CpuGpuBlockAllocator instance with the
- specified configuration.
- Notes:
- - The block IDs are assigned contiguously, with GPU block IDs coming
- before CPU block IDs.
- """
- block_ids = list(range(num_gpu_blocks + num_cpu_blocks))
- gpu_block_ids = block_ids[:num_gpu_blocks]
- cpu_block_ids = block_ids[num_gpu_blocks:]
- if allocator_type == "naive":
- gpu_allocator: BlockAllocator = NaiveBlockAllocator(
- create_block=NaiveBlock, # type: ignore
- num_blocks=num_gpu_blocks,
- block_size=block_size,
- block_ids=gpu_block_ids,
- )
- cpu_allocator: BlockAllocator = NaiveBlockAllocator(
- create_block=NaiveBlock, # type: ignore
- num_blocks=num_cpu_blocks,
- block_size=block_size,
- block_ids=cpu_block_ids,
- )
- elif allocator_type == "prefix_caching":
- gpu_allocator = PrefixCachingBlockAllocator(
- num_blocks=num_gpu_blocks,
- block_size=block_size,
- block_ids=gpu_block_ids,
- )
- cpu_allocator = PrefixCachingBlockAllocator(
- num_blocks=num_cpu_blocks,
- block_size=block_size,
- block_ids=cpu_block_ids,
- )
- else:
- raise ValueError(f"Unknown allocator type {allocator_type=}")
- return CpuGpuBlockAllocator(
- cpu_block_allocator=cpu_allocator,
- gpu_block_allocator=gpu_allocator,
- )
- def __init__(self, cpu_block_allocator: BlockAllocator,
- gpu_block_allocator: BlockAllocator):
- assert not (
- cpu_block_allocator.all_block_ids
- & gpu_block_allocator.all_block_ids
- ), "cpu and gpu block allocators can't have intersection of block ids"
- self._allocators = {
- Device.CPU: cpu_block_allocator,
- Device.GPU: gpu_block_allocator,
- }
- self._swap_mapping: Dict[int, int] = {}
- self._null_block: Optional[Block] = None
- self._block_ids_to_allocator: Dict[int, BlockAllocator] = {}
- for _, allocator in self._allocators.items():
- for block_id in allocator.all_block_ids:
- self._block_ids_to_allocator[block_id] = allocator
- def allocate_or_get_null_block(self) -> Block:
- if self._null_block is None:
- self._null_block = NullBlock(
- self.allocate_mutable_block(None, Device.GPU))
- return self._null_block
- def allocate_mutable_block(self, prev_block: Optional[Block],
- device: Device) -> Block:
- """Allocates a new mutable block on the specified device.
- Args:
- prev_block (Optional[Block]): The previous block to in the sequence.
- Used for prefix hashing.
- device (Device): The device on which to allocate the new block.
- Returns:
- Block: The newly allocated mutable block.
- """
- return self._allocators[device].allocate_mutable_block(prev_block)
- def allocate_immutable_blocks(self, prev_block: Optional[Block],
- block_token_ids: List[List[int]],
- device: Optional[Device]) -> List[Block]:
- """Allocates a new group of immutable blocks with the provided block
- token IDs on the specified device.
- Args:
- prev_block (Optional[Block]): The previous block in the sequence.
- Used for prefix hashing.
- block_token_ids (List[int]): The list of block token IDs to be
- stored in the new blocks.
- device (Device): The device on which to allocate the new block.
- Returns:
- List[Block]: The newly allocated list of immutable blocks
- containing the provided block token IDs.
- """
- return self._allocators[device].allocate_immutable_blocks(
- prev_block, block_token_ids)
- def allocate_immutable_block(self, prev_block: Optional[Block],
- token_ids: List[int],
- device: Device) -> Block:
- """Allocates a new immutable block with the provided token IDs on the
- specified device.
- Args:
- prev_block (Optional[Block]): The previous block in the sequence.
- Used for prefix hashing.
- token_ids (List[int]): The list of token IDs to be stored in the new
- block.
- device (Device): The device on which to allocate the new block.
- Returns:
- Block: The newly allocated immutable block containing the provided
- token IDs.
- """
- return self._allocators[device].allocate_immutable_block(
- prev_block, token_ids)
- def free(self, block: Block) -> None:
- """Frees the memory occupied by the given block.
- Args:
- block (Block): The block to be freed.
- """
- # Null block should never be freed
- if isinstance(block, NullBlock):
- return
- block_id = block.block_id
- assert block_id is not None
- allocator = self._block_ids_to_allocator[block_id]
- allocator.free(block)
- def fork(self, last_block: Block) -> List[Block]:
- """Creates a new sequence of blocks that shares the same underlying
- memory as the original sequence.
- Args:
- last_block (Block): The last block in the original sequence.
- Returns:
- List[Block]: A new list of blocks that shares the same memory as the
- original sequence.
- """
- # do not attempt to fork the null block
- assert not isinstance(last_block, NullBlock)
- block_id = last_block.block_id
- assert block_id is not None
- allocator = self._block_ids_to_allocator[block_id]
- return allocator.fork(last_block)
- def get_num_free_blocks(self, device: Device) -> int:
- """Returns the number of free blocks available on the specified device.
- Args:
- device (Device): The device for which to query the number of free
- blocks. AssertionError is raised if None is passed.
- Returns:
- int: The number of free blocks available on the specified device.
- """
- return self._allocators[device].get_num_free_blocks()
- def get_num_total_blocks(self, device: Device) -> int:
- return self._allocators[device].get_num_total_blocks()
- def get_physical_block_id(self, device: Device, absolute_id: int) -> int:
- """Returns the zero-offset block id on certain device given the
- absolute block id.
- Args:
- device (Device): The device for which to query relative block id.
- absolute_id (int): The absolute block id for the block in
- whole allocator.
- Returns:
- int: The zero-offset block id on certain device.
- """
- return self._allocators[device].get_physical_block_id(absolute_id)
- def swap(self, blocks: List[Block], src_device: Device,
- dst_device: Device) -> Dict[int, int]:
- """Execute the swap for the given blocks from source_device
- on to dest_device, save the current swap mapping and append
- them to the accumulated `self._swap_mapping` for each
- scheduling move.
- Args:
- blocks: List of blocks to be swapped.
- src_device (Device): Device to swap the 'blocks' from.
- dst_device (Device): Device to swap the 'blocks' to.
-
- Returns:
- Dict[int, int]: Swap mapping from source_device
- on to dest_device.
- """
- src_block_ids = [block.block_id for block in blocks]
- self._allocators[src_device].swap_out(blocks)
- self._allocators[dst_device].swap_in(blocks)
- dst_block_ids = [block.block_id for block in blocks]
- current_swap_mapping: Dict[int, int] = {}
- for src_block_id, dst_block_id in zip(src_block_ids, dst_block_ids):
- if src_block_id is not None and dst_block_id is not None:
- self._swap_mapping[src_block_id] = dst_block_id
- current_swap_mapping[src_block_id] = dst_block_id
- return current_swap_mapping
- def get_num_blocks_touched(self,
- blocks: List[Block],
- device: Device,
- num_lookahead_slots: int = 0) -> int:
- """Returns the number of blocks that will be touched by
- swapping in/out the given blocks on to the 'device'.
- Args:
- blocks: List of blocks to be swapped.
- device (Device): Device to swap the 'blocks' on.
- num_lookahead_slots (int): Number of lookahead slots used in
- speculative decoding, default to 0.
- Returns:
- int: the number of blocks that will be touched by
- swapping in/out the given blocks on to the 'device'.
- """
- return self._allocators[device].get_num_blocks_touched(
- blocks, num_lookahead_slots)
- def clear_copy_on_writes(self) -> List[Tuple[int, int]]:
- """Clears the copy-on-write (CoW) state and returns the mapping of
- source to destination block IDs.
- Returns:
- List[Tuple[int, int]]: A list mapping source block IDs to
- destination block IDs.
- """
- # CoW only supported on GPU
- device = Device.GPU
- return self._allocators[device].clear_copy_on_writes()
- def mark_blocks_as_accessed(self, block_ids: List[int],
- now: float) -> None:
- """Mark blocks as accessed, only use for prefix caching."""
- # Prefix caching only supported on GPU.
- device = Device.GPU
- return self._allocators[device].mark_blocks_as_accessed(block_ids, now)
- def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
- """Mark blocks as accessed, only use for prefix caching."""
- # Prefix caching only supported on GPU.
- device = Device.GPU
- return self._allocators[device].mark_blocks_as_computed(block_ids)
- def get_computed_block_ids(self, prev_computed_block_ids: List[int],
- block_ids: List[int],
- skip_last_block_id: bool) -> List[int]:
- # Prefix caching only supported on GPU.
- device = Device.GPU
- return self._allocators[device].get_computed_block_ids(
- prev_computed_block_ids, block_ids, skip_last_block_id)
- def get_common_computed_block_ids(
- self, computed_seq_block_ids: List[List[int]]) -> List[int]:
- # Prefix caching only supported on GPU.
- device = Device.GPU
- return self._allocators[device].get_common_computed_block_ids(
- computed_seq_block_ids)
- @property
- def all_block_ids(self) -> FrozenSet[int]:
- return frozenset(self._block_ids_to_allocator.keys())
- def get_prefix_cache_hit_rate(self, device: Device) -> float:
- """Prefix cache hit rate. -1 means not supported or disabled."""
- assert device in self._allocators
- return self._allocators[device].get_prefix_cache_hit_rate()
- def get_and_reset_swaps(self) -> List[Tuple[int, int]]:
- """Returns and clears the mapping of source to destination block IDs.
- Will be called after every swapping operations for now, and after every
- schedule when BlockManagerV2 become default. Currently not useful.
- Returns:
- List[Tuple[int, int]]: A mapping of source to destination block IDs.
- """
- mapping = self._swap_mapping.copy()
- self._swap_mapping.clear()
- return list(mapping.items())
- class NullBlock(Block):
- """
- Null blocks are used as a placeholders for KV cache blocks that have
- been dropped due to sliding window.
- This implementation just wraps an ordinary block and prevents it from
- being modified. It also allows for testing if a block is NullBlock
- via isinstance().
- """
- def __init__(self, proxy: Block):
- super().__init__()
- self._proxy = proxy
- def append_token_ids(self, token_ids: List[BlockId]):
- raise ValueError("null block should not be modified")
- @property
- def block_id(self):
- return self._proxy.block_id
- @block_id.setter
- def block_id(self, value: Optional[BlockId]):
- raise ValueError("null block should not be modified")
- @property
- def token_ids(self) -> List[BlockId]:
- return self._proxy.token_ids
- @property
- def num_tokens_total(self) -> int:
- raise NotImplementedError(
- "num_tokens_total is not used for null block")
- @property
- def num_empty_slots(self) -> BlockId:
- return self._proxy.num_empty_slots
- @property
- def is_full(self):
- return self._proxy.is_full
- @property
- def prev_block(self):
- return self._proxy.prev_block
- @property
- def computed(self):
- return self._proxy.computed
- @computed.setter
- def computed(self, value):
- self._proxy.computed = value
- @property
- def last_accessed(self) -> float:
- return self._proxy.last_accessed
- @last_accessed.setter
- def last_accessed(self, last_accessed_ts: float):
- self._proxy.last_accessed = last_accessed_ts
- @property
- def content_hash(self):
- return self._proxy.content_hash
|