1
0

executor_base.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. from abc import ABC, abstractmethod
  2. from typing import List, Optional, Set, Tuple
  3. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  4. LoRAConfig, ModelConfig, ParallelConfig,
  5. PromptAdapterConfig, SchedulerConfig,
  6. SpeculativeConfig)
  7. from aphrodite.common.sequence import ExecuteModelRequest
  8. from aphrodite.lora.request import LoRARequest
  9. from aphrodite.modeling.layers.sampler import SamplerOutput
  10. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  11. class ExecutorBase(ABC):
  12. """Base class for all executors.
  13. An executor is responsible for executing the model on a specific device
  14. type (e.g., CPU, GPU, Neuron, etc.). Or it can be a distributed executor
  15. that can execute the model on multiple devices.
  16. """
  17. uses_ray: bool # whether the executor uses Ray for orchestration.
  18. def __init__(
  19. self,
  20. model_config: ModelConfig,
  21. cache_config: CacheConfig,
  22. parallel_config: ParallelConfig,
  23. scheduler_config: SchedulerConfig,
  24. device_config: DeviceConfig,
  25. load_config: LoadConfig,
  26. lora_config: Optional[LoRAConfig],
  27. speculative_config: Optional[SpeculativeConfig],
  28. prompt_adapter_config: Optional[PromptAdapterConfig],
  29. ) -> None:
  30. self.model_config = model_config
  31. self.cache_config = cache_config
  32. self.lora_config = lora_config
  33. self.load_config = load_config
  34. self.parallel_config = parallel_config
  35. self.scheduler_config = scheduler_config
  36. self.device_config = device_config
  37. self.speculative_config = speculative_config
  38. self.prompt_adapter_config = prompt_adapter_config
  39. self._init_executor()
  40. @abstractmethod
  41. def _init_executor(self) -> None:
  42. pass
  43. @abstractmethod
  44. def determine_num_available_blocks(self) -> Tuple[int, int]:
  45. """Determine the number of available blocks for the GPU KV cache and
  46. swappable CPU KV cache.
  47. Normally, this should simply delegate to the underlying Worker. Some
  48. ExecutorBase may require modification of the result, e.g. to ensure the
  49. selected cache sizes are compatible with all workers.
  50. Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
  51. are blocks that are "active" on the device and can be appended to.
  52. num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
  53. appended to.
  54. """
  55. raise NotImplementedError
  56. @abstractmethod
  57. def initialize_cache(self, num_gpu_blocks: int,
  58. num_cpu_blocks: int) -> None:
  59. """Initialize the KV cache with the given size in blocks.
  60. """
  61. raise NotImplementedError
  62. @abstractmethod
  63. def execute_model(
  64. self, execute_model_req: ExecuteModelRequest
  65. ) -> Optional[List[SamplerOutput]]:
  66. """Executes at least one model step on the given sequences."""
  67. raise NotImplementedError
  68. def stop_remote_worker_execution_loop(self) -> None:
  69. """Releases parallel workers from model loop."""
  70. return
  71. @abstractmethod
  72. def add_lora(self, lora_request: LoRARequest) -> bool:
  73. raise NotImplementedError
  74. @abstractmethod
  75. def remove_lora(self, lora_id: int) -> bool:
  76. raise NotImplementedError
  77. @abstractmethod
  78. def list_loras(self) -> Set[int]:
  79. raise NotImplementedError
  80. @abstractmethod
  81. def pin_lora(self, lora_id: int) -> bool:
  82. raise NotImplementedError
  83. @abstractmethod
  84. def add_prompt_adapter(
  85. self, prompt_adapter_request: PromptAdapterRequest) -> bool:
  86. raise NotImplementedError
  87. @abstractmethod
  88. def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  89. raise NotImplementedError
  90. @abstractmethod
  91. def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  92. raise NotImplementedError # type: ignore
  93. @abstractmethod
  94. def list_prompt_adapters(self) -> Set[int]:
  95. raise NotImplementedError
  96. @abstractmethod
  97. def check_health(self) -> None:
  98. """Checks if the executor is healthy. If not, it should raise an
  99. exception."""
  100. raise NotImplementedError
  101. def shutdown(self) -> None:
  102. """Shutdown the executor."""
  103. return
  104. def __del__(self):
  105. self.shutdown()
  106. class ExecutorAsyncBase(ExecutorBase):
  107. @abstractmethod
  108. async def execute_model_async(
  109. self,
  110. execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
  111. """Executes one model step on the given sequences."""
  112. raise NotImplementedError
  113. async def stop_remote_worker_execution_loop_async(self) -> None:
  114. """Releases parallel workers from model loop."""
  115. return
  116. async def check_health_async(self) -> None:
  117. """Checks if the executor is healthy. If not, it should raise an
  118. exception."""
  119. self.check_health()