cache_engine.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. """CacheEngine class for managing the KV cache."""
  2. from typing import List
  3. import torch
  4. from aphrodite.attention import get_attn_backend
  5. from aphrodite.common.config import (CacheConfig, DeviceConfig, ModelConfig,
  6. ParallelConfig)
  7. from aphrodite.common.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size,
  8. is_pin_memory_available)
  9. class CacheEngine:
  10. """Manages the KV cache.
  11. This class is responsible for initializing and managing the GPU and CPU KV
  12. caches. It also provides methods for performing KV cache operations, such
  13. as swapping and copying.
  14. """
  15. def __init__(
  16. self,
  17. cache_config: CacheConfig,
  18. model_config: ModelConfig,
  19. parallel_config: ParallelConfig,
  20. device_config: DeviceConfig,
  21. ) -> None:
  22. self.cache_config = cache_config
  23. self.model_config = model_config
  24. self.parallel_config = parallel_config
  25. self.device_config = device_config
  26. self.head_size = model_config.get_head_size()
  27. self.num_layers = model_config.get_num_layers(parallel_config)
  28. self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
  29. self.block_size = cache_config.block_size
  30. self.num_gpu_blocks = cache_config.num_gpu_blocks
  31. self.num_cpu_blocks = cache_config.num_cpu_blocks
  32. if cache_config.cache_dtype == "auto":
  33. self.dtype = model_config.dtype
  34. else:
  35. self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
  36. # Get attention backend.
  37. self.attn_backend = get_attn_backend(
  38. model_config.get_num_attention_heads(parallel_config),
  39. self.head_size,
  40. self.num_kv_heads,
  41. model_config.get_sliding_window(),
  42. model_config.dtype,
  43. cache_config.cache_dtype,
  44. self.block_size,
  45. )
  46. # Initialize the cache.
  47. self.gpu_cache = self._allocate_kv_cache(
  48. self.num_gpu_blocks, self.device_config.device_type)
  49. self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu")
  50. def _allocate_kv_cache(
  51. self,
  52. num_blocks: int,
  53. device: str,
  54. ) -> List[torch.Tensor]:
  55. """Allocates KV cache on the specified device."""
  56. kv_cache_shape = self.attn_backend.get_kv_cache_shape(
  57. num_blocks, self.block_size, self.num_kv_heads, self.head_size)
  58. pin_memory = is_pin_memory_available() if device == "cpu" else False
  59. kv_cache: List[torch.Tensor] = []
  60. for _ in range(self.num_layers):
  61. # null block in CpuGpuBlockAllocator requires at least that
  62. # block to be zeroed-out.
  63. # We zero-out everything for simplicity.
  64. kv_cache.append(
  65. torch.zeros(kv_cache_shape,
  66. dtype=self.dtype,
  67. pin_memory=pin_memory,
  68. device=device))
  69. return kv_cache
  70. def swap_in(self, src_to_dst: torch.Tensor) -> None:
  71. for i in range(self.num_layers):
  72. self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i],
  73. src_to_dst)
  74. def swap_out(self, src_to_dst: torch.Tensor) -> None:
  75. for i in range(self.num_layers):
  76. self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i],
  77. src_to_dst)
  78. def copy(self, src_to_dsts: torch.Tensor) -> None:
  79. self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts)
  80. @staticmethod
  81. def get_cache_block_size(
  82. cache_config: CacheConfig,
  83. model_config: ModelConfig,
  84. parallel_config: ParallelConfig,
  85. ) -> int:
  86. head_size = model_config.get_head_size()
  87. num_heads = model_config.get_num_kv_heads(parallel_config)
  88. num_layers = model_config.get_num_layers(parallel_config)
  89. key_cache_block = cache_config.block_size * num_heads * head_size
  90. value_cache_block = key_cache_block
  91. total = num_layers * (key_cache_block + value_cache_block)
  92. if cache_config.cache_dtype == "auto":
  93. dtype = model_config.dtype
  94. else:
  95. dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
  96. dtype_size = get_dtype_size(dtype)
  97. return dtype_size * total