123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121 |
- """CacheEngine class for managing the KV cache."""
- from typing import Dict, List
- import torch
- from aphrodite.attention import get_attn_backend
- from aphrodite.common.config import CacheConfig, ModelConfig, ParallelConfig
- from aphrodite.common.utils import (
- is_pin_memory_available,
- STR_DTYPE_TO_TORCH_DTYPE,
- )
- class CacheEngine:
- """Manages the KV cache.
- This class is responsible for initializing and managing the GPU and CPU KV
- caches. It also provides methods for performing KV cache operations, such
- as swapping and copying.
- """
- def __init__(
- self,
- cache_config: CacheConfig,
- model_config: ModelConfig,
- parallel_config: ParallelConfig,
- ) -> None:
- self.cache_config = cache_config
- self.model_config = model_config
- self.parallel_config = parallel_config
- self.head_size = model_config.get_head_size()
- self.num_layers = CacheEngine.get_num_attention_layers(
- model_config, parallel_config)
- self.num_heads = model_config.get_num_kv_heads(parallel_config)
- self.block_size = cache_config.block_size
- self.num_gpu_blocks = cache_config.num_gpu_blocks
- self.num_cpu_blocks = cache_config.num_cpu_blocks
- if cache_config.cache_dtype == "auto":
- self.dtype = model_config.dtype
- else:
- self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
- # Get attention backend.
- self.attn_backend = get_attn_backend(model_config.dtype)
- # Initialize the cache.
- # Get attention backend.
- self.attn_backend = get_attn_backend(model_config.dtype)
- # Initialize the cache.
- self.gpu_cache = self._allocate_kv_cache(self.num_gpu_blocks, "cuda")
- self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu")
- def _allocate_kv_cache(
- self,
- num_blocks: int,
- device: str,
- ) -> List[torch.Tensor]:
- """Allocates KV cache on the specified device."""
- kv_cache_shape = self.attn_backend.get_kv_cache_shape(
- num_blocks, self.block_size, self.num_heads, self.head_size)
- pin_memory = is_pin_memory_available() if device == "cpu" else False
- kv_cache: List[torch.Tensor] = []
- for _ in range(self.num_layers):
- kv_cache.append(
- torch.empty(kv_cache_shape,
- dtype=self.dtype,
- pin_memory=pin_memory,
- device=device))
- return kv_cache
- def swap_in(self, src_to_dst: Dict[int, int]) -> None:
- for i in range(self.num_layers):
- self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i],
- src_to_dst)
- def swap_out(self, src_to_dst: Dict[int, int]) -> None:
- for i in range(self.num_layers):
- self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i],
- src_to_dst)
- def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
- self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts)
- @staticmethod
- def get_num_attention_layers(model_config: ModelConfig,
- parallel_config: ParallelConfig):
- num_layers = model_config.get_num_layers(parallel_config)
- is_mamba = model_config.hf_config.model_type == "jamba"
- if is_mamba:
- attention_period = model_config.hf_config.attn_layer_period
- num_layers = max(num_layers // attention_period, 1)
- return num_layers
- @staticmethod
- def get_cache_block_size(
- cache_config: CacheConfig,
- model_config: ModelConfig,
- parallel_config: ParallelConfig,
- ) -> int:
- head_size = model_config.get_head_size()
- num_heads = model_config.get_num_kv_heads(parallel_config)
- num_layers = CacheEngine.get_num_attention_layers(
- model_config, parallel_config)
- key_cache_block = cache_config.block_size * num_heads * head_size
- value_cache_block = key_cache_block
- total = num_layers * (key_cache_block + value_cache_block)
- if cache_config.cache_dtype == "auto":
- dtype = model_config.dtype
- else:
- dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
- dtype_size = _get_dtype_size(dtype)
- return dtype_size * total
- def _get_dtype_size(dtype: torch.dtype) -> int:
- return torch.tensor([], dtype=dtype).element_size()
|