1
0

worker_base.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417
  1. import dataclasses
  2. import importlib
  3. import os
  4. from abc import ABC, abstractmethod
  5. from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
  6. import torch
  7. from loguru import logger
  8. from aphrodite.common.sequence import (ExecuteModelRequest,
  9. IntermediateTensors, SamplerOutput)
  10. from aphrodite.common.utils import (enable_trace_function_call_for_thread,
  11. update_environment_variables)
  12. from aphrodite.distributed import (broadcast_tensor_dict, get_pp_group,
  13. get_tp_group)
  14. from aphrodite.lora.request import LoRARequest
  15. from aphrodite.platforms import current_platform
  16. from aphrodite.task_handler.model_runner_base import (ModelRunnerBase,
  17. ModelRunnerInputBase)
  18. class WorkerBase(ABC):
  19. """Worker interface that allows Aphrodite to cleanly separate
  20. implementations for different hardware. Also abstracts control plane
  21. communication, e.g., to communicate request metadata to other workers.
  22. """
  23. @abstractmethod
  24. def init_device(self) -> None:
  25. """Initialize device state, such as loading the model or other on-device
  26. memory allocations.
  27. """
  28. raise NotImplementedError
  29. @abstractmethod
  30. def determine_num_available_blocks(self) -> Tuple[int, int]:
  31. """Determine the number of available blocks for the GPU KV cache and
  32. swappable CPU KV cache.
  33. The implementation may run profiling or other heuristics to determine
  34. the size of caches.
  35. Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
  36. are blocks that are "active" on the device and can be appended to.
  37. num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
  38. appended to.
  39. """
  40. raise NotImplementedError
  41. @abstractmethod
  42. def initialize_cache(self, num_gpu_blocks: int,
  43. num_cpu_blocks: int) -> None:
  44. """Initialize the KV cache with the given size in blocks.
  45. """
  46. raise NotImplementedError
  47. @current_platform.inference_mode()
  48. def start_worker_execution_loop(self) -> None:
  49. """Execute model loop in parallel worker.
  50. You can stop the loop by executing a driver worker with an empty output.
  51. See `stop_remote_worker_execution_loop` for more details.
  52. """
  53. while True:
  54. output = self.execute_model(execute_model_req=None)
  55. if output is None:
  56. return None
  57. @abstractmethod
  58. def execute_model(
  59. self,
  60. execute_model_req: Optional[ExecuteModelRequest] = None
  61. ) -> Optional[List[SamplerOutput]]:
  62. raise NotImplementedError
  63. @abstractmethod
  64. def get_cache_block_size_bytes(self) -> int:
  65. """Return the size of a single cache block, in bytes. Used in
  66. speculative decoding.
  67. """
  68. raise NotImplementedError
  69. @abstractmethod
  70. def add_lora(self, lora_request: LoRARequest) -> bool:
  71. raise NotImplementedError
  72. @abstractmethod
  73. def remove_lora(self, lora_id: int) -> bool:
  74. raise NotImplementedError
  75. @abstractmethod
  76. def pin_lora(self, lora_id: int) -> bool:
  77. raise NotImplementedError
  78. @abstractmethod
  79. def list_loras(self) -> Set[int]:
  80. raise NotImplementedError
  81. class LoraNotSupportedWorkerBase(WorkerBase):
  82. """Partial implementation of WorkerBase that raises exceptions when LoRA
  83. methods are invoked.
  84. """
  85. def add_lora(self, lora_request: LoRARequest) -> bool:
  86. raise ValueError(f"{type(self)} does not support LoRA")
  87. def remove_lora(self, lora_id: int) -> bool:
  88. raise ValueError(f"{type(self)} does not support LoRA")
  89. def pin_lora(self, lora_id: int) -> bool:
  90. return ValueError(
  91. f"{type(self)} does not support LoRA") # type: ignore
  92. def list_loras(self) -> Set[int]:
  93. raise ValueError(f"{type(self)} does not support LoRA")
  94. @dataclasses.dataclass(frozen=True)
  95. class WorkerInput:
  96. """Local inputs to each worker. May contain device-specific data. These
  97. fields should be broadcastable to other workers.
  98. """
  99. num_seq_groups: Optional[int] = None
  100. blocks_to_swap_in: Optional[torch.Tensor] = None
  101. blocks_to_swap_out: Optional[torch.Tensor] = None
  102. blocks_to_copy: Optional[torch.Tensor] = None
  103. virtual_engine: int = 0
  104. num_steps: int = 1
  105. @classmethod
  106. def from_broadcasted_tensor_dict(
  107. cls: Type["WorkerInput"],
  108. tensor_dict: Dict[str, Any],
  109. ) -> "WorkerInput":
  110. """
  111. Pop fields from the given tensor_dict and populate a new instance of
  112. WorkerInput.
  113. """
  114. return cls(
  115. num_seq_groups=tensor_dict.pop("num_seq_groups"),
  116. blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"),
  117. blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"),
  118. blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
  119. virtual_engine=tensor_dict["virtual_engine"],
  120. num_steps=tensor_dict.pop("num_steps"),
  121. )
  122. def as_broadcastable_tensor_dict(
  123. self) -> Dict[str, Union[int, torch.Tensor]]:
  124. """
  125. Extract broadcastable fields.
  126. """
  127. tensor_dict = {
  128. "num_seq_groups": self.num_seq_groups,
  129. "blocks_to_swap_in": self.blocks_to_swap_in,
  130. "blocks_to_swap_out": self.blocks_to_swap_out,
  131. "blocks_to_copy": self.blocks_to_copy,
  132. "virtual_engine": self.virtual_engine,
  133. "num_steps": self.num_steps,
  134. }
  135. return tensor_dict
  136. class LocalOrDistributedWorkerBase(WorkerBase):
  137. """
  138. Partial implementation of WorkerBase that has a default `execute_model`
  139. definition to perform metadata transfer between workers when in distributed
  140. mode. Subclasses of this interface should use model runners that inherit
  141. from ModelRunnerBase, and should only need to implement worker-local logic.
  142. If custom control plane logic is needed to transfer metadata, or if the
  143. model runner cannot inherit from ModelRunnerBase, use WorkerBase instead.
  144. """
  145. is_driver_worker: bool
  146. model_runner: ModelRunnerBase
  147. @property
  148. @abstractmethod
  149. def do_metadata_broadcast(self) -> bool:
  150. """
  151. Used by the default `execute_model` to check whether broadcast is
  152. needed to transfer request inputs from the driver worker to other
  153. workers in the TP group. If WorkerBase subclass only supports
  154. single-worker execution, then this method should return False.
  155. """
  156. raise NotImplementedError
  157. @property
  158. @abstractmethod
  159. def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
  160. """
  161. Gets the list of kv caches to pass to the worker's model runner. Each
  162. element in the list is a kv cache corresponding to a particular virtual
  163. engine (PP stream). Used by the default `execute_model`. If the worker's
  164. model runner does not follow the ModelRunnerBase interface, then inherit
  165. from WorkerBase instead.
  166. """
  167. raise NotImplementedError
  168. @abstractmethod
  169. def prepare_worker_input(
  170. self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
  171. """
  172. Prepare the inputs to WorkerBase.execute_worker from an execution
  173. request. This method may move data to the worker's local device. It is
  174. not allowed to communicate with other workers or devices.
  175. """
  176. raise NotImplementedError
  177. @abstractmethod
  178. def execute_worker(self, worker_input: WorkerInput) -> None:
  179. """
  180. Process an execution request.
  181. """
  182. raise NotImplementedError
  183. def _get_worker_input_from_broadcast(
  184. self) -> Optional[Tuple[ModelRunnerInputBase, WorkerInput]]:
  185. """ Get the worker input from the broadcasted tensor dict. """
  186. assert self.do_metadata_broadcast
  187. assert not self.is_driver_worker
  188. broadcast_data = broadcast_tensor_dict(src=0)
  189. if not broadcast_data:
  190. return None
  191. worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data)
  192. model_input = (
  193. self.model_runner.make_model_input_from_broadcasted_tensor_dict(
  194. broadcast_data))
  195. return model_input, worker_input
  196. def _get_driver_input_and_broadcast(
  197. self, execute_model_req: ExecuteModelRequest
  198. ) -> Tuple[ModelRunnerInputBase, WorkerInput]:
  199. """ Get the driver input and broadcast it to other workers. """
  200. assert self.is_driver_worker
  201. worker_input: WorkerInput = self.prepare_worker_input(
  202. execute_model_req=execute_model_req)
  203. model_input: ModelRunnerInputBase = (
  204. self.model_runner.prepare_model_input(
  205. execute_model_req.seq_group_metadata_list,
  206. execute_model_req.virtual_engine,
  207. execute_model_req.finished_requests_ids))
  208. if self.do_metadata_broadcast:
  209. broadcast_data = worker_input.as_broadcastable_tensor_dict()
  210. broadcast_data.update(model_input.as_broadcastable_tensor_dict())
  211. broadcast_tensor_dict(broadcast_data, src=0)
  212. return model_input, worker_input
  213. def prepare_input(
  214. self,
  215. execute_model_req: Optional[ExecuteModelRequest] = None
  216. ) -> Optional[Tuple[ModelRunnerInputBase, WorkerInput]]:
  217. """
  218. Prepare the inputs to ModelRunner and workers.
  219. """
  220. if self.is_driver_worker:
  221. if execute_model_req is None:
  222. if self.do_metadata_broadcast:
  223. # This signals that there's no more requests to process for
  224. # now. All workers are running infinite loop with
  225. # broadcast_tensor_dict, and it stops the loop when the
  226. # driver broadcasts an empty input. Send an empty input to
  227. # notify all other workers to stop their execution loop.
  228. broadcast_tensor_dict({}, src=0)
  229. return None
  230. return self._get_driver_input_and_broadcast(execute_model_req)
  231. else:
  232. return self._get_worker_input_from_broadcast()
  233. def execute_model(
  234. self,
  235. execute_model_req: Optional[ExecuteModelRequest] = None
  236. ) -> Optional[List[SamplerOutput]]:
  237. """Executes at least one model step on the given sequences, unless no
  238. sequences are provided."""
  239. inputs = self.prepare_input(execute_model_req)
  240. if inputs is None:
  241. return None
  242. model_input, worker_input = inputs
  243. num_steps = worker_input.num_steps
  244. self.execute_worker(worker_input)
  245. # If there is no input, we don't need to execute the model.
  246. if worker_input.num_seq_groups == 0:
  247. return []
  248. intermediate_tensors = None
  249. if not get_pp_group().is_first_rank:
  250. intermediate_tensors = IntermediateTensors(
  251. get_pp_group().recv_tensor_dict(
  252. all_gather_group=get_tp_group()))
  253. output = self.model_runner.execute_model(
  254. model_input, self.kv_cache[worker_input.virtual_engine]
  255. if self.kv_cache is not None else None, intermediate_tensors,
  256. num_steps)
  257. if not get_pp_group().is_last_rank:
  258. # output is IntermediateTensors
  259. get_pp_group().send_tensor_dict(output.tensors,
  260. all_gather_group=get_tp_group())
  261. return [None]
  262. # output is List[SamplerOutput]
  263. return output
  264. def _execute_model_spmd(
  265. self,
  266. execute_model_req: ExecuteModelRequest,
  267. intermediate_tensors: Optional[IntermediateTensors] = None
  268. ) -> Optional[List[SamplerOutput]]:
  269. """
  270. Execute model in Single Program Multiple Data (SPMD) fashion.
  271. All workers take the same request, prepare the input and
  272. execute the model.
  273. """
  274. assert execute_model_req is not None, (
  275. "_execute_model_spmd() requires each worker to take in an "
  276. "ExecuteModelRequest")
  277. worker_input: WorkerInput = self.prepare_worker_input(
  278. execute_model_req=execute_model_req)
  279. model_input: ModelRunnerInputBase = (
  280. self.model_runner.prepare_model_input(
  281. execute_model_req.seq_group_metadata_list))
  282. self.execute_worker(worker_input)
  283. # If there is no input, we don't need to execute the model.
  284. if worker_input.num_seq_groups == 0:
  285. return []
  286. return self.model_runner.execute_model(
  287. model_input, self.kv_cache[worker_input.virtual_engine]
  288. if self.kv_cache is not None else None, intermediate_tensors)
  289. class WorkerWrapperBase:
  290. """
  291. The whole point of this class is to lazily initialize the worker.
  292. We first instantiate the WorkerWrapper, which remembers the worker module
  293. and class name. Then, when we call `update_environment_variables`, and the
  294. real initialization happens in `init_worker`.
  295. If worker_class_fn is specified, it will be executed to get the worker
  296. class.
  297. Otherwise, the worker class will be obtained by dynamically importing it
  298. using worker_module_name and worker_class_name.
  299. """
  300. def __init__(
  301. self,
  302. worker_module_name: str,
  303. worker_class_name: str,
  304. trust_remote_code: bool = False,
  305. worker_class_fn: Optional[Callable[[],
  306. Type[WorkerBase]]] = None) -> None:
  307. self.worker_module_name = worker_module_name
  308. self.worker_class_name = worker_class_name
  309. self.worker_class_fn = worker_class_fn
  310. self.worker: Optional[WorkerBase] = None
  311. if trust_remote_code:
  312. # note: lazy import to avoid importing torch before initializing
  313. from aphrodite.common.utils import init_cached_hf_modules
  314. init_cached_hf_modules()
  315. @staticmethod
  316. def update_environment_variables(envs: Dict[str, str]) -> None:
  317. key = 'CUDA_VISIBLE_DEVICES'
  318. if key in envs and key in os.environ:
  319. # overwriting CUDA_VISIBLE_DEVICES is desired behavior
  320. # suppress the warning in `update_environment_variables`
  321. del os.environ[key]
  322. update_environment_variables(envs)
  323. def init_worker(self, *args, **kwargs):
  324. """
  325. Here we inject some common logic before initializing the worker.
  326. Arguments are passed to the worker class constructor.
  327. """
  328. enable_trace_function_call_for_thread()
  329. # see https://github.com/NVIDIA/nccl/issues/1234
  330. os.environ['NCCL_CUMEM_ENABLE'] = '0'
  331. from aphrodite.plugins import load_general_plugins
  332. load_general_plugins()
  333. if self.worker_class_fn:
  334. worker_class = self.worker_class_fn()
  335. else:
  336. mod = importlib.import_module(self.worker_module_name)
  337. worker_class = getattr(mod, self.worker_class_name)
  338. self.worker = worker_class(*args, **kwargs)
  339. assert self.worker is not None
  340. def execute_method(self, method, *args, **kwargs):
  341. try:
  342. target = self if self.worker is None else self.worker
  343. executor = getattr(target, method)
  344. return executor(*args, **kwargs)
  345. except Exception as e:
  346. # if the driver worker also execute methods,
  347. # exceptions in the rest worker may cause deadlock in rpc like ray
  348. # print the error and inform the user to solve the error
  349. msg = (f"Error executing method {method}. "
  350. "This might cause deadlock in distributed execution.")
  351. logger.exception(msg)
  352. raise e