worker_base.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. import datetime
  2. import importlib
  3. import os
  4. import tempfile
  5. import threading
  6. from abc import ABC, abstractmethod
  7. from typing import Dict, List, Set, Tuple
  8. from loguru import logger
  9. from aphrodite.common.logger import enable_trace_function_call
  10. from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
  11. from aphrodite.common.utils import (get_aphrodite_instance_id,
  12. update_environment_variables)
  13. from aphrodite.lora.request import LoRARequest
  14. class WorkerBase(ABC):
  15. """Worker interface that allows Aphrodite to cleanly separate
  16. implementations for different hardware.
  17. """
  18. @abstractmethod
  19. def init_device(self) -> None:
  20. """Initialize device state, such as loading the model or other on-device
  21. memory allocations.
  22. """
  23. raise NotImplementedError
  24. @abstractmethod
  25. def determine_num_available_blocks(self) -> Tuple[int, int]:
  26. """Determine the number of available blocks for the GPU KV cache and
  27. swappable CPU KV cache.
  28. The implementation may run profiling or other heuristics to determine
  29. the size of caches.
  30. Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
  31. are blocks that are "active" on the device and can be appended to.
  32. num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
  33. appended to.
  34. """
  35. raise NotImplementedError
  36. @abstractmethod
  37. def initialize_cache(self, num_gpu_blocks: int,
  38. num_cpu_blocks: int) -> None:
  39. """Initialize the KV cache with the given size in blocks.
  40. """
  41. raise NotImplementedError
  42. @abstractmethod
  43. def execute_model(
  44. self, seq_group_metadata_list: List[SequenceGroupMetadata],
  45. blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int,
  46. int],
  47. blocks_to_copy: Dict[int, List[int]]) -> List[SamplerOutput]:
  48. """Executes at least one model step on the given sequences, unless no
  49. sequences are provided."""
  50. raise NotImplementedError
  51. @abstractmethod
  52. def get_cache_block_size_bytes(self) -> int:
  53. """Return the size of a single cache block, in bytes. Used in
  54. speculative decoding.
  55. """
  56. raise NotImplementedError
  57. @abstractmethod
  58. def add_lora(self, lora_request: LoRARequest) -> bool:
  59. raise NotImplementedError
  60. @abstractmethod
  61. def remove_lora(self, lora_id: int) -> bool:
  62. raise NotImplementedError
  63. @abstractmethod
  64. def list_loras(self) -> Set[int]:
  65. raise NotImplementedError
  66. class LoraNotSupportedWorkerBase(WorkerBase):
  67. """Partial implementation of WorkerBase that raises exceptions when LoRA
  68. methods are invoked.
  69. """
  70. def add_lora(self, lora_request: LoRARequest) -> bool:
  71. raise ValueError(f"{type(self)} does not support LoRA")
  72. def remove_lora(self, lora_id: int) -> bool:
  73. raise ValueError(f"{type(self)} does not support LoRA")
  74. def list_loras(self) -> Set[int]:
  75. raise ValueError(f"{type(self)} does not support LoRA")
  76. class WorkerWrapperBase:
  77. """
  78. The whole point of this class is to lazily initialize the worker.
  79. We first instantiate the WorkerWrapper, which remembers the worker module
  80. and class name. Then, when we call `update_environment_variables`, and the
  81. real initialization happens in `init_worker`.
  82. """
  83. def __init__(self,
  84. worker_module_name=None,
  85. worker_class_name=None,
  86. trust_remote_code: bool = False) -> None:
  87. self.worker_module_name = worker_module_name
  88. self.worker_class_name = worker_class_name
  89. self.worker = None
  90. if trust_remote_code:
  91. # note: lazy import to avoid importing torch before initializing
  92. from aphrodite.common.utils import init_cached_hf_modules
  93. init_cached_hf_modules()
  94. @staticmethod
  95. def update_environment_variables(envs: Dict[str, str]) -> None:
  96. key = 'CUDA_VISIBLE_DEVICES'
  97. if key in envs and key in os.environ:
  98. # overwriting CUDA_VISIBLE_DEVICES is desired behavior
  99. # suppress the warning in `update_environment_variables`
  100. del os.environ[key]
  101. update_environment_variables(envs)
  102. def init_worker(self, *args, **kwargs):
  103. """
  104. Actual initialization of the worker class, and set up
  105. function tracing if required.
  106. Arguments are passed to the worker class constructor.
  107. """
  108. if int(os.getenv("APHRODITE_TRACE_FUNCTION", "0")):
  109. tmp_dir = tempfile.gettempdir()
  110. filename = (f"APHRODITE_TRACE_FUNCTION_for_process_{os.getpid()}"
  111. f"_thread_{threading.get_ident()}_"
  112. f"at_{datetime.datetime.now()}.log").replace(" ", "_")
  113. log_path = os.path.join(tmp_dir, "aphrodite",
  114. get_aphrodite_instance_id(), filename)
  115. os.makedirs(os.path.dirname(log_path), exist_ok=True)
  116. enable_trace_function_call(log_path)
  117. mod = importlib.import_module(self.worker_module_name)
  118. worker_class = getattr(mod, self.worker_class_name)
  119. self.worker = worker_class(*args, **kwargs)
  120. def execute_method(self, method, *args, **kwargs):
  121. try:
  122. target = self if self.worker is None else self.worker
  123. executor = getattr(target, method)
  124. return executor(*args, **kwargs)
  125. except Exception as e:
  126. # if the driver worker also execute methods,
  127. # exceptions in the rest worker may cause deadlock in rpc like ray
  128. # print the error and inform the user to solve the error
  129. msg = (f"Error executing method {method}. "
  130. "This might cause deadlock in distributed execution.")
  131. logger.exception(msg)
  132. raise e