1
0

cpu_worker.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. """A CPU worker class."""
  2. from typing import Any, Dict, List, Optional, Tuple
  3. import torch
  4. import torch.distributed
  5. from aphrodite.attention import get_attn_backend
  6. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  7. LoRAConfig, ModelConfig, ParallelConfig,
  8. SchedulerConfig, VisionLanguageConfig)
  9. from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
  10. from aphrodite.common.utils import STR_DTYPE_TO_TORCH_DTYPE
  11. from aphrodite.distributed import (broadcast_tensor_dict,
  12. ensure_model_parallel_initialized,
  13. init_distributed_environment)
  14. from aphrodite.modeling import set_random_seed
  15. from aphrodite.task_handler.cpu_model_runner import CPUModelRunner
  16. from aphrodite.task_handler.worker_base import LoraNotSupportedWorkerBase
  17. class CPUCacheEngine:
  18. """Manages the KV cache for CPU backend.
  19. This class is responsible for initializing and managing CPU KV
  20. caches. It also provides methods for performing KV cache operations, such
  21. as copying.
  22. """
  23. def __init__(self, cache_config: CacheConfig, model_config: ModelConfig,
  24. parallel_config: ParallelConfig,
  25. device_config: DeviceConfig) -> None:
  26. assert device_config.device_type == "cpu"
  27. self.cache_config = cache_config
  28. self.model_config = model_config
  29. self.parallel_config = parallel_config
  30. self.head_size = model_config.get_head_size()
  31. self.num_layers = model_config.get_num_layers(parallel_config)
  32. self.num_heads = model_config.get_num_kv_heads(parallel_config)
  33. self.block_size = cache_config.block_size
  34. # Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks
  35. # for CPU backend, because we want to reuse KV cache management
  36. # in the scheduler.
  37. self.num_cpu_blocks = cache_config.num_gpu_blocks
  38. if cache_config.cache_dtype == "auto":
  39. self.dtype = model_config.dtype
  40. else:
  41. self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
  42. # Get attention backend.
  43. self.attn_backend = get_attn_backend(model_config.dtype)
  44. # Initialize the cache.
  45. self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks)
  46. def _allocate_kv_cache(
  47. self,
  48. num_blocks: int,
  49. ) -> List[torch.Tensor]:
  50. """Allocates KV cache on CPU."""
  51. kv_cache_shape = self.attn_backend.get_kv_cache_shape(
  52. num_blocks, self.block_size, self.num_heads, self.head_size)
  53. kv_cache: List[torch.Tensor] = []
  54. for _ in range(self.num_layers):
  55. kv_cache.append(
  56. torch.empty(kv_cache_shape, dtype=self.dtype, device="cpu"))
  57. return kv_cache
  58. def swap_in(self, src_to_dst: Dict[int, int]) -> None:
  59. raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
  60. def swap_out(self, src_to_dst: Dict[int, int]) -> None:
  61. raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
  62. def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
  63. self.attn_backend.copy_blocks(self.cpu_cache, src_to_dsts)
  64. @staticmethod
  65. def get_cache_block_size(
  66. block_size: int,
  67. cache_dtype: str,
  68. model_config: ModelConfig,
  69. parallel_config: ParallelConfig,
  70. ) -> int:
  71. head_size = model_config.get_head_size()
  72. num_heads = model_config.get_num_kv_heads(parallel_config)
  73. num_layers = model_config.get_num_layers(parallel_config)
  74. key_cache_block = block_size * num_heads * head_size
  75. value_cache_block = key_cache_block
  76. total = num_layers * (key_cache_block + value_cache_block)
  77. if cache_dtype == "auto":
  78. dtype = model_config.dtype
  79. else:
  80. dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
  81. dtype_size = torch.tensor([], dtype=dtype).element_size()
  82. return dtype_size * total
  83. class CPUWorker(LoraNotSupportedWorkerBase):
  84. """A worker class that executes (a partition of) the model on a CPU socket.
  85. Each worker is associated with a single CPU socket. The worker is
  86. responsible for maintaining the KV cache and executing the model on the
  87. CPU. In case of distributed inference, each worker is assigned a partition
  88. of the model.
  89. """
  90. def __init__(
  91. self,
  92. model_config: ModelConfig,
  93. parallel_config: ParallelConfig,
  94. scheduler_config: SchedulerConfig,
  95. device_config: DeviceConfig,
  96. cache_config: CacheConfig,
  97. load_config: LoadConfig,
  98. local_rank: int,
  99. rank: int,
  100. distributed_init_method: str,
  101. lora_config: Optional[LoRAConfig] = None,
  102. vision_language_config: Optional[VisionLanguageConfig] = None,
  103. kv_cache_dtype: Optional[str] = "auto",
  104. is_driver_worker: bool = False,
  105. ) -> None:
  106. self.model_config = model_config
  107. self.parallel_config = parallel_config
  108. self.scheduler_config = scheduler_config
  109. self.device_config = device_config
  110. self.cache_config = cache_config
  111. self.load_config = load_config
  112. self.local_rank = local_rank
  113. self.rank = rank
  114. self.distributed_init_method = distributed_init_method
  115. self.lora_config = lora_config
  116. self.vision_language_config = vision_language_config
  117. self.is_driver_worker = is_driver_worker
  118. if self.is_driver_worker:
  119. assert self.rank == 0, "The driver worker must have rank 0."
  120. if self.model_config.trust_remote_code:
  121. # note: lazy import to avoid importing torch before initializing
  122. from aphrodite.common.utils import init_cached_hf_modules
  123. init_cached_hf_modules()
  124. self.model_runner = CPUModelRunner(
  125. model_config,
  126. parallel_config,
  127. scheduler_config,
  128. device_config,
  129. load_config=self.load_config,
  130. lora_config=self.lora_config,
  131. vision_language_config=self.vision_language_config,
  132. kv_cache_dtype=kv_cache_dtype,
  133. is_driver_worker=is_driver_worker)
  134. # Uninitialized cache engine. Will be initialized by
  135. # initialize_cache.
  136. self.cache_engine: CPUCacheEngine
  137. self.cpu_cache: List[torch.Tensor]
  138. def init_device(self) -> None:
  139. self.init_distributed_environment()
  140. # Set random seed.
  141. set_random_seed(self.model_config.seed)
  142. def load_model(self):
  143. self.model_runner.load_model()
  144. def determine_num_available_blocks(self) -> Tuple[int, int]:
  145. """Determine the number of blocks available for the KV cache.
  146. This determines how many KV blocks can fit into the configured CPU
  147. KV cache space.
  148. Note that since Aphrodite assumes a block resides on GPU if it can be
  149. modified, we return num_gpu_blocks=num_cpu_blocks and num_cpu_blocks=0.
  150. This allows us to reuse the scheduler of Aphrodite without generalizing
  151. it to different devices.
  152. """
  153. # For CPU device, the block number will be calculated based on the
  154. # cpu_kvcache_space.
  155. cache_block_size = self.get_cache_block_size_bytes()
  156. num_cpu_blocks = int(self.cache_config.cpu_kvcache_space_bytes //
  157. cache_block_size)
  158. num_cpu_blocks = max(num_cpu_blocks, 0)
  159. # NOTE: To reuse the cache management procedure,
  160. # use cpu cache as 'gpu cache'.
  161. num_gpu_blocks = num_cpu_blocks
  162. num_cpu_blocks = 0
  163. return num_gpu_blocks, num_cpu_blocks
  164. def initialize_cache(self, num_gpu_blocks: int,
  165. num_cpu_blocks: int) -> None:
  166. """Initialize the KV cache. Currently, swappable CPU memory is not
  167. supported.
  168. Since this worker does not support GPUs, we use the num_gpu_blocks to
  169. determine how many non-swappable CPU blocks to allocate.
  170. """
  171. assert (num_cpu_blocks == 0
  172. ), f"{type(self)} does not support swappable cache"
  173. # Note: To reuse the cache management procedure,
  174. # use cpu cache as 'gpu cache'.
  175. num_cpu_blocks = num_gpu_blocks
  176. self._validate_num_cpu_blocks(num_cpu_blocks)
  177. self.cache_config.num_gpu_blocks = num_cpu_blocks
  178. self.cache_config.num_cpu_blocks = 0
  179. # Initialize the cache.
  180. self._init_cache_engine()
  181. def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None:
  182. """Raise errors if the num_cpu_blocks is invalid.
  183. """
  184. if num_cpu_blocks <= 0:
  185. raise ValueError(
  186. "No available memory for the cache blocks. "
  187. "Try increasing `APHRODITE_CPU_KVCACHE_SPACE` when "
  188. "initializing the engine.")
  189. max_seq_len = self.cache_config.block_size * num_cpu_blocks
  190. if self.model_config.max_model_len > max_seq_len:
  191. raise ValueError(
  192. f"The model's max seq len ({self.model_config.max_model_len}) "
  193. "is larger than the maximum number of tokens that can be "
  194. f"stored in KV cache ({max_seq_len}). Try increasing "
  195. "`APHRODITE_CPU_KVCACHE_SPACE` or decreasing `max_model_len` "
  196. "when initializing the engine.")
  197. def _init_cache_engine(self) -> None:
  198. self.cache_engine = CPUCacheEngine(self.cache_config,
  199. self.model_config,
  200. self.parallel_config,
  201. self.device_config)
  202. self.cpu_cache = self.cache_engine.cpu_cache
  203. self.model_runner.block_size = self.cache_engine.block_size
  204. assert self.cpu_cache is not None
  205. # Populate the cache to warmup the memory
  206. for layer_cache in self.cpu_cache:
  207. layer_cache.fill_(0)
  208. def cache_copy(
  209. self,
  210. blocks_to_copy: Dict[int, List[int]],
  211. ) -> None:
  212. if blocks_to_copy:
  213. self.cache_engine.copy(blocks_to_copy)
  214. @torch.inference_mode()
  215. def execute_model(
  216. self,
  217. seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None,
  218. blocks_to_swap_in: Optional[Dict[int, int]] = None,
  219. blocks_to_swap_out: Optional[Dict[int, int]] = None,
  220. blocks_to_copy: Optional[Dict[int, List[int]]] = None,
  221. ) -> List[SamplerOutput]:
  222. if self.is_driver_worker:
  223. assert seq_group_metadata_list is not None
  224. num_seq_groups: int = len(seq_group_metadata_list)
  225. assert blocks_to_swap_in is not None
  226. assert blocks_to_swap_out is not None
  227. assert blocks_to_copy is not None
  228. assert len(blocks_to_swap_in) == 0
  229. assert len(blocks_to_swap_out) == 0
  230. data: Dict[str, Any] = {
  231. "num_seq_groups": num_seq_groups,
  232. "blocks_to_copy": blocks_to_copy,
  233. }
  234. broadcast_tensor_dict(data, src=0)
  235. else:
  236. data = broadcast_tensor_dict(src=0)
  237. num_seq_groups = data["num_seq_groups"]
  238. blocks_to_copy = data["blocks_to_copy"]
  239. assert blocks_to_copy is not None
  240. self.cache_copy(blocks_to_copy)
  241. # If there is no input, we don't need to execute the model.
  242. if num_seq_groups == 0:
  243. return []
  244. output = self.model_runner.execute_model(seq_group_metadata_list,
  245. self.cpu_cache)
  246. # CPU worker only supports single-step execution.
  247. return [output]
  248. def init_distributed_environment(self) -> None:
  249. """Initialize the distributed environment."""
  250. parallel_config = self.parallel_config
  251. rank = self.rank
  252. distributed_init_method = self.distributed_init_method
  253. init_distributed_environment(
  254. world_size=parallel_config.world_size,
  255. rank=rank,
  256. distributed_init_method=distributed_init_method,
  257. backend="gloo",
  258. )
  259. # A small all_reduce for warmup.
  260. torch.distributed.all_reduce(torch.zeros(1).cpu())
  261. ensure_model_parallel_initialized(
  262. parallel_config.tensor_parallel_size,
  263. parallel_config.pipeline_parallel_size)
  264. def get_cache_block_size_bytes(self) -> int:
  265. """Return the size in bytes of a single KV cache block.
  266. """
  267. return CPUCacheEngine.get_cache_block_size(
  268. self.cache_config.block_size, self.cache_config.cache_dtype,
  269. self.model_config, self.parallel_config)