executor_base.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. from abc import ABC, abstractmethod
  2. from typing import List, Optional, Set, Tuple
  3. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  4. LoRAConfig, ModelConfig, ParallelConfig,
  5. SchedulerConfig, SpeculativeConfig,
  6. VisionLanguageConfig)
  7. from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
  8. from aphrodite.lora.request import LoRARequest
  9. class ExecutorBase(ABC):
  10. """Base class for all executors.
  11. An executor is responsible for executing the model on a specific device
  12. type (e.g., CPU, GPU, Neuron, etc.). Or it can be a distributed executor
  13. that can execute the model on multiple devices.
  14. """
  15. def __init__(
  16. self,
  17. model_config: ModelConfig,
  18. cache_config: CacheConfig,
  19. parallel_config: ParallelConfig,
  20. scheduler_config: SchedulerConfig,
  21. device_config: DeviceConfig,
  22. load_config: LoadConfig,
  23. lora_config: Optional[LoRAConfig],
  24. vision_language_config: Optional[VisionLanguageConfig],
  25. speculative_config: Optional[SpeculativeConfig],
  26. ) -> None:
  27. self.model_config = model_config
  28. self.cache_config = cache_config
  29. self.lora_config = lora_config
  30. self.load_config = load_config
  31. self.parallel_config = parallel_config
  32. self.scheduler_config = scheduler_config
  33. self.device_config = device_config
  34. self.vision_language_config = vision_language_config
  35. self.speculative_config = speculative_config
  36. self._init_executor()
  37. @abstractmethod
  38. def _init_executor(self) -> None:
  39. pass
  40. @abstractmethod
  41. def determine_num_available_blocks(self) -> Tuple[int, int]:
  42. """Determine the number of available blocks for the GPU KV cache and
  43. swappable CPU KV cache.
  44. Normally, this should simply delegate to the underlying Worker. Some
  45. ExecutorBase may require modification of the result, e.g. to ensure the
  46. selected cache sizes are compatible with all workers.
  47. Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
  48. are blocks that are "active" on the device and can be appended to.
  49. num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
  50. appended to.
  51. """
  52. raise NotImplementedError
  53. @abstractmethod
  54. def initialize_cache(self, num_gpu_blocks: int,
  55. num_cpu_blocks: int) -> None:
  56. """Initialize the KV cache with the given size in blocks.
  57. """
  58. raise NotImplementedError
  59. @abstractmethod
  60. def execute_model(
  61. self,
  62. execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
  63. """Executes at least one model step on the given sequences."""
  64. raise NotImplementedError
  65. def stop_remote_worker_execution_loop(self) -> None:
  66. """Releases parallel workers from model loop."""
  67. return
  68. @abstractmethod
  69. def add_lora(self, lora_request: LoRARequest) -> bool:
  70. raise NotImplementedError
  71. @abstractmethod
  72. def remove_lora(self, lora_id: int) -> bool:
  73. raise NotImplementedError
  74. @abstractmethod
  75. def list_loras(self) -> Set[int]:
  76. raise NotImplementedError
  77. @abstractmethod
  78. def check_health(self) -> None:
  79. """Checks if the executor is healthy. If not, it should raise an
  80. exception."""
  81. raise NotImplementedError
  82. def shutdown(self) -> None:
  83. """Shutdown the executor."""
  84. return
  85. def __del__(self):
  86. self.shutdown()
  87. class ExecutorAsyncBase(ExecutorBase):
  88. @abstractmethod
  89. async def execute_model_async(
  90. self,
  91. execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
  92. """Executes one model step on the given sequences."""
  93. raise NotImplementedError
  94. async def stop_remote_worker_execution_loop_async(self) -> None:
  95. """Releases parallel workers from model loop."""
  96. return
  97. async def check_health_async(self) -> None:
  98. """Checks if the executor is healthy. If not, it should raise an
  99. exception."""
  100. self.check_health()