worker_base.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. import dataclasses
  2. import importlib
  3. import os
  4. from abc import ABC, abstractmethod
  5. from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
  6. import torch
  7. from loguru import logger
  8. from aphrodite.common.sequence import (ExecuteModelRequest,
  9. IntermediateTensors, SamplerOutput)
  10. from aphrodite.common.utils import (enable_trace_function_call_for_thread,
  11. update_environment_variables)
  12. from aphrodite.distributed import broadcast_tensor_dict, get_pp_group
  13. from aphrodite.lora.request import LoRARequest
  14. from aphrodite.platforms import current_platform
  15. from aphrodite.task_handler.model_runner_base import (ModelRunnerBase,
  16. ModelRunnerInputBase)
  17. class WorkerBase(ABC):
  18. """Worker interface that allows Aphrodite to cleanly separate
  19. implementations for different hardware. Also abstracts control plane
  20. communication, e.g., to communicate request metadata to other workers.
  21. """
  22. @abstractmethod
  23. def init_device(self) -> None:
  24. """Initialize device state, such as loading the model or other on-device
  25. memory allocations.
  26. """
  27. raise NotImplementedError
  28. @abstractmethod
  29. def determine_num_available_blocks(self) -> Tuple[int, int]:
  30. """Determine the number of available blocks for the GPU KV cache and
  31. swappable CPU KV cache.
  32. The implementation may run profiling or other heuristics to determine
  33. the size of caches.
  34. Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
  35. are blocks that are "active" on the device and can be appended to.
  36. num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
  37. appended to.
  38. """
  39. raise NotImplementedError
  40. @abstractmethod
  41. def initialize_cache(self, num_gpu_blocks: int,
  42. num_cpu_blocks: int) -> None:
  43. """Initialize the KV cache with the given size in blocks.
  44. """
  45. raise NotImplementedError
  46. @current_platform.inference_mode()
  47. def start_worker_execution_loop(self) -> None:
  48. """Execute model loop in parallel worker.
  49. You can stop the loop by executing a driver worker with an empty output.
  50. See `stop_remote_worker_execution_loop` for more details.
  51. """
  52. while True:
  53. output = self.execute_model(execute_model_req=None)
  54. if output is None:
  55. return None
  56. @abstractmethod
  57. def execute_model(
  58. self,
  59. execute_model_req: Optional[ExecuteModelRequest] = None
  60. ) -> Optional[List[SamplerOutput]]:
  61. raise NotImplementedError
  62. @abstractmethod
  63. def get_cache_block_size_bytes(self) -> int:
  64. """Return the size of a single cache block, in bytes. Used in
  65. speculative decoding.
  66. """
  67. raise NotImplementedError
  68. @abstractmethod
  69. def add_lora(self, lora_request: LoRARequest) -> bool:
  70. raise NotImplementedError
  71. @abstractmethod
  72. def remove_lora(self, lora_id: int) -> bool:
  73. raise NotImplementedError
  74. @abstractmethod
  75. def pin_lora(self, lora_id: int) -> bool:
  76. raise NotImplementedError
  77. @abstractmethod
  78. def list_loras(self) -> Set[int]:
  79. raise NotImplementedError
  80. class LoraNotSupportedWorkerBase(WorkerBase):
  81. """Partial implementation of WorkerBase that raises exceptions when LoRA
  82. methods are invoked.
  83. """
  84. def add_lora(self, lora_request: LoRARequest) -> bool:
  85. raise ValueError(f"{type(self)} does not support LoRA")
  86. def remove_lora(self, lora_id: int) -> bool:
  87. raise ValueError(f"{type(self)} does not support LoRA")
  88. def pin_lora(self, lora_id: int) -> bool:
  89. return ValueError(
  90. f"{type(self)} does not support LoRA") # type: ignore
  91. def list_loras(self) -> Set[int]:
  92. raise ValueError(f"{type(self)} does not support LoRA")
  93. @dataclasses.dataclass(frozen=True)
  94. class WorkerInput:
  95. """Local inputs to each worker. May contain device-specific data. These
  96. fields should be broadcastable to other workers.
  97. """
  98. num_seq_groups: Optional[int] = None
  99. blocks_to_swap_in: Optional[torch.Tensor] = None
  100. blocks_to_swap_out: Optional[torch.Tensor] = None
  101. blocks_to_copy: Optional[torch.Tensor] = None
  102. virtual_engine: int = 0
  103. @classmethod
  104. def from_broadcasted_tensor_dict(
  105. cls: Type["WorkerInput"],
  106. tensor_dict: Dict[str, Any],
  107. ) -> "WorkerInput":
  108. """
  109. Pop fields from the given tensor_dict and populate a new instance of
  110. WorkerInput.
  111. """
  112. return cls(
  113. num_seq_groups=tensor_dict.pop("num_seq_groups"),
  114. blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"),
  115. blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"),
  116. blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
  117. virtual_engine=tensor_dict["virtual_engine"],
  118. )
  119. def as_broadcastable_tensor_dict(
  120. self) -> Dict[str, Union[int, torch.Tensor]]:
  121. """
  122. Extract broadcastable fields.
  123. """
  124. tensor_dict = {
  125. "num_seq_groups": self.num_seq_groups,
  126. "blocks_to_swap_in": self.blocks_to_swap_in,
  127. "blocks_to_swap_out": self.blocks_to_swap_out,
  128. "blocks_to_copy": self.blocks_to_copy,
  129. "virtual_engine": self.virtual_engine,
  130. }
  131. return tensor_dict
  132. class LocalOrDistributedWorkerBase(WorkerBase):
  133. """
  134. Partial implementation of WorkerBase that has a default `execute_model`
  135. definition to perform metadata transfer between workers when in distributed
  136. mode. Subclasses of this interface should use model runners that inherit
  137. from ModelRunnerBase, and should only need to implement worker-local logic.
  138. If custom control plane logic is needed to transfer metadata, or if the
  139. model runner cannot inherit from ModelRunnerBase, use WorkerBase instead.
  140. """
  141. is_driver_worker: bool
  142. model_runner: ModelRunnerBase
  143. @property
  144. @abstractmethod
  145. def do_metadata_broadcast(self) -> bool:
  146. """
  147. Used by the default `execute_model` to check whether broadcast is
  148. needed to transfer request inputs from the driver worker to other
  149. workers in the TP group. If WorkerBase subclass only supports
  150. single-worker execution, then this method should return False.
  151. """
  152. raise NotImplementedError
  153. @property
  154. @abstractmethod
  155. def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
  156. """
  157. Gets the list of kv caches to pass to the worker's model runner. Each
  158. element in the list is a kv cache corresponding to a particular virtual
  159. engine (PP stream). Used by the default `execute_model`. If the worker's
  160. model runner does not follow the ModelRunnerBase interface, then inherit
  161. from WorkerBase instead.
  162. """
  163. raise NotImplementedError
  164. @abstractmethod
  165. def prepare_worker_input(
  166. self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
  167. """
  168. Prepare the inputs to WorkerBase.execute_worker from an execution
  169. request. This method may move data to the worker's local device. It is
  170. not allowed to communicate with other workers or devices.
  171. """
  172. raise NotImplementedError
  173. @abstractmethod
  174. def execute_worker(self, worker_input: WorkerInput) -> None:
  175. """
  176. Process an execution request.
  177. """
  178. raise NotImplementedError
  179. def execute_model(
  180. self,
  181. execute_model_req: Optional[ExecuteModelRequest] = None
  182. ) -> Optional[List[SamplerOutput]]:
  183. """Executes at least one model step on the given sequences, unless no
  184. sequences are provided."""
  185. if self.is_driver_worker:
  186. if execute_model_req is None:
  187. if self.do_metadata_broadcast:
  188. # This signals that there's no more requests to process for
  189. # now. All workers are running infinite loop with
  190. # broadcast_tensor_dict, and it stops the loop when the
  191. # driver broadcasts an empty input. Send an empty input to
  192. # notify all other workers to stop their execution loop.
  193. broadcast_tensor_dict({}, src=0)
  194. return None
  195. worker_input: WorkerInput = self.prepare_worker_input(
  196. execute_model_req=execute_model_req)
  197. model_input: ModelRunnerInputBase = (
  198. self.model_runner.prepare_model_input(
  199. execute_model_req.seq_group_metadata_list,
  200. execute_model_req.virtual_engine,
  201. execute_model_req.finished_requests_ids))
  202. num_steps = execute_model_req.num_steps
  203. if self.do_metadata_broadcast:
  204. broadcast_data = worker_input.as_broadcastable_tensor_dict()
  205. broadcast_data.update(
  206. model_input.as_broadcastable_tensor_dict())
  207. broadcast_data["num_steps"] = num_steps
  208. broadcast_tensor_dict(broadcast_data, src=0)
  209. else:
  210. assert self.do_metadata_broadcast
  211. broadcast_data = broadcast_tensor_dict(src=0)
  212. if not broadcast_data:
  213. return None
  214. num_steps = broadcast_data.pop("num_steps")
  215. worker_input = WorkerInput.from_broadcasted_tensor_dict(
  216. broadcast_data)
  217. model_input = (
  218. self.model_runner.
  219. make_model_input_from_broadcasted_tensor_dict(broadcast_data))
  220. self.execute_worker(worker_input)
  221. # If there is no input, we don't need to execute the model.
  222. if worker_input.num_seq_groups == 0:
  223. return []
  224. intermediate_tensors = None
  225. if not get_pp_group().is_first_rank:
  226. intermediate_tensors = IntermediateTensors(
  227. get_pp_group().recv_tensor_dict())
  228. output = self.model_runner.execute_model(
  229. model_input, self.kv_cache[worker_input.virtual_engine]
  230. if self.kv_cache is not None else None, intermediate_tensors,
  231. num_steps)
  232. if not get_pp_group().is_last_rank:
  233. # output is IntermediateTensors
  234. get_pp_group().send_tensor_dict(output.tensors)
  235. return [None]
  236. # output is List[SamplerOutput]
  237. return output
  238. def _execute_model_spmd(
  239. self, execute_model_req: ExecuteModelRequest
  240. ) -> Optional[List[SamplerOutput]]:
  241. """
  242. Execute model in Single Program Multiple Data (SPMD) fashion.
  243. All workers take the same request, prepare the input and
  244. execute the model.
  245. """
  246. assert execute_model_req is not None, (
  247. "_execute_model_spmd() requires each worker to take in an "
  248. "ExecuteModelRequest")
  249. worker_input: WorkerInput = self.prepare_worker_input(
  250. execute_model_req=execute_model_req)
  251. model_input: ModelRunnerInputBase = (
  252. self.model_runner.prepare_model_input(
  253. execute_model_req.seq_group_metadata_list))
  254. self.execute_worker(worker_input)
  255. # If there is no input, we don't need to execute the model.
  256. if worker_input.num_seq_groups == 0:
  257. return []
  258. return self.model_runner.execute_model(
  259. model_input, self.kv_cache[worker_input.virtual_engine]
  260. if self.kv_cache is not None else None)
  261. class WorkerWrapperBase:
  262. """
  263. The whole point of this class is to lazily initialize the worker.
  264. We first instantiate the WorkerWrapper, which remembers the worker module
  265. and class name. Then, when we call `update_environment_variables`, and the
  266. real initialization happens in `init_worker`.
  267. If worker_class_fn is specified, it will be executed to get the worker
  268. class.
  269. Otherwise, the worker class will be obtained by dynamically importing it
  270. using worker_module_name and worker_class_name.
  271. """
  272. def __init__(
  273. self,
  274. worker_module_name: str,
  275. worker_class_name: str,
  276. trust_remote_code: bool = False,
  277. worker_class_fn: Optional[Callable[[],
  278. Type[WorkerBase]]] = None) -> None:
  279. self.worker_module_name = worker_module_name
  280. self.worker_class_name = worker_class_name
  281. self.worker_class_fn = worker_class_fn
  282. self.worker: Optional[WorkerBase] = None
  283. if trust_remote_code:
  284. # note: lazy import to avoid importing torch before initializing
  285. from aphrodite.common.utils import init_cached_hf_modules
  286. init_cached_hf_modules()
  287. @staticmethod
  288. def update_environment_variables(envs: Dict[str, str]) -> None:
  289. key = 'CUDA_VISIBLE_DEVICES'
  290. if key in envs and key in os.environ:
  291. # overwriting CUDA_VISIBLE_DEVICES is desired behavior
  292. # suppress the warning in `update_environment_variables`
  293. del os.environ[key]
  294. update_environment_variables(envs)
  295. def init_worker(self, *args, **kwargs):
  296. """
  297. Here we inject some common logic before initializing the worker.
  298. Arguments are passed to the worker class constructor.
  299. """
  300. enable_trace_function_call_for_thread()
  301. # see https://github.com/NVIDIA/nccl/issues/1234
  302. os.environ['NCCL_CUMEM_ENABLE'] = '0'
  303. if self.worker_class_fn:
  304. worker_class = self.worker_class_fn()
  305. else:
  306. mod = importlib.import_module(self.worker_module_name)
  307. worker_class = getattr(mod, self.worker_class_name)
  308. self.worker = worker_class(*args, **kwargs)
  309. assert self.worker is not None
  310. def execute_method(self, method, *args, **kwargs):
  311. try:
  312. target = self if self.worker is None else self.worker
  313. executor = getattr(target, method)
  314. return executor(*args, **kwargs)
  315. except Exception as e:
  316. # if the driver worker also execute methods,
  317. # exceptions in the rest worker may cause deadlock in rpc like ray
  318. # print the error and inform the user to solve the error
  319. msg = (f"Error executing method {method}. "
  320. "This might cause deadlock in distributed execution.")
  321. logger.exception(msg)
  322. raise e