cache_engine.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. """CacheEngine class for managing the KV cache."""
  2. from typing import Dict, List, Tuple
  3. import torch
  4. from loguru import logger
  5. from aphrodite._C import cache_ops
  6. from aphrodite.common.config import CacheConfig, ModelConfig, ParallelConfig
  7. from aphrodite.common.utils import in_wsl, STR_DTYPE_TO_TORCH_DTYPE
  8. KVCache = Tuple[torch.Tensor, torch.Tensor]
  9. class CacheEngine:
  10. """Manages the KV cache.
  11. This class is responsible for initializing and managing the GPU and CPU KV
  12. caches. It also provides methods for performing KV cache operations, such
  13. as swapping and copying.
  14. """
  15. def __init__(
  16. self,
  17. cache_config: CacheConfig,
  18. model_config: ModelConfig,
  19. parallel_config: ParallelConfig,
  20. ) -> None:
  21. self.cache_config = cache_config
  22. self.model_config = model_config
  23. self.parallel_config = parallel_config
  24. self.head_size = model_config.get_head_size()
  25. self.num_layers = model_config.get_num_layers(parallel_config)
  26. self.num_heads = model_config.get_num_kv_heads(parallel_config)
  27. self.block_size = cache_config.block_size
  28. self.num_gpu_blocks = cache_config.num_gpu_blocks
  29. self.num_cpu_blocks = cache_config.num_cpu_blocks
  30. if cache_config.cache_dtype == "auto":
  31. self.dtype = model_config.dtype
  32. else:
  33. self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
  34. # Initialize the cache.
  35. self.gpu_cache = self.allocate_gpu_cache()
  36. self.cpu_cache = self.allocate_cpu_cache()
  37. # Initialize the stream for caching operations.
  38. self.cache_stream = torch.cuda.Stream()
  39. assert self.cache_stream != torch.cuda.current_stream()
  40. # Initialize the events for stream synchronization.
  41. self.events = [torch.cuda.Event() for _ in range(self.num_layers)]
  42. def get_key_block_shape(self) -> Tuple[int, int, int, int]:
  43. element_size = torch.tensor([], dtype=self.dtype).element_size()
  44. x = 16 // element_size
  45. return (
  46. self.num_heads,
  47. self.head_size // x,
  48. self.block_size,
  49. x,
  50. )
  51. def get_value_block_shape(self) -> Tuple[int, int, int]:
  52. return (
  53. self.num_heads,
  54. self.head_size,
  55. self.block_size,
  56. )
  57. def allocate_gpu_cache(self) -> List[KVCache]:
  58. gpu_cache: List[KVCache] = []
  59. key_block_shape = self.get_key_block_shape()
  60. value_block_shape = self.get_value_block_shape()
  61. for _ in range(self.num_layers):
  62. key_blocks = torch.empty(
  63. size=(self.num_gpu_blocks, *key_block_shape),
  64. dtype=self.dtype,
  65. device="cuda",
  66. )
  67. value_blocks = torch.empty(
  68. size=(self.num_gpu_blocks, *value_block_shape),
  69. dtype=self.dtype,
  70. device="cuda",
  71. )
  72. gpu_cache.append((key_blocks, value_blocks))
  73. return gpu_cache
  74. def allocate_cpu_cache(self) -> List[KVCache]:
  75. cpu_cache: List[KVCache] = []
  76. key_block_shape = self.get_key_block_shape()
  77. value_block_shape = self.get_value_block_shape()
  78. pin_memory = not in_wsl()
  79. if not pin_memory:
  80. # Pinning memory in WSL is not supported.
  81. # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
  82. logger.warning("Using 'pin_memory=False' as WSL is detected. "
  83. "This may slow down the performance.")
  84. for _ in range(self.num_layers):
  85. key_blocks = torch.empty(
  86. size=(self.num_cpu_blocks, *key_block_shape),
  87. dtype=self.dtype,
  88. pin_memory=pin_memory,
  89. device="cpu",
  90. )
  91. value_blocks = torch.empty(
  92. size=(self.num_cpu_blocks, *value_block_shape),
  93. dtype=self.dtype,
  94. pin_memory=pin_memory,
  95. device="cpu",
  96. )
  97. cpu_cache.append((key_blocks, value_blocks))
  98. return cpu_cache
  99. def _swap(
  100. self,
  101. src: List[KVCache],
  102. dst: List[KVCache],
  103. src_to_dst: Dict[int, int],
  104. ) -> None:
  105. with torch.cuda.stream(self.cache_stream):
  106. for i in range(self.num_layers):
  107. src_key_cache, src_value_cache = src[i]
  108. dst_key_cache, dst_value_cache = dst[i]
  109. # Copy the key blocks.
  110. cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
  111. # Copy the value blocks.
  112. cache_ops.swap_blocks(src_value_cache, dst_value_cache,
  113. src_to_dst)
  114. event = self.events[i]
  115. event.record(stream=self.cache_stream)
  116. def swap_in(self, src_to_dst: Dict[int, int]) -> None:
  117. self._swap(self.cpu_cache, self.gpu_cache, src_to_dst)
  118. def swap_out(self, src_to_dst: Dict[int, int]) -> None:
  119. self._swap(self.gpu_cache, self.cpu_cache, src_to_dst)
  120. def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
  121. key_caches = [key_cache for key_cache, _ in self.gpu_cache]
  122. value_caches = [value_cache for _, value_cache in self.gpu_cache]
  123. # NOTE: This operation implicitly synchronizes the CPU and GPU.
  124. cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts)
  125. @staticmethod
  126. def get_cache_block_size(
  127. block_size: int,
  128. cache_dtype: str,
  129. model_config: ModelConfig,
  130. parallel_config: ParallelConfig,
  131. ) -> int:
  132. head_size = model_config.get_head_size()
  133. num_heads = model_config.get_num_kv_heads(parallel_config)
  134. num_layers = model_config.get_num_layers(parallel_config)
  135. key_cache_block = block_size * num_heads * head_size
  136. value_cache_block = key_cache_block
  137. total = num_layers * (key_cache_block + value_cache_block)
  138. if cache_dtype == "auto":
  139. dtype = model_config.dtype
  140. else:
  141. dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
  142. dtype_size = _get_dtype_size(dtype)
  143. return dtype_size * total
  144. def _get_dtype_size(dtype: torch.dtype) -> int:
  145. return torch.tensor([], dtype=dtype).element_size()