gpu_executor.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
  2. from loguru import logger
  3. from aphrodite.common.sequence import ExecuteModelRequest, PoolerOutput
  4. from aphrodite.common.utils import (get_distributed_init_method, get_ip,
  5. get_open_port, make_async)
  6. from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
  7. from aphrodite.lora.request import LoRARequest
  8. from aphrodite.modeling.layers.sampler import SamplerOutput
  9. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  10. from aphrodite.worker.worker_base import WorkerBase, WorkerWrapperBase
  11. def create_worker(worker_module_name: str, worker_class_name: str,
  12. worker_class_fn: Optional[Callable[[], Type[WorkerBase]]],
  13. **kwargs):
  14. wrapper = WorkerWrapperBase(
  15. worker_module_name=worker_module_name,
  16. worker_class_name=worker_class_name,
  17. worker_class_fn=worker_class_fn,
  18. )
  19. wrapper.init_worker(**kwargs)
  20. return wrapper.worker
  21. class GPUExecutor(ExecutorBase):
  22. uses_ray: bool = False
  23. def _init_executor(self) -> None:
  24. """Initialize the worker and load the model.
  25. """
  26. assert self.parallel_config.world_size == 1, (
  27. "GPUExecutor only supports single GPU.")
  28. self.driver_worker = self._create_worker()
  29. self.driver_worker.init_device()
  30. self.driver_worker.load_model()
  31. def _get_worker_kwargs(
  32. self,
  33. local_rank: int = 0,
  34. rank: int = 0,
  35. distributed_init_method: Optional[str] = None) -> Dict[str, Any]:
  36. """Return worker init args for a given rank."""
  37. if distributed_init_method is None:
  38. distributed_init_method = get_distributed_init_method(
  39. get_ip(), get_open_port())
  40. return dict(
  41. model_config=self.model_config,
  42. parallel_config=self.parallel_config,
  43. scheduler_config=self.scheduler_config,
  44. device_config=self.device_config,
  45. cache_config=self.cache_config,
  46. load_config=self.load_config,
  47. local_rank=local_rank,
  48. rank=rank,
  49. distributed_init_method=distributed_init_method,
  50. lora_config=self.lora_config,
  51. speculative_config=self.speculative_config,
  52. prompt_adapter_config=self.prompt_adapter_config,
  53. is_driver_worker=(not self.parallel_config)
  54. or (rank % self.parallel_config.tensor_parallel_size == 0),
  55. )
  56. def _get_worker_module_and_class(
  57. self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]:
  58. worker_class_fn = None
  59. if self.scheduler_config.is_multi_step:
  60. worker_module_name = "aphrodite.worker.multi_step_worker"
  61. worker_class_name = "MultiStepWorker"
  62. elif self.speculative_config:
  63. worker_module_name = "aphrodite.spec_decode.spec_decode_worker"
  64. worker_class_name = "create_spec_worker"
  65. else:
  66. worker_module_name = "aphrodite.worker.worker"
  67. worker_class_name = "Worker"
  68. return (worker_module_name, worker_class_name, worker_class_fn)
  69. def _get_create_worker_kwargs(
  70. self,
  71. local_rank: int = 0,
  72. rank: int = 0,
  73. distributed_init_method: Optional[str] = None) -> Dict:
  74. worker_kwargs = self._get_worker_kwargs(local_rank, rank,
  75. distributed_init_method)
  76. (worker_module_name, worker_class_name,
  77. worker_class_fn) = self._get_worker_module_and_class()
  78. worker_kwargs.update(
  79. worker_module_name=worker_module_name,
  80. worker_class_name=worker_class_name,
  81. worker_class_fn=worker_class_fn,
  82. )
  83. return worker_kwargs
  84. def _create_worker(self,
  85. local_rank: int = 0,
  86. rank: int = 0,
  87. distributed_init_method: Optional[str] = None):
  88. return create_worker(**self._get_create_worker_kwargs(
  89. local_rank=local_rank,
  90. rank=rank,
  91. distributed_init_method=distributed_init_method))
  92. def determine_num_available_blocks(self) -> Tuple[int, int]:
  93. """Determine the number of available KV blocks by invoking the
  94. underlying worker.
  95. """
  96. return self.driver_worker.determine_num_available_blocks()
  97. def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
  98. """Initialize the KV cache by invoking the underlying worker.
  99. """
  100. # NOTE: This is logged in the executor because there can be >1 worker
  101. # with other executors. We could log in the engine level, but work
  102. # remains to abstract away the device for non-GPU configurations.
  103. logger.info(f"# GPU blocks: {num_gpu_blocks}, "
  104. f"# CPU blocks: {num_cpu_blocks}")
  105. logger.info(
  106. f"Minimum concurrency: {num_gpu_blocks * self.cache_config.block_size / self.scheduler_config.max_model_len:.2f}x" # noqa: E501
  107. )
  108. self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
  109. def execute_model(
  110. self, execute_model_req: ExecuteModelRequest
  111. ) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
  112. output = self.driver_worker.execute_model(execute_model_req)
  113. return output
  114. def add_lora(self, lora_request: LoRARequest) -> bool:
  115. assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
  116. return self.driver_worker.add_lora(lora_request)
  117. def remove_lora(self, lora_id: int) -> bool:
  118. assert lora_id > 0, "lora_id must be greater than 0."
  119. return self.driver_worker.remove_lora(lora_id)
  120. def list_loras(self) -> Set[int]:
  121. return self.driver_worker.list_loras()
  122. def pin_lora(self, lora_id: int) -> bool:
  123. assert lora_id > 0, "lora_id must be greater than 0."
  124. return self.driver_worker.pin_lora(lora_id)
  125. def add_prompt_adapter(
  126. self, prompt_adapter_request: PromptAdapterRequest) -> bool:
  127. assert prompt_adapter_request.prompt_adapter_id > 0, \
  128. "prompt_adapter_id must be greater than 0."
  129. return self.driver_worker.add_prompt_adapter(prompt_adapter_request)
  130. def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  131. assert prompt_adapter_id > 0, \
  132. "prompt_adapter_id must be greater than 0."
  133. return self.driver_worker.remove_prompt_adapter(prompt_adapter_id)
  134. def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  135. assert prompt_adapter_id > 0, \
  136. "prompt_adapter_id must be greater than 0."
  137. return self.driver_worker.pin_prompt_adapter(prompt_adapter_id)
  138. def list_prompt_adapters(self) -> Set[int]:
  139. return self.driver_worker.list_prompt_adapters()
  140. def check_health(self) -> None:
  141. # GPUExecutor will always be healthy as long as
  142. # it's running.
  143. return
  144. class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
  145. async def execute_model_async(
  146. self,
  147. execute_model_req: ExecuteModelRequest,
  148. ) -> List[Union[SamplerOutput, PoolerOutput]]:
  149. output = await make_async(self.driver_worker.execute_model
  150. )(execute_model_req=execute_model_req, )
  151. return output