cpu_executor.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. import os
  2. from typing import Dict, List, Optional
  3. import torch
  4. from loguru import logger
  5. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoRAConfig,
  6. ModelConfig, ParallelConfig,
  7. SchedulerConfig, SpeculativeConfig)
  8. from aphrodite.executor.executor_base import ExecutorBase
  9. from aphrodite.lora.request import LoRARequest
  10. from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
  11. from aphrodite.common.utils import (
  12. get_distributed_init_method,
  13. get_ip,
  14. get_open_port,
  15. make_async,
  16. )
  17. class CPUExecutor(ExecutorBase):
  18. def __init__(
  19. self,
  20. model_config: ModelConfig,
  21. cache_config: CacheConfig,
  22. parallel_config: ParallelConfig,
  23. scheduler_config: SchedulerConfig,
  24. device_config: DeviceConfig,
  25. lora_config: Optional[LoRAConfig],
  26. speculative_config: Optional[SpeculativeConfig],
  27. *args,
  28. **kwargs,
  29. ) -> None:
  30. assert device_config.device_type == "cpu"
  31. assert lora_config is None, "cpu backend doesn't support LoRA"
  32. assert (not speculative_config
  33. ), "Speculative decoding not yet supported for CPU backend."
  34. model_config = _verify_and_get_model_config(model_config)
  35. cache_config = _verify_and_get_cache_config(cache_config)
  36. scheduler_config = _verify_and_get_scheduler_config(scheduler_config)
  37. self.model_config = model_config
  38. self.cache_config = cache_config
  39. self.lora_config = lora_config
  40. self.parallel_config = parallel_config
  41. self.scheduler_config = scheduler_config
  42. self.device_config = device_config
  43. # Instantiate the worker and load the model to CPU.
  44. self._init_worker()
  45. def _init_worker(self):
  46. from aphrodite.task_handler.cpu_worker import CPUWorker
  47. assert (self.parallel_config.world_size == 1
  48. ), "CPUExecutor only supports single CPU socket currently."
  49. distributed_init_method = get_distributed_init_method(
  50. get_ip(), get_open_port())
  51. self.driver_worker = CPUWorker(
  52. model_config=self.model_config,
  53. parallel_config=self.parallel_config,
  54. scheduler_config=self.scheduler_config,
  55. device_config=self.device_config,
  56. cache_config=self.cache_config,
  57. local_rank=0,
  58. rank=0,
  59. distributed_init_method=distributed_init_method,
  60. lora_config=self.lora_config,
  61. kv_cache_dtype=self.cache_config.cache_dtype,
  62. is_driver_worker=True,
  63. )
  64. self.driver_worker.init_device()
  65. self.driver_worker.load_model()
  66. def determine_num_available_blocks(self) -> tuple[int, int]:
  67. """Determine the number of available KV blocks by invoking the
  68. underlying worker.
  69. """
  70. return self.driver_worker.determine_num_available_blocks()
  71. def initialize_cache(self, num_gpu_blocks: int,
  72. num_cpu_blocks: int) -> None:
  73. """Initialize the KV cache by invoking the underlying worker.
  74. """
  75. # NOTE: We log here to avoid multiple logs when number of workers is
  76. # greater than one. We could log in the engine, but not all executors
  77. # have GPUs.
  78. logger.info(f"# CPU blocks: {num_cpu_blocks}")
  79. self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
  80. def execute_model(
  81. self,
  82. seq_group_metadata_list: List[SequenceGroupMetadata],
  83. blocks_to_swap_in: Dict[int, int],
  84. blocks_to_swap_out: Dict[int, int],
  85. blocks_to_copy: Dict[int, List[int]],
  86. ) -> SamplerOutput:
  87. output = self.driver_worker.execute_model(
  88. seq_group_metadata_list=seq_group_metadata_list,
  89. blocks_to_swap_in=blocks_to_swap_in,
  90. blocks_to_swap_out=blocks_to_swap_out,
  91. blocks_to_copy=blocks_to_copy,
  92. )
  93. return output
  94. async def execute_model_async(
  95. self,
  96. seq_group_metadata_list: List[SequenceGroupMetadata],
  97. blocks_to_swap_in: Dict[int, int],
  98. blocks_to_swap_out: Dict[int, int],
  99. blocks_to_copy: Dict[int, List[int]],
  100. ) -> SamplerOutput:
  101. output = await make_async(self.driver_worker.execute_model)(
  102. seq_group_metadata_list=seq_group_metadata_list,
  103. blocks_to_swap_in=blocks_to_swap_in,
  104. blocks_to_swap_out=blocks_to_swap_out,
  105. blocks_to_copy=blocks_to_copy)
  106. return output
  107. def add_lora(self, lora_request: LoRARequest) -> bool:
  108. return self.driver_worker.add_lora(lora_request)
  109. def remove_lora(self, lora_id: int) -> bool:
  110. return self.driver_worker.remove_lora(lora_id)
  111. def list_loras(self) -> List[int]:
  112. return self.driver_worker.list_loras()
  113. def check_health(self) -> None:
  114. # CPUExecutor will always be healthy as long as
  115. # it's running.
  116. return
  117. def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
  118. if config.dtype == torch.float16:
  119. logger.warning("float16 is not supported on CPU, casting to bfloat16.")
  120. config.dtype = torch.bfloat16
  121. if not config.enforce_eager:
  122. logger.warning(
  123. "CUDA graph is not supported on CPU, fallback to the eager "
  124. "mode.")
  125. config.enforce_eager = True
  126. return config
  127. def _verify_and_get_scheduler_config(
  128. config: SchedulerConfig) -> SchedulerConfig:
  129. if config.chunked_prefill_enabled:
  130. logger.warning("Chunked prefill is not supported on CPU, disable it.")
  131. config.chunked_prefill_enabled = False
  132. return config
  133. def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
  134. _GB = 1 << 30
  135. if config.context_shift:
  136. logger.warning("Prefix caching is not supported on CPU, disable it.")
  137. config.context_shift = False
  138. kv_cache_space_str = os.getenv("APHRODITE_CPU_KVCACHE_SPACE", "0")
  139. kv_cache_space = int(kv_cache_space_str)
  140. if kv_cache_space >= 0:
  141. if kv_cache_space == 0:
  142. config.cpu_kvcache_space_bytes = 4 * _GB # type: ignore
  143. logger.warning(
  144. "Environment variable APHRODITE_CPU_KVCACHE_SPACE (GB) "
  145. "for CPU backend is not set, using 4 by default.")
  146. else:
  147. config.cpu_kvcache_space_bytes = kv_cache_space * _GB # type: ignore
  148. else:
  149. raise RuntimeError(
  150. "Invalid environment variable APHRODITE_CPU_KVCACHE_SPACE"
  151. f" {kv_cache_space}, expect a positive integer value.")
  152. return config