cpu_executor.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. import os
  2. from functools import partial
  3. from typing import Any, Awaitable, List, Optional, Set, Tuple, Union
  4. import torch
  5. from loguru import logger
  6. import aphrodite.common.envs as envs
  7. from aphrodite.common.config import (CacheConfig, ModelConfig, ParallelConfig,
  8. SchedulerConfig)
  9. from aphrodite.common.sequence import ExecuteModelRequest
  10. from aphrodite.common.utils import (GiB_bytes, get_aphrodite_instance_id,
  11. get_distributed_init_method, get_open_port,
  12. make_async)
  13. from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
  14. from aphrodite.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
  15. ResultHandler,
  16. WorkerMonitor)
  17. from aphrodite.lora.request import LoRARequest
  18. from aphrodite.modeling.layers.sampler import SamplerOutput
  19. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  20. from aphrodite.worker.worker_base import WorkerWrapperBase
  21. class CPUExecutor(ExecutorBase):
  22. uses_ray: bool = False
  23. def _init_executor(self) -> None:
  24. assert self.device_config.device_type == "cpu"
  25. assert self.lora_config is None, "cpu backend doesn't support LoRA"
  26. #
  27. # Environment variables for CPU executor
  28. #
  29. # Ensure that APHRODITE_INSTANCE_ID is set, to be inherited by workers
  30. os.environ["APHRODITE_INSTANCE_ID"] = get_aphrodite_instance_id()
  31. # Disable torch async compiling which won't work with daemonic
  32. # processes
  33. os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
  34. # Intel OpenMP setting
  35. ld_prealod_str = os.getenv("LD_PRELOAD", "")
  36. if "libiomp5.so" in ld_prealod_str:
  37. # The time(milliseconds) that a thread should wait after
  38. # completing the execution of a parallel region, before sleeping.
  39. os.environ['KMP_BLOCKTIME'] = "1"
  40. # Prevents the CPU to run into low performance state
  41. os.environ['KMP_TPAUSE'] = "0"
  42. # Provides fine granularity parallelism
  43. os.environ['KMP_FORKJOIN_BARRIER_PATTERN'] = "dist,dist"
  44. os.environ['KMP_PLAIN_BARRIER_PATTERN'] = "dist,dist"
  45. os.environ['KMP_REDUCTION_BARRIER_PATTERN'] = "dist,dist"
  46. # To hint IPEX uses shared memory based AllReduce
  47. os.environ["LOCAL_WORLD_SIZE"] = str(
  48. self.parallel_config.tensor_parallel_size)
  49. self.model_config = _verify_and_get_model_config(self.model_config)
  50. self.cache_config = _verify_and_get_cache_config(self.cache_config)
  51. self.scheduler_config = _verify_and_get_scheduler_config(
  52. self.scheduler_config)
  53. self.parallel_config = _verify_and_get_parallel_config(
  54. self.parallel_config)
  55. # Multiprocessing-based executor does not support multi-node setting.
  56. # Since it only works for single node, we can use the loopback address
  57. # 127.0.0.1 for communication.
  58. ip = "127.0.0.1"
  59. port = get_open_port()
  60. self.distributed_init_method = get_distributed_init_method(ip, port)
  61. is_async = isinstance(self, CPUExecutorAsync)
  62. world_size = self.parallel_config.tensor_parallel_size
  63. result_handler = ResultHandler()
  64. self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
  65. self.workers = []
  66. if is_async:
  67. self.workers = [
  68. ProcessWorkerWrapper(
  69. result_handler,
  70. partial(
  71. self._create_worker,
  72. rank=rank,
  73. local_rank=rank,
  74. )) for rank in range(0, world_size)
  75. ]
  76. self.driver_worker = self.workers[0]
  77. self.workers = self.workers[1:]
  78. self.driver_method_invoker = _async_driver_method_invoker
  79. else:
  80. self.driver_worker = self._create_worker()
  81. self.driver_method_invoker = _driver_method_invoker
  82. if world_size != 1:
  83. self.workers = [
  84. ProcessWorkerWrapper(
  85. result_handler,
  86. partial(
  87. self._create_worker,
  88. rank=rank,
  89. local_rank=rank,
  90. )) for rank in range(1, world_size)
  91. ]
  92. self.worker_monitor = None
  93. if world_size != 1 or is_async:
  94. if is_async:
  95. async_worker_list = self.workers + [self.driver_worker]
  96. else:
  97. async_worker_list = self.workers
  98. self.worker_monitor = WorkerMonitor(async_worker_list,
  99. result_handler)
  100. result_handler.start()
  101. self.worker_monitor.start()
  102. self._run_workers("init_device")
  103. self._run_workers("load_model")
  104. def _create_worker(
  105. self,
  106. local_rank: int = 0,
  107. rank: int = 0,
  108. ):
  109. worker_module_name = "aphrodite.worker.cpu_worker"
  110. worker_class_name = "CPUWorker"
  111. wrapper = WorkerWrapperBase(
  112. worker_module_name=worker_module_name,
  113. worker_class_name=worker_class_name,
  114. )
  115. assert self.distributed_init_method is not None
  116. kwargs = dict(
  117. model_config=self.model_config,
  118. parallel_config=self.parallel_config,
  119. scheduler_config=self.scheduler_config,
  120. device_config=self.device_config,
  121. cache_config=self.cache_config,
  122. load_config=self.load_config,
  123. local_rank=local_rank,
  124. rank=rank,
  125. distributed_init_method=self.distributed_init_method,
  126. lora_config=self.lora_config,
  127. kv_cache_dtype=self.cache_config.cache_dtype,
  128. prompt_adapter_config=self.prompt_adapter_config,
  129. is_driver_worker=rank == 0,
  130. )
  131. wrapper.init_worker(**kwargs)
  132. return wrapper.worker
  133. def _run_workers(
  134. self,
  135. method: str,
  136. *args,
  137. async_run_remote_workers_only: bool = False,
  138. max_concurrent_workers: Optional[int] = None,
  139. **kwargs,
  140. ) -> Any:
  141. """Runs the given method on all workers.
  142. Args:
  143. async_run_remote_workers_only: If True the method will be run only
  144. in the remote workers, not the driver worker. It will also be
  145. run asynchronously and return a list of futures rather than
  146. blocking on the results.
  147. """
  148. if max_concurrent_workers:
  149. raise NotImplementedError(
  150. "max_concurrent_workers is not supported yet.")
  151. # Start the workers first.
  152. worker_outputs = [
  153. worker.execute_method(method, *args, **kwargs)
  154. for worker in self.workers
  155. ]
  156. if async_run_remote_workers_only:
  157. # Just return futures
  158. return worker_outputs
  159. driver_worker_output = self.driver_method_invoker(
  160. self.driver_worker, method, *args, **kwargs)
  161. # Get the results of the workers.
  162. return [driver_worker_output
  163. ] + [output.get() for output in worker_outputs]
  164. def determine_num_available_blocks(self) -> Tuple[int, int]:
  165. """Determine the number of available KV blocks by invoking the
  166. underlying worker.
  167. """
  168. return self.driver_method_invoker(self.driver_worker,
  169. "determine_num_available_blocks")
  170. def initialize_cache(self, num_gpu_blocks: int,
  171. num_cpu_blocks: int) -> None:
  172. """Initialize the KV cache by invoking the underlying worker.
  173. """
  174. # NOTE: We log here to avoid multiple logs when number of workers is
  175. # greater than one. We could log in the engine, but not all executors
  176. # have GPUs.
  177. # NOTE: `cpu block` for CPU backend is located on CPU memory but is
  178. # referred as `gpu block`. Because we want to reuse the existing block
  179. # management procedure.
  180. logger.info(f"# CPU blocks: {num_gpu_blocks}")
  181. self._run_workers("initialize_cache",
  182. num_gpu_blocks=num_gpu_blocks,
  183. num_cpu_blocks=num_cpu_blocks)
  184. def execute_model(
  185. self,
  186. execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
  187. if (self.parallel_config.tensor_parallel_size > 1
  188. and self.parallel_worker_tasks is None):
  189. self.parallel_worker_tasks = self._run_workers(
  190. "start_worker_execution_loop",
  191. async_run_remote_workers_only=True,
  192. )
  193. output = self.driver_method_invoker(self.driver_worker,
  194. "execute_model", execute_model_req)
  195. return output
  196. def stop_remote_worker_execution_loop(self) -> None:
  197. if self.parallel_worker_tasks is None:
  198. return
  199. """
  200. Passing None will cause the driver to stop the model execution
  201. loop running in each of the remote workers.
  202. """
  203. self.driver_method_invoker(self.driver_worker, "execute_model", None)
  204. parallel_worker_tasks = self.parallel_worker_tasks
  205. self.parallel_worker_tasks = None
  206. # Ensure that workers exit model loop cleanly
  207. # (this will raise otherwise)
  208. self._wait_for_tasks_completion(parallel_worker_tasks)
  209. def add_lora(self, lora_request: LoRARequest) -> bool:
  210. return all(self._run_workers("add_lora", lora_request))
  211. def remove_lora(self, lora_id: int) -> bool:
  212. return all(self._run_workers("remove_lora", lora_id))
  213. def pin_lora(self, lora_id: int) -> bool:
  214. assert lora_id > 0, "lora_id must be greater than 0."
  215. return all(self._run_workers(
  216. "pin_lora",
  217. lora_id=lora_id,
  218. ))
  219. def list_loras(self) -> Set[int]:
  220. return self.driver_method_invoker(self.driver_worker, "list_loras")
  221. def add_prompt_adapter(
  222. self, prompt_adapter_request: PromptAdapterRequest) -> bool:
  223. return all(
  224. self._run_workers(
  225. "add_prompt_adapter",
  226. prompt_adapter_request,
  227. ))
  228. def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  229. return all(
  230. self._run_workers(
  231. "remove_prompt_adapter",
  232. prompt_adapter_id,
  233. ))
  234. def list_prompt_adapters(self) -> Set[int]:
  235. return self.driver_method_invoker(self.driver_worker,
  236. "list_prompt_adapters")
  237. def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  238. return all(self._run_workers(
  239. "pin_prompt_adapter",
  240. prompt_adapter_id,
  241. ))
  242. def check_health(self) -> None:
  243. """Raises an error if engine is unhealthy."""
  244. if self.worker_monitor is not None and not self.worker_monitor.is_alive(
  245. ):
  246. raise RuntimeError("Worker processes are not running")
  247. def shutdown(self):
  248. if (worker_monitor := getattr(self, "worker_monitor",
  249. None)) is not None:
  250. worker_monitor.close()
  251. def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
  252. """Wait for futures returned from _run_workers() with
  253. async_run_remote_workers_only to complete."""
  254. for result in parallel_worker_tasks:
  255. result.get()
  256. class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase):
  257. async def execute_model_async(
  258. self,
  259. execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
  260. output = await make_async(self.execute_model
  261. )(execute_model_req=execute_model_req, )
  262. return output
  263. async def check_health_async(self) -> None:
  264. self.check_health()
  265. def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
  266. if config.dtype == torch.float16:
  267. logger.warning("float16 is not supported on CPU, casting to bfloat16.")
  268. config.dtype = torch.bfloat16
  269. if not config.enforce_eager:
  270. logger.warning(
  271. "CUDA graph is not supported on CPU, fallback to the eager "
  272. "mode.")
  273. config.enforce_eager = True
  274. return config
  275. def _verify_and_get_scheduler_config(
  276. config: SchedulerConfig) -> SchedulerConfig:
  277. if config.chunked_prefill_enabled:
  278. logger.warning("Chunked prefill is not supported on CPU, disable it.")
  279. config.chunked_prefill_enabled = False
  280. return config
  281. def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
  282. if config.enable_prefix_caching:
  283. logger.warning("Prefix caching is not supported on CPU, disable it.")
  284. config.enable_prefix_caching = False
  285. kv_cache_space_str = envs.APHRODITE_CPU_KVCACHE_SPACE
  286. kv_cache_space = int(kv_cache_space_str)
  287. if kv_cache_space >= 0:
  288. if kv_cache_space == 0:
  289. config.cpu_kvcache_space_bytes = 4 * GiB_bytes # type: ignore
  290. logger.warning(
  291. "Environment variable APHRODITE_CPU_KVCACHE_SPACE (GB) "
  292. "for CPU backend is not set, using 4 by default.")
  293. else:
  294. config.cpu_kvcache_space_bytes = kv_cache_space * GiB_bytes # type: ignore
  295. else:
  296. raise RuntimeError(
  297. "Invalid environment variable APHRODITE_CPU_KVCACHE_SPACE"
  298. f" {kv_cache_space}, expect a positive integer value.")
  299. return config
  300. def _verify_and_get_parallel_config(config: ParallelConfig) -> ParallelConfig:
  301. if (config.distributed_executor_backend is not None
  302. and config.distributed_executor_backend != "mp"):
  303. logger.warning(
  304. f"{config.distributed_executor_backend} is not supported on CPU, "
  305. "fallback to mp distributed executor backend.")
  306. config.distributed_executor_backend = "mp"
  307. return config
  308. def _driver_method_invoker(driver, method: str, *args, **kwargs):
  309. return getattr(driver, method)(*args, **kwargs)
  310. def _async_driver_method_invoker(driver, method: str, *args, **kwargs):
  311. return driver.execute_method(method, *args, **kwargs).get()