worker.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405
  1. """A GPU worker class."""
  2. import gc
  3. import os
  4. from typing import Any, Dict, List, Optional, Set, Tuple, Union
  5. import torch
  6. import torch.distributed
  7. from loguru import logger
  8. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  9. LoRAConfig, ModelConfig, ParallelConfig,
  10. SchedulerConfig, SpeculativeConfig,
  11. VisionLanguageConfig)
  12. from aphrodite.common.sequence import (ExecuteModelRequest, PoolerOutput,
  13. SamplerOutput)
  14. from aphrodite.distributed import (broadcast_tensor_dict,
  15. ensure_model_parallel_initialized,
  16. init_distributed_environment,
  17. set_custom_all_reduce)
  18. from aphrodite.lora.request import LoRARequest
  19. from aphrodite.modeling import set_random_seed
  20. from aphrodite.modeling.model_loader.tensorizer import TensorizerConfig
  21. from aphrodite.task_handler.cache_engine import CacheEngine
  22. from aphrodite.task_handler.embedding_model_runner import EmbeddingModelRunner
  23. from aphrodite.task_handler.model_runner import ModelRunner
  24. from aphrodite.task_handler.worker_base import WorkerBase
  25. class Worker(WorkerBase):
  26. """A worker class that executes (a partition of) the model on a GPU.
  27. Each worker is associated with a single GPU. The worker is responsible for
  28. maintaining the KV cache and executing the model on the GPU. In case of
  29. distributed inference, each worker is assigned a partition of the model.
  30. """
  31. def __init__(
  32. self,
  33. model_config: ModelConfig,
  34. parallel_config: ParallelConfig,
  35. scheduler_config: SchedulerConfig,
  36. device_config: DeviceConfig,
  37. cache_config: CacheConfig,
  38. load_config: LoadConfig,
  39. local_rank: int,
  40. rank: int,
  41. distributed_init_method: str,
  42. lora_config: Optional[LoRAConfig] = None,
  43. vision_language_config: Optional[VisionLanguageConfig] = None,
  44. speculative_config: Optional[SpeculativeConfig] = None,
  45. is_driver_worker: bool = False,
  46. ) -> None:
  47. self.model_config = model_config
  48. self.parallel_config = parallel_config
  49. self.scheduler_config = scheduler_config
  50. self.device_config = device_config
  51. self.cache_config = cache_config
  52. self.local_rank = local_rank
  53. self.rank = rank
  54. self.distributed_init_method = distributed_init_method
  55. self.lora_config = lora_config
  56. self.load_config = load_config
  57. self.is_driver_worker = is_driver_worker
  58. if self.is_driver_worker:
  59. assert self.rank == 0, "The driver worker must have rank 0."
  60. if self.model_config.trust_remote_code:
  61. # note: lazy import to avoid importing torch before initializing
  62. from aphrodite.common.utils import init_cached_hf_modules
  63. init_cached_hf_modules()
  64. self.vision_language_config = vision_language_config
  65. if self.vision_language_config:
  66. assert not self.lora_config, (
  67. "To be tested: vision language model with LoRA settings.")
  68. ModelRunnerClass = (EmbeddingModelRunner if
  69. self.model_config.embedding_mode else ModelRunner)
  70. self.model_runner = ModelRunnerClass(
  71. model_config,
  72. parallel_config,
  73. scheduler_config,
  74. device_config,
  75. cache_config,
  76. load_config=load_config,
  77. lora_config=self.lora_config,
  78. kv_cache_dtype=self.cache_config.cache_dtype,
  79. is_driver_worker=is_driver_worker,
  80. vision_language_config=vision_language_config,
  81. )
  82. # Uninitialized cache engine. Will be initialized by
  83. # initialize_cache.
  84. self.cache_engine: CacheEngine
  85. # Initialize gpu_cache as embedding models don't initialize kv_caches
  86. self.gpu_cache: Optional[List[torch.tensor]] = None
  87. def init_device(self) -> None:
  88. if self.device_config.device.type == "cuda":
  89. # torch.distributed.all_reduce does not free the input tensor until
  90. # the synchronization point. This causes the memory usage to grow
  91. # as the number of all_reduce calls increases. This env var disables
  92. # this behavior.
  93. # Related issue:
  94. # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
  95. os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
  96. # This env var set by Ray causes exceptions with graph building.
  97. os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
  98. self.device = torch.device(f"cuda:{self.local_rank}")
  99. torch.cuda.set_device(self.device)
  100. _check_if_gpu_supports_dtype(self.model_config.dtype)
  101. torch.cuda.empty_cache()
  102. self.init_gpu_memory = torch.cuda.mem_get_info()[0]
  103. else:
  104. raise RuntimeError(
  105. f"Not support device type: {self.device_config.device}")
  106. # Initialize the distributed environment.
  107. init_worker_distributed_environment(self.parallel_config, self.rank,
  108. self.distributed_init_method,
  109. self.local_rank)
  110. # Set random seed.
  111. set_random_seed(self.model_config.seed)
  112. def load_model(self):
  113. self.model_runner.load_model()
  114. def save_sharded_state(
  115. self,
  116. path: str,
  117. pattern: Optional[str] = None,
  118. max_size: Optional[int] = None,
  119. ) -> None:
  120. self.model_runner.save_sharded_state(
  121. path,
  122. pattern=pattern,
  123. max_size=max_size,
  124. )
  125. def save_tensorized_model(
  126. self,
  127. tensorizer_config: TensorizerConfig,
  128. ) -> None:
  129. self.model_runner.save_tensorized_model(
  130. tensorizer_config=tensorizer_config, )
  131. @torch.inference_mode()
  132. def determine_num_available_blocks(self) -> Tuple[int, int]:
  133. """Profiles the peak memory usage of the model to determine how many
  134. KV blocks may be allocated without OOMs.
  135. The engine will first conduct a profiling of the existing memory usage.
  136. Then, it calculate the maximum possible number of GPU and CPU blocks
  137. that can be allocated with the remaining free memory.
  138. .. tip::
  139. You may limit the usage of GPU memory
  140. by adjusting the `gpu_memory_utilization` parameter.
  141. """
  142. # Profile the memory usage of the model and get the maximum number of
  143. # cache blocks that can be allocated with the remaining free memory.
  144. torch.cuda.empty_cache()
  145. # Execute a forward pass with dummy inputs to profile the memory usage
  146. # of the model.
  147. self.model_runner.profile_run()
  148. # Calculate the number of blocks that can be allocated with the
  149. # profiled peak memory.
  150. torch.cuda.synchronize()
  151. free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
  152. # NOTE: Here we assume that the other processes using the same
  153. # GPU did not change their memory usage during the profiling.
  154. peak_memory = self.init_gpu_memory - free_gpu_memory
  155. assert peak_memory > 0, (
  156. "Error in memory profiling. This happens when the GPU memory was "
  157. "not properly cleaned up before initializing Aphrodite.")
  158. cache_block_size = self.get_cache_block_size_bytes()
  159. num_gpu_blocks = int(
  160. (total_gpu_memory * self.cache_config.gpu_memory_utilization -
  161. peak_memory) // cache_block_size)
  162. num_cpu_blocks = int(self.cache_config.swap_space_bytes //
  163. cache_block_size)
  164. num_gpu_blocks = max(num_gpu_blocks, 0)
  165. num_cpu_blocks = max(num_cpu_blocks, 0)
  166. if self.model_runner.lora_manager:
  167. self.model_runner.remove_all_loras()
  168. gc.collect()
  169. torch.cuda.empty_cache()
  170. return num_gpu_blocks, num_cpu_blocks
  171. def initialize_cache(self, num_gpu_blocks: int,
  172. num_cpu_blocks: int) -> None:
  173. """Allocate GPU and CPU KV cache with the specified number of blocks.
  174. This also warms up the model, which may record CUDA graphs.
  175. """
  176. raise_if_cache_size_invalid(num_gpu_blocks,
  177. self.cache_config.block_size,
  178. self.model_config.max_model_len)
  179. self.cache_config.num_gpu_blocks = num_gpu_blocks
  180. self.cache_config.num_cpu_blocks = num_cpu_blocks
  181. self._init_cache_engine()
  182. self._warm_up_model()
  183. def _init_cache_engine(self):
  184. assert self.cache_config.num_gpu_blocks is not None
  185. self.cache_engine = CacheEngine(self.cache_config, self.model_config,
  186. self.parallel_config,
  187. self.device_config)
  188. self.gpu_cache = self.cache_engine.gpu_cache
  189. def _warm_up_model(self) -> None:
  190. if not self.model_config.enforce_eager:
  191. self.model_runner.capture_model(self.gpu_cache)
  192. # Reset the seed to ensure that the random state is not affected by
  193. # the model initialization and profiling.
  194. set_random_seed(self.model_config.seed)
  195. def cache_swap(
  196. self,
  197. blocks_to_swap_in: torch.Tensor,
  198. blocks_to_swap_out: torch.Tensor,
  199. blocks_to_copy: torch.Tensor,
  200. ) -> None:
  201. # Issue cache operations.
  202. if blocks_to_swap_in.numel() > 0:
  203. self.cache_engine.swap_in(blocks_to_swap_in)
  204. if blocks_to_swap_out.numel() > 0:
  205. self.cache_engine.swap_out(blocks_to_swap_out)
  206. if blocks_to_copy.numel() > 0:
  207. self.cache_engine.copy(blocks_to_copy)
  208. @torch.inference_mode()
  209. def execute_model(
  210. self,
  211. execute_model_req: Optional[ExecuteModelRequest] = None
  212. ) -> List[Union[SamplerOutput, PoolerOutput]]:
  213. if not self.is_driver_worker:
  214. self._execute_model_non_driver()
  215. return []
  216. if execute_model_req is None:
  217. # This signals that there's no more requests to process for now.
  218. # All workers are running infinite loop with broadcast_tensor_dict,
  219. # and it stops the loop when the driver broadcasts an empty input.
  220. # Send an empty input to notify all other workers to stop their
  221. # execution loop.
  222. broadcast_tensor_dict({}, src=0)
  223. return []
  224. seq_group_metadata_list = execute_model_req.seq_group_metadata_list
  225. num_seq_groups = len(seq_group_metadata_list)
  226. # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
  227. # they contain parameters to launch cudamemcpyasync.
  228. blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in,
  229. device="cpu",
  230. dtype=torch.int64).view(-1, 2)
  231. blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out,
  232. device="cpu",
  233. dtype=torch.int64).view(-1, 2)
  234. # `blocks_to_copy` is a gpu tensor. The src and tgt of
  235. # blocks to copy are in the same device, and `blocks_to_copy`
  236. # can be used directly within cuda kernels.
  237. blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
  238. device=self.device,
  239. dtype=torch.int64).view(-1, 2)
  240. data: Dict[str, Any] = {
  241. "num_seq_groups": num_seq_groups,
  242. "blocks_to_swap_in": blocks_to_swap_in,
  243. "blocks_to_swap_out": blocks_to_swap_out,
  244. "blocks_to_copy": blocks_to_copy,
  245. }
  246. broadcast_tensor_dict(data, src=0)
  247. self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
  248. # If there is no input, we don't need to execute the model.
  249. if num_seq_groups == 0:
  250. return []
  251. output = self.model_runner.execute_model(seq_group_metadata_list,
  252. self.gpu_cache)
  253. # Worker only supports single-step execution. Wrap the output in a list
  254. # to conform to interface.
  255. return [output]
  256. @torch.inference_mode()
  257. def start_worker_execution_loop(self) -> None:
  258. """Execute model loop in parallel worker.
  259. You can stop the loop by executing a driver worker with an empty output.
  260. See `stop_remote_worker_execution_loop` for more details.
  261. """
  262. while self._execute_model_non_driver():
  263. pass
  264. def _execute_model_non_driver(self) -> bool:
  265. """Execute model in parallel worker.
  266. Returns True iff there are remaining sequences to process.
  267. """
  268. assert not self.is_driver_worker
  269. data = broadcast_tensor_dict(src=0)
  270. if not data:
  271. return False
  272. num_seq_groups = data.get("num_seq_groups", 0)
  273. blocks_to_swap_in = data.get("blocks_to_swap_in")
  274. blocks_to_swap_out = data.get("blocks_to_swap_out")
  275. blocks_to_copy = data.get("blocks_to_copy")
  276. self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
  277. # If there is no input, we don't need to execute the model.
  278. if num_seq_groups == 0:
  279. return False
  280. self.model_runner.execute_model(None, self.gpu_cache)
  281. return True
  282. def add_lora(self, lora_request: LoRARequest) -> bool:
  283. return self.model_runner.add_lora(lora_request)
  284. def remove_lora(self, lora_id: int) -> bool:
  285. return self.model_runner.remove_lora(lora_id)
  286. def list_loras(self) -> Set[int]:
  287. return self.model_runner.list_loras()
  288. @property
  289. def max_model_len(self) -> int:
  290. return self.model_config.max_model_len
  291. @property
  292. def vocab_size(self) -> int:
  293. return self.model_runner.vocab_size
  294. def get_cache_block_size_bytes(self) -> int:
  295. """Get the size of the KV cache block size in bytes.
  296. """
  297. return CacheEngine.get_cache_block_size(self.cache_config,
  298. self.model_config,
  299. self.parallel_config)
  300. def init_worker_distributed_environment(
  301. parallel_config: ParallelConfig,
  302. rank: int,
  303. distributed_init_method: Optional[str] = None,
  304. local_rank: int = -1,
  305. ) -> None:
  306. """Initialize the distributed environment."""
  307. set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
  308. init_distributed_environment(parallel_config.world_size, rank,
  309. distributed_init_method, local_rank)
  310. ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
  311. parallel_config.pipeline_parallel_size)
  312. def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
  313. # Check if the GPU supports the dtype.
  314. if torch_dtype == torch.bfloat16:
  315. compute_capability = torch.cuda.get_device_capability()
  316. if compute_capability[0] < 8:
  317. gpu_name = torch.cuda.get_device_name()
  318. raise ValueError(
  319. "Bfloat16 is only supported on GPUs with compute capability "
  320. f"of at least 8.0. Your {gpu_name} GPU has compute capability "
  321. f"{compute_capability[0]}.{compute_capability[1]}. "
  322. "You can use float16 instead by explicitly setting the"
  323. "`dtype` flag in CLI, for example: --dtype=half.")
  324. def raise_if_cache_size_invalid(num_gpu_blocks, block_size,
  325. max_model_len) -> None:
  326. if num_gpu_blocks <= 0:
  327. raise ValueError("No available memory for the cache blocks. "
  328. "Try increasing `gpu_memory_utilization` when "
  329. "initializing the engine.")
  330. max_seq_len = block_size * num_gpu_blocks
  331. logger.info(f"Maximum sequence length allowed in the cache: "
  332. f"{max_seq_len}")
  333. if max_model_len > max_seq_len:
  334. original_max_model_len = max_model_len
  335. max_model_len = max_seq_len
  336. # raise ValueError(
  337. # f"The model's max seq len ({max_model_len}) "
  338. # "is larger than the maximum number of tokens that can be "
  339. # f"stored in KV cache ({max_seq_len}). Try increasing "
  340. # "`gpu_memory_utilization` or decreasing `max_model_len` when "
  341. # "initializing the engine.")
  342. # set the max_model_len to the max_seq_len, but raise a logger.error
  343. # so the user is made aware of this
  344. logger.error(
  345. f"The model's max seq len ({original_max_model_len}) "
  346. "is larger than the maximum number of tokens that can be "
  347. f"stored in KV cache ({max_seq_len}). "
  348. "Try increasing "
  349. "`gpu_memory_utilization`, setting "
  350. "`--enable-chunked-prefill`, or `--kv-cache-dtype fp8` "
  351. "when initializing the engine. The last two are currently "
  352. "mutually exclusive.\n"
  353. f"Forcing max_model_len to {max_seq_len}.")