cache_engine.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. """CacheEngine class for managing the KV cache."""
  2. from typing import Dict, List
  3. import torch
  4. from aphrodite.attention import get_attn_backend
  5. from aphrodite.common.config import CacheConfig, ModelConfig, ParallelConfig
  6. from aphrodite.common.utils import (
  7. is_pin_memory_available,
  8. STR_DTYPE_TO_TORCH_DTYPE,
  9. )
  10. class CacheEngine:
  11. """Manages the KV cache.
  12. This class is responsible for initializing and managing the GPU and CPU KV
  13. caches. It also provides methods for performing KV cache operations, such
  14. as swapping and copying.
  15. """
  16. def __init__(
  17. self,
  18. cache_config: CacheConfig,
  19. model_config: ModelConfig,
  20. parallel_config: ParallelConfig,
  21. ) -> None:
  22. self.cache_config = cache_config
  23. self.model_config = model_config
  24. self.parallel_config = parallel_config
  25. self.head_size = model_config.get_head_size()
  26. self.num_layers = model_config.get_num_layers(parallel_config)
  27. self.num_heads = model_config.get_num_kv_heads(parallel_config)
  28. self.block_size = cache_config.block_size
  29. self.num_gpu_blocks = cache_config.num_gpu_blocks
  30. self.num_cpu_blocks = cache_config.num_cpu_blocks
  31. if cache_config.cache_dtype == "auto":
  32. self.dtype = model_config.dtype
  33. else:
  34. self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
  35. # Get attention backend.
  36. self.attn_backend = get_attn_backend(model_config.dtype)
  37. # Initialize the cache.
  38. # Get attention backend.
  39. self.attn_backend = get_attn_backend(model_config.dtype)
  40. # Initialize the cache.
  41. self.gpu_cache = self._allocate_kv_cache(self.num_gpu_blocks, "cuda")
  42. self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu")
  43. def _allocate_kv_cache(
  44. self,
  45. num_blocks: int,
  46. device: str,
  47. ) -> List[torch.Tensor]:
  48. """Allocates KV cache on the specified device."""
  49. kv_cache_shape = self.attn_backend.get_kv_cache_shape(
  50. num_blocks, self.block_size, self.num_heads, self.head_size)
  51. pin_memory = is_pin_memory_available() if device == "cpu" else False
  52. kv_cache: List[torch.Tensor] = []
  53. for _ in range(self.num_layers):
  54. kv_cache.append(
  55. torch.empty(kv_cache_shape,
  56. dtype=self.dtype,
  57. pin_memory=pin_memory,
  58. device=device))
  59. return kv_cache
  60. def swap_in(self, src_to_dst: Dict[int, int]) -> None:
  61. for i in range(self.num_layers):
  62. self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i],
  63. src_to_dst)
  64. def swap_out(self, src_to_dst: Dict[int, int]) -> None:
  65. for i in range(self.num_layers):
  66. self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i],
  67. src_to_dst)
  68. def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
  69. self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts)
  70. @staticmethod
  71. def get_cache_block_size(
  72. cache_config: CacheConfig,
  73. model_config: ModelConfig,
  74. parallel_config: ParallelConfig,
  75. ) -> int:
  76. head_size = model_config.get_head_size()
  77. num_heads = model_config.get_num_kv_heads(parallel_config)
  78. num_layers = model_config.get_num_layers(parallel_config)
  79. key_cache_block = cache_config.block_size * num_heads * head_size
  80. value_cache_block = key_cache_block
  81. total = num_layers * (key_cache_block + value_cache_block)
  82. if cache_config.cache_dtype == "auto":
  83. dtype = model_config.dtype
  84. else:
  85. dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
  86. dtype_size = _get_dtype_size(dtype)
  87. return dtype_size * total
  88. def _get_dtype_size(dtype: torch.dtype) -> int:
  89. return torch.tensor([], dtype=dtype).element_size()