cpu_executor.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. import os
  2. from functools import partial
  3. from typing import Any, Awaitable, List, Optional, Set, Tuple, Union
  4. import torch
  5. from loguru import logger
  6. import aphrodite.common.envs as envs
  7. from aphrodite.common.config import CacheConfig, ModelConfig, SchedulerConfig
  8. from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
  9. from aphrodite.common.utils import (GiB_bytes, get_aphrodite_instance_id,
  10. get_distributed_init_method, get_open_port,
  11. make_async)
  12. from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
  13. from aphrodite.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
  14. ResultHandler,
  15. WorkerMonitor)
  16. from aphrodite.lora.request import LoRARequest
  17. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  18. from aphrodite.task_handler.worker_base import WorkerWrapperBase
  19. class CPUExecutor(ExecutorBase):
  20. uses_ray: bool = False
  21. def _init_executor(self) -> None:
  22. assert self.device_config.device_type == "cpu"
  23. assert self.lora_config is None, "cpu backend doesn't support LoRA"
  24. #
  25. # Environment variables for CPU executor
  26. #
  27. # Ensure that APHRODITE_INSTANCE_ID is set, to be inherited by workers
  28. os.environ["APHRODITE_INSTANCE_ID"] = get_aphrodite_instance_id()
  29. # Disable torch async compiling which won't work with daemonic
  30. # processes
  31. os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
  32. # Intel OpenMP setting
  33. ld_prealod_str = os.getenv("LD_PRELOAD", "")
  34. if "libiomp5.so" in ld_prealod_str:
  35. # The time(milliseconds) that a thread should wait after
  36. # completing the execution of a parallel region, before sleeping.
  37. os.environ['KMP_BLOCKTIME'] = "1"
  38. # Prevents the CPU to run into low performance state
  39. os.environ['KMP_TPAUSE'] = "0"
  40. # Provides fine granularity parallelism
  41. os.environ['KMP_FORKJOIN_BARRIER_PATTERN'] = "dist,dist"
  42. os.environ['KMP_PLAIN_BARRIER_PATTERN'] = "dist,dist"
  43. os.environ['KMP_REDUCTION_BARRIER_PATTERN'] = "dist,dist"
  44. # To hint IPEX uses shared memory based AllReduce
  45. os.environ["LOCAL_WORLD_SIZE"] = str(
  46. self.parallel_config.tensor_parallel_size)
  47. self.model_config = _verify_and_get_model_config(self.model_config)
  48. self.cache_config = _verify_and_get_cache_config(self.cache_config)
  49. self.scheduler_config = _verify_and_get_scheduler_config(
  50. self.scheduler_config)
  51. # Multiprocessing-based executor does not support multi-node setting.
  52. # Since it only works for single node, we can use the loopback address
  53. # 127.0.0.1 for communication.
  54. ip = "127.0.0.1"
  55. port = get_open_port()
  56. self.distributed_init_method = get_distributed_init_method(ip, port)
  57. is_async = isinstance(self, CPUExecutorAsync)
  58. world_size = self.parallel_config.tensor_parallel_size
  59. result_handler = ResultHandler()
  60. self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
  61. self.workers = []
  62. if is_async:
  63. self.workers = [
  64. ProcessWorkerWrapper(
  65. result_handler,
  66. partial(
  67. self._create_worker,
  68. rank=rank,
  69. local_rank=rank,
  70. )) for rank in range(0, world_size)
  71. ]
  72. self.driver_worker = self.workers[0]
  73. self.workers = self.workers[1:]
  74. self.driver_method_invoker = _async_driver_method_invoker
  75. else:
  76. self.driver_worker = self._create_worker()
  77. self.driver_method_invoker = _driver_method_invoker
  78. if world_size != 1:
  79. self.workers = [
  80. ProcessWorkerWrapper(
  81. result_handler,
  82. partial(
  83. self._create_worker,
  84. rank=rank,
  85. local_rank=rank,
  86. )) for rank in range(1, world_size)
  87. ]
  88. if world_size != 1 or is_async:
  89. if is_async:
  90. async_worker_list = self.workers + [self.driver_worker]
  91. else:
  92. async_worker_list = self.workers
  93. self.worker_monitor = WorkerMonitor(async_worker_list,
  94. result_handler)
  95. result_handler.start()
  96. self.worker_monitor.start()
  97. self._run_workers("init_device")
  98. self._run_workers("load_model")
  99. def _create_worker(
  100. self,
  101. local_rank: int = 0,
  102. rank: int = 0,
  103. ):
  104. worker_module_name = "aphrodite.task_handler.cpu_worker"
  105. worker_class_name = "CPUWorker"
  106. wrapper = WorkerWrapperBase(
  107. worker_module_name=worker_module_name,
  108. worker_class_name=worker_class_name,
  109. )
  110. assert self.distributed_init_method is not None
  111. kwargs = dict(
  112. model_config=self.model_config,
  113. parallel_config=self.parallel_config,
  114. scheduler_config=self.scheduler_config,
  115. device_config=self.device_config,
  116. cache_config=self.cache_config,
  117. load_config=self.load_config,
  118. local_rank=local_rank,
  119. rank=rank,
  120. distributed_init_method=self.distributed_init_method,
  121. lora_config=self.lora_config,
  122. kv_cache_dtype=self.cache_config.cache_dtype,
  123. prompt_adapter_config=self.prompt_adapter_config,
  124. is_driver_worker=rank == 0,
  125. )
  126. wrapper.init_worker(**kwargs)
  127. return wrapper.worker
  128. def _run_workers(
  129. self,
  130. method: str,
  131. *args,
  132. async_run_remote_workers_only: bool = False,
  133. max_concurrent_workers: Optional[int] = None,
  134. **kwargs,
  135. ) -> Any:
  136. """Runs the given method on all workers.
  137. Args:
  138. async_run_remote_workers_only: If True the method will be run only
  139. in the remote workers, not the driver worker. It will also be
  140. run asynchronously and return a list of futures rather than
  141. blocking on the results.
  142. """
  143. if max_concurrent_workers:
  144. raise NotImplementedError(
  145. "max_concurrent_workers is not supported yet.")
  146. # Start the workers first.
  147. worker_outputs = [
  148. worker.execute_method(method, *args, **kwargs)
  149. for worker in self.workers
  150. ]
  151. if async_run_remote_workers_only:
  152. # Just return futures
  153. return worker_outputs
  154. driver_worker_output = self.driver_method_invoker(
  155. self.driver_worker, method, *args, **kwargs)
  156. # Get the results of the workers.
  157. return [driver_worker_output
  158. ] + [output.get() for output in worker_outputs]
  159. def determine_num_available_blocks(self) -> Tuple[int, int]:
  160. """Determine the number of available KV blocks by invoking the
  161. underlying worker.
  162. """
  163. return self.driver_method_invoker(self.driver_worker,
  164. "determine_num_available_blocks")
  165. def initialize_cache(self, num_gpu_blocks: int,
  166. num_cpu_blocks: int) -> None:
  167. """Initialize the KV cache by invoking the underlying worker.
  168. """
  169. # NOTE: We log here to avoid multiple logs when number of workers is
  170. # greater than one. We could log in the engine, but not all executors
  171. # have GPUs.
  172. # NOTE: `cpu block` for CPU backend is located on CPU memory but is
  173. # referred as `gpu block`. Because we want to reuse the existing block
  174. # management procedure.
  175. logger.info("# CPU blocks: %d", num_gpu_blocks)
  176. self._run_workers("initialize_cache",
  177. num_gpu_blocks=num_gpu_blocks,
  178. num_cpu_blocks=num_cpu_blocks)
  179. def execute_model(
  180. self,
  181. execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
  182. if (self.parallel_config.tensor_parallel_size > 1
  183. and self.parallel_worker_tasks is None):
  184. self.parallel_worker_tasks = self._run_workers(
  185. "start_worker_execution_loop",
  186. async_run_remote_workers_only=True,
  187. )
  188. output = self.driver_method_invoker(self.driver_worker,
  189. "execute_model", execute_model_req)
  190. return output
  191. def stop_remote_worker_execution_loop(self) -> None:
  192. if self.parallel_worker_tasks is None:
  193. return
  194. """
  195. Passing None will cause the driver to stop the model execution
  196. loop running in each of the remote workers.
  197. """
  198. self.driver_method_invoker(self.driver_worker, "execute_model", None)
  199. parallel_worker_tasks = self.parallel_worker_tasks
  200. self.parallel_worker_tasks = None
  201. # Ensure that workers exit model loop cleanly
  202. # (this will raise otherwise)
  203. self._wait_for_tasks_completion(parallel_worker_tasks)
  204. def add_lora(self, lora_request: LoRARequest) -> bool:
  205. return all(self._run_workers("add_lora", lora_request))
  206. def remove_lora(self, lora_id: int) -> bool:
  207. return all(self._run_workers("remove_lora", lora_id))
  208. def pin_lora(self, lora_id: int) -> bool:
  209. assert lora_id > 0, "lora_id must be greater than 0."
  210. return all(self._run_workers(
  211. "pin_lora",
  212. lora_id=lora_id,
  213. ))
  214. def list_loras(self) -> Set[int]:
  215. return self.driver_method_invoker(self.driver_worker, "list_loras")
  216. def add_prompt_adapter(
  217. self, prompt_adapter_request: PromptAdapterRequest) -> bool:
  218. return all(
  219. self._run_workers(
  220. "add_prompt_adapter",
  221. prompt_adapter_request,
  222. ))
  223. def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  224. return all(
  225. self._run_workers(
  226. "remove_prompt_adapter",
  227. prompt_adapter_id,
  228. ))
  229. def list_prompt_adapters(self) -> Set[int]:
  230. return self.driver_method_invoker(self.driver_worker,
  231. "list_prompt_adapters")
  232. def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  233. return all(self._run_workers(
  234. "pin_prompt_adapter",
  235. prompt_adapter_id,
  236. ))
  237. def check_health(self) -> None:
  238. """Raises an error if engine is unhealthy."""
  239. if self.worker_monitor is not None and not self.worker_monitor.is_alive(
  240. ):
  241. raise RuntimeError("Worker processes are not running")
  242. def shutdown(self):
  243. if (worker_monitor := getattr(self, "worker_monitor",
  244. None)) is not None:
  245. worker_monitor.close()
  246. def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
  247. """Wait for futures returned from _run_workers() with
  248. async_run_remote_workers_only to complete."""
  249. for result in parallel_worker_tasks:
  250. result.get()
  251. class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase):
  252. async def execute_model_async(
  253. self,
  254. execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
  255. output = await make_async(self.execute_model
  256. )(execute_model_req=execute_model_req, )
  257. return output
  258. async def check_health_async(self) -> None:
  259. self.check_health()
  260. def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
  261. if config.dtype == torch.float16:
  262. logger.warning("float16 is not supported on CPU, casting to bfloat16.")
  263. config.dtype = torch.bfloat16
  264. if not config.enforce_eager:
  265. logger.warning(
  266. "CUDA graph is not supported on CPU, fallback to the eager "
  267. "mode.")
  268. config.enforce_eager = True
  269. return config
  270. def _verify_and_get_scheduler_config(
  271. config: SchedulerConfig) -> SchedulerConfig:
  272. if config.chunked_prefill_enabled:
  273. logger.warning("Chunked prefill is not supported on CPU, disable it.")
  274. config.chunked_prefill_enabled = False
  275. return config
  276. def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
  277. if config.enable_prefix_caching:
  278. logger.warning("Prefix caching is not supported on CPU, disable it.")
  279. config.enable_prefix_caching = False
  280. kv_cache_space_str = envs.APHRODITE_CPU_KVCACHE_SPACE
  281. kv_cache_space = int(kv_cache_space_str)
  282. if kv_cache_space >= 0:
  283. if kv_cache_space == 0:
  284. config.cpu_kvcache_space_bytes = 4 * GiB_bytes # type: ignore
  285. logger.warning(
  286. "Environment variable APHRODITE_CPU_KVCACHE_SPACE (GB) "
  287. "for CPU backend is not set, using 4 by default.")
  288. else:
  289. config.cpu_kvcache_space_bytes = kv_cache_space * GiB_bytes # type: ignore
  290. else:
  291. raise RuntimeError(
  292. "Invalid environment variable APHRODITE_CPU_KVCACHE_SPACE"
  293. f" {kv_cache_space}, expect a positive integer value.")
  294. return config
  295. def _driver_method_invoker(driver, method: str, *args, **kwargs):
  296. return getattr(driver, method)(*args, **kwargs)
  297. def _async_driver_method_invoker(driver, method: str, *args, **kwargs):
  298. return driver.execute_method(method, *args, **kwargs).get()