executor_base.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. from abc import ABC, abstractmethod
  2. from typing import Dict, List, Optional
  3. from aphrodite.common.config import (
  4. CacheConfig,
  5. DeviceConfig,
  6. ModelConfig,
  7. ParallelConfig,
  8. SchedulerConfig,
  9. LoRAConfig,
  10. )
  11. from aphrodite.lora.request import LoRARequest
  12. from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
  13. class ExecutorBase(ABC):
  14. """Base class for all executors.
  15. An executor is responsible for executing the model on a specific device
  16. type (e.g., CPU, GPU, Neuron, etc.). Or it can be a distributed executor
  17. that can execute the model on multiple devices.
  18. """
  19. @abstractmethod
  20. def __init__(
  21. self,
  22. model_config: ModelConfig,
  23. cache_config: CacheConfig,
  24. parallel_config: ParallelConfig,
  25. scheduler_config: SchedulerConfig,
  26. device_config: DeviceConfig,
  27. lora_config: Optional[LoRAConfig],
  28. ) -> None:
  29. raise NotImplementedError
  30. @abstractmethod
  31. def execute_model(
  32. self,
  33. seq_group_metadata_list: List[SequenceGroupMetadata],
  34. blocks_to_swap_in: Dict[int, int],
  35. blocks_to_swap_out: Dict[int, int],
  36. blocks_to_copy: Dict[int, List[int]],
  37. ) -> SamplerOutput:
  38. """Executes one model step on the given sequences."""
  39. raise NotImplementedError
  40. @abstractmethod
  41. def add_lora(self, lora_request: LoRARequest) -> bool:
  42. raise NotImplementedError
  43. @abstractmethod
  44. def remove_lora(self, lora_id: int) -> bool:
  45. raise NotImplementedError
  46. @abstractmethod
  47. def list_loras(self) -> List[int]:
  48. raise NotImplementedError
  49. @abstractmethod
  50. def check_health(self) -> None:
  51. """Checks if the executor is healthy. If not, it should raise an
  52. exception."""
  53. raise NotImplementedError
  54. class ExecutorAsyncBase(ExecutorBase):
  55. @abstractmethod
  56. async def execute_model_async(
  57. self,
  58. seq_group_metadata_list: List[SequenceGroupMetadata],
  59. blocks_to_swap_in: Dict[int, int],
  60. blocks_to_swap_out: Dict[int, int],
  61. blocks_to_copy: Dict[int, List[int]],
  62. ) -> SamplerOutput:
  63. """Executes one model step on the given sequences."""
  64. raise NotImplementedError
  65. @abstractmethod
  66. async def check_health_async(self) -> None:
  67. """Checks if the executor is healthy. If not, it should raise an
  68. exception."""
  69. raise NotImplementedError