cpu_worker.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. """A CPU worker class."""
  2. from typing import Dict, List, Optional, Tuple
  3. import torch
  4. import torch.distributed
  5. from aphrodite.attention import get_attn_backend
  6. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  7. LoRAConfig, ModelConfig, ParallelConfig,
  8. SchedulerConfig, VisionLanguageConfig)
  9. from aphrodite.common.sequence import ExecuteModelRequest
  10. from aphrodite.common.utils import STR_DTYPE_TO_TORCH_DTYPE
  11. from aphrodite.distributed import (ensure_model_parallel_initialized,
  12. init_distributed_environment)
  13. from aphrodite.modeling import set_random_seed
  14. from aphrodite.task_handler.cpu_model_runner import CPUModelRunner
  15. from aphrodite.task_handler.worker_base import (LocalOrDistributedWorkerBase,
  16. LoraNotSupportedWorkerBase,
  17. WorkerInput)
  18. class CPUCacheEngine:
  19. """Manages the KV cache for CPU backend.
  20. This class is responsible for initializing and managing CPU KV
  21. caches. It also provides methods for performing KV cache operations, such
  22. as copying.
  23. """
  24. def __init__(self, cache_config: CacheConfig, model_config: ModelConfig,
  25. parallel_config: ParallelConfig,
  26. device_config: DeviceConfig) -> None:
  27. assert device_config.device_type == "cpu"
  28. self.cache_config = cache_config
  29. self.model_config = model_config
  30. self.parallel_config = parallel_config
  31. self.head_size = model_config.get_head_size()
  32. self.num_layers = model_config.get_num_layers(parallel_config)
  33. self.num_heads = model_config.get_num_kv_heads(parallel_config)
  34. self.block_size = cache_config.block_size
  35. # Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks
  36. # for CPU backend, because we want to reuse KV cache management
  37. # in the scheduler.
  38. self.num_cpu_blocks = cache_config.num_gpu_blocks
  39. if cache_config.cache_dtype == "auto":
  40. self.dtype = model_config.dtype
  41. else:
  42. self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
  43. # Get attention backend.
  44. self.attn_backend = get_attn_backend(
  45. self.model_config.get_num_attention_heads(self.parallel_config),
  46. self.model_config.get_head_size(),
  47. self.model_config.get_num_kv_heads(self.parallel_config),
  48. self.model_config.get_sliding_window(),
  49. self.model_config.dtype,
  50. cache_config.cache_dtype,
  51. self.block_size,
  52. )
  53. # Initialize the cache.
  54. self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks)
  55. def _allocate_kv_cache(
  56. self,
  57. num_blocks: int,
  58. ) -> List[torch.Tensor]:
  59. """Allocates KV cache on CPU."""
  60. kv_cache_shape = self.attn_backend.get_kv_cache_shape(
  61. num_blocks, self.block_size, self.num_heads, self.head_size)
  62. kv_cache: List[torch.Tensor] = []
  63. for _ in range(self.num_layers):
  64. kv_cache.append(
  65. torch.empty(kv_cache_shape, dtype=self.dtype, device="cpu"))
  66. return kv_cache
  67. def swap_in(self, src_to_dst: Dict[int, int]) -> None:
  68. raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
  69. def swap_out(self, src_to_dst: Dict[int, int]) -> None:
  70. raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
  71. def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
  72. self.attn_backend.copy_blocks(self.cpu_cache, src_to_dsts)
  73. @staticmethod
  74. def get_cache_block_size(
  75. block_size: int,
  76. cache_dtype: str,
  77. model_config: ModelConfig,
  78. parallel_config: ParallelConfig,
  79. ) -> int:
  80. head_size = model_config.get_head_size()
  81. num_heads = model_config.get_num_kv_heads(parallel_config)
  82. num_layers = model_config.get_num_layers(parallel_config)
  83. key_cache_block = block_size * num_heads * head_size
  84. value_cache_block = key_cache_block
  85. total = num_layers * (key_cache_block + value_cache_block)
  86. if cache_dtype == "auto":
  87. dtype = model_config.dtype
  88. else:
  89. dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
  90. dtype_size = torch.tensor([], dtype=dtype).element_size()
  91. return dtype_size * total
  92. class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
  93. """A worker class that executes (a partition of) the model on a CPU socket.
  94. Each worker is associated with a single CPU socket. The worker is
  95. responsible for maintaining the KV cache and executing the model on the
  96. CPU. In case of distributed inference, each worker is assigned a partition
  97. of the model.
  98. """
  99. def __init__(
  100. self,
  101. model_config: ModelConfig,
  102. parallel_config: ParallelConfig,
  103. scheduler_config: SchedulerConfig,
  104. device_config: DeviceConfig,
  105. cache_config: CacheConfig,
  106. load_config: LoadConfig,
  107. local_rank: int,
  108. rank: int,
  109. distributed_init_method: str,
  110. lora_config: Optional[LoRAConfig] = None,
  111. vision_language_config: Optional[VisionLanguageConfig] = None,
  112. kv_cache_dtype: Optional[str] = "auto",
  113. is_driver_worker: bool = False,
  114. ) -> None:
  115. self.model_config = model_config
  116. self.parallel_config = parallel_config
  117. self.scheduler_config = scheduler_config
  118. self.device_config = device_config
  119. self.cache_config = cache_config
  120. self.load_config = load_config
  121. self.local_rank = local_rank
  122. self.rank = rank
  123. self.distributed_init_method = distributed_init_method
  124. self.lora_config = lora_config
  125. self.vision_language_config = vision_language_config
  126. self.is_driver_worker = is_driver_worker
  127. if self.is_driver_worker:
  128. assert self.rank == 0, "The driver worker must have rank 0."
  129. if self.model_config.trust_remote_code:
  130. # note: lazy import to avoid importing torch before initializing
  131. from aphrodite.common.utils import init_cached_hf_modules
  132. init_cached_hf_modules()
  133. self.model_runner: CPUModelRunner = CPUModelRunner(
  134. model_config,
  135. parallel_config,
  136. scheduler_config,
  137. device_config,
  138. cache_config,
  139. load_config=self.load_config,
  140. lora_config=self.lora_config,
  141. vision_language_config=self.vision_language_config,
  142. kv_cache_dtype=kv_cache_dtype,
  143. is_driver_worker=is_driver_worker)
  144. # Uninitialized cache engine. Will be initialized by
  145. # initialize_cache.
  146. self.cache_engine: List[CPUCacheEngine]
  147. self.cpu_cache: List[List[torch.Tensor]]
  148. def init_device(self) -> None:
  149. self.init_distributed_environment()
  150. # Set random seed.
  151. set_random_seed(self.model_config.seed)
  152. def load_model(self):
  153. self.model_runner.load_model()
  154. def determine_num_available_blocks(self) -> Tuple[int, int]:
  155. """Determine the number of blocks available for the KV cache.
  156. This determines how many KV blocks can fit into the configured CPU
  157. KV cache space.
  158. Note that since Aphrodite assumes a block resides on GPU if it can be
  159. modified, we return num_gpu_blocks=num_cpu_blocks and num_cpu_blocks=0.
  160. This allows us to reuse the scheduler of Aphrodite without generalizing
  161. it to different devices.
  162. """
  163. # For CPU device, the block number will be calculated based on the
  164. # cpu_kvcache_space.
  165. cache_block_size = self.get_cache_block_size_bytes()
  166. num_cpu_blocks = int(self.cache_config.cpu_kvcache_space_bytes //
  167. cache_block_size)
  168. num_cpu_blocks = max(num_cpu_blocks, 0)
  169. # NOTE: To reuse the cache management procedure,
  170. # use cpu cache as 'gpu cache'.
  171. num_gpu_blocks = num_cpu_blocks
  172. num_cpu_blocks = 0
  173. return num_gpu_blocks, num_cpu_blocks
  174. def initialize_cache(self, num_gpu_blocks: int,
  175. num_cpu_blocks: int) -> None:
  176. """Initialize the KV cache. Currently, swappable CPU memory is not
  177. supported.
  178. Since this worker does not support GPUs, we use the num_gpu_blocks to
  179. determine how many non-swappable CPU blocks to allocate.
  180. """
  181. assert (num_cpu_blocks == 0
  182. ), f"{type(self)} does not support swappable cache"
  183. # NOTE: To reuse the cache management procedure,
  184. # use cpu cache as 'gpu cache'.
  185. num_cpu_blocks = num_gpu_blocks
  186. self._validate_num_cpu_blocks(num_cpu_blocks)
  187. self.cache_config.num_gpu_blocks = num_cpu_blocks
  188. self.cache_config.num_cpu_blocks = 0
  189. # Initialize the cache.
  190. self._init_cache_engine()
  191. def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None:
  192. """Raise errors if the num_cpu_blocks is invalid.
  193. """
  194. if num_cpu_blocks <= 0:
  195. raise ValueError(
  196. "No available memory for the cache blocks. "
  197. "Try increasing `APHRODITE_CPU_KVCACHE_SPACE` when "
  198. "initializing the engine.")
  199. max_seq_len = self.cache_config.block_size * num_cpu_blocks
  200. if self.model_config.max_model_len > max_seq_len:
  201. raise ValueError(
  202. f"The model's max seq len ({self.model_config.max_model_len}) "
  203. "is larger than the maximum number of tokens that can be "
  204. f"stored in KV cache ({max_seq_len}). Try increasing "
  205. "`APHRODITE_CPU_KVCACHE_SPACE` or decreasing `max_model_len` "
  206. "when initializing the engine.")
  207. def _init_cache_engine(self) -> None:
  208. self.cache_engine = [
  209. CPUCacheEngine(self.cache_config, self.model_config,
  210. self.parallel_config, self.device_config)
  211. for _ in range(self.parallel_config.pipeline_parallel_size)
  212. ]
  213. self.cpu_cache = [
  214. self.cache_engine[ve].cpu_cache
  215. for ve in range(self.parallel_config.pipeline_parallel_size)
  216. ]
  217. self.model_runner.block_size = self.cache_engine[0].block_size
  218. assert all(
  219. self.cpu_cache[ve] is not None
  220. for ve in range(self.parallel_config.pipeline_parallel_size))
  221. # Populate the cache to warmup the memory
  222. for ve in range(self.parallel_config.pipeline_parallel_size):
  223. for layer_cache in self.cpu_cache[ve]:
  224. layer_cache.fill_(0)
  225. @property
  226. def do_metadata_broadcast(self) -> bool:
  227. return self.parallel_config.tensor_parallel_size > 1
  228. @property
  229. def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
  230. return self.cpu_cache
  231. def execute_worker(
  232. self,
  233. worker_input: WorkerInput,
  234. ) -> None:
  235. if (worker_input.blocks_to_copy is not None
  236. and worker_input.blocks_to_copy.numel() > 0):
  237. self.cache_engine[worker_input.virtual_engine].copy(
  238. worker_input.blocks_to_copy)
  239. @torch.inference_mode()
  240. def prepare_worker_input(
  241. self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
  242. assert execute_model_req is not None
  243. virtual_engine = execute_model_req.virtual_engine
  244. num_seq_groups: int = len(execute_model_req.seq_group_metadata_list)
  245. blocks_to_copy = execute_model_req.blocks_to_copy
  246. blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
  247. device="cpu",
  248. dtype=torch.int64).view(-1, 2)
  249. assert len(execute_model_req.blocks_to_swap_in) == 0
  250. assert len(execute_model_req.blocks_to_swap_out) == 0
  251. return WorkerInput(
  252. num_seq_groups=num_seq_groups,
  253. blocks_to_copy=blocks_to_copy,
  254. virtual_engine=virtual_engine,
  255. )
  256. def init_distributed_environment(self) -> None:
  257. """Initialize the distributed environment."""
  258. parallel_config = self.parallel_config
  259. rank = self.rank
  260. distributed_init_method = self.distributed_init_method
  261. init_distributed_environment(
  262. world_size=parallel_config.world_size,
  263. rank=rank,
  264. distributed_init_method=distributed_init_method,
  265. backend="gloo",
  266. )
  267. # A small all_reduce for warmup.
  268. torch.distributed.all_reduce(torch.zeros(1).cpu())
  269. ensure_model_parallel_initialized(
  270. parallel_config.tensor_parallel_size,
  271. parallel_config.pipeline_parallel_size)
  272. def get_cache_block_size_bytes(self) -> int:
  273. """Return the size in bytes of a single KV cache block.
  274. """
  275. return CPUCacheEngine.get_cache_block_size(
  276. self.cache_config.block_size, self.cache_config.cache_dtype,
  277. self.model_config, self.parallel_config)