cpu_worker.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. """A CPU worker class."""
  2. import os
  3. from typing import Dict, List, Optional, Tuple
  4. import torch
  5. import torch.distributed
  6. from aphrodite.attention import get_attn_backend
  7. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  8. LoRAConfig, ModelConfig, MultiModalConfig,
  9. ParallelConfig, PromptAdapterConfig,
  10. SchedulerConfig)
  11. from aphrodite.common.sequence import ExecuteModelRequest
  12. from aphrodite.common.utils import STR_DTYPE_TO_TORCH_DTYPE
  13. from aphrodite.distributed import (ensure_model_parallel_initialized,
  14. init_distributed_environment)
  15. from aphrodite.modeling import set_random_seed
  16. from aphrodite.task_handler.cpu_model_runner import CPUModelRunner
  17. from aphrodite.task_handler.worker_base import (LocalOrDistributedWorkerBase,
  18. LoraNotSupportedWorkerBase,
  19. WorkerInput)
  20. APHRODITE_CPU_OMP_THREADS_BIND = os.getenv("APHRODITE_CPU_OMP_THREADS_BIND",
  21. "all")
  22. class CPUCacheEngine:
  23. """Manages the KV cache for CPU backend.
  24. This class is responsible for initializing and managing CPU KV
  25. caches. It also provides methods for performing KV cache operations, such
  26. as copying.
  27. """
  28. def __init__(self, cache_config: CacheConfig, model_config: ModelConfig,
  29. parallel_config: ParallelConfig,
  30. device_config: DeviceConfig) -> None:
  31. assert device_config.device_type == "cpu"
  32. self.cache_config = cache_config
  33. self.model_config = model_config
  34. self.parallel_config = parallel_config
  35. self.head_size = model_config.get_head_size()
  36. self.num_layers = model_config.get_num_layers(parallel_config)
  37. self.num_heads = model_config.get_num_kv_heads(parallel_config)
  38. self.block_size = cache_config.block_size
  39. # Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks
  40. # for CPU backend, because we want to reuse KV cache management
  41. # in the scheduler.
  42. self.num_cpu_blocks = cache_config.num_gpu_blocks
  43. if cache_config.cache_dtype == "auto":
  44. self.dtype = model_config.dtype
  45. else:
  46. self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
  47. # Get attention backend.
  48. self.attn_backend = get_attn_backend(
  49. self.model_config.get_num_attention_heads(self.parallel_config),
  50. self.model_config.get_head_size(),
  51. self.model_config.get_num_kv_heads(self.parallel_config),
  52. self.model_config.get_sliding_window(),
  53. self.model_config.dtype,
  54. cache_config.cache_dtype,
  55. self.block_size,
  56. )
  57. # Initialize the cache.
  58. self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks)
  59. def _allocate_kv_cache(
  60. self,
  61. num_blocks: int,
  62. ) -> List[torch.Tensor]:
  63. """Allocates KV cache on CPU."""
  64. kv_cache_shape = self.attn_backend.get_kv_cache_shape(
  65. num_blocks, self.block_size, self.num_heads, self.head_size)
  66. kv_cache: List[torch.Tensor] = []
  67. for _ in range(self.num_layers):
  68. kv_cache.append(
  69. torch.empty(kv_cache_shape, dtype=self.dtype, device="cpu"))
  70. return kv_cache
  71. def swap_in(self, src_to_dst: Dict[int, int]) -> None:
  72. raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
  73. def swap_out(self, src_to_dst: Dict[int, int]) -> None:
  74. raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
  75. def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
  76. self.attn_backend.copy_blocks(self.cpu_cache, src_to_dsts)
  77. @staticmethod
  78. def get_cache_block_size(
  79. block_size: int,
  80. cache_dtype: str,
  81. model_config: ModelConfig,
  82. parallel_config: ParallelConfig,
  83. ) -> int:
  84. head_size = model_config.get_head_size()
  85. num_heads = model_config.get_num_kv_heads(parallel_config)
  86. num_layers = model_config.get_num_layers(parallel_config)
  87. key_cache_block = block_size * num_heads * head_size
  88. value_cache_block = key_cache_block
  89. total = num_layers * (key_cache_block + value_cache_block)
  90. if cache_dtype == "auto":
  91. dtype = model_config.dtype
  92. else:
  93. dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
  94. dtype_size = torch.tensor([], dtype=dtype).element_size()
  95. return dtype_size * total
  96. class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
  97. """A worker class that executes (a partition of) the model on a CPU socket.
  98. Each worker is associated with a single CPU socket. The worker is
  99. responsible for maintaining the KV cache and executing the model on the
  100. CPU. In case of distributed inference, each worker is assigned a partition
  101. of the model.
  102. """
  103. def __init__(
  104. self,
  105. model_config: ModelConfig,
  106. parallel_config: ParallelConfig,
  107. scheduler_config: SchedulerConfig,
  108. device_config: DeviceConfig,
  109. cache_config: CacheConfig,
  110. load_config: LoadConfig,
  111. local_rank: int,
  112. rank: int,
  113. distributed_init_method: str,
  114. lora_config: Optional[LoRAConfig] = None,
  115. multimodal_config: Optional[MultiModalConfig] = None,
  116. kv_cache_dtype: Optional[str] = "auto",
  117. prompt_adapter_config: Optional[PromptAdapterConfig] = None,
  118. is_driver_worker: bool = False,
  119. ) -> None:
  120. self.model_config = model_config
  121. self.parallel_config = parallel_config
  122. self.scheduler_config = scheduler_config
  123. self.device_config = device_config
  124. self.cache_config = cache_config
  125. self.load_config = load_config
  126. self.local_rank = local_rank
  127. self.rank = rank
  128. self.distributed_init_method = distributed_init_method
  129. self.lora_config = lora_config
  130. self.prompt_adapter_config = prompt_adapter_config
  131. self.multimodal_config = multimodal_config
  132. self.is_driver_worker = is_driver_worker
  133. if self.is_driver_worker:
  134. assert self.rank == 0, "The driver worker must have rank 0."
  135. if self.model_config.trust_remote_code:
  136. # note: lazy import to avoid importing torch before initializing
  137. from aphrodite.common.utils import init_cached_hf_modules
  138. init_cached_hf_modules()
  139. # Setup OpenMP threads affinity.
  140. omp_cpuids = APHRODITE_CPU_OMP_THREADS_BIND
  141. if omp_cpuids == "all":
  142. self.local_omp_cpuid = "all"
  143. else:
  144. self.local_omp_cpuid = omp_cpuids.split("|")[rank]
  145. self.model_runner: CPUModelRunner = CPUModelRunner(
  146. model_config,
  147. parallel_config,
  148. scheduler_config,
  149. device_config,
  150. cache_config,
  151. load_config=self.load_config,
  152. lora_config=self.lora_config,
  153. multimodal_config=self.multimodal_config,
  154. kv_cache_dtype=kv_cache_dtype,
  155. prompt_adapter_config=self.prompt_adapter_config,
  156. is_driver_worker=is_driver_worker)
  157. # Uninitialized cache engine. Will be initialized by
  158. # initialize_cache.
  159. self.cache_engine: List[CPUCacheEngine]
  160. self.cpu_cache: List[List[torch.Tensor]]
  161. def init_device(self) -> None:
  162. if self.local_omp_cpuid != "all":
  163. torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
  164. self.init_distributed_environment()
  165. # Set random seed.
  166. set_random_seed(self.model_config.seed)
  167. def load_model(self):
  168. self.model_runner.load_model()
  169. def determine_num_available_blocks(self) -> Tuple[int, int]:
  170. """Determine the number of blocks available for the KV cache.
  171. This determines how many KV blocks can fit into the configured CPU
  172. KV cache space.
  173. Note that since Aphrodite assumes a block resides on GPU if it can be
  174. modified, we return num_gpu_blocks=num_cpu_blocks and num_cpu_blocks=0.
  175. This allows us to reuse the scheduler of Aphrodite without generalizing
  176. it to different devices.
  177. """
  178. # For CPU device, the block number will be calculated based on the
  179. # cpu_kvcache_space.
  180. cache_block_size = self.get_cache_block_size_bytes()
  181. num_cpu_blocks = int(self.cache_config.cpu_kvcache_space_bytes //
  182. cache_block_size)
  183. num_cpu_blocks = max(num_cpu_blocks, 0)
  184. # NOTE: To reuse the cache management procedure,
  185. # use cpu cache as 'gpu cache'.
  186. num_gpu_blocks = num_cpu_blocks
  187. num_cpu_blocks = 0
  188. return num_gpu_blocks, num_cpu_blocks
  189. def initialize_cache(self, num_gpu_blocks: int,
  190. num_cpu_blocks: int) -> None:
  191. """Initialize the KV cache. Currently, swappable CPU memory is not
  192. supported.
  193. Since this worker does not support GPUs, we use the num_gpu_blocks to
  194. determine how many non-swappable CPU blocks to allocate.
  195. """
  196. assert (num_cpu_blocks == 0
  197. ), f"{type(self)} does not support swappable cache"
  198. # NOTE: To reuse the cache management procedure,
  199. # use cpu cache as 'gpu cache'.
  200. num_cpu_blocks = num_gpu_blocks
  201. self._validate_num_cpu_blocks(num_cpu_blocks)
  202. self.cache_config.num_gpu_blocks = num_cpu_blocks
  203. self.cache_config.num_cpu_blocks = 0
  204. # Initialize the cache.
  205. self._init_cache_engine()
  206. def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None:
  207. """Raise errors if the num_cpu_blocks is invalid.
  208. """
  209. if num_cpu_blocks <= 0:
  210. raise ValueError(
  211. "No available memory for the cache blocks. "
  212. "Try increasing `APHRODITE_CPU_KVCACHE_SPACE` when "
  213. "initializing the engine.")
  214. max_seq_len = self.cache_config.block_size * num_cpu_blocks
  215. if self.model_config.max_model_len > max_seq_len:
  216. raise ValueError(
  217. f"The model's max seq len ({self.model_config.max_model_len}) "
  218. "is larger than the maximum number of tokens that can be "
  219. f"stored in KV cache ({max_seq_len}). Try increasing "
  220. "`APHRODITE_CPU_KVCACHE_SPACE` or decreasing `max_model_len` "
  221. "when initializing the engine.")
  222. def _init_cache_engine(self) -> None:
  223. self.cache_engine = [
  224. CPUCacheEngine(self.cache_config, self.model_config,
  225. self.parallel_config, self.device_config)
  226. for _ in range(self.parallel_config.pipeline_parallel_size)
  227. ]
  228. self.cpu_cache = [
  229. self.cache_engine[ve].cpu_cache
  230. for ve in range(self.parallel_config.pipeline_parallel_size)
  231. ]
  232. self.model_runner.block_size = self.cache_engine[0].block_size
  233. assert all(
  234. self.cpu_cache[ve] is not None
  235. for ve in range(self.parallel_config.pipeline_parallel_size))
  236. # Populate the cache to warmup the memory
  237. for ve in range(self.parallel_config.pipeline_parallel_size):
  238. for layer_cache in self.cpu_cache[ve]:
  239. layer_cache.fill_(0)
  240. @property
  241. def do_metadata_broadcast(self) -> bool:
  242. return self.parallel_config.tensor_parallel_size > 1
  243. @property
  244. def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
  245. return self.cpu_cache
  246. def execute_worker(
  247. self,
  248. worker_input: WorkerInput,
  249. ) -> None:
  250. if (worker_input.blocks_to_copy is not None
  251. and worker_input.blocks_to_copy.numel() > 0):
  252. self.cache_engine[worker_input.virtual_engine].copy(
  253. worker_input.blocks_to_copy)
  254. @torch.inference_mode()
  255. def prepare_worker_input(
  256. self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
  257. assert execute_model_req is not None
  258. virtual_engine = execute_model_req.virtual_engine
  259. num_seq_groups: int = len(execute_model_req.seq_group_metadata_list)
  260. blocks_to_copy = execute_model_req.blocks_to_copy
  261. blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
  262. device="cpu",
  263. dtype=torch.int64).view(-1, 2)
  264. assert len(execute_model_req.blocks_to_swap_in) == 0
  265. assert len(execute_model_req.blocks_to_swap_out) == 0
  266. return WorkerInput(
  267. num_seq_groups=num_seq_groups,
  268. blocks_to_copy=blocks_to_copy,
  269. virtual_engine=virtual_engine,
  270. )
  271. def init_distributed_environment(self) -> None:
  272. """Initialize the distributed environment."""
  273. parallel_config = self.parallel_config
  274. rank = self.rank
  275. distributed_init_method = self.distributed_init_method
  276. init_distributed_environment(
  277. world_size=parallel_config.world_size,
  278. rank=rank,
  279. distributed_init_method=distributed_init_method,
  280. backend="gloo",
  281. )
  282. # A small all_reduce for warmup.
  283. torch.distributed.all_reduce(torch.zeros(1).cpu())
  284. ensure_model_parallel_initialized(
  285. parallel_config.tensor_parallel_size,
  286. parallel_config.pipeline_parallel_size)
  287. def get_cache_block_size_bytes(self) -> int:
  288. """Return the size in bytes of a single KV cache block.
  289. """
  290. return CPUCacheEngine.get_cache_block_size(
  291. self.cache_config.block_size, self.cache_config.cache_dtype,
  292. self.model_config, self.parallel_config)