1
0

gpu_executor.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. from typing import Any, Dict, List, Optional, Set, Tuple
  2. from loguru import logger
  3. from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
  4. from aphrodite.common.utils import (get_distributed_init_method, get_ip,
  5. get_open_port, make_async)
  6. from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
  7. from aphrodite.lora.request import LoRARequest
  8. from aphrodite.task_handler.worker_base import WorkerWrapperBase
  9. class GPUExecutor(ExecutorBase):
  10. def _init_executor(self) -> None:
  11. """Initialize the worker and load the model.
  12. If speculative decoding is enabled, we instead create the speculative
  13. worker.
  14. """
  15. if self.speculative_config is None:
  16. self._init_non_spec_worker()
  17. else:
  18. self._init_spec_worker()
  19. def _get_worker_kwargs(
  20. self,
  21. local_rank: int = 0,
  22. rank: int = 0,
  23. distributed_init_method: Optional[str] = None) -> Dict[str, Any]:
  24. """Return worker init args for a given rank."""
  25. if distributed_init_method is None:
  26. distributed_init_method = get_distributed_init_method(
  27. get_ip(), get_open_port())
  28. return dict(
  29. model_config=self.model_config,
  30. parallel_config=self.parallel_config,
  31. scheduler_config=self.scheduler_config,
  32. device_config=self.device_config,
  33. cache_config=self.cache_config,
  34. load_config=self.load_config,
  35. local_rank=local_rank,
  36. rank=rank,
  37. distributed_init_method=distributed_init_method,
  38. lora_config=self.lora_config,
  39. vision_language_config=self.vision_language_config,
  40. is_driver_worker=rank == 0,
  41. )
  42. def _create_worker(self,
  43. local_rank: int = 0,
  44. rank: int = 0,
  45. distributed_init_method: Optional[str] = None):
  46. wrapper = WorkerWrapperBase(
  47. worker_module_name="aphrodite.task_handler.worker",
  48. worker_class_name="Worker",
  49. )
  50. wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
  51. distributed_init_method))
  52. return wrapper.worker
  53. def _init_non_spec_worker(self):
  54. assert self.parallel_config.world_size == 1, (
  55. "GPUExecutor only supports single GPU.")
  56. self.driver_worker = self._create_worker()
  57. self.driver_worker.init_device()
  58. self.driver_worker.load_model()
  59. def _init_spec_worker(self):
  60. """Initialize a SpecDecodeWorker, using a draft model for proposals.
  61. """
  62. assert self.speculative_config is not None
  63. from aphrodite.spec_decode.spec_decode_worker import SpecDecodeWorker
  64. target_worker = self._create_worker()
  65. draft_worker_kwargs = self._get_worker_kwargs()
  66. # Override draft-model specific worker args.
  67. draft_worker_kwargs.update(
  68. model_config=self.speculative_config.draft_model_config,
  69. parallel_config=self.speculative_config.draft_parallel_config,
  70. ngram_prompt_lookup_max=self.speculative_config.
  71. ngram_prompt_lookup_max,
  72. ngram_prompt_lookup_min=self.speculative_config.
  73. ngram_prompt_lookup_min,
  74. # TODO allow draft-model specific load config.
  75. #load_config=self.load_config,
  76. )
  77. spec_decode_worker = SpecDecodeWorker.create_worker(
  78. scorer_worker=target_worker,
  79. draft_worker_kwargs=draft_worker_kwargs,
  80. )
  81. assert self.parallel_config.world_size == 1, (
  82. "GPUExecutor only supports single GPU.")
  83. self.driver_worker = spec_decode_worker
  84. # Load model handled in spec decode worker.
  85. self.driver_worker.init_device()
  86. def determine_num_available_blocks(self) -> Tuple[int, int]:
  87. """Determine the number of available KV blocks by invoking the
  88. underlying worker.
  89. """
  90. return self.driver_worker.determine_num_available_blocks()
  91. def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
  92. """Initialize the KV cache by invoking the underlying worker.
  93. """
  94. # NOTE: This is logged in the executor because there can be >1 worker
  95. # with other executors. We could log in the engine level, but work
  96. # remains to abstract away the device for non-GPU configurations.
  97. logger.info(f"# GPU blocks: {num_gpu_blocks}, "
  98. f"# CPU blocks: {num_cpu_blocks}")
  99. logger.info(
  100. f"Minimum concurrency: {num_gpu_blocks * self.cache_config.block_size / self.scheduler_config.max_model_len:.2f}x" # noqa: E501
  101. )
  102. self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
  103. def execute_model(
  104. self,
  105. execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
  106. output = self.driver_worker.execute_model(execute_model_req)
  107. return output
  108. def add_lora(self, lora_request: LoRARequest) -> bool:
  109. assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
  110. return self.driver_worker.add_lora(lora_request)
  111. def remove_lora(self, lora_id: int) -> bool:
  112. assert lora_id > 0, "lora_id must be greater than 0."
  113. return self.driver_worker.remove_lora(lora_id)
  114. def list_loras(self) -> Set[int]:
  115. return self.driver_worker.list_loras()
  116. def check_health(self) -> None:
  117. # GPUExecutor will always be healthy as long as
  118. # it's running.
  119. return
  120. class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
  121. async def execute_model_async(
  122. self,
  123. execute_model_req: ExecuteModelRequest,
  124. ) -> List[SamplerOutput]:
  125. output = await make_async(self.driver_worker.execute_model
  126. )(execute_model_req=execute_model_req, )
  127. return output