cpu_worker.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. """A CPU worker class."""
  2. from typing import Dict, List, Optional, Tuple
  3. import torch
  4. import torch.distributed
  5. from loguru import logger
  6. import aphrodite.common.envs as envs
  7. from aphrodite.attention import get_attn_backend
  8. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  9. LoRAConfig, ModelConfig, ParallelConfig,
  10. PromptAdapterConfig, SchedulerConfig)
  11. from aphrodite.common.sequence import ExecuteModelRequest
  12. from aphrodite.common.utils import STR_DTYPE_TO_TORCH_DTYPE
  13. from aphrodite.distributed import (ensure_model_parallel_initialized,
  14. init_distributed_environment)
  15. from aphrodite.modeling import set_random_seed
  16. from aphrodite.worker.cpu_model_runner import CPUModelRunner
  17. from aphrodite.worker.worker_base import (LocalOrDistributedWorkerBase,
  18. LoraNotSupportedWorkerBase,
  19. WorkerInput)
  20. APHRODITE_CPU_OMP_THREADS_BIND = envs.APHRODITE_CPU_OMP_THREADS_BIND
  21. class CPUCacheEngine:
  22. """Manages the KV cache for CPU backend.
  23. This class is responsible for initializing and managing CPU KV
  24. caches. It also provides methods for performing KV cache operations, such
  25. as copying.
  26. """
  27. def __init__(self, cache_config: CacheConfig, model_config: ModelConfig,
  28. parallel_config: ParallelConfig,
  29. device_config: DeviceConfig) -> None:
  30. assert device_config.device_type == "cpu"
  31. self.cache_config = cache_config
  32. self.model_config = model_config
  33. self.parallel_config = parallel_config
  34. self.head_size = model_config.get_head_size()
  35. self.num_layers = model_config.get_num_layers(parallel_config)
  36. self.num_heads = model_config.get_num_kv_heads(parallel_config)
  37. self.block_size = cache_config.block_size
  38. # Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks
  39. # for CPU backend, because we want to reuse KV cache management
  40. # in the scheduler.
  41. self.num_cpu_blocks = cache_config.num_gpu_blocks
  42. if cache_config.cache_dtype == "auto":
  43. self.dtype = model_config.dtype
  44. else:
  45. self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
  46. # Get attention backend.
  47. self.attn_backend = get_attn_backend(
  48. self.model_config.get_head_size(),
  49. self.model_config.get_sliding_window(),
  50. self.model_config.dtype,
  51. cache_config.cache_dtype,
  52. self.block_size,
  53. self.model_config.is_attention_free(),
  54. )
  55. # Initialize the cache.
  56. self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks)
  57. def _allocate_kv_cache(
  58. self,
  59. num_blocks: int,
  60. ) -> List[torch.Tensor]:
  61. """Allocates KV cache on CPU."""
  62. kv_cache_shape = self.attn_backend.get_kv_cache_shape(
  63. num_blocks, self.block_size, self.num_heads, self.head_size)
  64. kv_cache: List[torch.Tensor] = []
  65. for _ in range(self.num_layers):
  66. kv_cache.append(
  67. torch.empty(kv_cache_shape, dtype=self.dtype, device="cpu"))
  68. return kv_cache
  69. def swap_in(self, src_to_dst: Dict[int, int]) -> None:
  70. raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
  71. def swap_out(self, src_to_dst: Dict[int, int]) -> None:
  72. raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
  73. def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
  74. self.attn_backend.copy_blocks(self.cpu_cache, src_to_dsts)
  75. @staticmethod
  76. def get_cache_block_size(
  77. block_size: int,
  78. cache_dtype: str,
  79. model_config: ModelConfig,
  80. parallel_config: ParallelConfig,
  81. ) -> int:
  82. head_size = model_config.get_head_size()
  83. num_heads = model_config.get_num_kv_heads(parallel_config)
  84. num_layers = model_config.get_num_layers(parallel_config)
  85. key_cache_block = block_size * num_heads * head_size
  86. value_cache_block = key_cache_block
  87. total = num_layers * (key_cache_block + value_cache_block)
  88. if cache_dtype == "auto":
  89. dtype = model_config.dtype
  90. else:
  91. dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
  92. dtype_size = torch.tensor([], dtype=dtype).element_size()
  93. return dtype_size * total
  94. class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
  95. """A worker class that executes (a partition of) the model on a CPU socket.
  96. Each worker is associated with a single CPU socket. The worker is
  97. responsible for maintaining the KV cache and executing the model on the
  98. CPU. In case of distributed inference, each worker is assigned a partition
  99. of the model.
  100. """
  101. def __init__(
  102. self,
  103. model_config: ModelConfig,
  104. parallel_config: ParallelConfig,
  105. scheduler_config: SchedulerConfig,
  106. device_config: DeviceConfig,
  107. cache_config: CacheConfig,
  108. load_config: LoadConfig,
  109. local_rank: int,
  110. rank: int,
  111. distributed_init_method: str,
  112. lora_config: Optional[LoRAConfig] = None,
  113. kv_cache_dtype: Optional[str] = "auto",
  114. prompt_adapter_config: Optional[PromptAdapterConfig] = None,
  115. is_driver_worker: bool = False,
  116. ) -> None:
  117. self.model_config = model_config
  118. self.parallel_config = parallel_config
  119. self.scheduler_config = scheduler_config
  120. self.device_config = device_config
  121. self.cache_config = cache_config
  122. self.load_config = load_config
  123. self.local_rank = local_rank
  124. self.rank = rank
  125. self.distributed_init_method = distributed_init_method
  126. self.lora_config = lora_config
  127. self.prompt_adapter_config = prompt_adapter_config
  128. self.is_driver_worker = is_driver_worker
  129. if self.is_driver_worker:
  130. assert self.rank == 0, "The driver worker must have rank 0."
  131. if self.model_config.trust_remote_code:
  132. # note: lazy import to avoid importing torch before initializing
  133. from aphrodite.common.utils import init_cached_hf_modules
  134. init_cached_hf_modules()
  135. # Setup OpenMP threads affinity.
  136. omp_cpuids = APHRODITE_CPU_OMP_THREADS_BIND
  137. if omp_cpuids == "all":
  138. self.local_omp_cpuid = "all"
  139. else:
  140. self.local_omp_cpuid = omp_cpuids.split("|")[rank]
  141. self.model_runner: CPUModelRunner = CPUModelRunner(
  142. model_config,
  143. parallel_config,
  144. scheduler_config,
  145. device_config,
  146. cache_config,
  147. load_config=self.load_config,
  148. lora_config=self.lora_config,
  149. kv_cache_dtype=kv_cache_dtype,
  150. prompt_adapter_config=self.prompt_adapter_config,
  151. is_driver_worker=is_driver_worker)
  152. # Uninitialized cache engine. Will be initialized by
  153. # initialize_cache.
  154. self.cache_engine: List[CPUCacheEngine]
  155. self.cpu_cache: List[List[torch.Tensor]]
  156. def init_device(self) -> None:
  157. if self.local_omp_cpuid != "all":
  158. ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
  159. logger.info(ret)
  160. self.init_distributed_environment()
  161. # Set random seed.
  162. set_random_seed(self.model_config.seed)
  163. def load_model(self):
  164. self.model_runner.load_model()
  165. def determine_num_available_blocks(self) -> Tuple[int, int]:
  166. """Determine the number of blocks available for the KV cache.
  167. This determines how many KV blocks can fit into the configured CPU
  168. KV cache space.
  169. Note that since Aphrodite assumes a block resides on GPU if it can be
  170. modified, we return num_gpu_blocks=num_cpu_blocks and num_cpu_blocks=0.
  171. This allows us to reuse the scheduler of Aphrodite without generalizing
  172. it to different devices.
  173. """
  174. # For CPU device, the block number will be calculated based on the
  175. # cpu_kvcache_space.
  176. cache_block_size = self.get_cache_block_size_bytes()
  177. num_cpu_blocks = int(self.cache_config.cpu_kvcache_space_bytes //
  178. cache_block_size)
  179. num_cpu_blocks = max(num_cpu_blocks, 0)
  180. # NOTE: To reuse the cache management procedure,
  181. # use cpu cache as 'gpu cache'.
  182. num_gpu_blocks = num_cpu_blocks
  183. num_cpu_blocks = 0
  184. return num_gpu_blocks, num_cpu_blocks
  185. def initialize_cache(self, num_gpu_blocks: int,
  186. num_cpu_blocks: int) -> None:
  187. """Initialize the KV cache. Currently, swappable CPU memory is not
  188. supported.
  189. Since this worker does not support GPUs, we use the num_gpu_blocks to
  190. determine how many non-swappable CPU blocks to allocate.
  191. """
  192. assert (num_cpu_blocks == 0
  193. ), f"{type(self)} does not support swappable cache"
  194. # NOTE: To reuse the cache management procedure,
  195. # use cpu cache as 'gpu cache'.
  196. num_cpu_blocks = num_gpu_blocks
  197. self._validate_num_cpu_blocks(num_cpu_blocks)
  198. self.cache_config.num_gpu_blocks = num_cpu_blocks
  199. self.cache_config.num_cpu_blocks = 0
  200. # Initialize the cache.
  201. self._init_cache_engine()
  202. def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None:
  203. """Raise errors if the num_cpu_blocks is invalid.
  204. """
  205. if num_cpu_blocks <= 0:
  206. raise ValueError(
  207. "No available memory for the cache blocks. "
  208. "Try increasing `APHRODITE_CPU_KVCACHE_SPACE` when "
  209. "initializing the engine.")
  210. max_seq_len = self.cache_config.block_size * num_cpu_blocks
  211. if self.model_config.max_model_len > max_seq_len:
  212. raise ValueError(
  213. f"The model's max seq len ({self.model_config.max_model_len}) "
  214. "is larger than the maximum number of tokens that can be "
  215. f"stored in KV cache ({max_seq_len}). Try increasing "
  216. "`APHRODITE_CPU_KVCACHE_SPACE` or decreasing `max_model_len` "
  217. "when initializing the engine.")
  218. def _init_cache_engine(self) -> None:
  219. self.cache_engine = [
  220. CPUCacheEngine(self.cache_config, self.model_config,
  221. self.parallel_config, self.device_config)
  222. for _ in range(self.parallel_config.pipeline_parallel_size)
  223. ]
  224. self.cpu_cache = [
  225. self.cache_engine[ve].cpu_cache
  226. for ve in range(self.parallel_config.pipeline_parallel_size)
  227. ]
  228. self.model_runner.block_size = self.cache_engine[0].block_size
  229. assert all(
  230. self.cpu_cache[ve] is not None
  231. for ve in range(self.parallel_config.pipeline_parallel_size))
  232. # Populate the cache to warmup the memory
  233. for ve in range(self.parallel_config.pipeline_parallel_size):
  234. for layer_cache in self.cpu_cache[ve]:
  235. layer_cache.fill_(0)
  236. @property
  237. def do_metadata_broadcast(self) -> bool:
  238. return self.parallel_config.tensor_parallel_size > 1
  239. @property
  240. def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
  241. return self.cpu_cache
  242. def execute_worker(
  243. self,
  244. worker_input: WorkerInput,
  245. ) -> None:
  246. if (worker_input.blocks_to_copy is not None
  247. and worker_input.blocks_to_copy.numel() > 0):
  248. self.cache_engine[worker_input.virtual_engine].copy(
  249. worker_input.blocks_to_copy)
  250. @torch.inference_mode()
  251. def prepare_worker_input(
  252. self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
  253. assert execute_model_req is not None
  254. virtual_engine = execute_model_req.virtual_engine
  255. num_seq_groups: int = len(execute_model_req.seq_group_metadata_list)
  256. blocks_to_copy = execute_model_req.blocks_to_copy
  257. blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
  258. device="cpu",
  259. dtype=torch.int64).view(-1, 2)
  260. assert len(execute_model_req.blocks_to_swap_in) == 0
  261. assert len(execute_model_req.blocks_to_swap_out) == 0
  262. return WorkerInput(
  263. num_seq_groups=num_seq_groups,
  264. blocks_to_copy=blocks_to_copy,
  265. virtual_engine=virtual_engine,
  266. )
  267. def init_distributed_environment(self) -> None:
  268. """Initialize the distributed environment."""
  269. parallel_config = self.parallel_config
  270. rank = self.rank
  271. distributed_init_method = self.distributed_init_method
  272. init_distributed_environment(
  273. world_size=parallel_config.world_size,
  274. rank=rank,
  275. distributed_init_method=distributed_init_method,
  276. backend="gloo",
  277. )
  278. # A small all_reduce for warmup.
  279. torch.distributed.all_reduce(torch.zeros(1).cpu())
  280. ensure_model_parallel_initialized(
  281. parallel_config.tensor_parallel_size,
  282. parallel_config.pipeline_parallel_size)
  283. def get_cache_block_size_bytes(self) -> int:
  284. """Return the size in bytes of a single KV cache block.
  285. """
  286. return CPUCacheEngine.get_cache_block_size(
  287. self.cache_config.block_size, self.cache_config.cache_dtype,
  288. self.model_config, self.parallel_config)