cpu_worker.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. """A CPU worker class."""
  2. from typing import Dict, List, Optional
  3. import torch
  4. import torch.distributed
  5. from aphrodite.attention import get_attn_backend
  6. from aphrodite.common.config import (
  7. CacheConfig,
  8. DeviceConfig,
  9. LoRAConfig,
  10. ModelConfig,
  11. ParallelConfig,
  12. SchedulerConfig,
  13. )
  14. from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
  15. from aphrodite.common.utils import STR_DTYPE_TO_TORCH_DTYPE
  16. from aphrodite.distributed import (
  17. broadcast_tensor_dict,
  18. ensure_model_parallel_initialized,
  19. init_distributed_environment,
  20. )
  21. from aphrodite.modeling import set_random_seed
  22. from aphrodite.task_handler.cpu_model_runner import CPUModelRunner
  23. from aphrodite.task_handler.worker_base import LoraNotSupportedWorkerBase
  24. class CPUCacheEngine:
  25. """Manages the KV cache for CPU backend.
  26. This class is responsible for initializing and managing CPU KV
  27. caches. It also provides methods for performing KV cache operations, such
  28. as copying.
  29. """
  30. def __init__(self, cache_config: CacheConfig, model_config: ModelConfig,
  31. parallel_config: ParallelConfig,
  32. device_config: DeviceConfig) -> None:
  33. assert device_config.device_type == "cpu"
  34. self.cache_config = cache_config
  35. self.model_config = model_config
  36. self.parallel_config = parallel_config
  37. self.head_size = model_config.get_head_size()
  38. self.num_layers = model_config.get_num_layers(parallel_config)
  39. self.num_heads = model_config.get_num_kv_heads(parallel_config)
  40. self.block_size = cache_config.block_size
  41. # Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks
  42. # for CPU backend, because we want to reuse KV cache management
  43. # in the scheduler.
  44. self.num_cpu_blocks = cache_config.num_gpu_blocks
  45. if cache_config.cache_dtype == "auto":
  46. self.dtype = model_config.dtype
  47. else:
  48. self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
  49. # Get attention backend.
  50. self.attn_backend = get_attn_backend(model_config.dtype)
  51. # Initialize the cache.
  52. self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks)
  53. def _allocate_kv_cache(
  54. self,
  55. num_blocks: int,
  56. ) -> List[torch.Tensor]:
  57. """Allocates KV cache on CPU."""
  58. kv_cache_shape = self.attn_backend.get_kv_cache_shape(
  59. num_blocks, self.block_size, self.num_heads, self.head_size)
  60. kv_cache: List[torch.Tensor] = []
  61. for _ in range(self.num_layers):
  62. kv_cache.append(
  63. torch.empty(kv_cache_shape, dtype=self.dtype, device="cpu"))
  64. return kv_cache
  65. def swap_in(self, src_to_dst: Dict[int, int]) -> None:
  66. raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
  67. def swap_out(self, src_to_dst: Dict[int, int]) -> None:
  68. raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
  69. def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
  70. self.attn_backend.copy_blocks(self.cpu_cache, src_to_dsts)
  71. @staticmethod
  72. def get_cache_block_size(
  73. block_size: int,
  74. cache_dtype: str,
  75. model_config: ModelConfig,
  76. parallel_config: ParallelConfig,
  77. ) -> int:
  78. head_size = model_config.get_head_size()
  79. num_heads = model_config.get_num_kv_heads(parallel_config)
  80. num_layers = model_config.get_num_layers(parallel_config)
  81. key_cache_block = block_size * num_heads * head_size
  82. value_cache_block = key_cache_block
  83. total = num_layers * (key_cache_block + value_cache_block)
  84. if cache_dtype == "auto":
  85. dtype = model_config.dtype
  86. else:
  87. dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
  88. dtype_size = torch.tensor([], dtype=dtype).element_size()
  89. return dtype_size * total
  90. class CPUWorker(LoraNotSupportedWorkerBase):
  91. """A worker class that executes (a partition of) the model on a CPU socket.
  92. Each worker is associated with a single CPU socket. The worker is
  93. responsible for maintaining the KV cache and executing the model on the
  94. CPU. In case of distributed inference, each worker is assigned a partition
  95. of the model.
  96. """
  97. def __init__(
  98. self,
  99. model_config: ModelConfig,
  100. parallel_config: ParallelConfig,
  101. scheduler_config: SchedulerConfig,
  102. device_config: DeviceConfig,
  103. cache_config: CacheConfig,
  104. local_rank: int,
  105. rank: int,
  106. distributed_init_method: str,
  107. lora_config: Optional[LoRAConfig] = None,
  108. kv_cache_dtype: Optional[str] = "auto",
  109. is_driver_worker: bool = False,
  110. ) -> None:
  111. self.model_config = model_config
  112. self.parallel_config = parallel_config
  113. self.scheduler_config = scheduler_config
  114. self.device_config = device_config
  115. self.cache_config = cache_config
  116. self.local_rank = local_rank
  117. self.rank = rank
  118. self.distributed_init_method = distributed_init_method
  119. self.lora_config = lora_config
  120. self.is_driver_worker = is_driver_worker
  121. if self.is_driver_worker:
  122. assert self.rank == 0, "The driver worker must have rank 0."
  123. self.model_runner = CPUModelRunner(model_config,
  124. parallel_config,
  125. scheduler_config,
  126. device_config,
  127. lora_config=self.lora_config,
  128. kv_cache_dtype=kv_cache_dtype,
  129. is_driver_worker=is_driver_worker)
  130. # Uninitialized cache engine. Will be initialized by
  131. # initialize_cache
  132. self.cache_engine = None
  133. self.cpu_cache = None
  134. def init_device(self) -> None:
  135. self.init_distributed_environment()
  136. # Set random seed.
  137. set_random_seed(self.model_config.seed)
  138. def load_model(self):
  139. self.model_runner.load_model()
  140. def determine_num_available_blocks(self) -> tuple[int, int]:
  141. """Determine the number of blocks available for the KV cache.
  142. This determines how many KV blocks can fit into the configured CPU
  143. KV cache space.
  144. Note that since Aphrodite assumes a block resides on GPU if it can be
  145. modified, we return num_gpu_blocks=num_cpu_blocks and num_cpu_blocks=0.
  146. This allows us to reuse the scheduler of Aphrodite without generalizing
  147. it to different devices.
  148. """
  149. # For CPU device, the block number will be calculated based on the
  150. # cpu_kvcache_space.
  151. cache_block_size = self.get_cache_block_size_bytes()
  152. num_cpu_blocks = int(self.cache_config.cpu_kvcache_space_bytes //
  153. cache_block_size)
  154. num_cpu_blocks = max(num_cpu_blocks, 0)
  155. # NOTE: To reuse the cache management procedure,
  156. # use cpu cache as 'gpu cache'.
  157. num_gpu_blocks = num_cpu_blocks
  158. num_cpu_blocks = 0
  159. return num_gpu_blocks, num_cpu_blocks
  160. def initialize_cache(self, num_gpu_blocks: int,
  161. num_cpu_blocks: int) -> None:
  162. """Initialize the KV cache. Currently, swappable CPU memory is not
  163. supported.
  164. Since this worker does not support GPUs, we use the num_gpu_blocks to
  165. determine how many non-swappable CPU blocks to allocate.
  166. """
  167. assert (num_cpu_blocks == 0
  168. ), f"{type(self)} does not support swappable cache"
  169. # Note: To reuse the cache management procedure,
  170. # use cpu cache as 'gpu cache'.
  171. num_cpu_blocks = num_gpu_blocks
  172. self._validate_num_cpu_blocks(num_cpu_blocks)
  173. self.cache_config.num_gpu_blocks = num_cpu_blocks
  174. self.cache_config.num_cpu_blocks = 0
  175. # Initialize the cache.
  176. self._init_cache_engine()
  177. def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None:
  178. """Raise errors if the num_cpu_blocks is invalid.
  179. """
  180. if num_cpu_blocks <= 0:
  181. raise ValueError(
  182. "No available memory for the cache blocks. "
  183. "Try increasing `APHRODITE_CPU_KVCACHE_SPACE` when"
  184. " initializing the engine.")
  185. max_seq_len = self.cache_config.block_size * num_cpu_blocks
  186. if self.model_config.max_model_len > max_seq_len:
  187. raise ValueError(
  188. f"The model's max seq len ({self.model_config.max_model_len}) "
  189. "is larger than the maximum number of tokens that can be "
  190. f"stored in KV cache ({max_seq_len}). Try increasing "
  191. "`APHRODITE_CPU_KVCACHE_SPACE` or decreasing `max_model_len` "
  192. "when initializing the engine.")
  193. def _init_cache_engine(self) -> None:
  194. self.cache_engine = CPUCacheEngine(self.cache_config,
  195. self.model_config,
  196. self.parallel_config,
  197. self.device_config)
  198. self.cpu_cache = self.cache_engine.cpu_cache
  199. self.model_runner.block_size = self.cache_engine.block_size
  200. assert self.cpu_cache is not None
  201. # Populate the cache to warmup the memory
  202. for layer_cache in self.cpu_cache:
  203. layer_cache.fill_(0)
  204. def cache_copy(
  205. self,
  206. blocks_to_copy: Dict[int, List[int]],
  207. ) -> None:
  208. if blocks_to_copy:
  209. self.cache_engine.copy(blocks_to_copy)
  210. @torch.inference_mode()
  211. def execute_model(
  212. self,
  213. seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None,
  214. blocks_to_swap_in: Optional[Dict[int, int]] = None,
  215. blocks_to_swap_out: Optional[Dict[int, int]] = None,
  216. blocks_to_copy: Optional[Dict[int, List[int]]] = None,
  217. ) -> Optional[SamplerOutput]:
  218. if self.is_driver_worker:
  219. assert seq_group_metadata_list is not None
  220. num_seq_groups = len(seq_group_metadata_list)
  221. assert blocks_to_swap_in is not None
  222. assert blocks_to_swap_out is not None
  223. assert blocks_to_copy is not None
  224. assert len(blocks_to_swap_in) == 0
  225. assert len(blocks_to_swap_out) == 0
  226. data = {
  227. "num_seq_groups": num_seq_groups,
  228. "blocks_to_copy": blocks_to_copy,
  229. }
  230. broadcast_tensor_dict(data, src=0)
  231. else:
  232. data = broadcast_tensor_dict(src=0)
  233. num_seq_groups = data["num_seq_groups"]
  234. blocks_to_copy = data["blocks_to_copy"]
  235. self.cache_copy(blocks_to_copy)
  236. # If there is no input, we don't need to execute the model.
  237. if num_seq_groups == 0:
  238. return {}
  239. output = self.model_runner.execute_model(seq_group_metadata_list,
  240. self.cpu_cache)
  241. return output
  242. def init_distributed_environment(self) -> None:
  243. """Initialize the distributed environment."""
  244. parallel_config = self.parallel_config
  245. rank = self.rank
  246. distributed_init_method = self.distributed_init_method
  247. init_distributed_environment(
  248. world_size=parallel_config.world_size,
  249. rank=rank,
  250. distributed_init_method=distributed_init_method,
  251. backend="gloo",
  252. )
  253. # A small all_reduce for warmup.
  254. torch.distributed.all_reduce(torch.zeros(1).cpu())
  255. ensure_model_parallel_initialized(
  256. parallel_config.tensor_parallel_size,
  257. parallel_config.pipeline_parallel_size)
  258. def get_cache_block_size_bytes(self) -> int:
  259. """Return the size in bytes of a single KV cache block.
  260. """
  261. return CPUCacheEngine.get_cache_block_size(
  262. self.cache_config.block_size, self.cache_config.cache_dtype,
  263. self.model_config, self.parallel_config)