gpu_executor.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
  2. from loguru import logger
  3. from aphrodite.common.sequence import (ExecuteModelRequest, PoolerOutput,
  4. SamplerOutput)
  5. from aphrodite.common.utils import (get_distributed_init_method, get_ip,
  6. get_open_port, make_async)
  7. from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
  8. from aphrodite.lora.request import LoRARequest
  9. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  10. from aphrodite.task_handler.worker_base import WorkerBase, WorkerWrapperBase
  11. def create_worker(worker_module_name: str, worker_class_name: str,
  12. worker_class_fn: Optional[Callable[[], Type[WorkerBase]]],
  13. **kwargs):
  14. wrapper = WorkerWrapperBase(
  15. worker_module_name=worker_module_name,
  16. worker_class_name=worker_class_name,
  17. worker_class_fn=worker_class_fn,
  18. )
  19. wrapper.init_worker(**kwargs)
  20. return wrapper.worker
  21. class GPUExecutor(ExecutorBase):
  22. uses_ray: bool = False
  23. def _init_executor(self) -> None:
  24. """Initialize the worker and load the model.
  25. """
  26. assert self.parallel_config.world_size == 1, (
  27. "GPUExecutor only supports single GPU.")
  28. self.driver_worker = self._create_worker()
  29. self.driver_worker.init_device()
  30. self.driver_worker.load_model()
  31. def _get_worker_kwargs(
  32. self,
  33. local_rank: int = 0,
  34. rank: int = 0,
  35. distributed_init_method: Optional[str] = None) -> Dict[str, Any]:
  36. """Return worker init args for a given rank."""
  37. if distributed_init_method is None:
  38. distributed_init_method = get_distributed_init_method(
  39. get_ip(), get_open_port())
  40. return dict(
  41. model_config=self.model_config,
  42. parallel_config=self.parallel_config,
  43. scheduler_config=self.scheduler_config,
  44. device_config=self.device_config,
  45. cache_config=self.cache_config,
  46. load_config=self.load_config,
  47. local_rank=local_rank,
  48. rank=rank,
  49. distributed_init_method=distributed_init_method,
  50. lora_config=self.lora_config,
  51. speculative_config=self.speculative_config,
  52. prompt_adapter_config=self.prompt_adapter_config,
  53. is_driver_worker=(not self.parallel_config)
  54. or (rank % self.parallel_config.tensor_parallel_size == 0),
  55. )
  56. def _get_worker_module_and_class(
  57. self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]:
  58. worker_class_fn = None
  59. if self.scheduler_config.is_multi_step:
  60. worker_module_name = "aphrodite.task_handler.multi_step_worker"
  61. worker_class_name = "MultiStepWorker"
  62. elif self.speculative_config:
  63. worker_module_name = "aphrodite.spec_decode.spec_decode_worker"
  64. worker_class_name = "create_spec_worker"
  65. else:
  66. worker_module_name = "aphrodite.task_handler.worker"
  67. worker_class_name = "Worker"
  68. return (worker_module_name, worker_class_name, worker_class_fn)
  69. def _get_create_worker_kwargs(
  70. self,
  71. local_rank: int = 0,
  72. rank: int = 0,
  73. distributed_init_method: Optional[str] = None) -> Dict:
  74. worker_kwargs = self._get_worker_kwargs(local_rank, rank,
  75. distributed_init_method)
  76. (worker_module_name, worker_class_name,
  77. worker_class_fn) = self._get_worker_module_and_class()
  78. worker_kwargs.update(
  79. worker_module_name=worker_module_name,
  80. worker_class_name=worker_class_name,
  81. worker_class_fn=worker_class_fn,
  82. )
  83. return worker_kwargs
  84. def _create_worker(self,
  85. local_rank: int = 0,
  86. rank: int = 0,
  87. distributed_init_method: Optional[str] = None):
  88. return create_worker(**self._get_create_worker_kwargs(
  89. local_rank=local_rank,
  90. rank=rank,
  91. distributed_init_method=distributed_init_method))
  92. def determine_num_available_blocks(self) -> Tuple[int, int]:
  93. """Determine the number of available KV blocks by invoking the
  94. underlying worker.
  95. """
  96. return self.driver_worker.determine_num_available_blocks()
  97. def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
  98. """Initialize the KV cache by invoking the underlying worker.
  99. """
  100. # NOTE: This is logged in the executor because there can be >1 worker
  101. # with other executors. We could log in the engine level, but work
  102. # remains to abstract away the device for non-GPU configurations.
  103. logger.info(f"# GPU blocks: {num_gpu_blocks}, "
  104. f"# CPU blocks: {num_cpu_blocks}")
  105. logger.info(
  106. f"Minimum concurrency: {num_gpu_blocks * self.cache_config.block_size / self.scheduler_config.max_model_len:.2f}x" # noqa: E501
  107. )
  108. self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
  109. def execute_model(
  110. self, execute_model_req: ExecuteModelRequest
  111. ) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
  112. output = self.driver_worker.execute_model(execute_model_req)
  113. return output
  114. def add_lora(self, lora_request: LoRARequest) -> bool:
  115. assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
  116. return self.driver_worker.add_lora(lora_request)
  117. def remove_lora(self, lora_id: int) -> bool:
  118. assert lora_id > 0, "lora_id must be greater than 0."
  119. return self.driver_worker.remove_lora(lora_id)
  120. def list_loras(self) -> Set[int]:
  121. return self.driver_worker.list_loras()
  122. def pin_lora(self, lora_id: int) -> bool:
  123. assert lora_id > 0, "lora_id must be greater than 0."
  124. return self.driver_worker.pin_lora(lora_id)
  125. def add_prompt_adapter(
  126. self, prompt_adapter_request: PromptAdapterRequest) -> bool:
  127. assert prompt_adapter_request.prompt_adapter_id > 0, \
  128. "prompt_adapter_id must be greater than 0."
  129. return self.driver_worker.add_prompt_adapter(prompt_adapter_request)
  130. def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  131. assert prompt_adapter_id > 0, \
  132. "prompt_adapter_id must be greater than 0."
  133. return self.driver_worker.remove_prompt_adapter(prompt_adapter_id)
  134. def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  135. assert prompt_adapter_id > 0, \
  136. "prompt_adapter_id must be greater than 0."
  137. return self.driver_worker.pin_prompt_adapter(prompt_adapter_id)
  138. def list_prompt_adapters(self) -> Set[int]:
  139. return self.driver_worker.list_prompt_adapters()
  140. def check_health(self) -> None:
  141. # GPUExecutor will always be healthy as long as
  142. # it's running.
  143. return
  144. class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
  145. async def execute_model_async(
  146. self,
  147. execute_model_req: ExecuteModelRequest,
  148. ) -> List[Union[SamplerOutput, PoolerOutput]]:
  149. output = await make_async(self.driver_worker.execute_model
  150. )(execute_model_req=execute_model_req, )
  151. return output