1
0

worker.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. """A GPU worker class."""
  2. import gc
  3. import os
  4. from typing import Any, Dict, List, Optional, Set, Tuple, Union
  5. import torch
  6. import torch.distributed
  7. from loguru import logger
  8. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  9. LoRAConfig, ModelConfig, ParallelConfig,
  10. SchedulerConfig, SpeculativeConfig,
  11. VisionLanguageConfig)
  12. from aphrodite.common.sequence import (ExecuteModelRequest, PoolerOutput,
  13. SamplerOutput)
  14. from aphrodite.distributed import (broadcast_tensor_dict,
  15. ensure_model_parallel_initialized,
  16. init_distributed_environment,
  17. set_custom_all_reduce)
  18. from aphrodite.lora.request import LoRARequest
  19. from aphrodite.modeling import set_random_seed
  20. from aphrodite.task_handler.cache_engine import CacheEngine
  21. from aphrodite.task_handler.embedding_model_runner import EmbeddingModelRunner
  22. from aphrodite.task_handler.model_runner import ModelRunner
  23. from aphrodite.task_handler.worker_base import WorkerBase
  24. class Worker(WorkerBase):
  25. """A worker class that executes (a partition of) the model on a GPU.
  26. Each worker is associated with a single GPU. The worker is responsible for
  27. maintaining the KV cache and executing the model on the GPU. In case of
  28. distributed inference, each worker is assigned a partition of the model.
  29. """
  30. def __init__(
  31. self,
  32. model_config: ModelConfig,
  33. parallel_config: ParallelConfig,
  34. scheduler_config: SchedulerConfig,
  35. device_config: DeviceConfig,
  36. cache_config: CacheConfig,
  37. load_config: LoadConfig,
  38. local_rank: int,
  39. rank: int,
  40. distributed_init_method: str,
  41. lora_config: Optional[LoRAConfig] = None,
  42. vision_language_config: Optional[VisionLanguageConfig] = None,
  43. speculative_config: Optional[SpeculativeConfig] = None,
  44. is_driver_worker: bool = False,
  45. ) -> None:
  46. self.model_config = model_config
  47. self.parallel_config = parallel_config
  48. self.scheduler_config = scheduler_config
  49. self.device_config = device_config
  50. self.cache_config = cache_config
  51. self.local_rank = local_rank
  52. self.rank = rank
  53. self.distributed_init_method = distributed_init_method
  54. self.lora_config = lora_config
  55. self.load_config = load_config
  56. self.is_driver_worker = is_driver_worker
  57. if self.is_driver_worker:
  58. assert self.rank == 0, "The driver worker must have rank 0."
  59. if self.model_config.trust_remote_code:
  60. # note: lazy import to avoid importing torch before initializing
  61. from aphrodite.common.utils import init_cached_hf_modules
  62. init_cached_hf_modules()
  63. self.vision_language_config = vision_language_config
  64. if self.vision_language_config:
  65. assert not self.lora_config, (
  66. "To be tested: vision language model with LoRA settings.")
  67. ModelRunnerClass = (EmbeddingModelRunner if
  68. self.model_config.embedding_mode else ModelRunner)
  69. self.model_runner = ModelRunnerClass(
  70. model_config,
  71. parallel_config,
  72. scheduler_config,
  73. device_config,
  74. cache_config,
  75. load_config=load_config,
  76. lora_config=self.lora_config,
  77. kv_cache_dtype=self.cache_config.cache_dtype,
  78. is_driver_worker=is_driver_worker,
  79. vision_language_config=vision_language_config,
  80. )
  81. # Uninitialized cache engine. Will be initialized by
  82. # initialize_cache.
  83. self.cache_engine: CacheEngine
  84. # Initialize gpu_cache as embedding models don't initialize kv_caches
  85. self.gpu_cache: Optional[List[torch.tensor]] = None
  86. def init_device(self) -> None:
  87. if self.device_config.device.type == "cuda":
  88. # torch.distributed.all_reduce does not free the input tensor until
  89. # the synchronization point. This causes the memory usage to grow
  90. # as the number of all_reduce calls increases. This env var disables
  91. # this behavior.
  92. # Related issue:
  93. # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
  94. os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
  95. # This env var set by Ray causes exceptions with graph building.
  96. os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
  97. self.device = torch.device(f"cuda:{self.local_rank}")
  98. torch.cuda.set_device(self.device)
  99. _check_if_gpu_supports_dtype(self.model_config.dtype)
  100. torch.cuda.empty_cache()
  101. self.init_gpu_memory = torch.cuda.mem_get_info()[0]
  102. else:
  103. raise RuntimeError(
  104. f"Not support device type: {self.device_config.device}")
  105. # Initialize the distributed environment.
  106. init_worker_distributed_environment(self.parallel_config, self.rank,
  107. self.distributed_init_method,
  108. self.local_rank)
  109. # Set random seed.
  110. set_random_seed(self.model_config.seed)
  111. def load_model(self):
  112. self.model_runner.load_model()
  113. def save_sharded_state(
  114. self,
  115. path: str,
  116. pattern: Optional[str] = None,
  117. max_size: Optional[int] = None,
  118. ) -> None:
  119. self.model_runner.save_sharded_state(
  120. path,
  121. pattern=pattern,
  122. max_size=max_size,
  123. )
  124. @torch.inference_mode()
  125. def determine_num_available_blocks(self) -> Tuple[int, int]:
  126. """Profiles the peak memory usage of the model to determine how many
  127. KV blocks may be allocated without OOMs.
  128. The engine will first conduct a profiling of the existing memory usage.
  129. Then, it calculate the maximum possible number of GPU and CPU blocks
  130. that can be allocated with the remaining free memory.
  131. .. tip::
  132. You may limit the usage of GPU memory
  133. by adjusting the `gpu_memory_utilization` parameter.
  134. """
  135. # Profile the memory usage of the model and get the maximum number of
  136. # cache blocks that can be allocated with the remaining free memory.
  137. torch.cuda.empty_cache()
  138. # Execute a forward pass with dummy inputs to profile the memory usage
  139. # of the model.
  140. self.model_runner.profile_run()
  141. # Calculate the number of blocks that can be allocated with the
  142. # profiled peak memory.
  143. torch.cuda.synchronize()
  144. free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
  145. # NOTE: Here we assume that the other processes using the same
  146. # GPU did not change their memory usage during the profiling.
  147. peak_memory = self.init_gpu_memory - free_gpu_memory
  148. assert peak_memory > 0, (
  149. "Error in memory profiling. This happens when the GPU memory was "
  150. "not properly cleaned up before initializing Aphrodite.")
  151. cache_block_size = self.get_cache_block_size_bytes()
  152. num_gpu_blocks = int(
  153. (total_gpu_memory * self.cache_config.gpu_memory_utilization -
  154. peak_memory) // cache_block_size)
  155. num_cpu_blocks = int(self.cache_config.swap_space_bytes //
  156. cache_block_size)
  157. num_gpu_blocks = max(num_gpu_blocks, 0)
  158. num_cpu_blocks = max(num_cpu_blocks, 0)
  159. if self.model_runner.lora_manager:
  160. self.model_runner.remove_all_loras()
  161. gc.collect()
  162. torch.cuda.empty_cache()
  163. return num_gpu_blocks, num_cpu_blocks
  164. def initialize_cache(self, num_gpu_blocks: int,
  165. num_cpu_blocks: int) -> None:
  166. """Allocate GPU and CPU KV cache with the specified number of blocks.
  167. This also warms up the model, which may record CUDA graphs.
  168. """
  169. raise_if_cache_size_invalid(num_gpu_blocks,
  170. self.cache_config.block_size,
  171. self.model_config.max_model_len)
  172. self.cache_config.num_gpu_blocks = num_gpu_blocks
  173. self.cache_config.num_cpu_blocks = num_cpu_blocks
  174. self._init_cache_engine()
  175. self._warm_up_model()
  176. def _init_cache_engine(self):
  177. assert self.cache_config.num_gpu_blocks is not None
  178. self.cache_engine = CacheEngine(self.cache_config, self.model_config,
  179. self.parallel_config)
  180. self.gpu_cache = self.cache_engine.gpu_cache
  181. def _warm_up_model(self) -> None:
  182. if not self.model_config.enforce_eager:
  183. self.model_runner.capture_model(self.gpu_cache)
  184. # Reset the seed to ensure that the random state is not affected by
  185. # the model initialization and profiling.
  186. set_random_seed(self.model_config.seed)
  187. def cache_swap(
  188. self,
  189. blocks_to_swap_in: torch.Tensor,
  190. blocks_to_swap_out: torch.Tensor,
  191. blocks_to_copy: torch.Tensor,
  192. ) -> None:
  193. # Issue cache operations.
  194. if blocks_to_swap_in.numel() > 0:
  195. self.cache_engine.swap_in(blocks_to_swap_in)
  196. if blocks_to_swap_out.numel() > 0:
  197. self.cache_engine.swap_out(blocks_to_swap_out)
  198. if blocks_to_copy.numel() > 0:
  199. self.cache_engine.copy(blocks_to_copy)
  200. @torch.inference_mode()
  201. def execute_model(
  202. self,
  203. execute_model_req: Optional[ExecuteModelRequest] = None
  204. ) -> List[Union[SamplerOutput, PoolerOutput]]:
  205. if not self.is_driver_worker:
  206. self._execute_model_non_driver()
  207. return []
  208. if execute_model_req is None:
  209. # This signals that there's no more requests to process for now.
  210. # All workers are running infinite loop with broadcast_tensor_dict,
  211. # and it stops the loop when the driver broadcasts an empty input.
  212. # Send an empty input to notify all other workers to stop their
  213. # execution loop.
  214. broadcast_tensor_dict({}, src=0)
  215. return []
  216. seq_group_metadata_list = execute_model_req.seq_group_metadata_list
  217. num_seq_groups = len(seq_group_metadata_list)
  218. # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
  219. # they contain parameters to launch cudamemcpyasync.
  220. blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in,
  221. device="cpu",
  222. dtype=torch.int64).view(-1, 2)
  223. blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out,
  224. device="cpu",
  225. dtype=torch.int64).view(-1, 2)
  226. # `blocks_to_copy` is a gpu tensor. The src and tgt of
  227. # blocks to copy are in the same device, and `blocks_to_copy`
  228. # can be used directly within cuda kernels.
  229. blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
  230. device=self.device,
  231. dtype=torch.int64).view(-1, 2)
  232. data: Dict[str, Any] = {
  233. "num_seq_groups": num_seq_groups,
  234. "blocks_to_swap_in": blocks_to_swap_in,
  235. "blocks_to_swap_out": blocks_to_swap_out,
  236. "blocks_to_copy": blocks_to_copy,
  237. }
  238. broadcast_tensor_dict(data, src=0)
  239. self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
  240. # If there is no input, we don't need to execute the model.
  241. if num_seq_groups == 0:
  242. return []
  243. output = self.model_runner.execute_model(seq_group_metadata_list,
  244. self.gpu_cache)
  245. # Worker only supports single-step execution. Wrap the output in a list
  246. # to conform to interface.
  247. return [output]
  248. @torch.inference_mode()
  249. def start_worker_execution_loop(self) -> None:
  250. """Execute model loop in parallel worker.
  251. You can stop the loop by executing a driver worker with an empty output.
  252. See `stop_remote_worker_execution_loop` for more details.
  253. """
  254. while self._execute_model_non_driver():
  255. pass
  256. def _execute_model_non_driver(self) -> bool:
  257. """Execute model in parallel worker.
  258. Returns True iff there are remaining sequences to process.
  259. """
  260. assert not self.is_driver_worker
  261. data = broadcast_tensor_dict(src=0)
  262. if not data:
  263. return False
  264. num_seq_groups = data.get("num_seq_groups", 0)
  265. blocks_to_swap_in = data.get("blocks_to_swap_in")
  266. blocks_to_swap_out = data.get("blocks_to_swap_out")
  267. blocks_to_copy = data.get("blocks_to_copy")
  268. self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
  269. # If there is no input, we don't need to execute the model.
  270. if num_seq_groups == 0:
  271. return False
  272. self.model_runner.execute_model(None, self.gpu_cache)
  273. return True
  274. def add_lora(self, lora_request: LoRARequest) -> bool:
  275. return self.model_runner.add_lora(lora_request)
  276. def remove_lora(self, lora_id: int) -> bool:
  277. return self.model_runner.remove_lora(lora_id)
  278. def list_loras(self) -> Set[int]:
  279. return self.model_runner.list_loras()
  280. @property
  281. def max_model_len(self) -> int:
  282. return self.model_config.max_model_len
  283. @property
  284. def vocab_size(self) -> int:
  285. return self.model_runner.vocab_size
  286. def get_cache_block_size_bytes(self) -> int:
  287. """Get the size of the KV cache block size in bytes.
  288. """
  289. return CacheEngine.get_cache_block_size(self.cache_config,
  290. self.model_config,
  291. self.parallel_config)
  292. def init_worker_distributed_environment(
  293. parallel_config: ParallelConfig,
  294. rank: int,
  295. distributed_init_method: Optional[str] = None,
  296. local_rank: int = -1,
  297. ) -> None:
  298. """Initialize the distributed environment."""
  299. set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
  300. init_distributed_environment(parallel_config.world_size, rank,
  301. distributed_init_method, local_rank)
  302. ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
  303. parallel_config.pipeline_parallel_size)
  304. def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
  305. # Check if the GPU supports the dtype.
  306. if torch_dtype == torch.bfloat16:
  307. compute_capability = torch.cuda.get_device_capability()
  308. if compute_capability[0] < 8:
  309. gpu_name = torch.cuda.get_device_name()
  310. raise ValueError(
  311. "Bfloat16 is only supported on GPUs with compute capability "
  312. f"of at least 8.0. Your {gpu_name} GPU has compute capability "
  313. f"{compute_capability[0]}.{compute_capability[1]}. "
  314. "You can use float16 instead by explicitly setting the"
  315. "`dtype` flag in CLI, for example: --dtype=half.")
  316. def raise_if_cache_size_invalid(num_gpu_blocks, block_size,
  317. max_model_len) -> None:
  318. if num_gpu_blocks <= 0:
  319. raise ValueError("No available memory for the cache blocks. "
  320. "Try increasing `gpu_memory_utilization` when "
  321. "initializing the engine.")
  322. max_seq_len = block_size * num_gpu_blocks
  323. logger.info(f"Maximum sequence length allowed in the cache: "
  324. f"{max_seq_len}")
  325. if max_model_len > max_seq_len:
  326. original_max_model_len = max_model_len
  327. max_model_len = max_seq_len
  328. # raise ValueError(
  329. # f"The model's max seq len ({max_model_len}) "
  330. # "is larger than the maximum number of tokens that can be "
  331. # f"stored in KV cache ({max_seq_len}). Try increasing "
  332. # "`gpu_memory_utilization` or decreasing `max_model_len` when "
  333. # "initializing the engine.")
  334. # set the max_model_len to the max_seq_len, but raise a logger.error
  335. # so the user is made aware of this
  336. logger.error(
  337. f"The model's max seq len ({original_max_model_len}) "
  338. "is larger than the maximum number of tokens that can be "
  339. f"stored in KV cache ({max_seq_len}). "
  340. "Try increasing "
  341. "`gpu_memory_utilization`, setting "
  342. "`--enable-chunked-prefill`, or `--kv-cache-dtype fp8` "
  343. "when initializing the engine. The last two are currently "
  344. "mutually exclusive.\n"
  345. f"Forcing max_model_len to {max_seq_len}.")