cpu_executor.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. import os
  2. from typing import List, Set, Tuple
  3. import torch
  4. from loguru import logger
  5. from aphrodite.common.config import CacheConfig, ModelConfig, SchedulerConfig
  6. from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
  7. from aphrodite.common.utils import (get_distributed_init_method, get_ip,
  8. get_open_port, make_async)
  9. from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
  10. from aphrodite.lora.request import LoRARequest
  11. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  12. class CPUExecutor(ExecutorBase):
  13. uses_ray: bool = False
  14. def _init_executor(self) -> None:
  15. assert self.device_config.device_type == "cpu"
  16. assert self.lora_config is None, "cpu backend doesn't support LoRA"
  17. self.model_config = _verify_and_get_model_config(self.model_config)
  18. self.cache_config = _verify_and_get_cache_config(self.cache_config)
  19. self.scheduler_config = _verify_and_get_scheduler_config(
  20. self.scheduler_config)
  21. # Instantiate the worker and load the model to CPU.
  22. self._init_worker()
  23. def _init_worker(self):
  24. from aphrodite.task_handler.cpu_worker import CPUWorker
  25. assert self.parallel_config.world_size == 1, (
  26. "CPUExecutor only supports single CPU socket currently.")
  27. distributed_init_method = get_distributed_init_method(
  28. get_ip(), get_open_port())
  29. self.driver_worker = CPUWorker(
  30. model_config=self.model_config,
  31. parallel_config=self.parallel_config,
  32. scheduler_config=self.scheduler_config,
  33. device_config=self.device_config,
  34. cache_config=self.cache_config,
  35. load_config=self.load_config,
  36. local_rank=0,
  37. rank=0,
  38. distributed_init_method=distributed_init_method,
  39. lora_config=self.lora_config,
  40. multimodal_config=self.multimodal_config,
  41. kv_cache_dtype=self.cache_config.cache_dtype,
  42. prompt_adapter_config=self.prompt_adapter_config,
  43. is_driver_worker=True,
  44. )
  45. self.driver_worker.init_device()
  46. self.driver_worker.load_model()
  47. def determine_num_available_blocks(self) -> Tuple[int, int]:
  48. """Determine the number of available KV blocks by invoking the
  49. underlying worker.
  50. """
  51. return self.driver_worker.determine_num_available_blocks()
  52. def initialize_cache(self, num_gpu_blocks: int,
  53. num_cpu_blocks: int) -> None:
  54. """Initialize the KV cache by invoking the underlying worker.
  55. """
  56. # NOTE: We log here to avoid multiple logs when number of workers is
  57. # greater than one. We could log in the engine, but not all executors
  58. # have GPUs.
  59. # NOTE: `cpu block` for CPU backend is located on CPU memory but is
  60. # referred as `gpu block`. Because we want to reuse the existing block
  61. # management procedure.
  62. logger.info(f"# CPU blocks: {num_gpu_blocks}")
  63. self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
  64. def execute_model(
  65. self,
  66. execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
  67. output = self.driver_worker.execute_model(execute_model_req)
  68. return output
  69. def add_lora(self, lora_request: LoRARequest) -> bool:
  70. return self.driver_worker.add_lora(lora_request)
  71. def remove_lora(self, lora_id: int) -> bool:
  72. return self.driver_worker.remove_lora(lora_id)
  73. def list_loras(self) -> Set[int]:
  74. return self.driver_worker.list_loras()
  75. def add_prompt_adapter(
  76. self, prompt_adapter_request: PromptAdapterRequest) -> bool:
  77. return self.driver_worker.add_prompt_adapter(prompt_adapter_request)
  78. def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  79. return self.driver_worker.remove_prompt_adapter(prompt_adapter_id)
  80. def list_prompt_adapters(self) -> Set[int]:
  81. return self.driver_worker.list_prompt_adapters()
  82. def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  83. return self.driver_worker.pin_prompt_adapter(prompt_adapter_id)
  84. def pin_lora(self, lora_id: int) -> bool:
  85. return self.driver_worker.pin_lora(lora_id)
  86. def check_health(self) -> None:
  87. # CPUExecutor will always be healthy as long as
  88. # it's running.
  89. return
  90. class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase):
  91. async def execute_model_async(
  92. self,
  93. execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
  94. output = await make_async(self.driver_worker.execute_model
  95. )(execute_model_req=execute_model_req, )
  96. return output
  97. async def check_health_async(self) -> None:
  98. # CPUExecutor will always be healthy as long as
  99. # it's running.
  100. return
  101. def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
  102. if config.dtype == torch.float16:
  103. logger.warning("float16 is not supported on CPU, casting to bfloat16.")
  104. config.dtype = torch.bfloat16
  105. if not config.enforce_eager:
  106. logger.warning(
  107. "CUDA graph is not supported on CPU, fallback to the eager "
  108. "mode.")
  109. config.enforce_eager = True
  110. return config
  111. def _verify_and_get_scheduler_config(
  112. config: SchedulerConfig) -> SchedulerConfig:
  113. if config.chunked_prefill_enabled:
  114. logger.warning("Chunked prefill is not supported on CPU, disable it.")
  115. config.chunked_prefill_enabled = False
  116. return config
  117. def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
  118. _GB = 1 << 30
  119. if config.enable_prefix_caching:
  120. logger.warning("Prefix caching is not supported on CPU, disable it.")
  121. config.enable_prefix_caching = False
  122. kv_cache_space_str = os.getenv("APHRODITE_CPU_KVCACHE_SPACE", "0")
  123. kv_cache_space = int(kv_cache_space_str)
  124. if kv_cache_space >= 0:
  125. if kv_cache_space == 0:
  126. config.cpu_kvcache_space_bytes = 4 * _GB # type: ignore
  127. logger.warning(
  128. "Environment variable APHRODITE_CPU_KVCACHE_SPACE (GB) "
  129. "for CPU backend is not set, using 4 by default.")
  130. else:
  131. config.cpu_kvcache_space_bytes = kv_cache_space * _GB # type: ignore
  132. else:
  133. raise RuntimeError(
  134. "Invalid environment variable APHRODITE_CPU_KVCACHE_SPACE"
  135. f" {kv_cache_space}, expect a positive integer value.")
  136. return config