cpu_executor.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. import os
  2. from typing import Dict, List, Set, Tuple
  3. import torch
  4. from loguru import logger
  5. from aphrodite.common.config import CacheConfig, ModelConfig, SchedulerConfig
  6. from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
  7. from aphrodite.common.utils import (get_distributed_init_method, get_ip,
  8. get_open_port, make_async)
  9. from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
  10. from aphrodite.lora.request import LoRARequest
  11. class CPUExecutor(ExecutorBase):
  12. def _init_executor(self) -> None:
  13. assert self.device_config.device_type == "cpu"
  14. assert self.lora_config is None, "cpu backend doesn't support LoRA"
  15. self.model_config = _verify_and_get_model_config(self.model_config)
  16. self.cache_config = _verify_and_get_cache_config(self.cache_config)
  17. self.scheduler_config = _verify_and_get_scheduler_config(
  18. self.scheduler_config)
  19. # Instantiate the worker and load the model to CPU.
  20. self._init_worker()
  21. def _init_worker(self):
  22. from aphrodite.task_handler.cpu_worker import CPUWorker
  23. assert self.parallel_config.world_size == 1, (
  24. "CPUExecutor only supports single CPU socket currently.")
  25. distributed_init_method = get_distributed_init_method(
  26. get_ip(), get_open_port())
  27. self.driver_worker = CPUWorker(
  28. model_config=self.model_config,
  29. parallel_config=self.parallel_config,
  30. scheduler_config=self.scheduler_config,
  31. device_config=self.device_config,
  32. cache_config=self.cache_config,
  33. load_config=self.load_config,
  34. local_rank=0,
  35. rank=0,
  36. distributed_init_method=distributed_init_method,
  37. lora_config=self.lora_config,
  38. vision_language_config=self.vision_language_config,
  39. kv_cache_dtype=self.cache_config.cache_dtype,
  40. is_driver_worker=True,
  41. )
  42. self.driver_worker.init_device()
  43. self.driver_worker.load_model()
  44. def determine_num_available_blocks(self) -> Tuple[int, int]:
  45. """Determine the number of available KV blocks by invoking the
  46. underlying worker.
  47. """
  48. return self.driver_worker.determine_num_available_blocks()
  49. def initialize_cache(self, num_gpu_blocks: int,
  50. num_cpu_blocks: int) -> None:
  51. """Initialize the KV cache by invoking the underlying worker.
  52. """
  53. # NOTE: We log here to avoid multiple logs when number of workers is
  54. # greater than one. We could log in the engine, but not all executors
  55. # have GPUs.
  56. # NOTE: `cpu block` for CPU backend is located on CPU memory but is
  57. # referred as `gpu block`. Because we want to reuse the existing block
  58. # management procedure.
  59. logger.info(f"# CPU blocks: {num_gpu_blocks}")
  60. self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
  61. def execute_model(self,
  62. seq_group_metadata_list: List[SequenceGroupMetadata],
  63. blocks_to_swap_in: Dict[int, int],
  64. blocks_to_swap_out: Dict[int, int],
  65. blocks_to_copy: Dict[int, List[int]],
  66. num_lookahead_slots: int) -> List[SamplerOutput]:
  67. output = self.driver_worker.execute_model(
  68. seq_group_metadata_list=seq_group_metadata_list,
  69. blocks_to_swap_in=blocks_to_swap_in,
  70. blocks_to_swap_out=blocks_to_swap_out,
  71. blocks_to_copy=blocks_to_copy,
  72. num_lookahead_slots=num_lookahead_slots,
  73. )
  74. return output
  75. def add_lora(self, lora_request: LoRARequest) -> bool:
  76. return self.driver_worker.add_lora(lora_request)
  77. def remove_lora(self, lora_id: int) -> bool:
  78. return self.driver_worker.remove_lora(lora_id)
  79. def list_loras(self) -> Set[int]:
  80. return self.driver_worker.list_loras()
  81. def check_health(self) -> None:
  82. # CPUExecutor will always be healthy as long as
  83. # it's running.
  84. return
  85. class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase):
  86. async def execute_model_async(
  87. self,
  88. seq_group_metadata_list: List[SequenceGroupMetadata],
  89. blocks_to_swap_in: Dict[int, int],
  90. blocks_to_swap_out: Dict[int, int],
  91. blocks_to_copy: Dict[int, List[int]],
  92. num_lookahead_slots: int,
  93. ) -> SamplerOutput:
  94. output = await make_async(self.driver_worker.execute_model)(
  95. seq_group_metadata_list=seq_group_metadata_list,
  96. blocks_to_swap_in=blocks_to_swap_in,
  97. blocks_to_swap_out=blocks_to_swap_out,
  98. blocks_to_copy=blocks_to_copy,
  99. num_lookahead_slots=num_lookahead_slots)
  100. return output
  101. async def check_health_async(self) -> None:
  102. # CPUExecutor will always be healthy as long as
  103. # it's running.
  104. return
  105. def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
  106. if config.dtype == torch.float16:
  107. logger.warning("float16 is not supported on CPU, casting to bfloat16.")
  108. config.dtype = torch.bfloat16
  109. if not config.enforce_eager:
  110. logger.warning(
  111. "CUDA graph is not supported on CPU, fallback to the eager "
  112. "mode.")
  113. config.enforce_eager = True
  114. return config
  115. def _verify_and_get_scheduler_config(
  116. config: SchedulerConfig) -> SchedulerConfig:
  117. if config.chunked_prefill_enabled:
  118. logger.warning("Chunked prefill is not supported on CPU, disable it.")
  119. config.chunked_prefill_enabled = False
  120. return config
  121. def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
  122. _GB = 1 << 30
  123. if config.enable_prefix_caching:
  124. logger.warning("Prefix caching is not supported on CPU, disable it.")
  125. config.enable_prefix_caching = False
  126. kv_cache_space_str = os.getenv("APHRODITE_CPU_KVCACHE_SPACE", "0")
  127. kv_cache_space = int(kv_cache_space_str)
  128. if kv_cache_space >= 0:
  129. if kv_cache_space == 0:
  130. config.cpu_kvcache_space_bytes = 4 * _GB # type: ignore
  131. logger.warning(
  132. "Environment variable APHRODITE_CPU_KVCACHE_SPACE (GB) "
  133. "for CPU backend is not set, using 4 by default.")
  134. else:
  135. config.cpu_kvcache_space_bytes = kv_cache_space * _GB # type: ignore
  136. else:
  137. raise RuntimeError(
  138. "Invalid environment variable APHRODITE_CPU_KVCACHE_SPACE"
  139. f" {kv_cache_space}, expect a positive integer value.")
  140. return config