cpu_worker.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. """A CPU worker class."""
  2. from typing import Any, Dict, List, Optional, Tuple
  3. import torch
  4. import torch.distributed
  5. from aphrodite.attention import get_attn_backend
  6. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  7. LoRAConfig, ModelConfig, ParallelConfig,
  8. SchedulerConfig, VisionLanguageConfig)
  9. from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
  10. from aphrodite.common.utils import STR_DTYPE_TO_TORCH_DTYPE
  11. from aphrodite.distributed import (broadcast_tensor_dict,
  12. ensure_model_parallel_initialized,
  13. init_distributed_environment)
  14. from aphrodite.modeling import set_random_seed
  15. from aphrodite.task_handler.cpu_model_runner import CPUModelRunner
  16. from aphrodite.task_handler.worker_base import LoraNotSupportedWorkerBase
  17. class CPUCacheEngine:
  18. """Manages the KV cache for CPU backend.
  19. This class is responsible for initializing and managing CPU KV
  20. caches. It also provides methods for performing KV cache operations, such
  21. as copying.
  22. """
  23. def __init__(self, cache_config: CacheConfig, model_config: ModelConfig,
  24. parallel_config: ParallelConfig,
  25. device_config: DeviceConfig) -> None:
  26. assert device_config.device_type == "cpu"
  27. self.cache_config = cache_config
  28. self.model_config = model_config
  29. self.parallel_config = parallel_config
  30. self.head_size = model_config.get_head_size()
  31. self.num_layers = model_config.get_num_layers(parallel_config)
  32. self.num_heads = model_config.get_num_kv_heads(parallel_config)
  33. self.block_size = cache_config.block_size
  34. # Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks
  35. # for CPU backend, because we want to reuse KV cache management
  36. # in the scheduler.
  37. self.num_cpu_blocks = cache_config.num_gpu_blocks
  38. if cache_config.cache_dtype == "auto":
  39. self.dtype = model_config.dtype
  40. else:
  41. self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
  42. # Get attention backend.
  43. self.attn_backend = get_attn_backend(
  44. self.model_config.get_num_attention_heads(self.parallel_config),
  45. self.model_config.get_head_size(),
  46. self.model_config.get_num_kv_heads(self.parallel_config),
  47. self.model_config.get_sliding_window(),
  48. self.model_config.dtype,
  49. cache_config.cache_dtype,
  50. self.block_size,
  51. )
  52. # Initialize the cache.
  53. self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks)
  54. def _allocate_kv_cache(
  55. self,
  56. num_blocks: int,
  57. ) -> List[torch.Tensor]:
  58. """Allocates KV cache on CPU."""
  59. kv_cache_shape = self.attn_backend.get_kv_cache_shape(
  60. num_blocks, self.block_size, self.num_heads, self.head_size)
  61. kv_cache: List[torch.Tensor] = []
  62. for _ in range(self.num_layers):
  63. kv_cache.append(
  64. torch.empty(kv_cache_shape, dtype=self.dtype, device="cpu"))
  65. return kv_cache
  66. def swap_in(self, src_to_dst: Dict[int, int]) -> None:
  67. raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
  68. def swap_out(self, src_to_dst: Dict[int, int]) -> None:
  69. raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
  70. def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
  71. self.attn_backend.copy_blocks(self.cpu_cache, src_to_dsts)
  72. @staticmethod
  73. def get_cache_block_size(
  74. block_size: int,
  75. cache_dtype: str,
  76. model_config: ModelConfig,
  77. parallel_config: ParallelConfig,
  78. ) -> int:
  79. head_size = model_config.get_head_size()
  80. num_heads = model_config.get_num_kv_heads(parallel_config)
  81. num_layers = model_config.get_num_layers(parallel_config)
  82. key_cache_block = block_size * num_heads * head_size
  83. value_cache_block = key_cache_block
  84. total = num_layers * (key_cache_block + value_cache_block)
  85. if cache_dtype == "auto":
  86. dtype = model_config.dtype
  87. else:
  88. dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
  89. dtype_size = torch.tensor([], dtype=dtype).element_size()
  90. return dtype_size * total
  91. class CPUWorker(LoraNotSupportedWorkerBase):
  92. """A worker class that executes (a partition of) the model on a CPU socket.
  93. Each worker is associated with a single CPU socket. The worker is
  94. responsible for maintaining the KV cache and executing the model on the
  95. CPU. In case of distributed inference, each worker is assigned a partition
  96. of the model.
  97. """
  98. def __init__(
  99. self,
  100. model_config: ModelConfig,
  101. parallel_config: ParallelConfig,
  102. scheduler_config: SchedulerConfig,
  103. device_config: DeviceConfig,
  104. cache_config: CacheConfig,
  105. load_config: LoadConfig,
  106. local_rank: int,
  107. rank: int,
  108. distributed_init_method: str,
  109. lora_config: Optional[LoRAConfig] = None,
  110. vision_language_config: Optional[VisionLanguageConfig] = None,
  111. kv_cache_dtype: Optional[str] = "auto",
  112. is_driver_worker: bool = False,
  113. ) -> None:
  114. self.model_config = model_config
  115. self.parallel_config = parallel_config
  116. self.scheduler_config = scheduler_config
  117. self.device_config = device_config
  118. self.cache_config = cache_config
  119. self.load_config = load_config
  120. self.local_rank = local_rank
  121. self.rank = rank
  122. self.distributed_init_method = distributed_init_method
  123. self.lora_config = lora_config
  124. self.vision_language_config = vision_language_config
  125. self.is_driver_worker = is_driver_worker
  126. if self.is_driver_worker:
  127. assert self.rank == 0, "The driver worker must have rank 0."
  128. if self.model_config.trust_remote_code:
  129. # note: lazy import to avoid importing torch before initializing
  130. from aphrodite.common.utils import init_cached_hf_modules
  131. init_cached_hf_modules()
  132. self.model_runner = CPUModelRunner(
  133. model_config,
  134. parallel_config,
  135. scheduler_config,
  136. device_config,
  137. cache_config,
  138. load_config=self.load_config,
  139. lora_config=self.lora_config,
  140. vision_language_config=self.vision_language_config,
  141. kv_cache_dtype=kv_cache_dtype,
  142. is_driver_worker=is_driver_worker)
  143. # Uninitialized cache engine. Will be initialized by
  144. # initialize_cache.
  145. self.cache_engine: CPUCacheEngine
  146. self.cpu_cache: List[torch.Tensor]
  147. def init_device(self) -> None:
  148. self.init_distributed_environment()
  149. # Set random seed.
  150. set_random_seed(self.model_config.seed)
  151. def load_model(self):
  152. self.model_runner.load_model()
  153. def determine_num_available_blocks(self) -> Tuple[int, int]:
  154. """Determine the number of blocks available for the KV cache.
  155. This determines how many KV blocks can fit into the configured CPU
  156. KV cache space.
  157. Note that since Aphrodite assumes a block resides on GPU if it can be
  158. modified, we return num_gpu_blocks=num_cpu_blocks and num_cpu_blocks=0.
  159. This allows us to reuse the scheduler of Aphrodite without generalizing
  160. it to different devices.
  161. """
  162. # For CPU device, the block number will be calculated based on the
  163. # cpu_kvcache_space.
  164. cache_block_size = self.get_cache_block_size_bytes()
  165. num_cpu_blocks = int(self.cache_config.cpu_kvcache_space_bytes //
  166. cache_block_size)
  167. num_cpu_blocks = max(num_cpu_blocks, 0)
  168. # NOTE: To reuse the cache management procedure,
  169. # use cpu cache as 'gpu cache'.
  170. num_gpu_blocks = num_cpu_blocks
  171. num_cpu_blocks = 0
  172. return num_gpu_blocks, num_cpu_blocks
  173. def initialize_cache(self, num_gpu_blocks: int,
  174. num_cpu_blocks: int) -> None:
  175. """Initialize the KV cache. Currently, swappable CPU memory is not
  176. supported.
  177. Since this worker does not support GPUs, we use the num_gpu_blocks to
  178. determine how many non-swappable CPU blocks to allocate.
  179. """
  180. assert (num_cpu_blocks == 0
  181. ), f"{type(self)} does not support swappable cache"
  182. # Note: To reuse the cache management procedure,
  183. # use cpu cache as 'gpu cache'.
  184. num_cpu_blocks = num_gpu_blocks
  185. self._validate_num_cpu_blocks(num_cpu_blocks)
  186. self.cache_config.num_gpu_blocks = num_cpu_blocks
  187. self.cache_config.num_cpu_blocks = 0
  188. # Initialize the cache.
  189. self._init_cache_engine()
  190. def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None:
  191. """Raise errors if the num_cpu_blocks is invalid.
  192. """
  193. if num_cpu_blocks <= 0:
  194. raise ValueError(
  195. "No available memory for the cache blocks. "
  196. "Try increasing `APHRODITE_CPU_KVCACHE_SPACE` when "
  197. "initializing the engine.")
  198. max_seq_len = self.cache_config.block_size * num_cpu_blocks
  199. if self.model_config.max_model_len > max_seq_len:
  200. raise ValueError(
  201. f"The model's max seq len ({self.model_config.max_model_len}) "
  202. "is larger than the maximum number of tokens that can be "
  203. f"stored in KV cache ({max_seq_len}). Try increasing "
  204. "`APHRODITE_CPU_KVCACHE_SPACE` or decreasing `max_model_len` "
  205. "when initializing the engine.")
  206. def _init_cache_engine(self) -> None:
  207. self.cache_engine = CPUCacheEngine(self.cache_config,
  208. self.model_config,
  209. self.parallel_config,
  210. self.device_config)
  211. self.cpu_cache = self.cache_engine.cpu_cache
  212. self.model_runner.block_size = self.cache_engine.block_size
  213. assert self.cpu_cache is not None
  214. # Populate the cache to warmup the memory
  215. for layer_cache in self.cpu_cache:
  216. layer_cache.fill_(0)
  217. def cache_copy(
  218. self,
  219. blocks_to_copy: torch.Tensor,
  220. ) -> None:
  221. if blocks_to_copy.numel() > 0:
  222. self.cache_engine.copy(blocks_to_copy)
  223. @torch.inference_mode()
  224. def execute_model(
  225. self,
  226. execute_model_req: Optional[ExecuteModelRequest] = None,
  227. ) -> List[SamplerOutput]:
  228. if execute_model_req is None:
  229. seq_group_metadata_list = None
  230. else:
  231. seq_group_metadata_list = execute_model_req.seq_group_metadata_listv
  232. if self.is_driver_worker:
  233. assert seq_group_metadata_list is not None
  234. num_seq_groups: int = len(seq_group_metadata_list)
  235. assert execute_model_req is not None
  236. blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
  237. device="cpu",
  238. dtype=torch.int64).view(-1, 2)
  239. assert len(execute_model_req.blocks_to_swap_in) == 0
  240. assert len(execute_model_req.blocks_to_swap_out) == 0
  241. data: Dict[str, Any] = {
  242. "num_seq_groups": num_seq_groups,
  243. "blocks_to_copy": execute_model_req.blocks_to_copy,
  244. }
  245. broadcast_tensor_dict(data, src=0)
  246. else:
  247. data = broadcast_tensor_dict(src=0)
  248. num_seq_groups = data["num_seq_groups"]
  249. blocks_to_copy = data["blocks_to_copy"]
  250. self.cache_copy(blocks_to_copy)
  251. # If there is no input, we don't need to execute the model.
  252. if num_seq_groups == 0:
  253. return []
  254. output = self.model_runner.execute_model(seq_group_metadata_list,
  255. self.cpu_cache)
  256. # CPU worker only supports single-step execution.
  257. return [output]
  258. def init_distributed_environment(self) -> None:
  259. """Initialize the distributed environment."""
  260. parallel_config = self.parallel_config
  261. rank = self.rank
  262. distributed_init_method = self.distributed_init_method
  263. init_distributed_environment(
  264. world_size=parallel_config.world_size,
  265. rank=rank,
  266. distributed_init_method=distributed_init_method,
  267. backend="gloo",
  268. )
  269. # A small all_reduce for warmup.
  270. torch.distributed.all_reduce(torch.zeros(1).cpu())
  271. ensure_model_parallel_initialized(
  272. parallel_config.tensor_parallel_size,
  273. parallel_config.pipeline_parallel_size)
  274. def get_cache_block_size_bytes(self) -> int:
  275. """Return the size in bytes of a single KV cache block.
  276. """
  277. return CPUCacheEngine.get_cache_block_size(
  278. self.cache_config.block_size, self.cache_config.cache_dtype,
  279. self.model_config, self.parallel_config)