worker_base.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. import dataclasses
  2. import importlib
  3. import os
  4. from abc import ABC, abstractmethod
  5. from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
  6. import torch
  7. from loguru import logger
  8. from aphrodite.common.sequence import (ExecuteModelRequest,
  9. IntermediateTensors, SamplerOutput)
  10. from aphrodite.common.utils import (enable_trace_function_call_for_thread,
  11. update_environment_variables)
  12. from aphrodite.distributed import broadcast_tensor_dict, get_pp_group
  13. from aphrodite.lora.request import LoRARequest
  14. from aphrodite.task_handler.model_runner_base import (ModelRunnerBase,
  15. ModelRunnerInputBase)
  16. class WorkerBase(ABC):
  17. """Worker interface that allows Aphrodite to cleanly separate
  18. implementations for different hardware. Also abstracts control plane
  19. communication, e.g., to communicate request metadata to other workers.
  20. """
  21. @abstractmethod
  22. def init_device(self) -> None:
  23. """Initialize device state, such as loading the model or other on-device
  24. memory allocations.
  25. """
  26. raise NotImplementedError
  27. @abstractmethod
  28. def determine_num_available_blocks(self) -> Tuple[int, int]:
  29. """Determine the number of available blocks for the GPU KV cache and
  30. swappable CPU KV cache.
  31. The implementation may run profiling or other heuristics to determine
  32. the size of caches.
  33. Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
  34. are blocks that are "active" on the device and can be appended to.
  35. num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
  36. appended to.
  37. """
  38. raise NotImplementedError
  39. @abstractmethod
  40. def initialize_cache(self, num_gpu_blocks: int,
  41. num_cpu_blocks: int) -> None:
  42. """Initialize the KV cache with the given size in blocks.
  43. """
  44. raise NotImplementedError
  45. @torch.inference_mode()
  46. def start_worker_execution_loop(self) -> None:
  47. """Execute model loop in parallel worker.
  48. You can stop the loop by executing a driver worker with an empty output.
  49. See `stop_remote_worker_execution_loop` for more details.
  50. """
  51. while True:
  52. output = self.execute_model(execute_model_req=None)
  53. if output is None:
  54. return None
  55. @abstractmethod
  56. def execute_model(
  57. self,
  58. execute_model_req: Optional[ExecuteModelRequest] = None
  59. ) -> Optional[List[SamplerOutput]]:
  60. raise NotImplementedError
  61. @abstractmethod
  62. def get_cache_block_size_bytes(self) -> int:
  63. """Return the size of a single cache block, in bytes. Used in
  64. speculative decoding.
  65. """
  66. raise NotImplementedError
  67. @abstractmethod
  68. def add_lora(self, lora_request: LoRARequest) -> bool:
  69. raise NotImplementedError
  70. @abstractmethod
  71. def remove_lora(self, lora_id: int) -> bool:
  72. raise NotImplementedError
  73. @abstractmethod
  74. def pin_lora(self, lora_id: int) -> bool:
  75. raise NotImplementedError
  76. @abstractmethod
  77. def list_loras(self) -> Set[int]:
  78. raise NotImplementedError
  79. class LoraNotSupportedWorkerBase(WorkerBase):
  80. """Partial implementation of WorkerBase that raises exceptions when LoRA
  81. methods are invoked.
  82. """
  83. def add_lora(self, lora_request: LoRARequest) -> bool:
  84. raise ValueError(f"{type(self)} does not support LoRA")
  85. def remove_lora(self, lora_id: int) -> bool:
  86. raise ValueError(f"{type(self)} does not support LoRA")
  87. def pin_lora(self, lora_id: int) -> bool:
  88. return ValueError(
  89. f"{type(self)} does not support LoRA") # type: ignore
  90. def list_loras(self) -> Set[int]:
  91. raise ValueError(f"{type(self)} does not support LoRA")
  92. @dataclasses.dataclass(frozen=True)
  93. class WorkerInput:
  94. """Local inputs to each worker. May contain device-specific data. These
  95. fields should be broadcastable to other workers.
  96. """
  97. num_seq_groups: Optional[int] = None
  98. blocks_to_swap_in: Optional[torch.Tensor] = None
  99. blocks_to_swap_out: Optional[torch.Tensor] = None
  100. blocks_to_copy: Optional[torch.Tensor] = None
  101. virtual_engine: int = 0
  102. @classmethod
  103. def from_broadcasted_tensor_dict(
  104. cls: Type["WorkerInput"],
  105. tensor_dict: Dict[str, Any],
  106. ) -> "WorkerInput":
  107. """
  108. Pop fields from the given tensor_dict and populate a new instance of
  109. WorkerInput.
  110. """
  111. return cls(
  112. num_seq_groups=tensor_dict.pop("num_seq_groups"),
  113. blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"),
  114. blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"),
  115. blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
  116. virtual_engine=tensor_dict["virtual_engine"],
  117. )
  118. def as_broadcastable_tensor_dict(
  119. self) -> Dict[str, Union[int, torch.Tensor]]:
  120. """
  121. Extract broadcastable fields.
  122. """
  123. tensor_dict = {
  124. "num_seq_groups": self.num_seq_groups,
  125. "blocks_to_swap_in": self.blocks_to_swap_in,
  126. "blocks_to_swap_out": self.blocks_to_swap_out,
  127. "blocks_to_copy": self.blocks_to_copy,
  128. "virtual_engine": self.virtual_engine,
  129. }
  130. return tensor_dict
  131. class LocalOrDistributedWorkerBase(WorkerBase):
  132. """
  133. Partial implementation of WorkerBase that has a default `execute_model`
  134. definition to perform metadata transfer between workers when in distributed
  135. mode. Subclasses of this interface should use model runners that inherit
  136. from ModelRunnerBase, and should only need to implement worker-local logic.
  137. If custom control plane logic is needed to transfer metadata, or if the
  138. model runner cannot inherit from ModelRunnerBase, use WorkerBase instead.
  139. """
  140. is_driver_worker: bool
  141. model_runner: ModelRunnerBase
  142. @property
  143. @abstractmethod
  144. def do_metadata_broadcast(self) -> bool:
  145. """
  146. Used by the default `execute_model` to check whether broadcast is
  147. needed to transfer request inputs from the driver worker to other
  148. workers in the TP group. If WorkerBase subclass only supports
  149. single-worker execution, then this method should return False.
  150. """
  151. raise NotImplementedError
  152. @property
  153. @abstractmethod
  154. def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
  155. """
  156. Gets the list of kv caches to pass to the worker's model runner. Each
  157. element in the list is a kv cache corresponding to a particular virtual
  158. engine (PP stream). Used by the default `execute_model`. If the worker's
  159. model runner does not follow the ModelRunnerBase interface, then inherit
  160. from WorkerBase instead.
  161. """
  162. raise NotImplementedError
  163. @abstractmethod
  164. def prepare_worker_input(
  165. self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
  166. """
  167. Prepare the inputs to WorkerBase.execute_worker from an execution
  168. request. This method may move data to the worker's local device. It is
  169. not allowed to communicate with other workers or devices.
  170. """
  171. raise NotImplementedError
  172. @abstractmethod
  173. def execute_worker(self, worker_input: WorkerInput) -> None:
  174. """
  175. Process an execution request.
  176. """
  177. raise NotImplementedError
  178. def execute_model(
  179. self,
  180. execute_model_req: Optional[ExecuteModelRequest] = None
  181. ) -> Optional[List[SamplerOutput]]:
  182. """Executes at least one model step on the given sequences, unless no
  183. sequences are provided."""
  184. if self.is_driver_worker:
  185. if execute_model_req is None:
  186. if self.do_metadata_broadcast:
  187. # This signals that there's no more requests to process for
  188. # now. All workers are running infinite loop with
  189. # broadcast_tensor_dict, and it stops the loop when the
  190. # driver broadcasts an empty input. Send an empty input to
  191. # notify all other workers to stop their execution loop.
  192. broadcast_tensor_dict({}, src=0)
  193. return None
  194. worker_input: WorkerInput = self.prepare_worker_input(
  195. execute_model_req=execute_model_req)
  196. model_input: ModelRunnerInputBase = (
  197. self.model_runner.prepare_model_input(
  198. execute_model_req.seq_group_metadata_list,
  199. execute_model_req.virtual_engine,
  200. execute_model_req.finished_requests_ids))
  201. num_steps = execute_model_req.num_steps
  202. if self.do_metadata_broadcast:
  203. broadcast_data = worker_input.as_broadcastable_tensor_dict()
  204. broadcast_data.update(
  205. model_input.as_broadcastable_tensor_dict())
  206. broadcast_data["num_steps"] = num_steps
  207. broadcast_tensor_dict(broadcast_data, src=0)
  208. else:
  209. assert self.do_metadata_broadcast
  210. broadcast_data = broadcast_tensor_dict(src=0)
  211. if not broadcast_data:
  212. return None
  213. num_steps = broadcast_data.pop("num_steps")
  214. worker_input = WorkerInput.from_broadcasted_tensor_dict(
  215. broadcast_data)
  216. model_input = (
  217. self.model_runner.
  218. make_model_input_from_broadcasted_tensor_dict(broadcast_data))
  219. self.execute_worker(worker_input)
  220. # If there is no input, we don't need to execute the model.
  221. if worker_input.num_seq_groups == 0:
  222. return []
  223. intermediate_tensors = None
  224. if not get_pp_group().is_first_rank:
  225. intermediate_tensors = IntermediateTensors(
  226. get_pp_group().recv_tensor_dict())
  227. output = self.model_runner.execute_model(
  228. model_input, self.kv_cache[worker_input.virtual_engine]
  229. if self.kv_cache is not None else None, intermediate_tensors,
  230. num_steps)
  231. if not get_pp_group().is_last_rank:
  232. get_pp_group().send_tensor_dict(output.tensors)
  233. return [None]
  234. # Worker only supports single-step execution. Wrap the output in a
  235. # list to conform to interface.
  236. return output
  237. class WorkerWrapperBase:
  238. """
  239. The whole point of this class is to lazily initialize the worker.
  240. We first instantiate the WorkerWrapper, which remembers the worker module
  241. and class name. Then, when we call `update_environment_variables`, and the
  242. real initialization happens in `init_worker`.
  243. """
  244. def __init__(self,
  245. worker_module_name: str,
  246. worker_class_name: str,
  247. trust_remote_code: bool = False) -> None:
  248. self.worker_module_name = worker_module_name
  249. self.worker_class_name = worker_class_name
  250. self.worker = None
  251. if trust_remote_code:
  252. # note: lazy import to avoid importing torch before initializing
  253. from aphrodite.common.utils import init_cached_hf_modules
  254. init_cached_hf_modules()
  255. @staticmethod
  256. def update_environment_variables(envs: Dict[str, str]) -> None:
  257. key = 'CUDA_VISIBLE_DEVICES'
  258. if key in envs and key in os.environ:
  259. # overwriting CUDA_VISIBLE_DEVICES is desired behavior
  260. # suppress the warning in `update_environment_variables`
  261. del os.environ[key]
  262. update_environment_variables(envs)
  263. def init_worker(self, *args, **kwargs):
  264. """
  265. Here we inject some common logic before initializing the worker.
  266. Arguments are passed to the worker class constructor.
  267. """
  268. enable_trace_function_call_for_thread()
  269. # see https://github.com/NVIDIA/nccl/issues/1234
  270. os.environ['NCCL_CUMEM_ENABLE'] = '0'
  271. mod = importlib.import_module(self.worker_module_name)
  272. worker_class = getattr(mod, self.worker_class_name)
  273. self.worker = worker_class(*args, **kwargs)
  274. def execute_method(self, method, *args, **kwargs):
  275. try:
  276. target = self if self.worker is None else self.worker
  277. executor = getattr(target, method)
  278. return executor(*args, **kwargs)
  279. except Exception as e:
  280. # if the driver worker also execute methods,
  281. # exceptions in the rest worker may cause deadlock in rpc like ray
  282. # print the error and inform the user to solve the error
  283. msg = (f"Error executing method {method}. "
  284. "This might cause deadlock in distributed execution.")
  285. logger.exception(msg)
  286. raise e