worker_base.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418
  1. import dataclasses
  2. import importlib
  3. import os
  4. from abc import ABC, abstractmethod
  5. from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
  6. import torch
  7. from loguru import logger
  8. from aphrodite.common.sequence import (ExecuteModelRequest,
  9. IntermediateTensors, SamplerOutput)
  10. from aphrodite.common.utils import (enable_trace_function_call_for_thread,
  11. update_environment_variables)
  12. from aphrodite.distributed import (broadcast_tensor_dict, get_pp_group,
  13. get_tp_group)
  14. from aphrodite.lora.request import LoRARequest
  15. from aphrodite.platforms import current_platform
  16. from aphrodite.task_handler.model_runner_base import (BroadcastableModelInput,
  17. ModelRunnerBase,
  18. ModelRunnerInputBase)
  19. class WorkerBase(ABC):
  20. """Worker interface that allows Aphrodite to cleanly separate
  21. implementations for different hardware. Also abstracts control plane
  22. communication, e.g., to communicate request metadata to other workers.
  23. """
  24. @abstractmethod
  25. def init_device(self) -> None:
  26. """Initialize device state, such as loading the model or other on-device
  27. memory allocations.
  28. """
  29. raise NotImplementedError
  30. @abstractmethod
  31. def determine_num_available_blocks(self) -> Tuple[int, int]:
  32. """Determine the number of available blocks for the GPU KV cache and
  33. swappable CPU KV cache.
  34. The implementation may run profiling or other heuristics to determine
  35. the size of caches.
  36. Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
  37. are blocks that are "active" on the device and can be appended to.
  38. num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
  39. appended to.
  40. """
  41. raise NotImplementedError
  42. @abstractmethod
  43. def initialize_cache(self, num_gpu_blocks: int,
  44. num_cpu_blocks: int) -> None:
  45. """Initialize the KV cache with the given size in blocks.
  46. """
  47. raise NotImplementedError
  48. @current_platform.inference_mode()
  49. def start_worker_execution_loop(self) -> None:
  50. """Execute model loop in parallel worker.
  51. You can stop the loop by executing a driver worker with an empty output.
  52. See `stop_remote_worker_execution_loop` for more details.
  53. """
  54. while True:
  55. output = self.execute_model(execute_model_req=None)
  56. if output is None:
  57. return None
  58. @abstractmethod
  59. def execute_model(
  60. self,
  61. execute_model_req: Optional[ExecuteModelRequest] = None
  62. ) -> Optional[List[SamplerOutput]]:
  63. raise NotImplementedError
  64. @abstractmethod
  65. def get_cache_block_size_bytes(self) -> int:
  66. """Return the size of a single cache block, in bytes. Used in
  67. speculative decoding.
  68. """
  69. raise NotImplementedError
  70. @abstractmethod
  71. def add_lora(self, lora_request: LoRARequest) -> bool:
  72. raise NotImplementedError
  73. @abstractmethod
  74. def remove_lora(self, lora_id: int) -> bool:
  75. raise NotImplementedError
  76. @abstractmethod
  77. def pin_lora(self, lora_id: int) -> bool:
  78. raise NotImplementedError
  79. @abstractmethod
  80. def list_loras(self) -> Set[int]:
  81. raise NotImplementedError
  82. class LoraNotSupportedWorkerBase(WorkerBase):
  83. """Partial implementation of WorkerBase that raises exceptions when LoRA
  84. methods are invoked.
  85. """
  86. def add_lora(self, lora_request: LoRARequest) -> bool:
  87. raise ValueError(f"{type(self)} does not support LoRA")
  88. def remove_lora(self, lora_id: int) -> bool:
  89. raise ValueError(f"{type(self)} does not support LoRA")
  90. def pin_lora(self, lora_id: int) -> bool:
  91. return ValueError(
  92. f"{type(self)} does not support LoRA") # type: ignore
  93. def list_loras(self) -> Set[int]:
  94. raise ValueError(f"{type(self)} does not support LoRA")
  95. @dataclasses.dataclass(frozen=True)
  96. class WorkerInput:
  97. """Local inputs to each worker. May contain device-specific data. These
  98. fields should be broadcastable to other workers.
  99. """
  100. num_seq_groups: Optional[int] = None
  101. blocks_to_swap_in: Optional[torch.Tensor] = None
  102. blocks_to_swap_out: Optional[torch.Tensor] = None
  103. blocks_to_copy: Optional[torch.Tensor] = None
  104. virtual_engine: int = 0
  105. num_steps: int = 1
  106. @classmethod
  107. def from_broadcasted_tensor_dict(
  108. cls: Type["WorkerInput"],
  109. tensor_dict: Dict[str, Any],
  110. ) -> "WorkerInput":
  111. """
  112. Pop fields from the given tensor_dict and populate a new instance of
  113. WorkerInput.
  114. """
  115. return cls(
  116. num_seq_groups=tensor_dict.pop("num_seq_groups"),
  117. blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"),
  118. blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"),
  119. blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
  120. virtual_engine=tensor_dict["virtual_engine"],
  121. num_steps=tensor_dict.pop("num_steps"),
  122. )
  123. def as_broadcastable_tensor_dict(
  124. self) -> Dict[str, Union[int, torch.Tensor]]:
  125. """
  126. Extract broadcastable fields.
  127. """
  128. tensor_dict = {
  129. "num_seq_groups": self.num_seq_groups,
  130. "blocks_to_swap_in": self.blocks_to_swap_in,
  131. "blocks_to_swap_out": self.blocks_to_swap_out,
  132. "blocks_to_copy": self.blocks_to_copy,
  133. "virtual_engine": self.virtual_engine,
  134. "num_steps": self.num_steps,
  135. }
  136. return tensor_dict
  137. class LocalOrDistributedWorkerBase(WorkerBase):
  138. """
  139. Partial implementation of WorkerBase that has a default `execute_model`
  140. definition to perform metadata transfer between workers when in distributed
  141. mode. Subclasses of this interface should use model runners that inherit
  142. from ModelRunnerBase, and should only need to implement worker-local logic.
  143. If custom control plane logic is needed to transfer metadata, or if the
  144. model runner cannot inherit from ModelRunnerBase, use WorkerBase instead.
  145. """
  146. is_driver_worker: bool
  147. model_runner: ModelRunnerBase
  148. @property
  149. @abstractmethod
  150. def do_metadata_broadcast(self) -> bool:
  151. """
  152. Used by the default `execute_model` to check whether broadcast is
  153. needed to transfer request inputs from the driver worker to other
  154. workers in the TP group. If WorkerBase subclass only supports
  155. single-worker execution, then this method should return False.
  156. """
  157. raise NotImplementedError
  158. @property
  159. @abstractmethod
  160. def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
  161. """
  162. Gets the list of kv caches to pass to the worker's model runner. Each
  163. element in the list is a kv cache corresponding to a particular virtual
  164. engine (PP stream). Used by the default `execute_model`. If the worker's
  165. model runner does not follow the ModelRunnerBase interface, then inherit
  166. from WorkerBase instead.
  167. """
  168. raise NotImplementedError
  169. @abstractmethod
  170. def prepare_worker_input(
  171. self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
  172. """
  173. Prepare the inputs to WorkerBase.execute_worker from an execution
  174. request. This method may move data to the worker's local device. It is
  175. not allowed to communicate with other workers or devices.
  176. """
  177. raise NotImplementedError
  178. @abstractmethod
  179. def execute_worker(self, worker_input: WorkerInput) -> None:
  180. """
  181. Process an execution request.
  182. """
  183. raise NotImplementedError
  184. def _get_worker_input_from_broadcast(
  185. self) -> Optional[Tuple[BroadcastableModelInput, WorkerInput]]:
  186. """ Get the worker input from the broadcasted tensor dict. """
  187. assert self.do_metadata_broadcast
  188. assert not self.is_driver_worker
  189. broadcast_data = broadcast_tensor_dict(src=0)
  190. if not broadcast_data:
  191. return None
  192. worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data)
  193. model_input = (
  194. self.model_runner.make_model_input_from_broadcasted_tensor_dict(
  195. broadcast_data))
  196. return model_input, worker_input
  197. def _get_driver_input_and_broadcast(
  198. self, execute_model_req: ExecuteModelRequest
  199. ) -> Tuple[BroadcastableModelInput, WorkerInput]:
  200. """ Get the driver input and broadcast it to other workers. """
  201. assert self.is_driver_worker
  202. worker_input: WorkerInput = self.prepare_worker_input(
  203. execute_model_req=execute_model_req)
  204. model_input: ModelRunnerInputBase = (
  205. self.model_runner.prepare_model_input(
  206. execute_model_req.seq_group_metadata_list,
  207. execute_model_req.virtual_engine,
  208. execute_model_req.finished_requests_ids))
  209. if self.do_metadata_broadcast:
  210. broadcast_data = worker_input.as_broadcastable_tensor_dict()
  211. broadcast_data.update(model_input.as_broadcastable_tensor_dict())
  212. broadcast_tensor_dict(broadcast_data, src=0)
  213. return model_input, worker_input
  214. def prepare_input(
  215. self,
  216. execute_model_req: Optional[ExecuteModelRequest] = None
  217. ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput]]:
  218. """
  219. Prepare the inputs to ModelRunner and workers.
  220. """
  221. if self.is_driver_worker:
  222. if execute_model_req is None:
  223. if self.do_metadata_broadcast:
  224. # This signals that there's no more requests to process for
  225. # now. All workers are running infinite loop with
  226. # broadcast_tensor_dict, and it stops the loop when the
  227. # driver broadcasts an empty input. Send an empty input to
  228. # notify all other workers to stop their execution loop.
  229. broadcast_tensor_dict({}, src=0)
  230. return None
  231. return self._get_driver_input_and_broadcast(execute_model_req)
  232. else:
  233. return self._get_worker_input_from_broadcast()
  234. def execute_model(
  235. self,
  236. execute_model_req: Optional[ExecuteModelRequest] = None
  237. ) -> Optional[List[SamplerOutput]]:
  238. """Executes at least one model step on the given sequences, unless no
  239. sequences are provided."""
  240. inputs = self.prepare_input(execute_model_req)
  241. if inputs is None:
  242. return None
  243. model_input, worker_input = inputs
  244. num_steps = worker_input.num_steps
  245. self.execute_worker(worker_input)
  246. # If there is no input, we don't need to execute the model.
  247. if worker_input.num_seq_groups == 0:
  248. return []
  249. intermediate_tensors = None
  250. if not get_pp_group().is_first_rank:
  251. intermediate_tensors = IntermediateTensors(
  252. get_pp_group().recv_tensor_dict(
  253. all_gather_group=get_tp_group()))
  254. output = self.model_runner.execute_model(
  255. model_input, self.kv_cache[worker_input.virtual_engine]
  256. if self.kv_cache is not None else None, intermediate_tensors,
  257. num_steps)
  258. if not get_pp_group().is_last_rank:
  259. # output is IntermediateTensors
  260. get_pp_group().send_tensor_dict(output.tensors,
  261. all_gather_group=get_tp_group())
  262. return [None]
  263. # output is List[SamplerOutput]
  264. return output
  265. def _execute_model_spmd(
  266. self,
  267. execute_model_req: ExecuteModelRequest,
  268. intermediate_tensors: Optional[IntermediateTensors] = None
  269. ) -> Optional[List[SamplerOutput]]:
  270. """
  271. Execute model in Single Program Multiple Data (SPMD) fashion.
  272. All workers take the same request, prepare the input and
  273. execute the model.
  274. """
  275. assert execute_model_req is not None, (
  276. "_execute_model_spmd() requires each worker to take in an "
  277. "ExecuteModelRequest")
  278. worker_input: WorkerInput = self.prepare_worker_input(
  279. execute_model_req=execute_model_req)
  280. model_input: ModelRunnerInputBase = (
  281. self.model_runner.prepare_model_input(
  282. execute_model_req.seq_group_metadata_list))
  283. self.execute_worker(worker_input)
  284. # If there is no input, we don't need to execute the model.
  285. if worker_input.num_seq_groups == 0:
  286. return []
  287. return self.model_runner.execute_model(
  288. model_input, self.kv_cache[worker_input.virtual_engine]
  289. if self.kv_cache is not None else None, intermediate_tensors)
  290. class WorkerWrapperBase:
  291. """
  292. The whole point of this class is to lazily initialize the worker.
  293. We first instantiate the WorkerWrapper, which remembers the worker module
  294. and class name. Then, when we call `update_environment_variables`, and the
  295. real initialization happens in `init_worker`.
  296. If worker_class_fn is specified, it will be executed to get the worker
  297. class.
  298. Otherwise, the worker class will be obtained by dynamically importing it
  299. using worker_module_name and worker_class_name.
  300. """
  301. def __init__(
  302. self,
  303. worker_module_name: str,
  304. worker_class_name: str,
  305. trust_remote_code: bool = False,
  306. worker_class_fn: Optional[Callable[[],
  307. Type[WorkerBase]]] = None) -> None:
  308. self.worker_module_name = worker_module_name
  309. self.worker_class_name = worker_class_name
  310. self.worker_class_fn = worker_class_fn
  311. self.worker: Optional[WorkerBase] = None
  312. if trust_remote_code:
  313. # note: lazy import to avoid importing torch before initializing
  314. from aphrodite.common.utils import init_cached_hf_modules
  315. init_cached_hf_modules()
  316. @staticmethod
  317. def update_environment_variables(envs: Dict[str, str]) -> None:
  318. key = 'CUDA_VISIBLE_DEVICES'
  319. if key in envs and key in os.environ:
  320. # overwriting CUDA_VISIBLE_DEVICES is desired behavior
  321. # suppress the warning in `update_environment_variables`
  322. del os.environ[key]
  323. update_environment_variables(envs)
  324. def init_worker(self, *args, **kwargs):
  325. """
  326. Here we inject some common logic before initializing the worker.
  327. Arguments are passed to the worker class constructor.
  328. """
  329. enable_trace_function_call_for_thread()
  330. # see https://github.com/NVIDIA/nccl/issues/1234
  331. os.environ['NCCL_CUMEM_ENABLE'] = '0'
  332. from aphrodite.plugins import load_general_plugins
  333. load_general_plugins()
  334. if self.worker_class_fn:
  335. worker_class = self.worker_class_fn()
  336. else:
  337. mod = importlib.import_module(self.worker_module_name)
  338. worker_class = getattr(mod, self.worker_class_name)
  339. self.worker = worker_class(*args, **kwargs)
  340. assert self.worker is not None
  341. def execute_method(self, method, *args, **kwargs):
  342. try:
  343. target = self if self.worker is None else self.worker
  344. executor = getattr(target, method)
  345. return executor(*args, **kwargs)
  346. except Exception as e:
  347. # if the driver worker also execute methods,
  348. # exceptions in the rest worker may cause deadlock in rpc like ray
  349. # print the error and inform the user to solve the error
  350. msg = (f"Error executing method {method}. "
  351. "This might cause deadlock in distributed execution.")
  352. logger.exception(msg)
  353. raise e