worker_base.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. import importlib
  2. import os
  3. from abc import ABC, abstractmethod
  4. from typing import Dict, List
  5. from loguru import logger
  6. from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
  7. from aphrodite.common.utils import update_environment_variables
  8. from aphrodite.lora.request import LoRARequest
  9. class WorkerBase(ABC):
  10. """Worker interface that allows Aphrodite to cleanly separate
  11. implementations for different hardware.
  12. """
  13. @abstractmethod
  14. def init_device(self) -> None:
  15. """Initialize device state, such as loading the model or other on-device
  16. memory allocations.
  17. """
  18. raise NotImplementedError
  19. @abstractmethod
  20. def determine_num_available_blocks(self) -> tuple[int, int]:
  21. """Determine the number of available blocks for the GPU KV cache and
  22. swappable CPU KV cache.
  23. The implementation may run profiling or other heuristics to determine
  24. the size of caches.
  25. Returns a tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
  26. are blocks that are "active" on the device and can be appended to.
  27. num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
  28. appended to.
  29. """
  30. raise NotImplementedError
  31. @abstractmethod
  32. def initialize_cache(self, num_gpu_blocks: int,
  33. num_cpu_blocks: int) -> None:
  34. """Initialize the KV cache with the given size in blocks.
  35. """
  36. raise NotImplementedError
  37. @abstractmethod
  38. def execute_model(
  39. self, seq_group_metadata_list: List[SequenceGroupMetadata],
  40. blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int,
  41. int],
  42. blocks_to_copy: Dict[int, List[int]]) -> List[SamplerOutput]:
  43. """Executes at least one model step on the given sequences, unless no
  44. sequences are provided."""
  45. raise NotImplementedError
  46. @abstractmethod
  47. def get_cache_block_size_bytes() -> int:
  48. """Return the size of a single cache block, in bytes. Used in
  49. speculative decoding.
  50. """
  51. raise NotImplementedError
  52. @abstractmethod
  53. def add_lora(self, lora_request: LoRARequest) -> bool:
  54. raise NotImplementedError
  55. @abstractmethod
  56. def remove_lora(self, lora_id: int) -> bool:
  57. raise NotImplementedError
  58. @abstractmethod
  59. def list_loras(self) -> List[int]:
  60. raise NotImplementedError
  61. class LoraNotSupportedWorkerBase(WorkerBase):
  62. """Partial implementation of WorkerBase that raises exceptions when LoRA
  63. methods are invoked.
  64. """
  65. def add_lora(self, lora_request: LoRARequest) -> bool:
  66. raise ValueError(f"{type(self)} does not support LoRA")
  67. def remove_lora(self, lora_id: int) -> bool:
  68. raise ValueError(f"{type(self)} does not support LoRA")
  69. def list_loras(self) -> List[int]:
  70. raise ValueError(f"{type(self)} does not support LoRA")
  71. class WorkerWrapperBase:
  72. def __init__(self,
  73. worker_module_name=None,
  74. worker_class_name=None) -> None:
  75. self.worker_module_name = worker_module_name
  76. self.worker_class_name = worker_class_name
  77. self.worker = None
  78. def update_environment_variables(self, envs: Dict[str, str]) -> None:
  79. """Update environment variables for the worker."""
  80. key = "CUDA_VISIBLE_DEVICES"
  81. if key in envs and key in os.environ:
  82. del os.environ[key]
  83. update_environment_variables(envs)
  84. def init_worker(self, *args, **kwargs):
  85. mod = importlib.import_module(self.worker_module_name)
  86. worker_class = getattr(mod, self.worker_class_name)
  87. self.worker = worker_class(*args, **kwargs)
  88. def execute_method(self, method, *args, **kwargs):
  89. try:
  90. if hasattr(self, method):
  91. executor = getattr(self, method)
  92. else:
  93. executor = getattr(self.worker, method)
  94. return executor(*args, **kwargs)
  95. except Exception as e:
  96. # exceptions in ray worker may cause deadlock
  97. # print the error and inform the user to solve the error
  98. msg = (f"Error executing method {method}. "
  99. "This might cause deadlock in distributed execution.")
  100. logger.exception(msg)
  101. raise e