executor_base.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. import asyncio
  2. from abc import ABC, abstractmethod
  3. from typing import List, Optional, Set, Tuple
  4. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  5. LoRAConfig, ModelConfig, MultiModalConfig,
  6. ParallelConfig, PromptAdapterConfig,
  7. SchedulerConfig, SpeculativeConfig)
  8. from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
  9. from aphrodite.lora.request import LoRARequest
  10. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  11. class ExecutorBase(ABC):
  12. """Base class for all executors.
  13. An executor is responsible for executing the model on a specific device
  14. type (e.g., CPU, GPU, Neuron, etc.). Or it can be a distributed executor
  15. that can execute the model on multiple devices.
  16. """
  17. def __init__(
  18. self,
  19. model_config: ModelConfig,
  20. cache_config: CacheConfig,
  21. parallel_config: ParallelConfig,
  22. scheduler_config: SchedulerConfig,
  23. device_config: DeviceConfig,
  24. load_config: LoadConfig,
  25. lora_config: Optional[LoRAConfig],
  26. multimodal_config: Optional[MultiModalConfig],
  27. speculative_config: Optional[SpeculativeConfig],
  28. prompt_adapter_config: Optional[PromptAdapterConfig],
  29. ) -> None:
  30. self.model_config = model_config
  31. self.cache_config = cache_config
  32. self.lora_config = lora_config
  33. self.load_config = load_config
  34. self.parallel_config = parallel_config
  35. self.scheduler_config = scheduler_config
  36. self.device_config = device_config
  37. self.multimodal_config = multimodal_config
  38. self.speculative_config = speculative_config
  39. self.prompt_adapter_config = prompt_adapter_config
  40. self._init_executor()
  41. @abstractmethod
  42. def _init_executor(self) -> None:
  43. pass
  44. @abstractmethod
  45. def determine_num_available_blocks(self) -> Tuple[int, int]:
  46. """Determine the number of available blocks for the GPU KV cache and
  47. swappable CPU KV cache.
  48. Normally, this should simply delegate to the underlying Worker. Some
  49. ExecutorBase may require modification of the result, e.g. to ensure the
  50. selected cache sizes are compatible with all workers.
  51. Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
  52. are blocks that are "active" on the device and can be appended to.
  53. num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
  54. appended to.
  55. """
  56. raise NotImplementedError
  57. @abstractmethod
  58. def initialize_cache(self, num_gpu_blocks: int,
  59. num_cpu_blocks: int) -> None:
  60. """Initialize the KV cache with the given size in blocks.
  61. """
  62. raise NotImplementedError
  63. @abstractmethod
  64. def execute_model(
  65. self, execute_model_req: ExecuteModelRequest
  66. ) -> Optional[List[SamplerOutput]]:
  67. """Executes at least one model step on the given sequences."""
  68. raise NotImplementedError
  69. def stop_remote_worker_execution_loop(self) -> None:
  70. """Releases parallel workers from model loop."""
  71. return
  72. @abstractmethod
  73. def add_lora(self, lora_request: LoRARequest) -> bool:
  74. raise NotImplementedError
  75. @abstractmethod
  76. def remove_lora(self, lora_id: int) -> bool:
  77. raise NotImplementedError
  78. @abstractmethod
  79. def list_loras(self) -> Set[int]:
  80. raise NotImplementedError
  81. @abstractmethod
  82. def pin_lora(self, lora_id: int) -> bool:
  83. raise NotImplementedError
  84. @abstractmethod
  85. def add_prompt_adapter(
  86. self, prompt_adapter_request: PromptAdapterRequest) -> bool:
  87. raise NotImplementedError
  88. @abstractmethod
  89. def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  90. raise NotImplementedError
  91. @abstractmethod
  92. def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  93. raise NotImplementedError # type: ignore
  94. @abstractmethod
  95. def list_prompt_adapters(self) -> Set[int]:
  96. raise NotImplementedError
  97. @abstractmethod
  98. def check_health(self) -> None:
  99. """Checks if the executor is healthy. If not, it should raise an
  100. exception."""
  101. raise NotImplementedError
  102. def shutdown(self) -> None:
  103. """Shutdown the executor."""
  104. return
  105. def __del__(self):
  106. self.shutdown()
  107. class ExecutorAsyncBase(ExecutorBase):
  108. def __init__(
  109. self,
  110. model_config: ModelConfig,
  111. cache_config: CacheConfig,
  112. parallel_config: ParallelConfig,
  113. scheduler_config: SchedulerConfig,
  114. device_config: DeviceConfig,
  115. load_config: LoadConfig,
  116. lora_config: Optional[LoRAConfig],
  117. multimodal_config: Optional[MultiModalConfig],
  118. speculative_config: Optional[SpeculativeConfig],
  119. prompt_adapter_config: Optional[PromptAdapterConfig],
  120. ) -> None:
  121. self.pp_locks: Optional[List[asyncio.Lock]] = None
  122. super().__init__(model_config, cache_config, parallel_config,
  123. scheduler_config, device_config, load_config,
  124. lora_config, multimodal_config, speculative_config,
  125. prompt_adapter_config)
  126. @abstractmethod
  127. async def execute_model_async(
  128. self,
  129. execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
  130. """Executes one model step on the given sequences."""
  131. raise NotImplementedError
  132. async def stop_remote_worker_execution_loop_async(self) -> None:
  133. """Releases parallel workers from model loop."""
  134. return
  135. async def check_health_async(self) -> None:
  136. """Checks if the executor is healthy. If not, it should raise an
  137. exception."""
  138. self.check_health()