executor_base.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. from abc import ABC, abstractmethod
  2. from typing import List, Optional, Set, Tuple
  3. from aphrodite.common.config import (CacheConfig, ControlVectorConfig,
  4. DeviceConfig, LoadConfig, LoRAConfig,
  5. ModelConfig, MultiModalConfig,
  6. ParallelConfig, PromptAdapterConfig,
  7. SchedulerConfig, SpeculativeConfig)
  8. from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
  9. from aphrodite.control_vectors.request import ControlVectorRequest
  10. from aphrodite.lora.request import LoRARequest
  11. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  12. class ExecutorBase(ABC):
  13. """Base class for all executors.
  14. An executor is responsible for executing the model on a specific device
  15. type (e.g., CPU, GPU, Neuron, etc.). Or it can be a distributed executor
  16. that can execute the model on multiple devices.
  17. """
  18. uses_ray: bool # whether the executor uses Ray for orchestration.
  19. def __init__(
  20. self,
  21. model_config: ModelConfig,
  22. cache_config: CacheConfig,
  23. parallel_config: ParallelConfig,
  24. scheduler_config: SchedulerConfig,
  25. device_config: DeviceConfig,
  26. load_config: LoadConfig,
  27. lora_config: Optional[LoRAConfig],
  28. multimodal_config: Optional[MultiModalConfig],
  29. speculative_config: Optional[SpeculativeConfig],
  30. prompt_adapter_config: Optional[PromptAdapterConfig],
  31. control_vector_config: Optional[ControlVectorConfig],
  32. ) -> None:
  33. self.model_config = model_config
  34. self.cache_config = cache_config
  35. self.lora_config = lora_config
  36. self.load_config = load_config
  37. self.parallel_config = parallel_config
  38. self.scheduler_config = scheduler_config
  39. self.device_config = device_config
  40. self.multimodal_config = multimodal_config
  41. self.speculative_config = speculative_config
  42. self.prompt_adapter_config = prompt_adapter_config
  43. self.control_vector_config = control_vector_config
  44. self._init_executor()
  45. @abstractmethod
  46. def _init_executor(self) -> None:
  47. pass
  48. @abstractmethod
  49. def determine_num_available_blocks(self) -> Tuple[int, int]:
  50. """Determine the number of available blocks for the GPU KV cache and
  51. swappable CPU KV cache.
  52. Normally, this should simply delegate to the underlying Worker. Some
  53. ExecutorBase may require modification of the result, e.g. to ensure the
  54. selected cache sizes are compatible with all workers.
  55. Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
  56. are blocks that are "active" on the device and can be appended to.
  57. num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
  58. appended to.
  59. """
  60. raise NotImplementedError
  61. @abstractmethod
  62. def initialize_cache(self, num_gpu_blocks: int,
  63. num_cpu_blocks: int) -> None:
  64. """Initialize the KV cache with the given size in blocks.
  65. """
  66. raise NotImplementedError
  67. @abstractmethod
  68. def execute_model(
  69. self, execute_model_req: ExecuteModelRequest
  70. ) -> Optional[List[SamplerOutput]]:
  71. """Executes at least one model step on the given sequences."""
  72. raise NotImplementedError
  73. def stop_remote_worker_execution_loop(self) -> None:
  74. """Releases parallel workers from model loop."""
  75. return
  76. @abstractmethod
  77. def add_lora(self, lora_request: LoRARequest) -> bool:
  78. raise NotImplementedError
  79. @abstractmethod
  80. def remove_lora(self, lora_id: int) -> bool:
  81. raise NotImplementedError
  82. @abstractmethod
  83. def list_loras(self) -> Set[int]:
  84. raise NotImplementedError
  85. @abstractmethod
  86. def pin_lora(self, lora_id: int) -> bool:
  87. raise NotImplementedError
  88. @abstractmethod
  89. def add_prompt_adapter(
  90. self, prompt_adapter_request: PromptAdapterRequest) -> bool:
  91. raise NotImplementedError
  92. @abstractmethod
  93. def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  94. raise NotImplementedError
  95. @abstractmethod
  96. def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  97. raise NotImplementedError # type: ignore
  98. @abstractmethod
  99. def list_prompt_adapters(self) -> Set[int]:
  100. raise NotImplementedError
  101. @abstractmethod
  102. def add_control_vector(
  103. self, control_vector_request: ControlVectorRequest) -> bool:
  104. raise NotImplementedError
  105. @abstractmethod
  106. def remove_control_vector(self, cv_id: int) -> bool:
  107. raise NotImplementedError
  108. @abstractmethod
  109. def check_health(self) -> None:
  110. """Checks if the executor is healthy. If not, it should raise an
  111. exception."""
  112. raise NotImplementedError
  113. def shutdown(self) -> None:
  114. """Shutdown the executor."""
  115. return
  116. def __del__(self):
  117. self.shutdown()
  118. class ExecutorAsyncBase(ExecutorBase):
  119. @abstractmethod
  120. async def execute_model_async(
  121. self,
  122. execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
  123. """Executes one model step on the given sequences."""
  124. raise NotImplementedError
  125. async def stop_remote_worker_execution_loop_async(self) -> None:
  126. """Releases parallel workers from model loop."""
  127. return
  128. async def check_health_async(self) -> None:
  129. """Checks if the executor is healthy. If not, it should raise an
  130. exception."""
  131. self.check_health()