123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169 |
- """CacheEngine class for managing the KV cache."""
- from typing import Dict, List, Tuple
- import torch
- from loguru import logger
- from aphrodite._C import cache_ops
- from aphrodite.common.config import CacheConfig, ModelConfig, ParallelConfig
- from aphrodite.common.utils import in_wsl, STR_DTYPE_TO_TORCH_DTYPE
- KVCache = Tuple[torch.Tensor, torch.Tensor]
- class CacheEngine:
- """Manages the KV cache.
- This class is responsible for initializing and managing the GPU and CPU KV
- caches. It also provides methods for performing KV cache operations, such
- as swapping and copying.
- """
- def __init__(
- self,
- cache_config: CacheConfig,
- model_config: ModelConfig,
- parallel_config: ParallelConfig,
- ) -> None:
- self.cache_config = cache_config
- self.model_config = model_config
- self.parallel_config = parallel_config
- self.head_size = model_config.get_head_size()
- self.num_layers = model_config.get_num_layers(parallel_config)
- self.num_heads = model_config.get_num_kv_heads(parallel_config)
- self.block_size = cache_config.block_size
- self.num_gpu_blocks = cache_config.num_gpu_blocks
- self.num_cpu_blocks = cache_config.num_cpu_blocks
- if cache_config.cache_dtype == "auto":
- self.dtype = model_config.dtype
- else:
- self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
- # Initialize the cache.
- self.gpu_cache = self.allocate_gpu_cache()
- self.cpu_cache = self.allocate_cpu_cache()
- # Initialize the stream for caching operations.
- self.cache_stream = torch.cuda.Stream()
- assert self.cache_stream != torch.cuda.current_stream()
- # Initialize the events for stream synchronization.
- self.events = [torch.cuda.Event() for _ in range(self.num_layers)]
- def get_key_block_shape(self) -> Tuple[int, int, int, int]:
- element_size = torch.tensor([], dtype=self.dtype).element_size()
- x = 16 // element_size
- return (
- self.num_heads,
- self.head_size // x,
- self.block_size,
- x,
- )
- def get_value_block_shape(self) -> Tuple[int, int, int]:
- return (
- self.num_heads,
- self.head_size,
- self.block_size,
- )
- def allocate_gpu_cache(self) -> List[KVCache]:
- gpu_cache: List[KVCache] = []
- key_block_shape = self.get_key_block_shape()
- value_block_shape = self.get_value_block_shape()
- for _ in range(self.num_layers):
- key_blocks = torch.empty(
- size=(self.num_gpu_blocks, *key_block_shape),
- dtype=self.dtype,
- device="cuda",
- )
- value_blocks = torch.empty(
- size=(self.num_gpu_blocks, *value_block_shape),
- dtype=self.dtype,
- device="cuda",
- )
- gpu_cache.append((key_blocks, value_blocks))
- return gpu_cache
- def allocate_cpu_cache(self) -> List[KVCache]:
- cpu_cache: List[KVCache] = []
- key_block_shape = self.get_key_block_shape()
- value_block_shape = self.get_value_block_shape()
- pin_memory = not in_wsl()
- if not pin_memory:
- # Pinning memory in WSL is not supported.
- # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
- logger.warning("Using 'pin_memory=False' as WSL is detected. "
- "This may slow down the performance.")
- for _ in range(self.num_layers):
- key_blocks = torch.empty(
- size=(self.num_cpu_blocks, *key_block_shape),
- dtype=self.dtype,
- pin_memory=pin_memory,
- device="cpu",
- )
- value_blocks = torch.empty(
- size=(self.num_cpu_blocks, *value_block_shape),
- dtype=self.dtype,
- pin_memory=pin_memory,
- device="cpu",
- )
- cpu_cache.append((key_blocks, value_blocks))
- return cpu_cache
- def _swap(
- self,
- src: List[KVCache],
- dst: List[KVCache],
- src_to_dst: Dict[int, int],
- ) -> None:
- with torch.cuda.stream(self.cache_stream):
- for i in range(self.num_layers):
- src_key_cache, src_value_cache = src[i]
- dst_key_cache, dst_value_cache = dst[i]
- # Copy the key blocks.
- cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
- # Copy the value blocks.
- cache_ops.swap_blocks(src_value_cache, dst_value_cache,
- src_to_dst)
- event = self.events[i]
- event.record(stream=self.cache_stream)
- def swap_in(self, src_to_dst: Dict[int, int]) -> None:
- self._swap(self.cpu_cache, self.gpu_cache, src_to_dst)
- def swap_out(self, src_to_dst: Dict[int, int]) -> None:
- self._swap(self.gpu_cache, self.cpu_cache, src_to_dst)
- def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
- key_caches = [key_cache for key_cache, _ in self.gpu_cache]
- value_caches = [value_cache for _, value_cache in self.gpu_cache]
- # NOTE: This operation implicitly synchronizes the CPU and GPU.
- cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts)
- @staticmethod
- def get_cache_block_size(
- block_size: int,
- cache_dtype: str,
- model_config: ModelConfig,
- parallel_config: ParallelConfig,
- ) -> int:
- head_size = model_config.get_head_size()
- num_heads = model_config.get_num_kv_heads(parallel_config)
- num_layers = model_config.get_num_layers(parallel_config)
- key_cache_block = block_size * num_heads * head_size
- value_cache_block = key_cache_block
- total = num_layers * (key_cache_block + value_cache_block)
- if cache_dtype == "auto":
- dtype = model_config.dtype
- else:
- dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
- dtype_size = _get_dtype_size(dtype)
- return dtype_size * total
- def _get_dtype_size(dtype: torch.dtype) -> int:
- return torch.tensor([], dtype=dtype).element_size()
|