worker.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. """A GPU worker class."""
  2. import gc
  3. import os
  4. from typing import Dict, List, Optional, Set, Tuple
  5. import torch
  6. import torch.distributed
  7. from loguru import logger
  8. from aphrodite.common.config import (
  9. CacheConfig,
  10. DeviceConfig,
  11. LoRAConfig,
  12. ModelConfig,
  13. ParallelConfig,
  14. SchedulerConfig,
  15. VisionLanguageConfig,
  16. )
  17. from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
  18. from aphrodite.common.utils import in_wsl
  19. from aphrodite.distributed import (
  20. broadcast_tensor_dict,
  21. ensure_model_parallel_initialized,
  22. init_distributed_environment,
  23. )
  24. from aphrodite.distributed.device_communicators import pynccl_utils
  25. from aphrodite.distributed.device_communicators.custom_all_reduce import (
  26. init_custom_ar, )
  27. from aphrodite.lora.request import LoRARequest
  28. from aphrodite.modeling import set_random_seed
  29. from aphrodite.task_handler.cache_engine import CacheEngine
  30. from aphrodite.task_handler.model_runner import ModelRunner
  31. from aphrodite.task_handler.worker_base import WorkerBase
  32. class Worker(WorkerBase):
  33. """A worker class that executes (a partition of) the model on a GPU.
  34. Each worker is associated with a single GPU. The worker is responsible for
  35. maintaining the KV cache and executing the model on the GPU. In case of
  36. distributed inference, each worker is assigned a partition of the model.
  37. """
  38. def __init__(
  39. self,
  40. model_config: ModelConfig,
  41. parallel_config: ParallelConfig,
  42. scheduler_config: SchedulerConfig,
  43. device_config: DeviceConfig,
  44. cache_config: CacheConfig,
  45. local_rank: int,
  46. rank: int,
  47. distributed_init_method: str,
  48. lora_config: Optional[LoRAConfig] = None,
  49. vision_language_config: Optional[VisionLanguageConfig] = None,
  50. is_driver_worker: bool = False,
  51. ) -> None:
  52. self.model_config = model_config
  53. self.parallel_config = parallel_config
  54. self.scheduler_config = scheduler_config
  55. self.device_config = device_config
  56. self.cache_config = cache_config
  57. self.local_rank = local_rank
  58. self.rank = rank
  59. self.distributed_init_method = distributed_init_method
  60. self.lora_config = lora_config
  61. self.is_driver_worker = is_driver_worker
  62. if self.is_driver_worker:
  63. assert self.rank == 0, "The driver worker must have rank 0."
  64. self.vision_language_config = vision_language_config
  65. if self.vision_language_config:
  66. assert not self.lora_config, (
  67. "To be tested: vision language model with LoRA settings.")
  68. self.model_runner = ModelRunner(
  69. model_config,
  70. parallel_config,
  71. scheduler_config,
  72. device_config,
  73. lora_config=self.lora_config,
  74. kv_cache_dtype=self.cache_config.cache_dtype,
  75. is_driver_worker=is_driver_worker,
  76. # kv_quant_params_path=kv_quant_params_path,
  77. vision_language_config=vision_language_config)
  78. # Uninitialized cache engine. Will be initialized by
  79. # initialize_cache
  80. self.cache_engine = None
  81. self.gpu_cache = None
  82. def init_device(self) -> None:
  83. if self.device_config.device.type == "cuda":
  84. # torch.distributed.all_reduce does not free the input tensor until
  85. # the synchronization point. This causes the memory usage to grow
  86. # as the number of all_reduce calls increases. This env var disables
  87. # this behavior.
  88. # Related issue:
  89. # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
  90. os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
  91. # This env var set by Ray causes exceptions with graph building.
  92. os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
  93. # Patch for torch.cuda.is_available() unexpected error in WSL;
  94. # always call torch.cuda.device_count() before initialising device
  95. if in_wsl():
  96. torch.cuda.device_count()
  97. self.device = torch.device(f"cuda:{self.local_rank}")
  98. torch.cuda.set_device(self.device)
  99. _check_if_gpu_supports_dtype(self.model_config.dtype)
  100. torch.cuda.empty_cache()
  101. self.init_gpu_memory = torch.cuda.mem_get_info()[0]
  102. else:
  103. raise RuntimeError(
  104. f"Not support device type: {self.device_config.device}")
  105. # Initialize the distributed environment.
  106. init_worker_distributed_environment(self.parallel_config, self.rank,
  107. self.distributed_init_method,
  108. self.local_rank)
  109. # Set random seed
  110. set_random_seed(self.model_config.seed)
  111. def load_model(self):
  112. self.model_runner.load_model()
  113. @torch.inference_mode()
  114. def determine_num_available_blocks(self) -> Tuple[int, int]:
  115. """Profiles the peak memory usage of the model to determine how many
  116. KV blocks may be allocated without OOMs.
  117. The engine will first conduct a profiling of the existing memory usage.
  118. Then, it calculate the maximum possible number of GPU and CPU blocks
  119. that can be allocated with the remaining free memory.
  120. .. tip::
  121. You may limit the usage of GPU memory
  122. by adjusting the `gpu_memory_utilization` parameter.
  123. """
  124. # Profile the memory usage of the model and get the maximum number of
  125. # cache blocks that can be allocated with the remaining free memory.
  126. torch.cuda.empty_cache()
  127. # Execute a forward pass with dummy inputs to profile the memory usage
  128. # of the model.
  129. self.model_runner.profile_run()
  130. # Calculate the number of blocks that can be allocated with the
  131. # profiled peak memory.
  132. torch.cuda.synchronize()
  133. free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
  134. # NOTE: Here we assume that the other processes using the same
  135. # GPU did not change their memory usage during the profiling.
  136. peak_memory = self.init_gpu_memory - free_gpu_memory
  137. assert peak_memory > 0, (
  138. "Error in memory profiling. This happens when the GPU memory was "
  139. "not properly cleaned up before initializing Aphrodite.")
  140. cache_block_size = self.get_cache_block_size_bytes()
  141. num_gpu_blocks = int(
  142. (total_gpu_memory * self.cache_config.gpu_memory_utilization -
  143. peak_memory) // cache_block_size)
  144. num_cpu_blocks = int(self.cache_config.swap_space_bytes //
  145. cache_block_size)
  146. num_gpu_blocks = max(num_gpu_blocks, 0)
  147. num_cpu_blocks = max(num_cpu_blocks, 0)
  148. if self.model_runner.lora_manager:
  149. self.model_runner.remove_all_loras()
  150. gc.collect()
  151. torch.cuda.empty_cache()
  152. return num_gpu_blocks, num_cpu_blocks
  153. def initialize_cache(self, num_gpu_blocks: int,
  154. num_cpu_blocks: int) -> None:
  155. """Allocate GPU and CPU KV cache with the specified number of blocks.
  156. This also warms up the model, which may record CUDA graphs.
  157. """
  158. raise_if_cache_size_invalid(num_gpu_blocks,
  159. self.cache_config.block_size,
  160. self.model_config.max_model_len)
  161. self.cache_config.num_gpu_blocks = num_gpu_blocks
  162. self.cache_config.num_cpu_blocks = num_cpu_blocks
  163. self._init_cache_engine()
  164. self._warm_up_model()
  165. def _init_cache_engine(self):
  166. assert self.cache_config.num_gpu_blocks is not None
  167. self.cache_engine = CacheEngine(self.cache_config, self.model_config,
  168. self.parallel_config)
  169. self.gpu_cache = self.cache_engine.gpu_cache
  170. self.model_runner.set_block_size(self.cache_engine.block_size)
  171. is_mamba = self.model_config.hf_config.model_type == "jamba"
  172. if is_mamba:
  173. self.model_runner.prepare_contiguous_mamba_cache(
  174. self.cache_engine.dtype)
  175. def _warm_up_model(self) -> None:
  176. if not self.model_config.enforce_eager:
  177. self.model_runner.capture_model(self.gpu_cache)
  178. # Reset the seed to ensure that the random state is not affected by
  179. # the model initialization and profiling.
  180. set_random_seed(self.model_config.seed)
  181. def cache_swap(
  182. self,
  183. blocks_to_swap_in: Dict[int, int],
  184. blocks_to_swap_out: Dict[int, int],
  185. blocks_to_copy: Dict[int, List[int]],
  186. ) -> None:
  187. # Issue cache operations.
  188. # TODO: Profile the overhead of swapping operations and optimize
  189. if blocks_to_swap_in:
  190. self.cache_engine.swap_in(blocks_to_swap_in)
  191. if blocks_to_swap_out:
  192. self.cache_engine.swap_out(blocks_to_swap_out)
  193. if blocks_to_copy:
  194. self.cache_engine.copy(blocks_to_copy)
  195. @torch.inference_mode()
  196. def execute_model(
  197. self,
  198. seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None,
  199. blocks_to_swap_in: Optional[Dict[int, int]] = None,
  200. blocks_to_swap_out: Optional[Dict[int, int]] = None,
  201. blocks_to_copy: Optional[Dict[int, List[int]]] = None,
  202. ) -> Optional[SamplerOutput]:
  203. if self.is_driver_worker:
  204. assert seq_group_metadata_list is not None
  205. num_seq_groups = len(seq_group_metadata_list)
  206. assert blocks_to_swap_in is not None
  207. assert blocks_to_swap_out is not None
  208. assert blocks_to_copy is not None
  209. data = {
  210. "num_seq_groups": num_seq_groups,
  211. "blocks_to_swap_in": blocks_to_swap_in,
  212. "blocks_to_swap_out": blocks_to_swap_out,
  213. "blocks_to_copy": blocks_to_copy,
  214. }
  215. broadcast_tensor_dict(data, src=0)
  216. else:
  217. data = broadcast_tensor_dict(src=0)
  218. num_seq_groups = data["num_seq_groups"]
  219. blocks_to_swap_in = data["blocks_to_swap_in"]
  220. blocks_to_swap_out = data["blocks_to_swap_out"]
  221. blocks_to_copy = data["blocks_to_copy"]
  222. self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
  223. # If there is no input, we don't need to execute the model.
  224. if num_seq_groups == 0:
  225. return {}
  226. output = self.model_runner.execute_model(seq_group_metadata_list,
  227. self.gpu_cache)
  228. return output
  229. def add_lora(self, lora_request: LoRARequest) -> bool:
  230. return self.model_runner.add_lora(lora_request)
  231. def remove_lora(self, lora_id: int) -> bool:
  232. return self.model_runner.remove_lora(lora_id)
  233. def list_loras(self) -> Set[int]:
  234. return self.model_runner.list_loras()
  235. @property
  236. def max_model_len(self) -> int:
  237. return self.model_config.max_model_len
  238. @property
  239. def vocab_size(self) -> int:
  240. return self.model_runner.vocab_size
  241. def get_cache_block_size_bytes(self) -> int:
  242. """Get the size of the KV cache block size in bytes.
  243. """
  244. return CacheEngine.get_cache_block_size(self.cache_config,
  245. self.model_config,
  246. self.parallel_config)
  247. def release_mamba_cache(self, requests_id: List[str]):
  248. self.model_runner.release_mamba_cache(requests_id)
  249. def init_worker_distributed_environment(
  250. parallel_config: ParallelConfig,
  251. rank: int,
  252. distributed_init_method: Optional[str] = None,
  253. local_rank: int = -1,
  254. ) -> None:
  255. """Initialize the distributed environment."""
  256. init_distributed_environment(parallel_config.world_size, rank,
  257. distributed_init_method, local_rank)
  258. if pynccl_utils.is_initialized():
  259. pynccl_world_size = pynccl_utils.get_world_size()
  260. if pynccl_world_size != parallel_config.world_size:
  261. raise RuntimeError(
  262. "pynccl is already initialized but the pynccl world "
  263. "size does not match parallel_config.world_size "
  264. f"({pynccl_world_size} vs. {parallel_config.world_size}).")
  265. elif parallel_config.world_size > 1:
  266. # NOTE: We don't initialize pynccl process group when world size
  267. # is 1.
  268. pynccl_utils.init_process_group(
  269. world_size=parallel_config.world_size,
  270. local_rank=local_rank,
  271. rank=rank,
  272. init_method=distributed_init_method,
  273. )
  274. ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
  275. parallel_config.pipeline_parallel_size)
  276. # Initialize a custom fast all-reduce implementation.
  277. if not parallel_config.disable_custom_all_reduce:
  278. init_custom_ar()
  279. # A small all_reduce for warmup.
  280. torch.distributed.all_reduce(torch.zeros(1).cuda())
  281. if pynccl_utils.is_initialized():
  282. pynccl_utils.all_reduce(torch.zeros(1).cuda())
  283. def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
  284. # Check if the GPU supports the dtype.
  285. if torch_dtype == torch.bfloat16:
  286. compute_capability = torch.cuda.get_device_capability()
  287. if compute_capability[0] < 8:
  288. gpu_name = torch.cuda.get_device_name()
  289. raise ValueError(
  290. "Bfloat16 is only supported on GPUs with compute capability "
  291. f"of at least 8.0. Your {gpu_name} GPU has compute capability "
  292. f"{compute_capability[0]}.{compute_capability[1]}. "
  293. "You can use float16 instead by explicitly setting the"
  294. "`dtype` flag in CLI, for example: --dtype=half.")
  295. def raise_if_cache_size_invalid(num_gpu_blocks, block_size,
  296. max_model_len) -> None:
  297. if num_gpu_blocks <= 0:
  298. raise ValueError("No available memory for the cache blocks. "
  299. "Try increasing `gpu_memory_utilization` when "
  300. "initializing the engine.")
  301. max_seq_len = block_size * num_gpu_blocks
  302. logger.info(f"Maximum sequence length allowed in the cache: "
  303. f"{max_seq_len}")
  304. if max_model_len > max_seq_len:
  305. raise ValueError(
  306. f"The model's max seq len ({max_model_len}) "
  307. "is larger than the maximum number of tokens that can be "
  308. f"stored in KV cache ({max_seq_len}). Try increasing "
  309. "`gpu_memory_utilization` or decreasing `max_model_len` when "
  310. "initializing the engine.")